Prep for custom models: support reading custom providers (#131)
This commit is contained in:
@@ -5,36 +5,38 @@ import { createOpenRouter } from "@openrouter/ai-sdk-provider";
|
||||
import { createOllama } from "ollama-ai-provider";
|
||||
import { createOpenAICompatible } from "@ai-sdk/openai-compatible";
|
||||
import type { LargeLanguageModel, UserSettings } from "../../lib/schemas";
|
||||
import {
|
||||
PROVIDER_TO_ENV_VAR,
|
||||
AUTO_MODELS,
|
||||
PROVIDERS,
|
||||
MODEL_OPTIONS,
|
||||
} from "../../constants/models";
|
||||
import { AUTO_MODELS, MODEL_OPTIONS } from "../../constants/models";
|
||||
import { getEnvVar } from "./read_env";
|
||||
import log from "electron-log";
|
||||
import { getLanguageModelProviders } from "../shared/language_model_helpers";
|
||||
|
||||
const logger = log.scope("getModelClient");
|
||||
export function getModelClient(
|
||||
export async function getModelClient(
|
||||
model: LargeLanguageModel,
|
||||
settings: UserSettings,
|
||||
) {
|
||||
const allProviders = await getLanguageModelProviders();
|
||||
|
||||
const dyadApiKey = settings.providerSettings?.auto?.apiKey?.value;
|
||||
// Handle 'auto' provider by trying each model in AUTO_MODELS until one works
|
||||
if (model.provider === "auto") {
|
||||
// Try each model in AUTO_MODELS in order until finding one with an API key
|
||||
for (const autoModel of AUTO_MODELS) {
|
||||
const providerInfo = allProviders.find(
|
||||
(p) => p.id === autoModel.provider,
|
||||
);
|
||||
const envVarName = providerInfo?.envVarName;
|
||||
|
||||
const apiKey =
|
||||
dyadApiKey ||
|
||||
settings.providerSettings?.[autoModel.provider]?.apiKey ||
|
||||
getEnvVar(PROVIDER_TO_ENV_VAR[autoModel.provider]);
|
||||
settings.providerSettings?.[autoModel.provider]?.apiKey?.value ||
|
||||
(envVarName ? getEnvVar(envVarName) : undefined);
|
||||
|
||||
if (apiKey) {
|
||||
logger.log(
|
||||
`Using provider: ${autoModel.provider} model: ${autoModel.name}`,
|
||||
);
|
||||
// Use the first model that has an API key
|
||||
return getModelClient(
|
||||
// Recursively call with the specific model found
|
||||
return await getModelClient(
|
||||
{
|
||||
provider: autoModel.provider,
|
||||
name: autoModel.name,
|
||||
@@ -43,27 +45,48 @@ export function getModelClient(
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// If no models have API keys, throw an error
|
||||
throw new Error("No API keys available for any model in AUTO_MODELS");
|
||||
throw new Error(
|
||||
"No API keys available for any model supported by the 'auto' provider.",
|
||||
);
|
||||
}
|
||||
|
||||
// --- Handle specific provider ---
|
||||
const providerConfig = allProviders.find((p) => p.id === model.provider);
|
||||
|
||||
if (!providerConfig) {
|
||||
throw new Error(`Configuration not found for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
// Handle Dyad Pro override
|
||||
if (dyadApiKey && settings.enableDyadPro) {
|
||||
const provider = createOpenAI({
|
||||
apiKey: dyadApiKey,
|
||||
baseURL: "https://llm-gateway.dyad.sh/v1",
|
||||
});
|
||||
const providerInfo = PROVIDERS[model.provider as keyof typeof PROVIDERS];
|
||||
logger.info("Using Dyad Pro API key");
|
||||
// Do not use free variant (for openrouter).
|
||||
const modelName = model.name.split(":free")[0];
|
||||
return provider(`${providerInfo.gatewayPrefix}${modelName}`);
|
||||
// Check if the selected provider supports Dyad Pro (has a gateway prefix)
|
||||
if (providerConfig.gatewayPrefix) {
|
||||
const provider = createOpenAI({
|
||||
apiKey: dyadApiKey,
|
||||
baseURL: "https://llm-gateway.dyad.sh/v1",
|
||||
});
|
||||
logger.info("Using Dyad Pro API key via Gateway");
|
||||
// Do not use free variant (for openrouter).
|
||||
const modelName = model.name.split(":free")[0];
|
||||
return provider(`${providerConfig.gatewayPrefix}${modelName}`);
|
||||
} else {
|
||||
logger.warn(
|
||||
`Dyad Pro enabled, but provider ${model.provider} does not have a gateway prefix defined. Falling back to direct provider connection.`,
|
||||
);
|
||||
// Fall through to regular provider logic if gateway prefix is missing
|
||||
}
|
||||
}
|
||||
|
||||
// Get API key for the specific provider
|
||||
const apiKey =
|
||||
settings.providerSettings?.[model.provider]?.apiKey?.value ||
|
||||
getEnvVar(PROVIDER_TO_ENV_VAR[model.provider]);
|
||||
switch (model.provider) {
|
||||
(providerConfig.envVarName
|
||||
? getEnvVar(providerConfig.envVarName)
|
||||
: undefined);
|
||||
|
||||
// Create client based on provider ID or type
|
||||
switch (providerConfig.id) {
|
||||
case "openai": {
|
||||
const provider = createOpenAI({ apiKey });
|
||||
return provider(model.name);
|
||||
@@ -81,18 +104,38 @@ export function getModelClient(
|
||||
return provider(model.name);
|
||||
}
|
||||
case "ollama": {
|
||||
const provider = createOllama();
|
||||
// Ollama typically runs locally and doesn't require an API key in the same way
|
||||
const provider = createOllama({
|
||||
baseURL: providerConfig.apiBaseUrl,
|
||||
});
|
||||
return provider(model.name);
|
||||
}
|
||||
case "lmstudio": {
|
||||
// Using LM Studio's OpenAI compatible API
|
||||
const baseURL = "http://localhost:1234/v1"; // Default LM Studio OpenAI API URL
|
||||
const provider = createOpenAICompatible({ name: "lmstudio", baseURL });
|
||||
// LM Studio uses OpenAI compatible API
|
||||
const baseURL = providerConfig.apiBaseUrl || "http://localhost:1234/v1";
|
||||
const provider = createOpenAICompatible({
|
||||
name: "lmstudio",
|
||||
baseURL,
|
||||
});
|
||||
return provider(model.name);
|
||||
}
|
||||
default: {
|
||||
// Ensure exhaustive check if more providers are added
|
||||
const _exhaustiveCheck: never = model.provider;
|
||||
// Handle custom providers
|
||||
if (providerConfig.type === "custom") {
|
||||
if (!providerConfig.apiBaseUrl) {
|
||||
throw new Error(
|
||||
`Custom provider ${model.provider} is missing the API Base URL.`,
|
||||
);
|
||||
}
|
||||
// Assume custom providers are OpenAI compatible for now
|
||||
const provider = createOpenAICompatible({
|
||||
name: providerConfig.id,
|
||||
baseURL: providerConfig.apiBaseUrl,
|
||||
apiKey: apiKey,
|
||||
});
|
||||
return provider(model.name);
|
||||
}
|
||||
// If it's not a known ID and not type 'custom', it's unsupported
|
||||
throw new Error(`Unsupported model provider: ${model.provider}`);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user