Prep for custom models: support reading custom providers (#131)

This commit is contained in:
Will Chen
2025-05-12 14:52:48 -07:00
committed by GitHub
parent 79a2b5a906
commit cd7eaa8ece
23 changed files with 901 additions and 173 deletions

View File

@@ -5,36 +5,38 @@ import { createOpenRouter } from "@openrouter/ai-sdk-provider";
import { createOllama } from "ollama-ai-provider";
import { createOpenAICompatible } from "@ai-sdk/openai-compatible";
import type { LargeLanguageModel, UserSettings } from "../../lib/schemas";
import {
PROVIDER_TO_ENV_VAR,
AUTO_MODELS,
PROVIDERS,
MODEL_OPTIONS,
} from "../../constants/models";
import { AUTO_MODELS, MODEL_OPTIONS } from "../../constants/models";
import { getEnvVar } from "./read_env";
import log from "electron-log";
import { getLanguageModelProviders } from "../shared/language_model_helpers";
const logger = log.scope("getModelClient");
export function getModelClient(
export async function getModelClient(
model: LargeLanguageModel,
settings: UserSettings,
) {
const allProviders = await getLanguageModelProviders();
const dyadApiKey = settings.providerSettings?.auto?.apiKey?.value;
// Handle 'auto' provider by trying each model in AUTO_MODELS until one works
if (model.provider === "auto") {
// Try each model in AUTO_MODELS in order until finding one with an API key
for (const autoModel of AUTO_MODELS) {
const providerInfo = allProviders.find(
(p) => p.id === autoModel.provider,
);
const envVarName = providerInfo?.envVarName;
const apiKey =
dyadApiKey ||
settings.providerSettings?.[autoModel.provider]?.apiKey ||
getEnvVar(PROVIDER_TO_ENV_VAR[autoModel.provider]);
settings.providerSettings?.[autoModel.provider]?.apiKey?.value ||
(envVarName ? getEnvVar(envVarName) : undefined);
if (apiKey) {
logger.log(
`Using provider: ${autoModel.provider} model: ${autoModel.name}`,
);
// Use the first model that has an API key
return getModelClient(
// Recursively call with the specific model found
return await getModelClient(
{
provider: autoModel.provider,
name: autoModel.name,
@@ -43,27 +45,48 @@ export function getModelClient(
);
}
}
// If no models have API keys, throw an error
throw new Error("No API keys available for any model in AUTO_MODELS");
throw new Error(
"No API keys available for any model supported by the 'auto' provider.",
);
}
// --- Handle specific provider ---
const providerConfig = allProviders.find((p) => p.id === model.provider);
if (!providerConfig) {
throw new Error(`Configuration not found for provider: ${model.provider}`);
}
// Handle Dyad Pro override
if (dyadApiKey && settings.enableDyadPro) {
const provider = createOpenAI({
apiKey: dyadApiKey,
baseURL: "https://llm-gateway.dyad.sh/v1",
});
const providerInfo = PROVIDERS[model.provider as keyof typeof PROVIDERS];
logger.info("Using Dyad Pro API key");
// Do not use free variant (for openrouter).
const modelName = model.name.split(":free")[0];
return provider(`${providerInfo.gatewayPrefix}${modelName}`);
// Check if the selected provider supports Dyad Pro (has a gateway prefix)
if (providerConfig.gatewayPrefix) {
const provider = createOpenAI({
apiKey: dyadApiKey,
baseURL: "https://llm-gateway.dyad.sh/v1",
});
logger.info("Using Dyad Pro API key via Gateway");
// Do not use free variant (for openrouter).
const modelName = model.name.split(":free")[0];
return provider(`${providerConfig.gatewayPrefix}${modelName}`);
} else {
logger.warn(
`Dyad Pro enabled, but provider ${model.provider} does not have a gateway prefix defined. Falling back to direct provider connection.`,
);
// Fall through to regular provider logic if gateway prefix is missing
}
}
// Get API key for the specific provider
const apiKey =
settings.providerSettings?.[model.provider]?.apiKey?.value ||
getEnvVar(PROVIDER_TO_ENV_VAR[model.provider]);
switch (model.provider) {
(providerConfig.envVarName
? getEnvVar(providerConfig.envVarName)
: undefined);
// Create client based on provider ID or type
switch (providerConfig.id) {
case "openai": {
const provider = createOpenAI({ apiKey });
return provider(model.name);
@@ -81,18 +104,38 @@ export function getModelClient(
return provider(model.name);
}
case "ollama": {
const provider = createOllama();
// Ollama typically runs locally and doesn't require an API key in the same way
const provider = createOllama({
baseURL: providerConfig.apiBaseUrl,
});
return provider(model.name);
}
case "lmstudio": {
// Using LM Studio's OpenAI compatible API
const baseURL = "http://localhost:1234/v1"; // Default LM Studio OpenAI API URL
const provider = createOpenAICompatible({ name: "lmstudio", baseURL });
// LM Studio uses OpenAI compatible API
const baseURL = providerConfig.apiBaseUrl || "http://localhost:1234/v1";
const provider = createOpenAICompatible({
name: "lmstudio",
baseURL,
});
return provider(model.name);
}
default: {
// Ensure exhaustive check if more providers are added
const _exhaustiveCheck: never = model.provider;
// Handle custom providers
if (providerConfig.type === "custom") {
if (!providerConfig.apiBaseUrl) {
throw new Error(
`Custom provider ${model.provider} is missing the API Base URL.`,
);
}
// Assume custom providers are OpenAI compatible for now
const provider = createOpenAICompatible({
name: providerConfig.id,
baseURL: providerConfig.apiBaseUrl,
apiKey: apiKey,
});
return provider(model.name);
}
// If it's not a known ID and not type 'custom', it's unsupported
throw new Error(`Unsupported model provider: ${model.provider}`);
}
}