Fix DB schemas (#138)
This commit is contained in:
@@ -17,9 +17,9 @@ import {
|
||||
DropdownMenuSubContent,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
import { useEffect, useState } from "react";
|
||||
import { MODEL_OPTIONS } from "@/constants/models";
|
||||
import { useLocalModels } from "@/hooks/useLocalModels";
|
||||
import { useLocalLMSModels } from "@/hooks/useLMStudioModels";
|
||||
import { useLanguageModelsByProviders } from "@/hooks/useLanguageModelsByProviders";
|
||||
import { ChevronDown } from "lucide-react";
|
||||
import { LocalModel } from "@/ipc/ipc_types";
|
||||
interface ModelPickerProps {
|
||||
@@ -33,6 +33,10 @@ export function ModelPicker({
|
||||
}: ModelPickerProps) {
|
||||
const [open, setOpen] = useState(false);
|
||||
|
||||
// Cloud models from providers
|
||||
const { data: modelsByProviders, isLoading: providersLoading } =
|
||||
useLanguageModelsByProviders();
|
||||
|
||||
// Ollama Models Hook
|
||||
const {
|
||||
models: ollamaModels,
|
||||
@@ -74,24 +78,35 @@ export function ModelPicker({
|
||||
);
|
||||
}
|
||||
|
||||
// Fallback for cloud models
|
||||
return (
|
||||
MODEL_OPTIONS[selectedModel.provider]?.find(
|
||||
(model) => model.name === selectedModel.name,
|
||||
)?.displayName || selectedModel.name
|
||||
);
|
||||
// For cloud models, look up in the modelsByProviders data
|
||||
if (modelsByProviders && modelsByProviders[selectedModel.provider]) {
|
||||
const foundModel = modelsByProviders[selectedModel.provider].find(
|
||||
(model) => model.apiName === selectedModel.name,
|
||||
);
|
||||
if (foundModel) {
|
||||
return foundModel.displayName;
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback if not found
|
||||
return selectedModel.name;
|
||||
};
|
||||
|
||||
const modelDisplayName = getModelDisplayName();
|
||||
|
||||
// Flatten the cloud model options
|
||||
const cloudModels = Object.entries(MODEL_OPTIONS).flatMap(
|
||||
([provider, models]) =>
|
||||
models.map((model) => ({
|
||||
...model,
|
||||
provider: provider as ModelProvider,
|
||||
})),
|
||||
);
|
||||
// Flatten the cloud models from all providers
|
||||
const cloudModels =
|
||||
!providersLoading && modelsByProviders
|
||||
? Object.entries(modelsByProviders).flatMap(([providerId, models]) =>
|
||||
models.map((model) => ({
|
||||
name: model.apiName,
|
||||
displayName: model.displayName,
|
||||
description: model.description || "",
|
||||
tag: model.tag,
|
||||
provider: providerId as ModelProvider,
|
||||
})),
|
||||
)
|
||||
: [];
|
||||
|
||||
// Determine availability of local models
|
||||
const hasOllamaModels =
|
||||
@@ -119,43 +134,54 @@ export function ModelPicker({
|
||||
{/* Increased width slightly */}
|
||||
<DropdownMenuLabel>Cloud Models</DropdownMenuLabel>
|
||||
<DropdownMenuSeparator />
|
||||
{/* Cloud models */}
|
||||
{cloudModels.map((model) => (
|
||||
<Tooltip key={`${model.provider}-${model.name}`}>
|
||||
<TooltipTrigger asChild>
|
||||
<DropdownMenuItem
|
||||
className={
|
||||
selectedModel.provider === model.provider &&
|
||||
selectedModel.name === model.name
|
||||
? "bg-secondary"
|
||||
: ""
|
||||
}
|
||||
onClick={() => {
|
||||
onModelSelect({
|
||||
name: model.name,
|
||||
provider: model.provider,
|
||||
});
|
||||
setOpen(false);
|
||||
}}
|
||||
>
|
||||
<div className="flex justify-between items-start w-full">
|
||||
<span className="flex flex-col items-start">
|
||||
<span>{model.displayName}</span>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
{model.provider}
|
||||
{/* Cloud models - loading state */}
|
||||
{providersLoading ? (
|
||||
<div className="text-xs text-center py-2 text-muted-foreground">
|
||||
Loading models...
|
||||
</div>
|
||||
) : cloudModels.length === 0 ? (
|
||||
<div className="text-xs text-center py-2 text-muted-foreground">
|
||||
No cloud models available
|
||||
</div>
|
||||
) : (
|
||||
/* Cloud models loaded */
|
||||
cloudModels.map((model) => (
|
||||
<Tooltip key={`${model.provider}-${model.name}`}>
|
||||
<TooltipTrigger asChild>
|
||||
<DropdownMenuItem
|
||||
className={
|
||||
selectedModel.provider === model.provider &&
|
||||
selectedModel.name === model.name
|
||||
? "bg-secondary"
|
||||
: ""
|
||||
}
|
||||
onClick={() => {
|
||||
onModelSelect({
|
||||
name: model.name,
|
||||
provider: model.provider,
|
||||
});
|
||||
setOpen(false);
|
||||
}}
|
||||
>
|
||||
<div className="flex justify-between items-start w-full">
|
||||
<span className="flex flex-col items-start">
|
||||
<span>{model.displayName}</span>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
{model.provider}
|
||||
</span>
|
||||
</span>
|
||||
</span>
|
||||
{model.tag && (
|
||||
<span className="text-[10px] bg-primary/10 text-primary px-1.5 py-0.5 rounded-full font-medium">
|
||||
{model.tag}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</DropdownMenuItem>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="right">{model.description}</TooltipContent>
|
||||
</Tooltip>
|
||||
))}
|
||||
{model.tag && (
|
||||
<span className="text-[10px] bg-primary/10 text-primary px-1.5 py-0.5 rounded-full font-medium">
|
||||
{model.tag}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</DropdownMenuItem>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="right">{model.description}</TooltipContent>
|
||||
</Tooltip>
|
||||
))
|
||||
)}
|
||||
<DropdownMenuSeparator />
|
||||
{/* Local Models Parent SubMenu */}
|
||||
<DropdownMenuSub>
|
||||
|
||||
@@ -85,9 +85,11 @@ export const language_models = sqliteTable("language_models", {
|
||||
id: integer("id").primaryKey({ autoIncrement: true }),
|
||||
displayName: text("display_name").notNull(),
|
||||
apiName: text("api_name").notNull(),
|
||||
provider_id: text("provider_id")
|
||||
.notNull()
|
||||
.references(() => language_model_providers.id, { onDelete: "cascade" }),
|
||||
builtinProviderId: text("builtin_provider_id"),
|
||||
customProviderId: text("custom_provider_id").references(
|
||||
() => language_model_providers.id,
|
||||
{ onDelete: "cascade" },
|
||||
),
|
||||
description: text("description"),
|
||||
max_output_tokens: integer("max_output_tokens"),
|
||||
context_window: integer("context_window"),
|
||||
@@ -111,7 +113,7 @@ export const languageModelsRelations = relations(
|
||||
language_models,
|
||||
({ one }) => ({
|
||||
provider: one(language_model_providers, {
|
||||
fields: [language_models.provider_id],
|
||||
fields: [language_models.customProviderId],
|
||||
references: [language_model_providers.id],
|
||||
}),
|
||||
}),
|
||||
|
||||
19
src/hooks/useLanguageModelsByProviders.ts
Normal file
19
src/hooks/useLanguageModelsByProviders.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import { IpcClient } from "@/ipc/ipc_client";
|
||||
import type { LanguageModel } from "@/ipc/ipc_types";
|
||||
|
||||
/**
|
||||
* Fetches all available language models grouped by their provider IDs.
|
||||
*
|
||||
* @returns TanStack Query result object for the language models organized by provider.
|
||||
*/
|
||||
export function useLanguageModelsByProviders() {
|
||||
const ipcClient = IpcClient.getInstance();
|
||||
|
||||
return useQuery<Record<string, LanguageModel[]>, Error>({
|
||||
queryKey: ["language-models-by-providers"],
|
||||
queryFn: async () => {
|
||||
return ipcClient.getLanguageModelsByProviders();
|
||||
},
|
||||
});
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import log from "electron-log";
|
||||
import {
|
||||
getLanguageModelProviders,
|
||||
getLanguageModels,
|
||||
getLanguageModelsByProviders,
|
||||
} from "../shared/language_model_helpers";
|
||||
import { db } from "@/db";
|
||||
import {
|
||||
@@ -64,7 +65,8 @@ export function registerLanguageModelHandlers() {
|
||||
|
||||
// Insert the new provider
|
||||
await db.insert(languageModelProvidersSchema).values({
|
||||
id,
|
||||
// Make sure we will never have accidental collisions with builtin providers
|
||||
id: "custom::" + id,
|
||||
name,
|
||||
api_base_url: apiBaseUrl,
|
||||
env_var_name: envVarName || null,
|
||||
@@ -108,12 +110,8 @@ export function registerLanguageModelHandlers() {
|
||||
}
|
||||
|
||||
// Check if provider exists
|
||||
const provider = db
|
||||
.select()
|
||||
.from(languageModelProvidersSchema)
|
||||
.where(eq(languageModelProvidersSchema.id, providerId))
|
||||
.get();
|
||||
|
||||
const providers = await getLanguageModelProviders();
|
||||
const provider = providers.find((p) => p.id === providerId);
|
||||
if (!provider) {
|
||||
throw new Error(`Provider with ID "${providerId}" not found`);
|
||||
}
|
||||
@@ -122,7 +120,8 @@ export function registerLanguageModelHandlers() {
|
||||
await db.insert(languageModelsSchema).values({
|
||||
displayName,
|
||||
apiName,
|
||||
provider_id: providerId,
|
||||
builtinProviderId: provider.type === "cloud" ? providerId : undefined,
|
||||
customProviderId: provider.type === "custom" ? providerId : undefined,
|
||||
description: description || null,
|
||||
max_output_tokens: maxOutputTokens || null,
|
||||
context_window: contextWindow || null,
|
||||
@@ -182,11 +181,22 @@ export function registerLanguageModelHandlers() {
|
||||
`Attempting to delete custom model ${modelApiName} for provider ${providerId}`,
|
||||
);
|
||||
|
||||
const providers = await getLanguageModelProviders();
|
||||
const provider = providers.find((p) => p.id === providerId);
|
||||
if (!provider) {
|
||||
throw new Error(`Provider with ID "${providerId}" not found`);
|
||||
}
|
||||
if (provider.type === "local") {
|
||||
throw new Error("Local models cannot be deleted");
|
||||
}
|
||||
const result = db
|
||||
.delete(language_models)
|
||||
.where(
|
||||
and(
|
||||
eq(language_models.provider_id, providerId),
|
||||
provider.type === "cloud"
|
||||
? eq(language_models.builtinProviderId, providerId)
|
||||
: eq(language_models.customProviderId, providerId),
|
||||
|
||||
eq(language_models.apiName, modelApiName),
|
||||
),
|
||||
)
|
||||
@@ -243,7 +253,7 @@ export function registerLanguageModelHandlers() {
|
||||
// 1. Delete associated models
|
||||
const deleteModelsResult = await tx
|
||||
.delete(languageModelsSchema)
|
||||
.where(eq(languageModelsSchema.provider_id, providerId))
|
||||
.where(eq(languageModelsSchema.customProviderId, providerId))
|
||||
.run();
|
||||
logger.info(
|
||||
`Deleted ${deleteModelsResult.changes} model(s) associated with provider ${providerId}`,
|
||||
@@ -279,7 +289,26 @@ export function registerLanguageModelHandlers() {
|
||||
if (!params || typeof params.providerId !== "string") {
|
||||
throw new Error("Invalid parameters: providerId (string) is required.");
|
||||
}
|
||||
return getLanguageModels({ providerId: params.providerId });
|
||||
const providers = await getLanguageModelProviders();
|
||||
const provider = providers.find((p) => p.id === params.providerId);
|
||||
if (!provider) {
|
||||
throw new Error(`Provider with ID "${params.providerId}" not found`);
|
||||
}
|
||||
if (provider.type === "local") {
|
||||
throw new Error("Local models cannot be fetched");
|
||||
}
|
||||
return getLanguageModels(
|
||||
provider.type === "cloud"
|
||||
? { builtinProviderId: params.providerId }
|
||||
: { customProviderId: params.providerId },
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
handle(
|
||||
"get-language-models-by-providers",
|
||||
async (): Promise<Record<string, LanguageModel[]>> => {
|
||||
return getLanguageModelsByProviders();
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
@@ -746,6 +746,12 @@ export class IpcClient {
|
||||
return this.ipcRenderer.invoke("get-language-models", params);
|
||||
}
|
||||
|
||||
public async getLanguageModelsByProviders(): Promise<
|
||||
Record<string, LanguageModel[]>
|
||||
> {
|
||||
return this.ipcRenderer.invoke("get-language-models-by-providers");
|
||||
}
|
||||
|
||||
public async createCustomLanguageModelProvider({
|
||||
id,
|
||||
name,
|
||||
|
||||
@@ -139,23 +139,72 @@ export async function getLanguageModelProviders(): Promise<
|
||||
* @param obj An object containing the providerId.
|
||||
* @returns A promise that resolves to an array of LanguageModel objects.
|
||||
*/
|
||||
export async function getLanguageModels(obj: {
|
||||
providerId: string;
|
||||
}): Promise<LanguageModel[]> {
|
||||
const { providerId } = obj;
|
||||
export async function getLanguageModels(
|
||||
obj:
|
||||
| {
|
||||
customProviderId: string;
|
||||
// builtinProviderId?: undefined;
|
||||
}
|
||||
| {
|
||||
builtinProviderId: string;
|
||||
// customProviderId?: undefined;
|
||||
},
|
||||
): Promise<LanguageModel[]> {
|
||||
const allProviders = await getLanguageModelProviders();
|
||||
const provider = allProviders.find((p) => p.id === providerId);
|
||||
const provider = allProviders.find(
|
||||
(p) =>
|
||||
p.id === (obj as { customProviderId: string }).customProviderId ||
|
||||
p.id === (obj as { builtinProviderId: string }).builtinProviderId,
|
||||
);
|
||||
|
||||
if (!provider) {
|
||||
console.warn(`Provider with ID "${providerId}" not found.`);
|
||||
console.warn(`Provider with ID "${JSON.stringify(obj)}" not found.`);
|
||||
return [];
|
||||
}
|
||||
|
||||
// Get custom models from DB for all provider types
|
||||
let customModels: LanguageModel[] = [];
|
||||
|
||||
try {
|
||||
const customModelsDb = await db
|
||||
.select({
|
||||
id: languageModelsSchema.id,
|
||||
displayName: languageModelsSchema.displayName,
|
||||
apiName: languageModelsSchema.apiName,
|
||||
description: languageModelsSchema.description,
|
||||
maxOutputTokens: languageModelsSchema.max_output_tokens,
|
||||
contextWindow: languageModelsSchema.context_window,
|
||||
})
|
||||
.from(languageModelsSchema)
|
||||
.where(
|
||||
"customProviderId" in obj
|
||||
? eq(languageModelsSchema.customProviderId, obj.customProviderId)
|
||||
: eq(languageModelsSchema.builtinProviderId, obj.builtinProviderId),
|
||||
);
|
||||
|
||||
customModels = customModelsDb.map((model) => ({
|
||||
...model,
|
||||
description: model.description ?? "",
|
||||
tag: undefined,
|
||||
maxOutputTokens: model.maxOutputTokens ?? undefined,
|
||||
contextWindow: model.contextWindow ?? undefined,
|
||||
type: "custom",
|
||||
}));
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Error fetching custom models for provider "${JSON.stringify(obj)}" from DB:`,
|
||||
error,
|
||||
);
|
||||
// Continue with empty custom models array
|
||||
}
|
||||
|
||||
// If it's a cloud provider, also get the hardcoded models
|
||||
let hardcodedModels: LanguageModel[] = [];
|
||||
const providerId = provider.id;
|
||||
if (provider.type === "cloud") {
|
||||
// Check if providerId is a valid key for MODEL_OPTIONS
|
||||
if (providerId in MODEL_OPTIONS) {
|
||||
const models = MODEL_OPTIONS[providerId as RegularModelProvider] || [];
|
||||
return models.map((model) => ({
|
||||
hardcodedModels = models.map((model) => ({
|
||||
...model,
|
||||
apiName: model.name,
|
||||
type: "cloud",
|
||||
@@ -164,49 +213,54 @@ export async function getLanguageModels(obj: {
|
||||
console.warn(
|
||||
`Provider "${providerId}" is cloud type but not found in MODEL_OPTIONS.`,
|
||||
);
|
||||
return [];
|
||||
}
|
||||
} else if (provider.type === "custom") {
|
||||
// Fetch models from the database for this custom provider
|
||||
// Assuming a language_models table with necessary columns and provider_id foreign key
|
||||
try {
|
||||
const customModelsDb = await db
|
||||
.select({
|
||||
id: languageModelsSchema.id,
|
||||
// Map DB columns to LanguageModel fields
|
||||
displayName: languageModelsSchema.displayName,
|
||||
apiName: languageModelsSchema.apiName,
|
||||
// No display_name in DB, use name instead
|
||||
description: languageModelsSchema.description,
|
||||
// No tag in DB
|
||||
maxOutputTokens: languageModelsSchema.max_output_tokens,
|
||||
contextWindow: languageModelsSchema.context_window,
|
||||
})
|
||||
.from(languageModelsSchema)
|
||||
.where(eq(languageModelsSchema.provider_id, providerId)); // Assuming eq is imported or available
|
||||
|
||||
return customModelsDb.map((model) => ({
|
||||
...model,
|
||||
// Ensure possibly null fields are handled, provide defaults or undefined if needed
|
||||
description: model.description ?? "",
|
||||
tag: undefined, // No tag for custom models from DB
|
||||
maxOutputTokens: model.maxOutputTokens ?? undefined,
|
||||
contextWindow: model.contextWindow ?? undefined,
|
||||
type: "custom",
|
||||
}));
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Error fetching custom models for provider "${providerId}" from DB:`,
|
||||
error,
|
||||
);
|
||||
// Depending on desired behavior, could throw, return empty, or return a specific error state
|
||||
return [];
|
||||
}
|
||||
} else {
|
||||
// Handle other types like "local" if necessary, currently ignored
|
||||
console.warn(
|
||||
`Provider type "${provider.type}" not handled for model fetching.`,
|
||||
);
|
||||
return [];
|
||||
}
|
||||
|
||||
// Merge the models, with custom models taking precedence over hardcoded ones
|
||||
const mergedModelsMap = new Map<string, LanguageModel>();
|
||||
|
||||
// Add hardcoded models first
|
||||
for (const model of hardcodedModels) {
|
||||
mergedModelsMap.set(model.apiName, model);
|
||||
}
|
||||
|
||||
// Then override with custom models
|
||||
for (const model of customModels) {
|
||||
mergedModelsMap.set(model.apiName, model);
|
||||
}
|
||||
|
||||
return Array.from(mergedModelsMap.values());
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetches all language models grouped by their provider IDs.
|
||||
* @returns A promise that resolves to a Record mapping provider IDs to arrays of LanguageModel objects.
|
||||
*/
|
||||
export async function getLanguageModelsByProviders(): Promise<
|
||||
Record<string, LanguageModel[]>
|
||||
> {
|
||||
const providers = await getLanguageModelProviders();
|
||||
|
||||
// Fetch all models concurrently
|
||||
const modelPromises = providers
|
||||
.filter((p) => p.type !== "local")
|
||||
.map(async (provider) => {
|
||||
const models = await getLanguageModels(
|
||||
provider.type === "cloud"
|
||||
? { builtinProviderId: provider.id }
|
||||
: { customProviderId: provider.id },
|
||||
);
|
||||
return { providerId: provider.id, models };
|
||||
});
|
||||
|
||||
// Wait for all requests to complete
|
||||
const results = await Promise.all(modelPromises);
|
||||
|
||||
// Convert the array of results to a record
|
||||
const record: Record<string, LanguageModel[]> = {};
|
||||
for (const result of results) {
|
||||
record[result.providerId] = result.models;
|
||||
}
|
||||
|
||||
return record;
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import { contextBridge, ipcRenderer } from "electron";
|
||||
// Whitelist of valid channels
|
||||
const validInvokeChannels = [
|
||||
"get-language-models",
|
||||
"get-language-models-by-providers",
|
||||
"create-custom-language-model",
|
||||
"get-language-model-providers",
|
||||
"delete-custom-language-model-provider",
|
||||
|
||||
Reference in New Issue
Block a user