Files
moreminimore-vibe/src/ipc/shared/language_model_helpers.ts
Tanner-Maasen 2ffbbbca8f Add Azure OpenAI Custom Model Integration (#1001)
Fixes #710 

This PR implements comprehensive Azure OpenAI integration for Dyad,
enabling users to leverage Azure
OpenAI models through proper environment variable configuration. The
implementation adds Azure as a
supported provider with full integration into the existing language
model architecture, including support
  for GPT-5 models. Key features include environment-based
configuration using `AZURE_API_KEY` and `AZURE_RESOURCE_NAME`,
specialized UI components that provide clear
setup instructions and status indicators, and seamless integration with
Dyad's existing provider system.
The Azure provider leverages the @ai-sdk/azure package (v1.3.25) for
compatibility with the current
  TypeScript language model interfaces.

The implementation includes robust error handling for missing
configuration, comprehensive test coverage
with 9 new unit tests covering critical functionality like model client
creation and error scenarios, and
  an E2E test for the Azure-specific settings UI. 

<img width="1510" height="908" alt="Screenshot 2025-08-18 at 9 14 32 PM"
src="https://github.com/user-attachments/assets/04aa99e1-1590-4bb0-86c9-a67b97bc7500"
/>

---------

Co-authored-by: graphite-app[bot] <96075541+graphite-app[bot]@users.noreply.github.com>
Co-authored-by: Will Chen <willchen90@gmail.com>
2025-08-30 20:47:25 -07:00

494 lines
15 KiB
TypeScript

import { db } from "@/db";
import {
language_model_providers as languageModelProvidersSchema,
language_models as languageModelsSchema,
} from "@/db/schema";
import type { LanguageModelProvider, LanguageModel } from "@/ipc/ipc_types";
import { eq } from "drizzle-orm";
export const PROVIDERS_THAT_SUPPORT_THINKING: (keyof typeof MODEL_OPTIONS)[] = [
"google",
"auto",
];
export interface ModelOption {
name: string;
displayName: string;
description: string;
temperature?: number;
tag?: string;
maxOutputTokens?: number;
contextWindow?: number;
}
export const MODEL_OPTIONS: Record<string, ModelOption[]> = {
openai: [
// https://platform.openai.com/docs/models/gpt-5
{
name: "gpt-5",
displayName: "GPT 5",
description: "OpenAI's flagship model",
// Technically it's 128k but OpenAI errors if you set max_tokens instead of max_completion_tokens
maxOutputTokens: undefined,
contextWindow: 400_000,
// Requires temperature to be default value (1)
temperature: 1,
},
// https://platform.openai.com/docs/models/gpt-5-mini
{
name: "gpt-5-mini",
displayName: "GPT 5 Mini",
description: "OpenAI's lightweight, but intelligent model",
// Technically it's 128k but OpenAI errors if you set max_tokens instead of max_completion_tokens
maxOutputTokens: undefined,
contextWindow: 400_000,
// Requires temperature to be default value (1)
temperature: 1,
},
// https://platform.openai.com/docs/models/gpt-5-nano
{
name: "gpt-5-nano",
displayName: "GPT 5 Nano",
description: "Fastest, most cost-efficient version of GPT-5",
// Technically it's 128k but OpenAI errors if you set max_tokens instead of max_completion_tokens
maxOutputTokens: undefined,
contextWindow: 400_000,
// Requires temperature to be default value (1)
temperature: 1,
},
// https://platform.openai.com/docs/models/gpt-4.1
{
name: "gpt-4.1",
displayName: "GPT 4.1",
description: "OpenAI's flagship model",
maxOutputTokens: 32_768,
contextWindow: 1_047_576,
temperature: 0,
},
// https://platform.openai.com/docs/models/gpt-4.1-mini
{
name: "gpt-4.1-mini",
displayName: "GPT 4.1 Mini",
description: "OpenAI's lightweight, but intelligent model",
maxOutputTokens: 32_768,
contextWindow: 1_047_576,
temperature: 0,
},
// https://platform.openai.com/docs/models/o3-mini
{
name: "o3-mini",
displayName: "o3 mini",
description: "Reasoning model",
// See o4-mini comment below for why we set this to 32k
maxOutputTokens: 32_000,
contextWindow: 200_000,
temperature: 0,
},
// https://platform.openai.com/docs/models/o4-mini
{
name: "o4-mini",
displayName: "o4 mini",
description: "Reasoning model",
// Technically the max output tokens is 100k, *however* if the user has a lot of input tokens,
// then setting a high max output token will cause the request to fail because
// the max output tokens is *included* in the context window limit.
maxOutputTokens: 32_000,
contextWindow: 200_000,
temperature: 0,
},
],
// https://docs.anthropic.com/en/docs/about-claude/models/all-models#model-comparison-table
anthropic: [
{
name: "claude-sonnet-4-20250514",
displayName: "Claude 4 Sonnet",
description: "Excellent coder",
// See comment below for Claude 3.7 Sonnet for why we set this to 16k
maxOutputTokens: 16_000,
contextWindow: 200_000,
temperature: 0,
},
{
name: "claude-3-7-sonnet-latest",
displayName: "Claude 3.7 Sonnet",
description: "Excellent coder",
// Technically the max output tokens is 64k, *however* if the user has a lot of input tokens,
// then setting a high max output token will cause the request to fail because
// the max output tokens is *included* in the context window limit, see:
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#max-tokens-and-context-window-size-with-extended-thinking
maxOutputTokens: 16_000,
contextWindow: 200_000,
temperature: 0,
},
{
name: "claude-3-5-sonnet-20241022",
displayName: "Claude 3.5 Sonnet",
description: "Good coder, excellent at following instructions",
maxOutputTokens: 8_000,
contextWindow: 200_000,
temperature: 0,
},
{
name: "claude-3-5-haiku-20241022",
displayName: "Claude 3.5 Haiku",
description: "Lightweight coder",
maxOutputTokens: 8_000,
contextWindow: 200_000,
temperature: 0,
},
],
google: [
// https://ai.google.dev/gemini-api/docs/models#gemini-2.5-pro-preview-03-25
{
name: "gemini-2.5-pro",
displayName: "Gemini 2.5 Pro",
description: "Google's Gemini 2.5 Pro model",
// See Flash 2.5 comment below (go 1 below just to be safe, even though it seems OK now).
maxOutputTokens: 65_536 - 1,
// Gemini context window = input token + output token
contextWindow: 1_048_576,
temperature: 0,
},
// https://ai.google.dev/gemini-api/docs/models#gemini-2.5-flash-preview
{
name: "gemini-2.5-flash",
displayName: "Gemini 2.5 Flash",
description: "Google's Gemini 2.5 Flash model (free tier available)",
// Weirdly for Vertex AI, the output token limit is *exclusive* of the stated limit.
maxOutputTokens: 65_536 - 1,
// Gemini context window = input token + output token
contextWindow: 1_048_576,
temperature: 0,
},
],
openrouter: [
{
name: "qwen/qwen3-coder",
displayName: "Qwen3 Coder",
description: "Qwen's best coding model",
maxOutputTokens: 32_000,
contextWindow: 262_000,
temperature: 0,
},
// https://openrouter.ai/deepseek/deepseek-chat-v3-0324:free
{
name: "deepseek/deepseek-chat-v3-0324:free",
displayName: "DeepSeek v3 (free)",
description: "Use for free (data may be used for training)",
maxOutputTokens: 32_000,
contextWindow: 128_000,
temperature: 0,
},
// https://openrouter.ai/moonshotai/kimi-k2
{
name: "moonshotai/kimi-k2",
displayName: "Kimi K2",
description: "Powerful cost-effective model",
maxOutputTokens: 32_000,
contextWindow: 131_000,
temperature: 0,
},
{
name: "deepseek/deepseek-r1-0528",
displayName: "DeepSeek R1",
description: "Good reasoning model with excellent price for performance",
maxOutputTokens: 32_000,
contextWindow: 128_000,
temperature: 0,
},
],
auto: [
{
name: "auto",
displayName: "Auto",
description: "Automatically selects the best model",
tag: "Default",
// These are below Gemini 2.5 Pro & Flash limits
// which are the ones defaulted to for both regular auto
// and smart auto.
maxOutputTokens: 32_000,
contextWindow: 1_000_000,
temperature: 0,
},
],
azure: [
{
name: "gpt-5",
displayName: "GPT-5",
description: "Azure OpenAI GPT-5 model with reasoning capabilities",
maxOutputTokens: 128_000,
contextWindow: 400_000,
temperature: 0,
},
{
name: "gpt-5-mini",
displayName: "GPT-5 Mini",
description: "Azure OpenAI GPT-5 Mini model",
maxOutputTokens: 128_000,
contextWindow: 400_000,
temperature: 0,
},
{
name: "gpt-5-nano",
displayName: "GPT-5 Nano",
description: "Azure OpenAI GPT-5 Nano model",
maxOutputTokens: 128_000,
contextWindow: 400_000,
temperature: 0,
},
{
name: "gpt-5-chat",
displayName: "GPT-5 Chat",
description: "Azure OpenAI GPT-5 Chat model",
maxOutputTokens: 16_384,
contextWindow: 128_000,
temperature: 0,
},
],
};
export const PROVIDER_TO_ENV_VAR: Record<string, string> = {
openai: "OPENAI_API_KEY",
anthropic: "ANTHROPIC_API_KEY",
google: "GEMINI_API_KEY",
openrouter: "OPENROUTER_API_KEY",
azure: "AZURE_API_KEY",
};
export const CLOUD_PROVIDERS: Record<
string,
{
displayName: string;
hasFreeTier?: boolean;
websiteUrl?: string;
gatewayPrefix: string;
}
> = {
openai: {
displayName: "OpenAI",
hasFreeTier: false,
websiteUrl: "https://platform.openai.com/api-keys",
gatewayPrefix: "",
},
anthropic: {
displayName: "Anthropic",
hasFreeTier: false,
websiteUrl: "https://console.anthropic.com/settings/keys",
gatewayPrefix: "anthropic/",
},
google: {
displayName: "Google",
hasFreeTier: true,
websiteUrl: "https://aistudio.google.com/app/apikey",
gatewayPrefix: "gemini/",
},
openrouter: {
displayName: "OpenRouter",
hasFreeTier: true,
websiteUrl: "https://openrouter.ai/settings/keys",
gatewayPrefix: "openrouter/",
},
auto: {
displayName: "Dyad",
websiteUrl: "https://academy.dyad.sh/settings",
gatewayPrefix: "dyad/",
},
azure: {
displayName: "Azure OpenAI",
hasFreeTier: false,
websiteUrl: "https://portal.azure.com/",
gatewayPrefix: "",
},
};
const LOCAL_PROVIDERS: Record<
string,
{
displayName: string;
hasFreeTier: boolean;
}
> = {
ollama: {
displayName: "Ollama",
hasFreeTier: true,
},
lmstudio: {
displayName: "LM Studio",
hasFreeTier: true,
},
};
/**
* Fetches language model providers from both the database (custom) and hardcoded constants (cloud),
* merging them with custom providers taking precedence.
* @returns A promise that resolves to an array of LanguageModelProvider objects.
*/
export async function getLanguageModelProviders(): Promise<
LanguageModelProvider[]
> {
// Fetch custom providers from the database
const customProvidersDb = await db
.select()
.from(languageModelProvidersSchema);
const customProvidersMap = new Map<string, LanguageModelProvider>();
for (const cp of customProvidersDb) {
customProvidersMap.set(cp.id, {
id: cp.id,
name: cp.name,
apiBaseUrl: cp.api_base_url,
envVarName: cp.env_var_name ?? undefined,
type: "custom",
// hasFreeTier, websiteUrl, gatewayPrefix are not in the custom DB schema
// They will be undefined unless overridden by hardcoded values if IDs match
});
}
// Get hardcoded cloud providers
const hardcodedProviders: LanguageModelProvider[] = [];
for (const providerKey in CLOUD_PROVIDERS) {
if (Object.prototype.hasOwnProperty.call(CLOUD_PROVIDERS, providerKey)) {
// Ensure providerKey is a key of PROVIDERS
const key = providerKey as keyof typeof CLOUD_PROVIDERS;
const providerDetails = CLOUD_PROVIDERS[key];
if (providerDetails) {
// Ensure providerDetails is not undefined
hardcodedProviders.push({
id: key,
name: providerDetails.displayName,
hasFreeTier: providerDetails.hasFreeTier,
websiteUrl: providerDetails.websiteUrl,
gatewayPrefix: providerDetails.gatewayPrefix,
envVarName: PROVIDER_TO_ENV_VAR[key] ?? undefined,
type: "cloud",
// apiBaseUrl is not directly in PROVIDERS
});
}
}
}
for (const providerKey in LOCAL_PROVIDERS) {
if (Object.prototype.hasOwnProperty.call(LOCAL_PROVIDERS, providerKey)) {
const key = providerKey as keyof typeof LOCAL_PROVIDERS;
const providerDetails = LOCAL_PROVIDERS[key];
hardcodedProviders.push({
id: key,
name: providerDetails.displayName,
hasFreeTier: providerDetails.hasFreeTier,
type: "local",
});
}
}
return [...hardcodedProviders, ...customProvidersMap.values()];
}
/**
* Fetches language models for a specific provider.
* @param obj An object containing the providerId.
* @returns A promise that resolves to an array of LanguageModel objects.
*/
export async function getLanguageModels({
providerId,
}: {
providerId: string;
}): Promise<LanguageModel[]> {
const allProviders = await getLanguageModelProviders();
const provider = allProviders.find((p) => p.id === providerId);
if (!provider) {
console.warn(`Provider with ID "${providerId}" not found.`);
return [];
}
// Get custom models from DB for all provider types
let customModels: LanguageModel[] = [];
try {
const customModelsDb = await db
.select({
id: languageModelsSchema.id,
displayName: languageModelsSchema.displayName,
apiName: languageModelsSchema.apiName,
description: languageModelsSchema.description,
maxOutputTokens: languageModelsSchema.max_output_tokens,
contextWindow: languageModelsSchema.context_window,
})
.from(languageModelsSchema)
.where(
isCustomProvider({ providerId })
? eq(languageModelsSchema.customProviderId, providerId)
: eq(languageModelsSchema.builtinProviderId, providerId),
);
customModels = customModelsDb.map((model) => ({
...model,
description: model.description ?? "",
tag: undefined,
maxOutputTokens: model.maxOutputTokens ?? undefined,
contextWindow: model.contextWindow ?? undefined,
type: "custom",
}));
} catch (error) {
console.error(
`Error fetching custom models for provider "${providerId}" from DB:`,
error,
);
// Continue with empty custom models array
}
// If it's a cloud provider, also get the hardcoded models
let hardcodedModels: LanguageModel[] = [];
if (provider.type === "cloud") {
if (providerId in MODEL_OPTIONS) {
const models = MODEL_OPTIONS[providerId] || [];
hardcodedModels = models.map((model) => ({
...model,
apiName: model.name,
type: "cloud",
}));
} else {
console.warn(
`Provider "${providerId}" is cloud type but not found in MODEL_OPTIONS.`,
);
}
}
return [...hardcodedModels, ...customModels];
}
/**
* Fetches all language models grouped by their provider IDs.
* @returns A promise that resolves to a Record mapping provider IDs to arrays of LanguageModel objects.
*/
export async function getLanguageModelsByProviders(): Promise<
Record<string, LanguageModel[]>
> {
const providers = await getLanguageModelProviders();
// Fetch all models concurrently
const modelPromises = providers
.filter((p) => p.type !== "local")
.map(async (provider) => {
const models = await getLanguageModels({ providerId: provider.id });
return { providerId: provider.id, models };
});
// Wait for all requests to complete
const results = await Promise.all(modelPromises);
// Convert the array of results to a record
const record: Record<string, LanguageModel[]> = {};
for (const result of results) {
record[result.providerId] = result.models;
}
return record;
}
export function isCustomProvider({ providerId }: { providerId: string }) {
return providerId.startsWith(CUSTOM_PROVIDER_PREFIX);
}
export const CUSTOM_PROVIDER_PREFIX = "custom::";