Implement saver mode (#154)

This commit is contained in:
Will Chen
2025-05-13 15:34:41 -07:00
committed by GitHub
parent 3763423dc7
commit 069c221292
11 changed files with 1630 additions and 15 deletions

View File

@@ -1,3 +1,4 @@
import { LanguageModelV1 } from "ai";
import { createOpenAI } from "@ai-sdk/openai";
import { createGoogleGenerativeAI as createGoogle } from "@ai-sdk/google";
import { createAnthropic } from "@ai-sdk/anthropic";
@@ -8,6 +9,8 @@ import type { LargeLanguageModel, UserSettings } from "../../lib/schemas";
import { getEnvVar } from "./read_env";
import log from "electron-log";
import { getLanguageModelProviders } from "../shared/language_model_helpers";
import { LanguageModelProvider } from "../ipc_types";
import { llmErrorStore } from "@/main/llm_error_store";
const AUTO_MODELS = [
{
@@ -24,11 +27,19 @@ const AUTO_MODELS = [
},
];
export interface ModelClient {
model: LanguageModelV1;
builtinProviderId?: string;
}
const logger = log.scope("getModelClient");
export async function getModelClient(
model: LargeLanguageModel,
settings: UserSettings,
) {
): Promise<{
modelClient: ModelClient;
backupModelClients: ModelClient[];
}> {
const allProviders = await getLanguageModelProviders();
const dyadApiKey = settings.providerSettings?.auto?.apiKey?.value;
@@ -83,7 +94,44 @@ export async function getModelClient(
logger.info("Using Dyad Pro API key via Gateway");
// Do not use free variant (for openrouter).
const modelName = model.name.split(":free")[0];
return provider(`${providerConfig.gatewayPrefix}${modelName}`);
const autoModelClient = {
model: provider(`${providerConfig.gatewayPrefix}${modelName}`),
builtinProviderId: "auto",
};
const googleSettings = settings.providerSettings?.google;
// Budget saver mode logic (all must be true):
// 1. Pro Saver Mode is enabled
// 2. Provider is Google
// 3. API Key is set
// 4. Has no recent errors
if (
settings.enableProSaverMode &&
providerConfig.id === "google" &&
googleSettings &&
googleSettings.apiKey?.value &&
llmErrorStore.modelHasNoRecentError({
model: model.name,
provider: providerConfig.id,
})
) {
return {
modelClient: getRegularModelClient(
{
provider: providerConfig.id,
name: model.name,
},
settings,
providerConfig,
).modelClient,
backupModelClients: [autoModelClient],
};
} else {
return {
modelClient: autoModelClient,
backupModelClients: [],
};
}
} else {
logger.warn(
`Dyad Pro enabled, but provider ${model.provider} does not have a gateway prefix defined. Falling back to direct provider connection.`,
@@ -91,7 +139,14 @@ export async function getModelClient(
// Fall through to regular provider logic if gateway prefix is missing
}
}
return getRegularModelClient(model, settings, providerConfig);
}
function getRegularModelClient(
model: LargeLanguageModel,
settings: UserSettings,
providerConfig: LanguageModelProvider,
) {
// Get API key for the specific provider
const apiKey =
settings.providerSettings?.[model.provider]?.apiKey?.value ||
@@ -99,30 +154,60 @@ export async function getModelClient(
? getEnvVar(providerConfig.envVarName)
: undefined);
const providerId = providerConfig.id;
// Create client based on provider ID or type
switch (providerConfig.id) {
switch (providerId) {
case "openai": {
const provider = createOpenAI({ apiKey });
return provider(model.name);
return {
modelClient: {
model: provider(model.name),
builtinProviderId: providerId,
},
backupModelClients: [],
};
}
case "anthropic": {
const provider = createAnthropic({ apiKey });
return provider(model.name);
return {
modelClient: {
model: provider(model.name),
builtinProviderId: providerId,
},
backupModelClients: [],
};
}
case "google": {
const provider = createGoogle({ apiKey });
return provider(model.name);
return {
modelClient: {
model: provider(model.name),
builtinProviderId: providerId,
},
backupModelClients: [],
};
}
case "openrouter": {
const provider = createOpenRouter({ apiKey });
return provider(model.name);
return {
modelClient: {
model: provider(model.name),
builtinProviderId: providerId,
},
backupModelClients: [],
};
}
case "ollama": {
// Ollama typically runs locally and doesn't require an API key in the same way
const provider = createOllama({
baseURL: providerConfig.apiBaseUrl,
});
return provider(model.name);
return {
modelClient: {
model: provider(model.name),
},
backupModelClients: [],
};
}
case "lmstudio": {
// LM Studio uses OpenAI compatible API
@@ -131,7 +216,12 @@ export async function getModelClient(
name: "lmstudio",
baseURL,
});
return provider(model.name);
return {
modelClient: {
model: provider(model.name),
},
backupModelClients: [],
};
}
default: {
// Handle custom providers
@@ -147,7 +237,12 @@ export async function getModelClient(
baseURL: providerConfig.apiBaseUrl,
apiKey: apiKey,
});
return provider(model.name);
return {
modelClient: {
model: provider(model.name),
},
backupModelClients: [],
};
}
// If it's not a known ID and not type 'custom', it's unsupported
throw new Error(`Unsupported model provider: ${model.provider}`);

View File

@@ -0,0 +1,123 @@
import { streamText } from "ai";
import log from "electron-log";
import { ModelClient } from "./get_model_client";
import { llmErrorStore } from "@/main/llm_error_store";
const logger = log.scope("stream_utils");
export interface StreamTextWithBackupParams
extends Omit<Parameters<typeof streamText>[0], "model"> {
model: ModelClient; // primary client
backupModelClients?: ModelClient[]; // ordered fall-backs
}
export function streamTextWithBackup(params: StreamTextWithBackupParams): {
textStream: AsyncIterable<string>;
} {
const {
model: primaryModel,
backupModelClients = [],
onError: callerOnError,
abortSignal: callerAbort,
...rest
} = params;
const modelClients: ModelClient[] = [primaryModel, ...backupModelClients];
async function* combinedGenerator(): AsyncIterable<string> {
let lastErr: { error: unknown } | undefined = undefined;
for (let i = 0; i < modelClients.length; i++) {
const currentModelClient = modelClients[i];
/* Local abort controller for this single attempt */
const attemptAbort = new AbortController();
if (callerAbort) {
if (callerAbort.aborted) {
// Already aborted, trigger immediately
attemptAbort.abort();
} else {
callerAbort.addEventListener("abort", () => attemptAbort.abort(), {
once: true,
});
}
}
let errorFromCurrent: { error: unknown } | undefined = undefined; // set when onError fires
const providerId = currentModelClient.builtinProviderId;
if (providerId) {
llmErrorStore.clearModelError({
model: currentModelClient.model.modelId,
provider: providerId,
});
}
logger.info(
"Streaming text with model",
currentModelClient.model.modelId,
"provider",
currentModelClient.model.provider,
"builtinProviderId",
currentModelClient.builtinProviderId,
);
const { textStream } = streamText({
...rest,
maxRetries: 0,
model: currentModelClient.model,
abortSignal: attemptAbort.signal,
onError: (error) => {
const providerId = currentModelClient.builtinProviderId;
if (providerId) {
llmErrorStore.recordModelError({
model: currentModelClient.model.modelId,
provider: providerId,
});
}
logger.error(
`Error streaming text with ${providerId} and model ${currentModelClient.model.modelId}: ${error}`,
error,
);
errorFromCurrent = error;
attemptAbort.abort(); // kill fetch / SSE
},
});
try {
for await (const chunk of textStream) {
/* If onError fired during streaming, bail out immediately. */
if (errorFromCurrent) throw errorFromCurrent;
yield chunk;
}
/* Stream ended check if it actually failed */
if (errorFromCurrent) throw errorFromCurrent;
/* Completed successfully stop trying more models. */
return;
} catch (err) {
if (typeof err === "object" && err !== null && "error" in err) {
lastErr = err as { error: unknown };
} else {
lastErr = { error: err };
}
logger.warn(
`[streamTextWithBackup] model #${i} failed ${
i < modelClients.length - 1
? "switching to backup"
: "no backups left"
}`,
err,
);
/* loop continues to next model (if any) */
}
}
/* Every model failed */
if (!lastErr) {
throw new Error("Invariant in StreamTextWithbackup failed!");
}
callerOnError?.(lastErr);
logger.error("All model invocations failed", lastErr);
// throw lastErr ?? new Error("All model invocations failed");
}
return { textStream: combinedGenerator() };
}