Prep for custom models: support reading custom providers (#131)
This commit is contained in:
@@ -22,7 +22,6 @@ import {
|
||||
killProcess,
|
||||
removeAppIfCurrentProcess,
|
||||
} from "../utils/process_manager";
|
||||
import { ALLOWED_ENV_VARS } from "../../constants/models";
|
||||
import { getEnvVar } from "../utils/read_env";
|
||||
import { readSettings } from "../../main/settings";
|
||||
|
||||
@@ -33,6 +32,7 @@ import util from "util";
|
||||
import log from "electron-log";
|
||||
import { getSupabaseProjectName } from "../../supabase_admin/supabase_management_client";
|
||||
import { createLoggedHandler } from "./safe_handle";
|
||||
import { getLanguageModelProviders } from "../shared/language_model_helpers";
|
||||
|
||||
const logger = log.scope("app_handlers");
|
||||
const handle = createLoggedHandler(logger);
|
||||
@@ -291,8 +291,11 @@ export function registerAppHandlers() {
|
||||
// Do NOT use handle for this, it contains sensitive information.
|
||||
ipcMain.handle("get-env-vars", async () => {
|
||||
const envVars: Record<string, string | undefined> = {};
|
||||
for (const key of ALLOWED_ENV_VARS) {
|
||||
envVars[key] = getEnvVar(key);
|
||||
const providers = await getLanguageModelProviders();
|
||||
for (const provider of providers) {
|
||||
if (provider.envVarName) {
|
||||
envVars[provider.envVarName] = getEnvVar(provider.envVarName);
|
||||
}
|
||||
}
|
||||
return envVars;
|
||||
});
|
||||
|
||||
@@ -212,7 +212,10 @@ export function registerChatStreamHandlers() {
|
||||
} else {
|
||||
// Normal AI processing for non-test prompts
|
||||
const settings = readSettings();
|
||||
const modelClient = getModelClient(settings.selectedModel, settings);
|
||||
const modelClient = await getModelClient(
|
||||
settings.selectedModel,
|
||||
settings,
|
||||
);
|
||||
|
||||
// Extract codebase information if app is associated with the chat
|
||||
let codebaseInfo = "";
|
||||
|
||||
16
src/ipc/handlers/language_model_handlers.ts
Normal file
16
src/ipc/handlers/language_model_handlers.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
import type { LanguageModelProvider } from "@/ipc/ipc_types";
|
||||
import { createLoggedHandler } from "./safe_handle";
|
||||
import log from "electron-log";
|
||||
import { getLanguageModelProviders } from "../shared/language_model_helpers";
|
||||
|
||||
const logger = log.scope("language_model_handlers");
|
||||
const handle = createLoggedHandler(logger);
|
||||
|
||||
export function registerLanguageModelHandlers() {
|
||||
handle(
|
||||
"get-language-model-providers",
|
||||
async (): Promise<LanguageModelProvider[]> => {
|
||||
return getLanguageModelProviders();
|
||||
},
|
||||
);
|
||||
}
|
||||
@@ -21,6 +21,7 @@ import type {
|
||||
TokenCountResult,
|
||||
ChatLogsData,
|
||||
BranchResult,
|
||||
LanguageModelProvider,
|
||||
} from "./ipc_types";
|
||||
import type { ProposalResult } from "@/lib/schemas";
|
||||
import { showError } from "@/lib/toast";
|
||||
@@ -724,13 +725,11 @@ export class IpcClient {
|
||||
|
||||
// Get system platform (win32, darwin, linux)
|
||||
public async getSystemPlatform(): Promise<string> {
|
||||
try {
|
||||
const platform = await this.ipcRenderer.invoke("window:get-platform");
|
||||
return platform;
|
||||
} catch (error) {
|
||||
showError(error);
|
||||
throw error;
|
||||
}
|
||||
return this.ipcRenderer.invoke("get-system-platform");
|
||||
}
|
||||
|
||||
public async getLanguageModelProviders(): Promise<LanguageModelProvider[]> {
|
||||
return this.ipcRenderer.invoke("get-language-model-providers");
|
||||
}
|
||||
|
||||
// --- End window control methods ---
|
||||
|
||||
@@ -14,6 +14,8 @@ import { registerTokenCountHandlers } from "./handlers/token_count_handlers";
|
||||
import { registerWindowHandlers } from "./handlers/window_handlers";
|
||||
import { registerUploadHandlers } from "./handlers/upload_handlers";
|
||||
import { registerVersionHandlers } from "./handlers/version_handlers";
|
||||
import { registerLanguageModelHandlers } from "./handlers/language_model_handlers";
|
||||
|
||||
export function registerIpcHandlers() {
|
||||
// Register all IPC handlers by category
|
||||
registerAppHandlers();
|
||||
@@ -32,4 +34,5 @@ export function registerIpcHandlers() {
|
||||
registerWindowHandlers();
|
||||
registerUploadHandlers();
|
||||
registerVersionHandlers();
|
||||
registerLanguageModelHandlers();
|
||||
}
|
||||
|
||||
@@ -132,3 +132,14 @@ export interface ChatLogsData {
|
||||
chat: Chat;
|
||||
codebase: string;
|
||||
}
|
||||
|
||||
export interface LanguageModelProvider {
|
||||
id: string;
|
||||
name: string;
|
||||
hasFreeTier?: boolean;
|
||||
websiteUrl?: string;
|
||||
gatewayPrefix?: string;
|
||||
envVarName?: string;
|
||||
apiBaseUrl?: string;
|
||||
type: "custom" | "local" | "cloud";
|
||||
}
|
||||
|
||||
131
src/ipc/shared/language_model_helpers.ts
Normal file
131
src/ipc/shared/language_model_helpers.ts
Normal file
@@ -0,0 +1,131 @@
|
||||
import { db } from "@/db";
|
||||
import { language_model_providers as languageModelProvidersSchema } from "@/db/schema";
|
||||
import { RegularModelProvider } from "@/constants/models";
|
||||
import type { LanguageModelProvider } from "@/ipc/ipc_types";
|
||||
|
||||
export const PROVIDER_TO_ENV_VAR: Record<string, string> = {
|
||||
openai: "OPENAI_API_KEY",
|
||||
anthropic: "ANTHROPIC_API_KEY",
|
||||
google: "GEMINI_API_KEY",
|
||||
openrouter: "OPENROUTER_API_KEY",
|
||||
};
|
||||
|
||||
export const PROVIDERS: Record<
|
||||
RegularModelProvider,
|
||||
{
|
||||
displayName: string;
|
||||
hasFreeTier?: boolean;
|
||||
websiteUrl?: string;
|
||||
gatewayPrefix: string;
|
||||
}
|
||||
> = {
|
||||
openai: {
|
||||
displayName: "OpenAI",
|
||||
hasFreeTier: false,
|
||||
websiteUrl: "https://platform.openai.com/api-keys",
|
||||
gatewayPrefix: "",
|
||||
},
|
||||
anthropic: {
|
||||
displayName: "Anthropic",
|
||||
hasFreeTier: false,
|
||||
websiteUrl: "https://console.anthropic.com/settings/keys",
|
||||
gatewayPrefix: "anthropic/",
|
||||
},
|
||||
google: {
|
||||
displayName: "Google",
|
||||
hasFreeTier: true,
|
||||
websiteUrl: "https://aistudio.google.com/app/apikey",
|
||||
gatewayPrefix: "gemini/",
|
||||
},
|
||||
openrouter: {
|
||||
displayName: "OpenRouter",
|
||||
hasFreeTier: true,
|
||||
websiteUrl: "https://openrouter.ai/settings/keys",
|
||||
gatewayPrefix: "openrouter/",
|
||||
},
|
||||
auto: {
|
||||
displayName: "Dyad",
|
||||
websiteUrl: "https://academy.dyad.sh/settings",
|
||||
gatewayPrefix: "",
|
||||
},
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches language model providers from both the database (custom) and hardcoded constants (cloud),
|
||||
* merging them with custom providers taking precedence.
|
||||
* @returns A promise that resolves to an array of LanguageModelProvider objects.
|
||||
*/
|
||||
export async function getLanguageModelProviders(): Promise<
|
||||
LanguageModelProvider[]
|
||||
> {
|
||||
// Fetch custom providers from the database
|
||||
const customProvidersDb = await db
|
||||
.select()
|
||||
.from(languageModelProvidersSchema);
|
||||
|
||||
const customProvidersMap = new Map<string, LanguageModelProvider>();
|
||||
for (const cp of customProvidersDb) {
|
||||
customProvidersMap.set(cp.id, {
|
||||
id: cp.id,
|
||||
name: cp.name,
|
||||
apiBaseUrl: cp.api_base_url,
|
||||
envVarName: cp.env_var_name ?? undefined,
|
||||
type: "custom",
|
||||
// hasFreeTier, websiteUrl, gatewayPrefix are not in the custom DB schema
|
||||
// They will be undefined unless overridden by hardcoded values if IDs match
|
||||
});
|
||||
}
|
||||
|
||||
// Get hardcoded cloud providers
|
||||
const hardcodedProviders: LanguageModelProvider[] = [];
|
||||
for (const providerKey in PROVIDERS) {
|
||||
if (Object.prototype.hasOwnProperty.call(PROVIDERS, providerKey)) {
|
||||
// Ensure providerKey is a key of PROVIDERS
|
||||
const key = providerKey as keyof typeof PROVIDERS;
|
||||
const providerDetails = PROVIDERS[key];
|
||||
if (providerDetails) {
|
||||
// Ensure providerDetails is not undefined
|
||||
hardcodedProviders.push({
|
||||
id: key,
|
||||
name: providerDetails.displayName,
|
||||
hasFreeTier: providerDetails.hasFreeTier,
|
||||
websiteUrl: providerDetails.websiteUrl,
|
||||
gatewayPrefix: providerDetails.gatewayPrefix,
|
||||
envVarName: PROVIDER_TO_ENV_VAR[key] ?? undefined,
|
||||
type: "cloud",
|
||||
// apiBaseUrl is not directly in PROVIDERS
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Merge lists: custom providers take precedence
|
||||
const mergedProvidersMap = new Map<string, LanguageModelProvider>();
|
||||
|
||||
// Add all hardcoded providers first
|
||||
for (const hp of hardcodedProviders) {
|
||||
mergedProvidersMap.set(hp.id, hp);
|
||||
}
|
||||
|
||||
// Add/overwrite with custom providers from DB
|
||||
for (const [id, cp] of customProvidersMap) {
|
||||
const existingProvider = mergedProvidersMap.get(id);
|
||||
if (existingProvider) {
|
||||
// If exists, merge. Custom fields take precedence.
|
||||
mergedProvidersMap.set(id, {
|
||||
...existingProvider, // start with hardcoded
|
||||
...cp, // override with custom where defined
|
||||
id: cp.id, // ensure custom id is used
|
||||
name: cp.name, // ensure custom name is used
|
||||
type: "custom", // explicitly set type to custom
|
||||
apiBaseUrl: cp.apiBaseUrl ?? existingProvider.apiBaseUrl,
|
||||
envVarName: cp.envVarName ?? existingProvider.envVarName,
|
||||
});
|
||||
} else {
|
||||
// If it doesn't exist in hardcoded, just add the custom one
|
||||
mergedProvidersMap.set(id, cp);
|
||||
}
|
||||
}
|
||||
|
||||
return Array.from(mergedProvidersMap.values());
|
||||
}
|
||||
@@ -5,36 +5,38 @@ import { createOpenRouter } from "@openrouter/ai-sdk-provider";
|
||||
import { createOllama } from "ollama-ai-provider";
|
||||
import { createOpenAICompatible } from "@ai-sdk/openai-compatible";
|
||||
import type { LargeLanguageModel, UserSettings } from "../../lib/schemas";
|
||||
import {
|
||||
PROVIDER_TO_ENV_VAR,
|
||||
AUTO_MODELS,
|
||||
PROVIDERS,
|
||||
MODEL_OPTIONS,
|
||||
} from "../../constants/models";
|
||||
import { AUTO_MODELS, MODEL_OPTIONS } from "../../constants/models";
|
||||
import { getEnvVar } from "./read_env";
|
||||
import log from "electron-log";
|
||||
import { getLanguageModelProviders } from "../shared/language_model_helpers";
|
||||
|
||||
const logger = log.scope("getModelClient");
|
||||
export function getModelClient(
|
||||
export async function getModelClient(
|
||||
model: LargeLanguageModel,
|
||||
settings: UserSettings,
|
||||
) {
|
||||
const allProviders = await getLanguageModelProviders();
|
||||
|
||||
const dyadApiKey = settings.providerSettings?.auto?.apiKey?.value;
|
||||
// Handle 'auto' provider by trying each model in AUTO_MODELS until one works
|
||||
if (model.provider === "auto") {
|
||||
// Try each model in AUTO_MODELS in order until finding one with an API key
|
||||
for (const autoModel of AUTO_MODELS) {
|
||||
const providerInfo = allProviders.find(
|
||||
(p) => p.id === autoModel.provider,
|
||||
);
|
||||
const envVarName = providerInfo?.envVarName;
|
||||
|
||||
const apiKey =
|
||||
dyadApiKey ||
|
||||
settings.providerSettings?.[autoModel.provider]?.apiKey ||
|
||||
getEnvVar(PROVIDER_TO_ENV_VAR[autoModel.provider]);
|
||||
settings.providerSettings?.[autoModel.provider]?.apiKey?.value ||
|
||||
(envVarName ? getEnvVar(envVarName) : undefined);
|
||||
|
||||
if (apiKey) {
|
||||
logger.log(
|
||||
`Using provider: ${autoModel.provider} model: ${autoModel.name}`,
|
||||
);
|
||||
// Use the first model that has an API key
|
||||
return getModelClient(
|
||||
// Recursively call with the specific model found
|
||||
return await getModelClient(
|
||||
{
|
||||
provider: autoModel.provider,
|
||||
name: autoModel.name,
|
||||
@@ -43,27 +45,48 @@ export function getModelClient(
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// If no models have API keys, throw an error
|
||||
throw new Error("No API keys available for any model in AUTO_MODELS");
|
||||
throw new Error(
|
||||
"No API keys available for any model supported by the 'auto' provider.",
|
||||
);
|
||||
}
|
||||
|
||||
// --- Handle specific provider ---
|
||||
const providerConfig = allProviders.find((p) => p.id === model.provider);
|
||||
|
||||
if (!providerConfig) {
|
||||
throw new Error(`Configuration not found for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
// Handle Dyad Pro override
|
||||
if (dyadApiKey && settings.enableDyadPro) {
|
||||
const provider = createOpenAI({
|
||||
apiKey: dyadApiKey,
|
||||
baseURL: "https://llm-gateway.dyad.sh/v1",
|
||||
});
|
||||
const providerInfo = PROVIDERS[model.provider as keyof typeof PROVIDERS];
|
||||
logger.info("Using Dyad Pro API key");
|
||||
// Do not use free variant (for openrouter).
|
||||
const modelName = model.name.split(":free")[0];
|
||||
return provider(`${providerInfo.gatewayPrefix}${modelName}`);
|
||||
// Check if the selected provider supports Dyad Pro (has a gateway prefix)
|
||||
if (providerConfig.gatewayPrefix) {
|
||||
const provider = createOpenAI({
|
||||
apiKey: dyadApiKey,
|
||||
baseURL: "https://llm-gateway.dyad.sh/v1",
|
||||
});
|
||||
logger.info("Using Dyad Pro API key via Gateway");
|
||||
// Do not use free variant (for openrouter).
|
||||
const modelName = model.name.split(":free")[0];
|
||||
return provider(`${providerConfig.gatewayPrefix}${modelName}`);
|
||||
} else {
|
||||
logger.warn(
|
||||
`Dyad Pro enabled, but provider ${model.provider} does not have a gateway prefix defined. Falling back to direct provider connection.`,
|
||||
);
|
||||
// Fall through to regular provider logic if gateway prefix is missing
|
||||
}
|
||||
}
|
||||
|
||||
// Get API key for the specific provider
|
||||
const apiKey =
|
||||
settings.providerSettings?.[model.provider]?.apiKey?.value ||
|
||||
getEnvVar(PROVIDER_TO_ENV_VAR[model.provider]);
|
||||
switch (model.provider) {
|
||||
(providerConfig.envVarName
|
||||
? getEnvVar(providerConfig.envVarName)
|
||||
: undefined);
|
||||
|
||||
// Create client based on provider ID or type
|
||||
switch (providerConfig.id) {
|
||||
case "openai": {
|
||||
const provider = createOpenAI({ apiKey });
|
||||
return provider(model.name);
|
||||
@@ -81,18 +104,38 @@ export function getModelClient(
|
||||
return provider(model.name);
|
||||
}
|
||||
case "ollama": {
|
||||
const provider = createOllama();
|
||||
// Ollama typically runs locally and doesn't require an API key in the same way
|
||||
const provider = createOllama({
|
||||
baseURL: providerConfig.apiBaseUrl,
|
||||
});
|
||||
return provider(model.name);
|
||||
}
|
||||
case "lmstudio": {
|
||||
// Using LM Studio's OpenAI compatible API
|
||||
const baseURL = "http://localhost:1234/v1"; // Default LM Studio OpenAI API URL
|
||||
const provider = createOpenAICompatible({ name: "lmstudio", baseURL });
|
||||
// LM Studio uses OpenAI compatible API
|
||||
const baseURL = providerConfig.apiBaseUrl || "http://localhost:1234/v1";
|
||||
const provider = createOpenAICompatible({
|
||||
name: "lmstudio",
|
||||
baseURL,
|
||||
});
|
||||
return provider(model.name);
|
||||
}
|
||||
default: {
|
||||
// Ensure exhaustive check if more providers are added
|
||||
const _exhaustiveCheck: never = model.provider;
|
||||
// Handle custom providers
|
||||
if (providerConfig.type === "custom") {
|
||||
if (!providerConfig.apiBaseUrl) {
|
||||
throw new Error(
|
||||
`Custom provider ${model.provider} is missing the API Base URL.`,
|
||||
);
|
||||
}
|
||||
// Assume custom providers are OpenAI compatible for now
|
||||
const provider = createOpenAICompatible({
|
||||
name: providerConfig.id,
|
||||
baseURL: providerConfig.apiBaseUrl,
|
||||
apiKey: apiKey,
|
||||
});
|
||||
return provider(model.name);
|
||||
}
|
||||
// If it's not a known ID and not type 'custom', it's unsupported
|
||||
throw new Error(`Unsupported model provider: ${model.provider}`);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user