Fix DB schemas (#138)

This commit is contained in:
Will Chen
2025-05-12 21:51:08 -07:00
committed by GitHub
parent f5a6a1abca
commit 993c5417e3
10 changed files with 273 additions and 128 deletions

View File

@@ -11,11 +11,12 @@ CREATE TABLE `language_models` (
`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `id` integer PRIMARY KEY AUTOINCREMENT NOT NULL,
`display_name` text NOT NULL, `display_name` text NOT NULL,
`api_name` text NOT NULL, `api_name` text NOT NULL,
`provider_id` text NOT NULL, `builtin_provider_id` text,
`custom_provider_id` text,
`description` text, `description` text,
`max_output_tokens` integer, `max_output_tokens` integer,
`context_window` integer, `context_window` integer,
`created_at` integer DEFAULT (unixepoch()) NOT NULL, `created_at` integer DEFAULT (unixepoch()) NOT NULL,
`updated_at` integer DEFAULT (unixepoch()) NOT NULL, `updated_at` integer DEFAULT (unixepoch()) NOT NULL,
FOREIGN KEY (`provider_id`) REFERENCES `language_model_providers`(`id`) ON UPDATE no action ON DELETE cascade FOREIGN KEY (`custom_provider_id`) REFERENCES `language_model_providers`(`id`) ON UPDATE no action ON DELETE cascade
); );

View File

@@ -1,7 +1,7 @@
{ {
"version": "6", "version": "6",
"dialect": "sqlite", "dialect": "sqlite",
"id": "424973e5-a102-4b71-8d5d-6160f1609b8c", "id": "0a47ec41-9477-4457-b3e8-e5ecb3e3a855",
"prevId": "ceedb797-6aa3-4a50-b42f-bc85ee08b3df", "prevId": "ceedb797-6aa3-4a50-b42f-bc85ee08b3df",
"tables": { "tables": {
"apps": { "apps": {
@@ -210,11 +210,18 @@
"notNull": true, "notNull": true,
"autoincrement": false "autoincrement": false
}, },
"provider_id": { "builtin_provider_id": {
"name": "provider_id", "name": "builtin_provider_id",
"type": "text", "type": "text",
"primaryKey": false, "primaryKey": false,
"notNull": true, "notNull": false,
"autoincrement": false
},
"custom_provider_id": {
"name": "custom_provider_id",
"type": "text",
"primaryKey": false,
"notNull": false,
"autoincrement": false "autoincrement": false
}, },
"description": { "description": {
@@ -257,12 +264,12 @@
}, },
"indexes": {}, "indexes": {},
"foreignKeys": { "foreignKeys": {
"language_models_provider_id_language_model_providers_id_fk": { "language_models_custom_provider_id_language_model_providers_id_fk": {
"name": "language_models_provider_id_language_model_providers_id_fk", "name": "language_models_custom_provider_id_language_model_providers_id_fk",
"tableFrom": "language_models", "tableFrom": "language_models",
"tableTo": "language_model_providers", "tableTo": "language_model_providers",
"columnsFrom": [ "columnsFrom": [
"provider_id" "custom_provider_id"
], ],
"columnsTo": [ "columnsTo": [
"id" "id"

View File

@@ -40,8 +40,8 @@
{ {
"idx": 5, "idx": 5,
"version": "6", "version": "6",
"when": 1747091036229, "when": 1747095436506,
"tag": "0005_superb_lady_mastermind", "tag": "0005_clumsy_namor",
"breakpoints": true "breakpoints": true
} }
] ]

View File

@@ -17,9 +17,9 @@ import {
DropdownMenuSubContent, DropdownMenuSubContent,
} from "@/components/ui/dropdown-menu"; } from "@/components/ui/dropdown-menu";
import { useEffect, useState } from "react"; import { useEffect, useState } from "react";
import { MODEL_OPTIONS } from "@/constants/models";
import { useLocalModels } from "@/hooks/useLocalModels"; import { useLocalModels } from "@/hooks/useLocalModels";
import { useLocalLMSModels } from "@/hooks/useLMStudioModels"; import { useLocalLMSModels } from "@/hooks/useLMStudioModels";
import { useLanguageModelsByProviders } from "@/hooks/useLanguageModelsByProviders";
import { ChevronDown } from "lucide-react"; import { ChevronDown } from "lucide-react";
import { LocalModel } from "@/ipc/ipc_types"; import { LocalModel } from "@/ipc/ipc_types";
interface ModelPickerProps { interface ModelPickerProps {
@@ -33,6 +33,10 @@ export function ModelPicker({
}: ModelPickerProps) { }: ModelPickerProps) {
const [open, setOpen] = useState(false); const [open, setOpen] = useState(false);
// Cloud models from providers
const { data: modelsByProviders, isLoading: providersLoading } =
useLanguageModelsByProviders();
// Ollama Models Hook // Ollama Models Hook
const { const {
models: ollamaModels, models: ollamaModels,
@@ -74,24 +78,35 @@ export function ModelPicker({
); );
} }
// Fallback for cloud models // For cloud models, look up in the modelsByProviders data
return ( if (modelsByProviders && modelsByProviders[selectedModel.provider]) {
MODEL_OPTIONS[selectedModel.provider]?.find( const foundModel = modelsByProviders[selectedModel.provider].find(
(model) => model.name === selectedModel.name, (model) => model.apiName === selectedModel.name,
)?.displayName || selectedModel.name
); );
if (foundModel) {
return foundModel.displayName;
}
}
// Fallback if not found
return selectedModel.name;
}; };
const modelDisplayName = getModelDisplayName(); const modelDisplayName = getModelDisplayName();
// Flatten the cloud model options // Flatten the cloud models from all providers
const cloudModels = Object.entries(MODEL_OPTIONS).flatMap( const cloudModels =
([provider, models]) => !providersLoading && modelsByProviders
? Object.entries(modelsByProviders).flatMap(([providerId, models]) =>
models.map((model) => ({ models.map((model) => ({
...model, name: model.apiName,
provider: provider as ModelProvider, displayName: model.displayName,
description: model.description || "",
tag: model.tag,
provider: providerId as ModelProvider,
})), })),
); )
: [];
// Determine availability of local models // Determine availability of local models
const hasOllamaModels = const hasOllamaModels =
@@ -119,8 +134,18 @@ export function ModelPicker({
{/* Increased width slightly */} {/* Increased width slightly */}
<DropdownMenuLabel>Cloud Models</DropdownMenuLabel> <DropdownMenuLabel>Cloud Models</DropdownMenuLabel>
<DropdownMenuSeparator /> <DropdownMenuSeparator />
{/* Cloud models */} {/* Cloud models - loading state */}
{cloudModels.map((model) => ( {providersLoading ? (
<div className="text-xs text-center py-2 text-muted-foreground">
Loading models...
</div>
) : cloudModels.length === 0 ? (
<div className="text-xs text-center py-2 text-muted-foreground">
No cloud models available
</div>
) : (
/* Cloud models loaded */
cloudModels.map((model) => (
<Tooltip key={`${model.provider}-${model.name}`}> <Tooltip key={`${model.provider}-${model.name}`}>
<TooltipTrigger asChild> <TooltipTrigger asChild>
<DropdownMenuItem <DropdownMenuItem
@@ -155,7 +180,8 @@ export function ModelPicker({
</TooltipTrigger> </TooltipTrigger>
<TooltipContent side="right">{model.description}</TooltipContent> <TooltipContent side="right">{model.description}</TooltipContent>
</Tooltip> </Tooltip>
))} ))
)}
<DropdownMenuSeparator /> <DropdownMenuSeparator />
{/* Local Models Parent SubMenu */} {/* Local Models Parent SubMenu */}
<DropdownMenuSub> <DropdownMenuSub>

View File

@@ -85,9 +85,11 @@ export const language_models = sqliteTable("language_models", {
id: integer("id").primaryKey({ autoIncrement: true }), id: integer("id").primaryKey({ autoIncrement: true }),
displayName: text("display_name").notNull(), displayName: text("display_name").notNull(),
apiName: text("api_name").notNull(), apiName: text("api_name").notNull(),
provider_id: text("provider_id") builtinProviderId: text("builtin_provider_id"),
.notNull() customProviderId: text("custom_provider_id").references(
.references(() => language_model_providers.id, { onDelete: "cascade" }), () => language_model_providers.id,
{ onDelete: "cascade" },
),
description: text("description"), description: text("description"),
max_output_tokens: integer("max_output_tokens"), max_output_tokens: integer("max_output_tokens"),
context_window: integer("context_window"), context_window: integer("context_window"),
@@ -111,7 +113,7 @@ export const languageModelsRelations = relations(
language_models, language_models,
({ one }) => ({ ({ one }) => ({
provider: one(language_model_providers, { provider: one(language_model_providers, {
fields: [language_models.provider_id], fields: [language_models.customProviderId],
references: [language_model_providers.id], references: [language_model_providers.id],
}), }),
}), }),

View File

@@ -0,0 +1,19 @@
import { useQuery } from "@tanstack/react-query";
import { IpcClient } from "@/ipc/ipc_client";
import type { LanguageModel } from "@/ipc/ipc_types";
/**
* Fetches all available language models grouped by their provider IDs.
*
* @returns TanStack Query result object for the language models organized by provider.
*/
export function useLanguageModelsByProviders() {
const ipcClient = IpcClient.getInstance();
return useQuery<Record<string, LanguageModel[]>, Error>({
queryKey: ["language-models-by-providers"],
queryFn: async () => {
return ipcClient.getLanguageModelsByProviders();
},
});
}

View File

@@ -9,6 +9,7 @@ import log from "electron-log";
import { import {
getLanguageModelProviders, getLanguageModelProviders,
getLanguageModels, getLanguageModels,
getLanguageModelsByProviders,
} from "../shared/language_model_helpers"; } from "../shared/language_model_helpers";
import { db } from "@/db"; import { db } from "@/db";
import { import {
@@ -64,7 +65,8 @@ export function registerLanguageModelHandlers() {
// Insert the new provider // Insert the new provider
await db.insert(languageModelProvidersSchema).values({ await db.insert(languageModelProvidersSchema).values({
id, // Make sure we will never have accidental collisions with builtin providers
id: "custom::" + id,
name, name,
api_base_url: apiBaseUrl, api_base_url: apiBaseUrl,
env_var_name: envVarName || null, env_var_name: envVarName || null,
@@ -108,12 +110,8 @@ export function registerLanguageModelHandlers() {
} }
// Check if provider exists // Check if provider exists
const provider = db const providers = await getLanguageModelProviders();
.select() const provider = providers.find((p) => p.id === providerId);
.from(languageModelProvidersSchema)
.where(eq(languageModelProvidersSchema.id, providerId))
.get();
if (!provider) { if (!provider) {
throw new Error(`Provider with ID "${providerId}" not found`); throw new Error(`Provider with ID "${providerId}" not found`);
} }
@@ -122,7 +120,8 @@ export function registerLanguageModelHandlers() {
await db.insert(languageModelsSchema).values({ await db.insert(languageModelsSchema).values({
displayName, displayName,
apiName, apiName,
provider_id: providerId, builtinProviderId: provider.type === "cloud" ? providerId : undefined,
customProviderId: provider.type === "custom" ? providerId : undefined,
description: description || null, description: description || null,
max_output_tokens: maxOutputTokens || null, max_output_tokens: maxOutputTokens || null,
context_window: contextWindow || null, context_window: contextWindow || null,
@@ -182,11 +181,22 @@ export function registerLanguageModelHandlers() {
`Attempting to delete custom model ${modelApiName} for provider ${providerId}`, `Attempting to delete custom model ${modelApiName} for provider ${providerId}`,
); );
const providers = await getLanguageModelProviders();
const provider = providers.find((p) => p.id === providerId);
if (!provider) {
throw new Error(`Provider with ID "${providerId}" not found`);
}
if (provider.type === "local") {
throw new Error("Local models cannot be deleted");
}
const result = db const result = db
.delete(language_models) .delete(language_models)
.where( .where(
and( and(
eq(language_models.provider_id, providerId), provider.type === "cloud"
? eq(language_models.builtinProviderId, providerId)
: eq(language_models.customProviderId, providerId),
eq(language_models.apiName, modelApiName), eq(language_models.apiName, modelApiName),
), ),
) )
@@ -243,7 +253,7 @@ export function registerLanguageModelHandlers() {
// 1. Delete associated models // 1. Delete associated models
const deleteModelsResult = await tx const deleteModelsResult = await tx
.delete(languageModelsSchema) .delete(languageModelsSchema)
.where(eq(languageModelsSchema.provider_id, providerId)) .where(eq(languageModelsSchema.customProviderId, providerId))
.run(); .run();
logger.info( logger.info(
`Deleted ${deleteModelsResult.changes} model(s) associated with provider ${providerId}`, `Deleted ${deleteModelsResult.changes} model(s) associated with provider ${providerId}`,
@@ -279,7 +289,26 @@ export function registerLanguageModelHandlers() {
if (!params || typeof params.providerId !== "string") { if (!params || typeof params.providerId !== "string") {
throw new Error("Invalid parameters: providerId (string) is required."); throw new Error("Invalid parameters: providerId (string) is required.");
} }
return getLanguageModels({ providerId: params.providerId }); const providers = await getLanguageModelProviders();
const provider = providers.find((p) => p.id === params.providerId);
if (!provider) {
throw new Error(`Provider with ID "${params.providerId}" not found`);
}
if (provider.type === "local") {
throw new Error("Local models cannot be fetched");
}
return getLanguageModels(
provider.type === "cloud"
? { builtinProviderId: params.providerId }
: { customProviderId: params.providerId },
);
},
);
handle(
"get-language-models-by-providers",
async (): Promise<Record<string, LanguageModel[]>> => {
return getLanguageModelsByProviders();
}, },
); );
} }

View File

@@ -746,6 +746,12 @@ export class IpcClient {
return this.ipcRenderer.invoke("get-language-models", params); return this.ipcRenderer.invoke("get-language-models", params);
} }
public async getLanguageModelsByProviders(): Promise<
Record<string, LanguageModel[]>
> {
return this.ipcRenderer.invoke("get-language-models-by-providers");
}
public async createCustomLanguageModelProvider({ public async createCustomLanguageModelProvider({
id, id,
name, name,

View File

@@ -139,23 +139,72 @@ export async function getLanguageModelProviders(): Promise<
* @param obj An object containing the providerId. * @param obj An object containing the providerId.
* @returns A promise that resolves to an array of LanguageModel objects. * @returns A promise that resolves to an array of LanguageModel objects.
*/ */
export async function getLanguageModels(obj: { export async function getLanguageModels(
providerId: string; obj:
}): Promise<LanguageModel[]> { | {
const { providerId } = obj; customProviderId: string;
// builtinProviderId?: undefined;
}
| {
builtinProviderId: string;
// customProviderId?: undefined;
},
): Promise<LanguageModel[]> {
const allProviders = await getLanguageModelProviders(); const allProviders = await getLanguageModelProviders();
const provider = allProviders.find((p) => p.id === providerId); const provider = allProviders.find(
(p) =>
p.id === (obj as { customProviderId: string }).customProviderId ||
p.id === (obj as { builtinProviderId: string }).builtinProviderId,
);
if (!provider) { if (!provider) {
console.warn(`Provider with ID "${providerId}" not found.`); console.warn(`Provider with ID "${JSON.stringify(obj)}" not found.`);
return []; return [];
} }
// Get custom models from DB for all provider types
let customModels: LanguageModel[] = [];
try {
const customModelsDb = await db
.select({
id: languageModelsSchema.id,
displayName: languageModelsSchema.displayName,
apiName: languageModelsSchema.apiName,
description: languageModelsSchema.description,
maxOutputTokens: languageModelsSchema.max_output_tokens,
contextWindow: languageModelsSchema.context_window,
})
.from(languageModelsSchema)
.where(
"customProviderId" in obj
? eq(languageModelsSchema.customProviderId, obj.customProviderId)
: eq(languageModelsSchema.builtinProviderId, obj.builtinProviderId),
);
customModels = customModelsDb.map((model) => ({
...model,
description: model.description ?? "",
tag: undefined,
maxOutputTokens: model.maxOutputTokens ?? undefined,
contextWindow: model.contextWindow ?? undefined,
type: "custom",
}));
} catch (error) {
console.error(
`Error fetching custom models for provider "${JSON.stringify(obj)}" from DB:`,
error,
);
// Continue with empty custom models array
}
// If it's a cloud provider, also get the hardcoded models
let hardcodedModels: LanguageModel[] = [];
const providerId = provider.id;
if (provider.type === "cloud") { if (provider.type === "cloud") {
// Check if providerId is a valid key for MODEL_OPTIONS
if (providerId in MODEL_OPTIONS) { if (providerId in MODEL_OPTIONS) {
const models = MODEL_OPTIONS[providerId as RegularModelProvider] || []; const models = MODEL_OPTIONS[providerId as RegularModelProvider] || [];
return models.map((model) => ({ hardcodedModels = models.map((model) => ({
...model, ...model,
apiName: model.name, apiName: model.name,
type: "cloud", type: "cloud",
@@ -164,49 +213,54 @@ export async function getLanguageModels(obj: {
console.warn( console.warn(
`Provider "${providerId}" is cloud type but not found in MODEL_OPTIONS.`, `Provider "${providerId}" is cloud type but not found in MODEL_OPTIONS.`,
); );
return [];
} }
} else if (provider.type === "custom") { }
// Fetch models from the database for this custom provider
// Assuming a language_models table with necessary columns and provider_id foreign key
try {
const customModelsDb = await db
.select({
id: languageModelsSchema.id,
// Map DB columns to LanguageModel fields
displayName: languageModelsSchema.displayName,
apiName: languageModelsSchema.apiName,
// No display_name in DB, use name instead
description: languageModelsSchema.description,
// No tag in DB
maxOutputTokens: languageModelsSchema.max_output_tokens,
contextWindow: languageModelsSchema.context_window,
})
.from(languageModelsSchema)
.where(eq(languageModelsSchema.provider_id, providerId)); // Assuming eq is imported or available
return customModelsDb.map((model) => ({ // Merge the models, with custom models taking precedence over hardcoded ones
...model, const mergedModelsMap = new Map<string, LanguageModel>();
// Ensure possibly null fields are handled, provide defaults or undefined if needed
description: model.description ?? "", // Add hardcoded models first
tag: undefined, // No tag for custom models from DB for (const model of hardcodedModels) {
maxOutputTokens: model.maxOutputTokens ?? undefined, mergedModelsMap.set(model.apiName, model);
contextWindow: model.contextWindow ?? undefined,
type: "custom",
}));
} catch (error) {
console.error(
`Error fetching custom models for provider "${providerId}" from DB:`,
error,
);
// Depending on desired behavior, could throw, return empty, or return a specific error state
return [];
} }
} else {
// Handle other types like "local" if necessary, currently ignored // Then override with custom models
console.warn( for (const model of customModels) {
`Provider type "${provider.type}" not handled for model fetching.`, mergedModelsMap.set(model.apiName, model);
);
return [];
} }
return Array.from(mergedModelsMap.values());
}
/**
* Fetches all language models grouped by their provider IDs.
* @returns A promise that resolves to a Record mapping provider IDs to arrays of LanguageModel objects.
*/
export async function getLanguageModelsByProviders(): Promise<
Record<string, LanguageModel[]>
> {
const providers = await getLanguageModelProviders();
// Fetch all models concurrently
const modelPromises = providers
.filter((p) => p.type !== "local")
.map(async (provider) => {
const models = await getLanguageModels(
provider.type === "cloud"
? { builtinProviderId: provider.id }
: { customProviderId: provider.id },
);
return { providerId: provider.id, models };
});
// Wait for all requests to complete
const results = await Promise.all(modelPromises);
// Convert the array of results to a record
const record: Record<string, LanguageModel[]> = {};
for (const result of results) {
record[result.providerId] = result.models;
}
return record;
} }

View File

@@ -6,6 +6,7 @@ import { contextBridge, ipcRenderer } from "electron";
// Whitelist of valid channels // Whitelist of valid channels
const validInvokeChannels = [ const validInvokeChannels = [
"get-language-models", "get-language-models",
"get-language-models-by-providers",
"create-custom-language-model", "create-custom-language-model",
"get-language-model-providers", "get-language-model-providers",
"delete-custom-language-model-provider", "delete-custom-language-model-provider",