Prep for custom models: support reading custom providers (#131)

This commit is contained in:
Will Chen
2025-05-12 14:52:48 -07:00
committed by GitHub
parent 79a2b5a906
commit cd7eaa8ece
23 changed files with 901 additions and 173 deletions

View File

@@ -22,7 +22,6 @@ import {
killProcess,
removeAppIfCurrentProcess,
} from "../utils/process_manager";
import { ALLOWED_ENV_VARS } from "../../constants/models";
import { getEnvVar } from "../utils/read_env";
import { readSettings } from "../../main/settings";
@@ -33,6 +32,7 @@ import util from "util";
import log from "electron-log";
import { getSupabaseProjectName } from "../../supabase_admin/supabase_management_client";
import { createLoggedHandler } from "./safe_handle";
import { getLanguageModelProviders } from "../shared/language_model_helpers";
const logger = log.scope("app_handlers");
const handle = createLoggedHandler(logger);
@@ -291,8 +291,11 @@ export function registerAppHandlers() {
// Do NOT use handle for this, it contains sensitive information.
ipcMain.handle("get-env-vars", async () => {
const envVars: Record<string, string | undefined> = {};
for (const key of ALLOWED_ENV_VARS) {
envVars[key] = getEnvVar(key);
const providers = await getLanguageModelProviders();
for (const provider of providers) {
if (provider.envVarName) {
envVars[provider.envVarName] = getEnvVar(provider.envVarName);
}
}
return envVars;
});

View File

@@ -212,7 +212,10 @@ export function registerChatStreamHandlers() {
} else {
// Normal AI processing for non-test prompts
const settings = readSettings();
const modelClient = getModelClient(settings.selectedModel, settings);
const modelClient = await getModelClient(
settings.selectedModel,
settings,
);
// Extract codebase information if app is associated with the chat
let codebaseInfo = "";

View File

@@ -0,0 +1,16 @@
import type { LanguageModelProvider } from "@/ipc/ipc_types";
import { createLoggedHandler } from "./safe_handle";
import log from "electron-log";
import { getLanguageModelProviders } from "../shared/language_model_helpers";
const logger = log.scope("language_model_handlers");
const handle = createLoggedHandler(logger);
export function registerLanguageModelHandlers() {
handle(
"get-language-model-providers",
async (): Promise<LanguageModelProvider[]> => {
return getLanguageModelProviders();
},
);
}

View File

@@ -21,6 +21,7 @@ import type {
TokenCountResult,
ChatLogsData,
BranchResult,
LanguageModelProvider,
} from "./ipc_types";
import type { ProposalResult } from "@/lib/schemas";
import { showError } from "@/lib/toast";
@@ -724,13 +725,11 @@ export class IpcClient {
// Get system platform (win32, darwin, linux)
public async getSystemPlatform(): Promise<string> {
try {
const platform = await this.ipcRenderer.invoke("window:get-platform");
return platform;
} catch (error) {
showError(error);
throw error;
}
return this.ipcRenderer.invoke("get-system-platform");
}
public async getLanguageModelProviders(): Promise<LanguageModelProvider[]> {
return this.ipcRenderer.invoke("get-language-model-providers");
}
// --- End window control methods ---

View File

@@ -14,6 +14,8 @@ import { registerTokenCountHandlers } from "./handlers/token_count_handlers";
import { registerWindowHandlers } from "./handlers/window_handlers";
import { registerUploadHandlers } from "./handlers/upload_handlers";
import { registerVersionHandlers } from "./handlers/version_handlers";
import { registerLanguageModelHandlers } from "./handlers/language_model_handlers";
export function registerIpcHandlers() {
// Register all IPC handlers by category
registerAppHandlers();
@@ -32,4 +34,5 @@ export function registerIpcHandlers() {
registerWindowHandlers();
registerUploadHandlers();
registerVersionHandlers();
registerLanguageModelHandlers();
}

View File

@@ -132,3 +132,14 @@ export interface ChatLogsData {
chat: Chat;
codebase: string;
}
export interface LanguageModelProvider {
id: string;
name: string;
hasFreeTier?: boolean;
websiteUrl?: string;
gatewayPrefix?: string;
envVarName?: string;
apiBaseUrl?: string;
type: "custom" | "local" | "cloud";
}

View File

@@ -0,0 +1,131 @@
import { db } from "@/db";
import { language_model_providers as languageModelProvidersSchema } from "@/db/schema";
import { RegularModelProvider } from "@/constants/models";
import type { LanguageModelProvider } from "@/ipc/ipc_types";
export const PROVIDER_TO_ENV_VAR: Record<string, string> = {
openai: "OPENAI_API_KEY",
anthropic: "ANTHROPIC_API_KEY",
google: "GEMINI_API_KEY",
openrouter: "OPENROUTER_API_KEY",
};
export const PROVIDERS: Record<
RegularModelProvider,
{
displayName: string;
hasFreeTier?: boolean;
websiteUrl?: string;
gatewayPrefix: string;
}
> = {
openai: {
displayName: "OpenAI",
hasFreeTier: false,
websiteUrl: "https://platform.openai.com/api-keys",
gatewayPrefix: "",
},
anthropic: {
displayName: "Anthropic",
hasFreeTier: false,
websiteUrl: "https://console.anthropic.com/settings/keys",
gatewayPrefix: "anthropic/",
},
google: {
displayName: "Google",
hasFreeTier: true,
websiteUrl: "https://aistudio.google.com/app/apikey",
gatewayPrefix: "gemini/",
},
openrouter: {
displayName: "OpenRouter",
hasFreeTier: true,
websiteUrl: "https://openrouter.ai/settings/keys",
gatewayPrefix: "openrouter/",
},
auto: {
displayName: "Dyad",
websiteUrl: "https://academy.dyad.sh/settings",
gatewayPrefix: "",
},
};
/**
* Fetches language model providers from both the database (custom) and hardcoded constants (cloud),
* merging them with custom providers taking precedence.
* @returns A promise that resolves to an array of LanguageModelProvider objects.
*/
export async function getLanguageModelProviders(): Promise<
LanguageModelProvider[]
> {
// Fetch custom providers from the database
const customProvidersDb = await db
.select()
.from(languageModelProvidersSchema);
const customProvidersMap = new Map<string, LanguageModelProvider>();
for (const cp of customProvidersDb) {
customProvidersMap.set(cp.id, {
id: cp.id,
name: cp.name,
apiBaseUrl: cp.api_base_url,
envVarName: cp.env_var_name ?? undefined,
type: "custom",
// hasFreeTier, websiteUrl, gatewayPrefix are not in the custom DB schema
// They will be undefined unless overridden by hardcoded values if IDs match
});
}
// Get hardcoded cloud providers
const hardcodedProviders: LanguageModelProvider[] = [];
for (const providerKey in PROVIDERS) {
if (Object.prototype.hasOwnProperty.call(PROVIDERS, providerKey)) {
// Ensure providerKey is a key of PROVIDERS
const key = providerKey as keyof typeof PROVIDERS;
const providerDetails = PROVIDERS[key];
if (providerDetails) {
// Ensure providerDetails is not undefined
hardcodedProviders.push({
id: key,
name: providerDetails.displayName,
hasFreeTier: providerDetails.hasFreeTier,
websiteUrl: providerDetails.websiteUrl,
gatewayPrefix: providerDetails.gatewayPrefix,
envVarName: PROVIDER_TO_ENV_VAR[key] ?? undefined,
type: "cloud",
// apiBaseUrl is not directly in PROVIDERS
});
}
}
}
// Merge lists: custom providers take precedence
const mergedProvidersMap = new Map<string, LanguageModelProvider>();
// Add all hardcoded providers first
for (const hp of hardcodedProviders) {
mergedProvidersMap.set(hp.id, hp);
}
// Add/overwrite with custom providers from DB
for (const [id, cp] of customProvidersMap) {
const existingProvider = mergedProvidersMap.get(id);
if (existingProvider) {
// If exists, merge. Custom fields take precedence.
mergedProvidersMap.set(id, {
...existingProvider, // start with hardcoded
...cp, // override with custom where defined
id: cp.id, // ensure custom id is used
name: cp.name, // ensure custom name is used
type: "custom", // explicitly set type to custom
apiBaseUrl: cp.apiBaseUrl ?? existingProvider.apiBaseUrl,
envVarName: cp.envVarName ?? existingProvider.envVarName,
});
} else {
// If it doesn't exist in hardcoded, just add the custom one
mergedProvidersMap.set(id, cp);
}
}
return Array.from(mergedProvidersMap.values());
}

View File

@@ -5,36 +5,38 @@ import { createOpenRouter } from "@openrouter/ai-sdk-provider";
import { createOllama } from "ollama-ai-provider";
import { createOpenAICompatible } from "@ai-sdk/openai-compatible";
import type { LargeLanguageModel, UserSettings } from "../../lib/schemas";
import {
PROVIDER_TO_ENV_VAR,
AUTO_MODELS,
PROVIDERS,
MODEL_OPTIONS,
} from "../../constants/models";
import { AUTO_MODELS, MODEL_OPTIONS } from "../../constants/models";
import { getEnvVar } from "./read_env";
import log from "electron-log";
import { getLanguageModelProviders } from "../shared/language_model_helpers";
const logger = log.scope("getModelClient");
export function getModelClient(
export async function getModelClient(
model: LargeLanguageModel,
settings: UserSettings,
) {
const allProviders = await getLanguageModelProviders();
const dyadApiKey = settings.providerSettings?.auto?.apiKey?.value;
// Handle 'auto' provider by trying each model in AUTO_MODELS until one works
if (model.provider === "auto") {
// Try each model in AUTO_MODELS in order until finding one with an API key
for (const autoModel of AUTO_MODELS) {
const providerInfo = allProviders.find(
(p) => p.id === autoModel.provider,
);
const envVarName = providerInfo?.envVarName;
const apiKey =
dyadApiKey ||
settings.providerSettings?.[autoModel.provider]?.apiKey ||
getEnvVar(PROVIDER_TO_ENV_VAR[autoModel.provider]);
settings.providerSettings?.[autoModel.provider]?.apiKey?.value ||
(envVarName ? getEnvVar(envVarName) : undefined);
if (apiKey) {
logger.log(
`Using provider: ${autoModel.provider} model: ${autoModel.name}`,
);
// Use the first model that has an API key
return getModelClient(
// Recursively call with the specific model found
return await getModelClient(
{
provider: autoModel.provider,
name: autoModel.name,
@@ -43,27 +45,48 @@ export function getModelClient(
);
}
}
// If no models have API keys, throw an error
throw new Error("No API keys available for any model in AUTO_MODELS");
throw new Error(
"No API keys available for any model supported by the 'auto' provider.",
);
}
// --- Handle specific provider ---
const providerConfig = allProviders.find((p) => p.id === model.provider);
if (!providerConfig) {
throw new Error(`Configuration not found for provider: ${model.provider}`);
}
// Handle Dyad Pro override
if (dyadApiKey && settings.enableDyadPro) {
const provider = createOpenAI({
apiKey: dyadApiKey,
baseURL: "https://llm-gateway.dyad.sh/v1",
});
const providerInfo = PROVIDERS[model.provider as keyof typeof PROVIDERS];
logger.info("Using Dyad Pro API key");
// Do not use free variant (for openrouter).
const modelName = model.name.split(":free")[0];
return provider(`${providerInfo.gatewayPrefix}${modelName}`);
// Check if the selected provider supports Dyad Pro (has a gateway prefix)
if (providerConfig.gatewayPrefix) {
const provider = createOpenAI({
apiKey: dyadApiKey,
baseURL: "https://llm-gateway.dyad.sh/v1",
});
logger.info("Using Dyad Pro API key via Gateway");
// Do not use free variant (for openrouter).
const modelName = model.name.split(":free")[0];
return provider(`${providerConfig.gatewayPrefix}${modelName}`);
} else {
logger.warn(
`Dyad Pro enabled, but provider ${model.provider} does not have a gateway prefix defined. Falling back to direct provider connection.`,
);
// Fall through to regular provider logic if gateway prefix is missing
}
}
// Get API key for the specific provider
const apiKey =
settings.providerSettings?.[model.provider]?.apiKey?.value ||
getEnvVar(PROVIDER_TO_ENV_VAR[model.provider]);
switch (model.provider) {
(providerConfig.envVarName
? getEnvVar(providerConfig.envVarName)
: undefined);
// Create client based on provider ID or type
switch (providerConfig.id) {
case "openai": {
const provider = createOpenAI({ apiKey });
return provider(model.name);
@@ -81,18 +104,38 @@ export function getModelClient(
return provider(model.name);
}
case "ollama": {
const provider = createOllama();
// Ollama typically runs locally and doesn't require an API key in the same way
const provider = createOllama({
baseURL: providerConfig.apiBaseUrl,
});
return provider(model.name);
}
case "lmstudio": {
// Using LM Studio's OpenAI compatible API
const baseURL = "http://localhost:1234/v1"; // Default LM Studio OpenAI API URL
const provider = createOpenAICompatible({ name: "lmstudio", baseURL });
// LM Studio uses OpenAI compatible API
const baseURL = providerConfig.apiBaseUrl || "http://localhost:1234/v1";
const provider = createOpenAICompatible({
name: "lmstudio",
baseURL,
});
return provider(model.name);
}
default: {
// Ensure exhaustive check if more providers are added
const _exhaustiveCheck: never = model.provider;
// Handle custom providers
if (providerConfig.type === "custom") {
if (!providerConfig.apiBaseUrl) {
throw new Error(
`Custom provider ${model.provider} is missing the API Base URL.`,
);
}
// Assume custom providers are OpenAI compatible for now
const provider = createOpenAICompatible({
name: providerConfig.id,
baseURL: providerConfig.apiBaseUrl,
apiKey: apiKey,
});
return provider(model.name);
}
// If it's not a known ID and not type 'custom', it's unsupported
throw new Error(`Unsupported model provider: ${model.provider}`);
}
}