299 lines
9.0 KiB
TypeScript
299 lines
9.0 KiB
TypeScript
import { LanguageModelV1 } from "ai";
|
|
import { createOpenAI } from "@ai-sdk/openai";
|
|
import { createGoogleGenerativeAI as createGoogle } from "@ai-sdk/google";
|
|
import { createAnthropic } from "@ai-sdk/anthropic";
|
|
import { createOpenRouter } from "@openrouter/ai-sdk-provider";
|
|
import { createOllama } from "ollama-ai-provider";
|
|
import { createOpenAICompatible } from "@ai-sdk/openai-compatible";
|
|
import type { LargeLanguageModel, UserSettings } from "../../lib/schemas";
|
|
import { getEnvVar } from "./read_env";
|
|
import log from "electron-log";
|
|
import { getLanguageModelProviders } from "../shared/language_model_helpers";
|
|
import { LanguageModelProvider } from "../ipc_types";
|
|
import { llmErrorStore } from "@/main/llm_error_store";
|
|
import { createDyadEngine } from "./llm_engine_provider";
|
|
import { findLanguageModel } from "./findLanguageModel";
|
|
import { LM_STUDIO_BASE_URL } from "./lm_studio_utils";
|
|
|
|
const dyadLocalEngine = process.env.DYAD_LOCAL_ENGINE;
|
|
const dyadGatewayUrl = process.env.DYAD_GATEWAY_URL;
|
|
|
|
const AUTO_MODELS = [
|
|
{
|
|
provider: "google",
|
|
name: "gemini-2.5-flash-preview-05-20",
|
|
},
|
|
{
|
|
provider: "anthropic",
|
|
name: "claude-3-7-sonnet-latest",
|
|
},
|
|
{
|
|
provider: "openai",
|
|
name: "gpt-4.1",
|
|
},
|
|
];
|
|
|
|
export interface ModelClient {
|
|
model: LanguageModelV1;
|
|
builtinProviderId?: string;
|
|
}
|
|
|
|
interface File {
|
|
path: string;
|
|
content: string;
|
|
}
|
|
|
|
const logger = log.scope("getModelClient");
|
|
export async function getModelClient(
|
|
model: LargeLanguageModel,
|
|
settings: UserSettings,
|
|
files?: File[],
|
|
): Promise<{
|
|
modelClient: ModelClient;
|
|
backupModelClients: ModelClient[];
|
|
isEngineEnabled?: boolean;
|
|
}> {
|
|
const allProviders = await getLanguageModelProviders();
|
|
|
|
const dyadApiKey = settings.providerSettings?.auto?.apiKey?.value;
|
|
// Handle 'auto' provider by trying each model in AUTO_MODELS until one works
|
|
if (model.provider === "auto") {
|
|
for (const autoModel of AUTO_MODELS) {
|
|
const providerInfo = allProviders.find(
|
|
(p) => p.id === autoModel.provider,
|
|
);
|
|
const envVarName = providerInfo?.envVarName;
|
|
|
|
const apiKey =
|
|
dyadApiKey ||
|
|
settings.providerSettings?.[autoModel.provider]?.apiKey?.value ||
|
|
(envVarName ? getEnvVar(envVarName) : undefined);
|
|
|
|
if (apiKey) {
|
|
logger.log(
|
|
`Using provider: ${autoModel.provider} model: ${autoModel.name}`,
|
|
);
|
|
// Recursively call with the specific model found
|
|
return await getModelClient(
|
|
{
|
|
provider: autoModel.provider,
|
|
name: autoModel.name,
|
|
},
|
|
settings,
|
|
files,
|
|
);
|
|
}
|
|
}
|
|
// If no models have API keys, throw an error
|
|
throw new Error(
|
|
"No API keys available for any model supported by the 'auto' provider.",
|
|
);
|
|
}
|
|
|
|
// --- Handle specific provider ---
|
|
const providerConfig = allProviders.find((p) => p.id === model.provider);
|
|
|
|
if (!providerConfig) {
|
|
throw new Error(`Configuration not found for provider: ${model.provider}`);
|
|
}
|
|
|
|
// Handle Dyad Pro override
|
|
if (dyadApiKey && settings.enableDyadPro) {
|
|
// Check if the selected provider supports Dyad Pro (has a gateway prefix) OR
|
|
// we're using local engine.
|
|
// IMPORTANT: some providers like OpenAI have an empty string gateway prefix,
|
|
// so we do a nullish and not a truthy check here.
|
|
if (providerConfig.gatewayPrefix != null || dyadLocalEngine) {
|
|
const languageModel = await findLanguageModel(model);
|
|
const engineProMode =
|
|
settings.enableProSmartFilesContextMode ||
|
|
settings.enableProLazyEditsMode;
|
|
// Currently engine is only used for turbo edits.
|
|
const isEngineEnabled = Boolean(
|
|
engineProMode &&
|
|
languageModel?.type === "cloud" &&
|
|
languageModel?.supportsTurboEdits,
|
|
);
|
|
const provider = isEngineEnabled
|
|
? createDyadEngine({
|
|
apiKey: dyadApiKey,
|
|
baseURL: dyadLocalEngine ?? "https://engine.dyad.sh/v1",
|
|
dyadOptions: {
|
|
enableLazyEdits: settings.enableProLazyEditsMode,
|
|
enableSmartFilesContext: settings.enableProSmartFilesContextMode,
|
|
},
|
|
})
|
|
: createOpenAICompatible({
|
|
name: "dyad-gateway",
|
|
apiKey: dyadApiKey,
|
|
baseURL: dyadGatewayUrl ?? "https://llm-gateway.dyad.sh/v1",
|
|
});
|
|
|
|
logger.info(`Using Dyad Pro API key. engine_enabled=${isEngineEnabled}`);
|
|
// Do not use free variant (for openrouter).
|
|
const modelName = model.name.split(":free")[0];
|
|
const autoModelClient = {
|
|
model: provider(
|
|
`${providerConfig.gatewayPrefix || ""}${modelName}`,
|
|
engineProMode
|
|
? {
|
|
files,
|
|
}
|
|
: undefined,
|
|
),
|
|
builtinProviderId: "auto",
|
|
};
|
|
const googleSettings = settings.providerSettings?.google;
|
|
|
|
// Budget saver mode logic (all must be true):
|
|
// 1. Pro Saver Mode is enabled
|
|
// 2. Provider is Google
|
|
// 3. API Key is set
|
|
// 4. Has no recent errors
|
|
if (
|
|
settings.enableProSaverMode &&
|
|
providerConfig.id === "google" &&
|
|
googleSettings &&
|
|
googleSettings.apiKey?.value &&
|
|
llmErrorStore.modelHasNoRecentError({
|
|
model: model.name,
|
|
provider: providerConfig.id,
|
|
})
|
|
) {
|
|
return {
|
|
modelClient: getRegularModelClient(
|
|
{
|
|
provider: providerConfig.id,
|
|
name: model.name,
|
|
},
|
|
settings,
|
|
providerConfig,
|
|
).modelClient,
|
|
backupModelClients: [autoModelClient],
|
|
isEngineEnabled,
|
|
};
|
|
} else {
|
|
return {
|
|
modelClient: autoModelClient,
|
|
backupModelClients: [],
|
|
isEngineEnabled,
|
|
};
|
|
}
|
|
} else {
|
|
logger.warn(
|
|
`Dyad Pro enabled, but provider ${model.provider} does not have a gateway prefix defined. Falling back to direct provider connection.`,
|
|
);
|
|
// Fall through to regular provider logic if gateway prefix is missing
|
|
}
|
|
}
|
|
return getRegularModelClient(model, settings, providerConfig);
|
|
}
|
|
|
|
function getRegularModelClient(
|
|
model: LargeLanguageModel,
|
|
settings: UserSettings,
|
|
providerConfig: LanguageModelProvider,
|
|
) {
|
|
// Get API key for the specific provider
|
|
const apiKey =
|
|
settings.providerSettings?.[model.provider]?.apiKey?.value ||
|
|
(providerConfig.envVarName
|
|
? getEnvVar(providerConfig.envVarName)
|
|
: undefined);
|
|
|
|
const providerId = providerConfig.id;
|
|
// Create client based on provider ID or type
|
|
switch (providerId) {
|
|
case "openai": {
|
|
const provider = createOpenAI({ apiKey });
|
|
return {
|
|
modelClient: {
|
|
model: provider(model.name),
|
|
builtinProviderId: providerId,
|
|
},
|
|
backupModelClients: [],
|
|
};
|
|
}
|
|
case "anthropic": {
|
|
const provider = createAnthropic({ apiKey });
|
|
return {
|
|
modelClient: {
|
|
model: provider(model.name),
|
|
builtinProviderId: providerId,
|
|
},
|
|
backupModelClients: [],
|
|
};
|
|
}
|
|
case "google": {
|
|
const provider = createGoogle({ apiKey });
|
|
return {
|
|
modelClient: {
|
|
model: provider(model.name),
|
|
builtinProviderId: providerId,
|
|
},
|
|
backupModelClients: [],
|
|
};
|
|
}
|
|
case "openrouter": {
|
|
const provider = createOpenRouter({ apiKey });
|
|
return {
|
|
modelClient: {
|
|
model: provider(model.name),
|
|
builtinProviderId: providerId,
|
|
},
|
|
backupModelClients: [],
|
|
};
|
|
}
|
|
case "ollama": {
|
|
// Ollama typically runs locally and doesn't require an API key in the same way
|
|
const provider = createOllama({
|
|
baseURL: process.env.OLLAMA_HOST,
|
|
});
|
|
return {
|
|
modelClient: {
|
|
model: provider(model.name),
|
|
},
|
|
backupModelClients: [],
|
|
};
|
|
}
|
|
case "lmstudio": {
|
|
// LM Studio uses OpenAI compatible API
|
|
const baseURL = providerConfig.apiBaseUrl || LM_STUDIO_BASE_URL + "/v1";
|
|
const provider = createOpenAICompatible({
|
|
name: "lmstudio",
|
|
baseURL,
|
|
});
|
|
return {
|
|
modelClient: {
|
|
model: provider(model.name),
|
|
},
|
|
backupModelClients: [],
|
|
};
|
|
}
|
|
default: {
|
|
// Handle custom providers
|
|
if (providerConfig.type === "custom") {
|
|
if (!providerConfig.apiBaseUrl) {
|
|
throw new Error(
|
|
`Custom provider ${model.provider} is missing the API Base URL.`,
|
|
);
|
|
}
|
|
// Assume custom providers are OpenAI compatible for now
|
|
const provider = createOpenAICompatible({
|
|
name: providerConfig.id,
|
|
baseURL: providerConfig.apiBaseUrl,
|
|
apiKey,
|
|
});
|
|
return {
|
|
modelClient: {
|
|
model: provider(model.name),
|
|
},
|
|
backupModelClients: [],
|
|
};
|
|
}
|
|
// If it's not a known ID and not type 'custom', it's unsupported
|
|
throw new Error(`Unsupported model provider: ${model.provider}`);
|
|
}
|
|
}
|
|
}
|