Implement saver mode (#154)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import { ipcMain } from "electron";
|
||||
import { CoreMessage, TextPart, ImagePart, streamText } from "ai";
|
||||
import { CoreMessage, TextPart, ImagePart } from "ai";
|
||||
import { db } from "../../db";
|
||||
import { chats, messages } from "../../db/schema";
|
||||
import { and, eq, isNull } from "drizzle-orm";
|
||||
@@ -29,6 +29,7 @@ import * as crypto from "crypto";
|
||||
import { readFile, writeFile, unlink } from "fs/promises";
|
||||
import { getMaxTokens } from "../utils/token_utils";
|
||||
import { MAX_CHAT_TURNS_IN_CONTEXT } from "@/constants/settings_constants";
|
||||
import { streamTextWithBackup } from "../utils/stream_utils";
|
||||
|
||||
const logger = log.scope("chat_stream_handlers");
|
||||
|
||||
@@ -214,7 +215,7 @@ export function registerChatStreamHandlers() {
|
||||
} else {
|
||||
// Normal AI processing for non-test prompts
|
||||
const settings = readSettings();
|
||||
const modelClient = await getModelClient(
|
||||
const { modelClient, backupModelClients } = await getModelClient(
|
||||
settings.selectedModel,
|
||||
settings,
|
||||
);
|
||||
@@ -372,13 +373,14 @@ This conversation includes one or more image attachments. When the user uploads
|
||||
}
|
||||
|
||||
// When calling streamText, the messages need to be properly formatted for mixed content
|
||||
const { textStream } = streamText({
|
||||
const { textStream } = streamTextWithBackup({
|
||||
maxTokens: await getMaxTokens(settings.selectedModel),
|
||||
temperature: 0,
|
||||
model: modelClient,
|
||||
backupModelClients: backupModelClients,
|
||||
system: systemPrompt,
|
||||
messages: chatMessages.filter((m) => m.content),
|
||||
onError: (error) => {
|
||||
onError: (error: any) => {
|
||||
logger.error("Error streaming text:", error);
|
||||
const message =
|
||||
(error as any)?.error?.message || JSON.stringify(error);
|
||||
|
||||
@@ -12,7 +12,9 @@ export function createLoggedHandler(logger: log.LogFunctions) {
|
||||
logger.log(`IPC: ${channel} called with args: ${JSON.stringify(args)}`);
|
||||
try {
|
||||
const result = await fn(event, ...args);
|
||||
logger.log(`IPC: ${channel} returned: ${JSON.stringify(result)}`);
|
||||
logger.log(
|
||||
`IPC: ${channel} returned: ${JSON.stringify(result).slice(0, 100)}...`,
|
||||
);
|
||||
return result;
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { LanguageModelV1 } from "ai";
|
||||
import { createOpenAI } from "@ai-sdk/openai";
|
||||
import { createGoogleGenerativeAI as createGoogle } from "@ai-sdk/google";
|
||||
import { createAnthropic } from "@ai-sdk/anthropic";
|
||||
@@ -8,6 +9,8 @@ import type { LargeLanguageModel, UserSettings } from "../../lib/schemas";
|
||||
import { getEnvVar } from "./read_env";
|
||||
import log from "electron-log";
|
||||
import { getLanguageModelProviders } from "../shared/language_model_helpers";
|
||||
import { LanguageModelProvider } from "../ipc_types";
|
||||
import { llmErrorStore } from "@/main/llm_error_store";
|
||||
|
||||
const AUTO_MODELS = [
|
||||
{
|
||||
@@ -24,11 +27,19 @@ const AUTO_MODELS = [
|
||||
},
|
||||
];
|
||||
|
||||
export interface ModelClient {
|
||||
model: LanguageModelV1;
|
||||
builtinProviderId?: string;
|
||||
}
|
||||
|
||||
const logger = log.scope("getModelClient");
|
||||
export async function getModelClient(
|
||||
model: LargeLanguageModel,
|
||||
settings: UserSettings,
|
||||
) {
|
||||
): Promise<{
|
||||
modelClient: ModelClient;
|
||||
backupModelClients: ModelClient[];
|
||||
}> {
|
||||
const allProviders = await getLanguageModelProviders();
|
||||
|
||||
const dyadApiKey = settings.providerSettings?.auto?.apiKey?.value;
|
||||
@@ -83,7 +94,44 @@ export async function getModelClient(
|
||||
logger.info("Using Dyad Pro API key via Gateway");
|
||||
// Do not use free variant (for openrouter).
|
||||
const modelName = model.name.split(":free")[0];
|
||||
return provider(`${providerConfig.gatewayPrefix}${modelName}`);
|
||||
const autoModelClient = {
|
||||
model: provider(`${providerConfig.gatewayPrefix}${modelName}`),
|
||||
builtinProviderId: "auto",
|
||||
};
|
||||
const googleSettings = settings.providerSettings?.google;
|
||||
|
||||
// Budget saver mode logic (all must be true):
|
||||
// 1. Pro Saver Mode is enabled
|
||||
// 2. Provider is Google
|
||||
// 3. API Key is set
|
||||
// 4. Has no recent errors
|
||||
if (
|
||||
settings.enableProSaverMode &&
|
||||
providerConfig.id === "google" &&
|
||||
googleSettings &&
|
||||
googleSettings.apiKey?.value &&
|
||||
llmErrorStore.modelHasNoRecentError({
|
||||
model: model.name,
|
||||
provider: providerConfig.id,
|
||||
})
|
||||
) {
|
||||
return {
|
||||
modelClient: getRegularModelClient(
|
||||
{
|
||||
provider: providerConfig.id,
|
||||
name: model.name,
|
||||
},
|
||||
settings,
|
||||
providerConfig,
|
||||
).modelClient,
|
||||
backupModelClients: [autoModelClient],
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
modelClient: autoModelClient,
|
||||
backupModelClients: [],
|
||||
};
|
||||
}
|
||||
} else {
|
||||
logger.warn(
|
||||
`Dyad Pro enabled, but provider ${model.provider} does not have a gateway prefix defined. Falling back to direct provider connection.`,
|
||||
@@ -91,7 +139,14 @@ export async function getModelClient(
|
||||
// Fall through to regular provider logic if gateway prefix is missing
|
||||
}
|
||||
}
|
||||
return getRegularModelClient(model, settings, providerConfig);
|
||||
}
|
||||
|
||||
function getRegularModelClient(
|
||||
model: LargeLanguageModel,
|
||||
settings: UserSettings,
|
||||
providerConfig: LanguageModelProvider,
|
||||
) {
|
||||
// Get API key for the specific provider
|
||||
const apiKey =
|
||||
settings.providerSettings?.[model.provider]?.apiKey?.value ||
|
||||
@@ -99,30 +154,60 @@ export async function getModelClient(
|
||||
? getEnvVar(providerConfig.envVarName)
|
||||
: undefined);
|
||||
|
||||
const providerId = providerConfig.id;
|
||||
// Create client based on provider ID or type
|
||||
switch (providerConfig.id) {
|
||||
switch (providerId) {
|
||||
case "openai": {
|
||||
const provider = createOpenAI({ apiKey });
|
||||
return provider(model.name);
|
||||
return {
|
||||
modelClient: {
|
||||
model: provider(model.name),
|
||||
builtinProviderId: providerId,
|
||||
},
|
||||
backupModelClients: [],
|
||||
};
|
||||
}
|
||||
case "anthropic": {
|
||||
const provider = createAnthropic({ apiKey });
|
||||
return provider(model.name);
|
||||
return {
|
||||
modelClient: {
|
||||
model: provider(model.name),
|
||||
builtinProviderId: providerId,
|
||||
},
|
||||
backupModelClients: [],
|
||||
};
|
||||
}
|
||||
case "google": {
|
||||
const provider = createGoogle({ apiKey });
|
||||
return provider(model.name);
|
||||
return {
|
||||
modelClient: {
|
||||
model: provider(model.name),
|
||||
builtinProviderId: providerId,
|
||||
},
|
||||
backupModelClients: [],
|
||||
};
|
||||
}
|
||||
case "openrouter": {
|
||||
const provider = createOpenRouter({ apiKey });
|
||||
return provider(model.name);
|
||||
return {
|
||||
modelClient: {
|
||||
model: provider(model.name),
|
||||
builtinProviderId: providerId,
|
||||
},
|
||||
backupModelClients: [],
|
||||
};
|
||||
}
|
||||
case "ollama": {
|
||||
// Ollama typically runs locally and doesn't require an API key in the same way
|
||||
const provider = createOllama({
|
||||
baseURL: providerConfig.apiBaseUrl,
|
||||
});
|
||||
return provider(model.name);
|
||||
return {
|
||||
modelClient: {
|
||||
model: provider(model.name),
|
||||
},
|
||||
backupModelClients: [],
|
||||
};
|
||||
}
|
||||
case "lmstudio": {
|
||||
// LM Studio uses OpenAI compatible API
|
||||
@@ -131,7 +216,12 @@ export async function getModelClient(
|
||||
name: "lmstudio",
|
||||
baseURL,
|
||||
});
|
||||
return provider(model.name);
|
||||
return {
|
||||
modelClient: {
|
||||
model: provider(model.name),
|
||||
},
|
||||
backupModelClients: [],
|
||||
};
|
||||
}
|
||||
default: {
|
||||
// Handle custom providers
|
||||
@@ -147,7 +237,12 @@ export async function getModelClient(
|
||||
baseURL: providerConfig.apiBaseUrl,
|
||||
apiKey: apiKey,
|
||||
});
|
||||
return provider(model.name);
|
||||
return {
|
||||
modelClient: {
|
||||
model: provider(model.name),
|
||||
},
|
||||
backupModelClients: [],
|
||||
};
|
||||
}
|
||||
// If it's not a known ID and not type 'custom', it's unsupported
|
||||
throw new Error(`Unsupported model provider: ${model.provider}`);
|
||||
|
||||
123
src/ipc/utils/stream_utils.ts
Normal file
123
src/ipc/utils/stream_utils.ts
Normal file
@@ -0,0 +1,123 @@
|
||||
import { streamText } from "ai";
|
||||
import log from "electron-log";
|
||||
import { ModelClient } from "./get_model_client";
|
||||
import { llmErrorStore } from "@/main/llm_error_store";
|
||||
const logger = log.scope("stream_utils");
|
||||
|
||||
export interface StreamTextWithBackupParams
|
||||
extends Omit<Parameters<typeof streamText>[0], "model"> {
|
||||
model: ModelClient; // primary client
|
||||
backupModelClients?: ModelClient[]; // ordered fall-backs
|
||||
}
|
||||
|
||||
export function streamTextWithBackup(params: StreamTextWithBackupParams): {
|
||||
textStream: AsyncIterable<string>;
|
||||
} {
|
||||
const {
|
||||
model: primaryModel,
|
||||
backupModelClients = [],
|
||||
onError: callerOnError,
|
||||
abortSignal: callerAbort,
|
||||
...rest
|
||||
} = params;
|
||||
|
||||
const modelClients: ModelClient[] = [primaryModel, ...backupModelClients];
|
||||
|
||||
async function* combinedGenerator(): AsyncIterable<string> {
|
||||
let lastErr: { error: unknown } | undefined = undefined;
|
||||
|
||||
for (let i = 0; i < modelClients.length; i++) {
|
||||
const currentModelClient = modelClients[i];
|
||||
|
||||
/* Local abort controller for this single attempt */
|
||||
const attemptAbort = new AbortController();
|
||||
if (callerAbort) {
|
||||
if (callerAbort.aborted) {
|
||||
// Already aborted, trigger immediately
|
||||
attemptAbort.abort();
|
||||
} else {
|
||||
callerAbort.addEventListener("abort", () => attemptAbort.abort(), {
|
||||
once: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let errorFromCurrent: { error: unknown } | undefined = undefined; // set when onError fires
|
||||
const providerId = currentModelClient.builtinProviderId;
|
||||
if (providerId) {
|
||||
llmErrorStore.clearModelError({
|
||||
model: currentModelClient.model.modelId,
|
||||
provider: providerId,
|
||||
});
|
||||
}
|
||||
logger.info(
|
||||
"Streaming text with model",
|
||||
currentModelClient.model.modelId,
|
||||
"provider",
|
||||
currentModelClient.model.provider,
|
||||
"builtinProviderId",
|
||||
currentModelClient.builtinProviderId,
|
||||
);
|
||||
const { textStream } = streamText({
|
||||
...rest,
|
||||
maxRetries: 0,
|
||||
model: currentModelClient.model,
|
||||
abortSignal: attemptAbort.signal,
|
||||
onError: (error) => {
|
||||
const providerId = currentModelClient.builtinProviderId;
|
||||
if (providerId) {
|
||||
llmErrorStore.recordModelError({
|
||||
model: currentModelClient.model.modelId,
|
||||
provider: providerId,
|
||||
});
|
||||
}
|
||||
logger.error(
|
||||
`Error streaming text with ${providerId} and model ${currentModelClient.model.modelId}: ${error}`,
|
||||
error,
|
||||
);
|
||||
errorFromCurrent = error;
|
||||
attemptAbort.abort(); // kill fetch / SSE
|
||||
},
|
||||
});
|
||||
|
||||
try {
|
||||
for await (const chunk of textStream) {
|
||||
/* If onError fired during streaming, bail out immediately. */
|
||||
if (errorFromCurrent) throw errorFromCurrent;
|
||||
yield chunk;
|
||||
}
|
||||
|
||||
/* Stream ended – check if it actually failed */
|
||||
if (errorFromCurrent) throw errorFromCurrent;
|
||||
|
||||
/* Completed successfully – stop trying more models. */
|
||||
return;
|
||||
} catch (err) {
|
||||
if (typeof err === "object" && err !== null && "error" in err) {
|
||||
lastErr = err as { error: unknown };
|
||||
} else {
|
||||
lastErr = { error: err };
|
||||
}
|
||||
logger.warn(
|
||||
`[streamTextWithBackup] model #${i} failed – ${
|
||||
i < modelClients.length - 1
|
||||
? "switching to backup"
|
||||
: "no backups left"
|
||||
}`,
|
||||
err,
|
||||
);
|
||||
/* loop continues to next model (if any) */
|
||||
}
|
||||
}
|
||||
|
||||
/* Every model failed */
|
||||
if (!lastErr) {
|
||||
throw new Error("Invariant in StreamTextWithbackup failed!");
|
||||
}
|
||||
callerOnError?.(lastErr);
|
||||
logger.error("All model invocations failed", lastErr);
|
||||
// throw lastErr ?? new Error("All model invocations failed");
|
||||
}
|
||||
|
||||
return { textStream: combinedGenerator() };
|
||||
}
|
||||
35
src/main/llm_error_store.ts
Normal file
35
src/main/llm_error_store.ts
Normal file
@@ -0,0 +1,35 @@
|
||||
class LlmErrorStore {
|
||||
private modelErrorToTimestamp: Record<string, number> = {};
|
||||
|
||||
constructor() {}
|
||||
|
||||
recordModelError({ model, provider }: { model: string; provider: string }) {
|
||||
this.modelErrorToTimestamp[this.getKey({ model, provider })] = Date.now();
|
||||
}
|
||||
|
||||
clearModelError({ model, provider }: { model: string; provider: string }) {
|
||||
delete this.modelErrorToTimestamp[this.getKey({ model, provider })];
|
||||
}
|
||||
|
||||
modelHasNoRecentError({
|
||||
model,
|
||||
provider,
|
||||
}: {
|
||||
model: string;
|
||||
provider: string;
|
||||
}): boolean {
|
||||
const key = this.getKey({ model, provider });
|
||||
const timestamp = this.modelErrorToTimestamp[key];
|
||||
if (!timestamp) {
|
||||
return true;
|
||||
}
|
||||
const oneHourAgo = Date.now() - 1000 * 60 * 60;
|
||||
return timestamp < oneHourAgo;
|
||||
}
|
||||
|
||||
private getKey({ model, provider }: { model: string; provider: string }) {
|
||||
return `${provider}::${model}`;
|
||||
}
|
||||
}
|
||||
|
||||
export const llmErrorStore = new LlmErrorStore();
|
||||
Reference in New Issue
Block a user