375 lines
12 KiB
TypeScript
375 lines
12 KiB
TypeScript
import { createOpenAI } from "@ai-sdk/openai";
|
|
import { createGoogleGenerativeAI as createGoogle } from "@ai-sdk/google";
|
|
import { createAnthropic } from "@ai-sdk/anthropic";
|
|
import { createXai } from "@ai-sdk/xai";
|
|
import { createVertex as createGoogleVertex } from "@ai-sdk/google-vertex";
|
|
import { createAzure } from "@ai-sdk/azure";
|
|
import { LanguageModelV2 } from "@ai-sdk/provider";
|
|
import { createOpenRouter } from "@openrouter/ai-sdk-provider";
|
|
import { createOpenAICompatible } from "@ai-sdk/openai-compatible";
|
|
import { createAmazonBedrock } from "@ai-sdk/amazon-bedrock";
|
|
import type {
|
|
LargeLanguageModel,
|
|
UserSettings,
|
|
VertexProviderSetting,
|
|
AzureProviderSetting,
|
|
} from "../../lib/schemas";
|
|
import { getEnvVar } from "./read_env";
|
|
import log from "electron-log";
|
|
import { FREE_OPENROUTER_MODEL_NAMES } from "../shared/language_model_constants";
|
|
import { getLanguageModelProviders } from "../shared/language_model_helpers";
|
|
import { LanguageModelProvider } from "../ipc_types";
|
|
// import { createDyadEngine } from "./llm_engine_provider"; // Removed - Dyad Engine dependency
|
|
|
|
import { LM_STUDIO_BASE_URL } from "./lm_studio_utils";
|
|
import { createOllamaProvider } from "./ollama_provider";
|
|
import { getOllamaApiUrl } from "../handlers/local_model_ollama_handler";
|
|
import { createFallback } from "./fallback_ai_model";
|
|
|
|
// const dyadEngineUrl = process.env.DYAD_ENGINE_URL; // Removed - Dyad Engine dependency
|
|
|
|
const AUTO_MODELS = [
|
|
{
|
|
provider: "google",
|
|
name: "gemini-2.5-flash",
|
|
},
|
|
{
|
|
provider: "openrouter",
|
|
name: "qwen/qwen3-coder:free",
|
|
},
|
|
{
|
|
provider: "anthropic",
|
|
name: "claude-sonnet-4-20250514",
|
|
},
|
|
{
|
|
provider: "openai",
|
|
name: "gpt-4.1",
|
|
},
|
|
];
|
|
|
|
export interface ModelClient {
|
|
model: LanguageModelV2;
|
|
builtinProviderId?: string;
|
|
}
|
|
|
|
const logger = log.scope("getModelClient");
|
|
export async function getModelClient(
|
|
model: LargeLanguageModel,
|
|
settings: UserSettings,
|
|
// files?: File[],
|
|
): Promise<{
|
|
modelClient: ModelClient;
|
|
isEngineEnabled?: boolean;
|
|
isSmartContextEnabled?: boolean;
|
|
}> {
|
|
const allProviders = await getLanguageModelProviders();
|
|
|
|
const dyadApiKey = settings.providerSettings?.auto?.apiKey?.value;
|
|
|
|
// --- Handle specific provider ---
|
|
const providerConfig = allProviders.find((p) => p.id === model.provider);
|
|
|
|
if (!providerConfig) {
|
|
throw new Error(`Configuration not found for provider: ${model.provider}`);
|
|
}
|
|
|
|
// Dyad Pro functionality removed - eliminated Dyad Engine dependency
|
|
// All models now use direct provider connections
|
|
if (dyadApiKey && settings.enableDyadPro) {
|
|
logger.warn(
|
|
`Dyad Pro was enabled but has been disabled to remove Dyad API dependency. Falling back to direct provider connection.`,
|
|
);
|
|
// Fall through to regular provider logic
|
|
}
|
|
// Handle 'auto' provider by trying each model in AUTO_MODELS until one works
|
|
if (model.provider === "auto") {
|
|
if (model.name === "free") {
|
|
const openRouterProvider = allProviders.find(
|
|
(p) => p.id === "openrouter",
|
|
);
|
|
if (!openRouterProvider) {
|
|
throw new Error("OpenRouter provider not found");
|
|
}
|
|
return {
|
|
modelClient: {
|
|
model: createFallback({
|
|
models: FREE_OPENROUTER_MODEL_NAMES.map(
|
|
(name: string) =>
|
|
getRegularModelClient(
|
|
{ provider: "openrouter", name },
|
|
settings,
|
|
openRouterProvider,
|
|
).modelClient.model,
|
|
),
|
|
}),
|
|
builtinProviderId: "openrouter",
|
|
},
|
|
isEngineEnabled: false,
|
|
};
|
|
}
|
|
for (const autoModel of AUTO_MODELS) {
|
|
const providerInfo = allProviders.find(
|
|
(p) => p.id === autoModel.provider,
|
|
);
|
|
const envVarName = providerInfo?.envVarName;
|
|
|
|
const apiKey =
|
|
settings.providerSettings?.[autoModel.provider]?.apiKey?.value ||
|
|
(envVarName ? getEnvVar(envVarName) : undefined);
|
|
|
|
if (apiKey) {
|
|
logger.log(
|
|
`Using provider: ${autoModel.provider} model: ${autoModel.name}`,
|
|
);
|
|
// Recursively call with the specific model found
|
|
return await getModelClient(
|
|
{
|
|
provider: autoModel.provider,
|
|
name: autoModel.name,
|
|
},
|
|
settings,
|
|
);
|
|
}
|
|
}
|
|
// If no models have API keys, throw an error
|
|
throw new Error(
|
|
"No API keys available for any model supported by the 'auto' provider.",
|
|
);
|
|
}
|
|
return getRegularModelClient(model, settings, providerConfig);
|
|
}
|
|
|
|
function getRegularModelClient(
|
|
model: LargeLanguageModel,
|
|
settings: UserSettings,
|
|
providerConfig: LanguageModelProvider,
|
|
): {
|
|
modelClient: ModelClient;
|
|
backupModelClients: ModelClient[];
|
|
} {
|
|
// Get API key for the specific provider
|
|
const apiKey =
|
|
settings.providerSettings?.[model.provider]?.apiKey?.value ||
|
|
(providerConfig.envVarName
|
|
? getEnvVar(providerConfig.envVarName)
|
|
: undefined);
|
|
|
|
const providerId = providerConfig.id;
|
|
// Create client based on provider ID or type
|
|
switch (providerId) {
|
|
case "openai": {
|
|
const provider = createOpenAI({ apiKey });
|
|
return {
|
|
modelClient: {
|
|
model: provider.responses(model.name),
|
|
builtinProviderId: providerId,
|
|
},
|
|
backupModelClients: [],
|
|
};
|
|
}
|
|
case "anthropic": {
|
|
const provider = createAnthropic({ apiKey });
|
|
return {
|
|
modelClient: {
|
|
model: provider(model.name),
|
|
builtinProviderId: providerId,
|
|
},
|
|
backupModelClients: [],
|
|
};
|
|
}
|
|
case "xai": {
|
|
const provider = createXai({ apiKey });
|
|
return {
|
|
modelClient: {
|
|
model: provider(model.name),
|
|
builtinProviderId: providerId,
|
|
},
|
|
backupModelClients: [],
|
|
};
|
|
}
|
|
case "google": {
|
|
const provider = createGoogle({ apiKey });
|
|
return {
|
|
modelClient: {
|
|
model: provider(model.name),
|
|
builtinProviderId: providerId,
|
|
},
|
|
backupModelClients: [],
|
|
};
|
|
}
|
|
case "vertex": {
|
|
// Vertex uses Google service account credentials with project/location
|
|
const vertexSettings = settings.providerSettings?.[
|
|
model.provider
|
|
] as VertexProviderSetting;
|
|
const project = vertexSettings?.projectId;
|
|
const location = vertexSettings?.location;
|
|
const serviceAccountKey = vertexSettings?.serviceAccountKey?.value;
|
|
|
|
// Use a baseURL that does NOT pin to publishers/google so that
|
|
// full publisher model IDs (e.g. publishers/deepseek-ai/models/...) work.
|
|
const regionHost = `${location === "global" ? "" : `${location}-`}aiplatform.googleapis.com`;
|
|
const baseURL = `https://${regionHost}/v1/projects/${project}/locations/${location}`;
|
|
const provider = createGoogleVertex({
|
|
project,
|
|
location,
|
|
baseURL,
|
|
googleAuthOptions: serviceAccountKey
|
|
? {
|
|
// Expecting the user to paste the full JSON of the service account key
|
|
credentials: JSON.parse(serviceAccountKey),
|
|
}
|
|
: undefined,
|
|
});
|
|
return {
|
|
modelClient: {
|
|
// For built-in Google models on Vertex, the path must include
|
|
// publishers/google/models/<model>. For partner MaaS models the
|
|
// full publisher path is already included.
|
|
model: provider(
|
|
model.name.includes("/")
|
|
? model.name
|
|
: `publishers/google/models/${model.name}`,
|
|
),
|
|
builtinProviderId: providerId,
|
|
},
|
|
backupModelClients: [],
|
|
};
|
|
}
|
|
case "openrouter": {
|
|
const provider = createOpenRouter({ apiKey });
|
|
return {
|
|
modelClient: {
|
|
model: provider(model.name),
|
|
builtinProviderId: providerId,
|
|
},
|
|
backupModelClients: [],
|
|
};
|
|
}
|
|
case "azure": {
|
|
// Check if we're in e2e testing mode
|
|
const testAzureBaseUrl = getEnvVar("TEST_AZURE_BASE_URL");
|
|
|
|
if (testAzureBaseUrl) {
|
|
// Use fake server for e2e testing
|
|
logger.info(`Using test Azure base URL: ${testAzureBaseUrl}`);
|
|
const provider = createOpenAICompatible({
|
|
name: "azure-test",
|
|
baseURL: testAzureBaseUrl,
|
|
apiKey: "fake-api-key-for-testing",
|
|
});
|
|
return {
|
|
modelClient: {
|
|
model: provider(model.name),
|
|
builtinProviderId: providerId,
|
|
},
|
|
backupModelClients: [],
|
|
};
|
|
}
|
|
|
|
const azureSettings = settings.providerSettings?.azure as
|
|
| AzureProviderSetting
|
|
| undefined;
|
|
const azureApiKeyFromSettings = (
|
|
azureSettings?.apiKey?.value ?? ""
|
|
).trim();
|
|
const azureResourceNameFromSettings = (
|
|
azureSettings?.resourceName ?? ""
|
|
).trim();
|
|
const envResourceName = (getEnvVar("AZURE_RESOURCE_NAME") ?? "").trim();
|
|
const envAzureApiKey = (getEnvVar("AZURE_API_KEY") ?? "").trim();
|
|
|
|
const resourceName = azureResourceNameFromSettings || envResourceName;
|
|
const azureApiKey = azureApiKeyFromSettings || envAzureApiKey;
|
|
|
|
if (!resourceName) {
|
|
throw new Error(
|
|
"Azure OpenAI resource name is required. Provide it in Settings or set the AZURE_RESOURCE_NAME environment variable.",
|
|
);
|
|
}
|
|
|
|
if (!azureApiKey) {
|
|
throw new Error(
|
|
"Azure OpenAI API key is required. Provide it in Settings or set the AZURE_API_KEY environment variable.",
|
|
);
|
|
}
|
|
|
|
const provider = createAzure({
|
|
resourceName,
|
|
apiKey: azureApiKey,
|
|
});
|
|
|
|
return {
|
|
modelClient: {
|
|
model: provider(model.name),
|
|
builtinProviderId: providerId,
|
|
},
|
|
backupModelClients: [],
|
|
};
|
|
}
|
|
case "ollama": {
|
|
const provider = createOllamaProvider({ baseURL: getOllamaApiUrl() });
|
|
return {
|
|
modelClient: {
|
|
model: provider(model.name),
|
|
builtinProviderId: providerId,
|
|
},
|
|
backupModelClients: [],
|
|
};
|
|
}
|
|
case "lmstudio": {
|
|
// LM Studio uses OpenAI compatible API
|
|
const baseURL = providerConfig.apiBaseUrl || LM_STUDIO_BASE_URL + "/v1";
|
|
const provider = createOpenAICompatible({
|
|
name: "lmstudio",
|
|
baseURL,
|
|
});
|
|
return {
|
|
modelClient: {
|
|
model: provider(model.name),
|
|
},
|
|
backupModelClients: [],
|
|
};
|
|
}
|
|
case "bedrock": {
|
|
// AWS Bedrock supports API key authentication using AWS_BEARER_TOKEN_BEDROCK
|
|
// See: https://sdk.vercel.ai/providers/ai-sdk-providers/amazon-bedrock#api-key-authentication
|
|
const provider = createAmazonBedrock({
|
|
apiKey: apiKey,
|
|
region: getEnvVar("AWS_REGION") || "us-east-1",
|
|
});
|
|
return {
|
|
modelClient: {
|
|
model: provider(model.name),
|
|
builtinProviderId: providerId,
|
|
},
|
|
backupModelClients: [],
|
|
};
|
|
}
|
|
default: {
|
|
// Handle custom providers
|
|
if (providerConfig.type === "custom") {
|
|
if (!providerConfig.apiBaseUrl) {
|
|
throw new Error(
|
|
`Custom provider ${model.provider} is missing the API Base URL.`,
|
|
);
|
|
}
|
|
// Assume custom providers are OpenAI compatible for now
|
|
const provider = createOpenAICompatible({
|
|
name: providerConfig.id,
|
|
baseURL: providerConfig.apiBaseUrl,
|
|
apiKey,
|
|
});
|
|
return {
|
|
modelClient: {
|
|
model: provider(model.name),
|
|
},
|
|
backupModelClients: [],
|
|
};
|
|
}
|
|
// If it's not a known ID and not type 'custom', it's unsupported
|
|
throw new Error(`Unsupported model provider: ${model.provider}`);
|
|
}
|
|
}
|
|
}
|