Add structured podcast research cost_est across backend/frontend

This commit is contained in:
ي
2026-04-19 16:13:46 +05:30
parent bcf62017aa
commit 196ea65af9
6 changed files with 167 additions and 15 deletions

View File

@@ -9,12 +9,16 @@ from typing import Dict, Any, List
from types import SimpleNamespace from types import SimpleNamespace
import json import json
import re import re
from datetime import datetime, timezone
from middleware.auth_middleware import get_current_user from middleware.auth_middleware import get_current_user
from api.story_writer.utils.auth import require_authenticated_user from api.story_writer.utils.auth import require_authenticated_user
from services.blog_writer.research.exa_provider import ExaResearchProvider from services.blog_writer.research.exa_provider import ExaResearchProvider
from services.llm_providers.main_text_generation import llm_text_gen from services.llm_providers.main_text_generation import llm_text_gen
from services.podcast_bible_service import PodcastBibleService from services.podcast_bible_service import PodcastBibleService
from services.database import get_db
from services.subscription import PricingService
from models.subscription_models import APIProvider
from loguru import logger from loguru import logger
from ..models import ( from ..models import (
PodcastExaResearchRequest, PodcastExaResearchRequest,
@@ -23,11 +27,101 @@ from ..models import (
PodcastExaConfig, PodcastExaConfig,
PodcastResearchInsight, PodcastResearchInsight,
PodcastResearchOutput, PodcastResearchOutput,
PodcastCostEst,
PodcastCostBreakdownItem,
) )
router = APIRouter() router = APIRouter()
def _estimate_tokens(text: str) -> int:
if not text:
return 0
return max(1, len(text) // 4)
def _get_price_from_catalog(
pricing_service: PricingService,
provider: APIProvider,
model_name: str,
key: str,
fallback: float = 0.0,
) -> float:
try:
pricing = pricing_service.get_pricing_for_provider_model(provider, model_name) or {}
value = pricing.get(key)
return float(value or fallback)
except Exception:
return fallback
def _build_research_cost_estimate(
request: PodcastExaResearchRequest,
raw_content: str,
sources_count: int,
provider_result: Dict[str, Any],
) -> PodcastCostEst:
# Fallback defaults mirror current catalog defaults.
exa_per_request = 0.005
gemini_in_token = 0.00000015
gemini_out_token = 0.0000006
try:
db = next(get_db())
try:
pricing_service = PricingService(db)
exa_per_request = _get_price_from_catalog(
pricing_service, APIProvider.EXA, "exa-search", "cost_per_request", exa_per_request
)
gemini_pricing = pricing_service.get_pricing_for_provider_model(APIProvider.GEMINI, "gemini-2.5-flash") or {}
gemini_in_token = float(gemini_pricing.get("cost_per_input_token") or gemini_in_token)
gemini_out_token = float(gemini_pricing.get("cost_per_output_token") or gemini_out_token)
finally:
db.close()
except Exception as pricing_err:
logger.warning(f"[Podcast Research] Failed loading pricing catalog; using defaults: {pricing_err}")
query_count = max(1, len(request.queries or []))
source_count = max(1, sources_count)
analyze_tokens = _estimate_tokens(request.topic) + sum(_estimate_tokens(q) for q in request.queries or [])
gather_search_calls = max(1, query_count)
gather_cost = gather_search_calls * exa_per_request
write_input_tokens = _estimate_tokens(raw_content) + _estimate_tokens(request.topic) + (query_count * 40)
write_output_tokens = max(500, int(write_input_tokens * 0.22))
write_cost = (write_input_tokens * gemini_in_token) + (write_output_tokens * gemini_out_token)
# "Produce" is shaping the final API payload and mapped artifacts.
produce_tokens = max(120, source_count * 30)
produce_cost = (produce_tokens * gemini_in_token) + (produce_tokens * 0.5 * gemini_out_token)
analyze_cost = analyze_tokens * gemini_in_token
provider_total = 0.0
if isinstance(provider_result, dict):
provider_total = float((provider_result.get("cost") or {}).get("total") or 0.0)
# Prefer transparent estimate built from catalog + usage. If provider reports a higher measured value, keep it.
estimated_total = analyze_cost + gather_cost + write_cost + produce_cost
scale = (provider_total / estimated_total) if estimated_total > 0 and provider_total > estimated_total else 1.0
breakdown = [
PodcastCostBreakdownItem(phase="Analyze", cost=round(analyze_cost * scale, 6)),
PodcastCostBreakdownItem(phase="Gather", cost=round(gather_cost * scale, 6)),
PodcastCostBreakdownItem(phase="Write", cost=round(write_cost * scale, 6)),
PodcastCostBreakdownItem(phase="Produce", cost=round(produce_cost * scale, 6)),
]
total = round(sum(item.cost for item in breakdown), 6)
return PodcastCostEst(
total=total,
breakdown=breakdown,
currency="USD",
last_updated=datetime.now(timezone.utc),
)
@router.post("/research/exa", response_model=PodcastExaResearchResponse) @router.post("/research/exa", response_model=PodcastExaResearchResponse)
async def podcast_research_exa( async def podcast_research_exa(
request: PodcastExaResearchRequest, request: PodcastExaResearchRequest,
@@ -302,9 +396,13 @@ QUALITY STANDARDS:
search_queries=result.get("search_queries", queries) if isinstance(result, dict) else queries, search_queries=result.get("search_queries", queries) if isinstance(result, dict) else queries,
summary=summary, summary=summary,
key_insights=key_insights, key_insights=key_insights,
cost=result.get("cost") if isinstance(result, dict) else None, cost_est=_build_research_cost_estimate(
request=request,
raw_content=raw_content,
sources_count=len(sources_payload),
provider_result=result if isinstance(result, dict) else {},
),
search_type=result.get("search_type") if isinstance(result, dict) else None, search_type=result.get("search_type") if isinstance(result, dict) else None,
provider=result.get("provider", "exa") if isinstance(result, dict) else "exa", provider=result.get("provider", "exa") if isinstance(result, dict) else "exa",
content=raw_content, content=raw_content,
) )

View File

@@ -5,7 +5,7 @@ All Pydantic request/response models for podcast endpoints.
""" """
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
from typing import List, Optional, Dict, Any from typing import List, Optional, Dict, Any, Literal
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
@@ -181,12 +181,24 @@ class PodcastResearchOutput(BaseModel):
mapped_angles: List[Dict[str, Any]] = [] # [{"title": str, "why": str, "mapped_fact_ids": []}] mapped_angles: List[Dict[str, Any]] = [] # [{"title": str, "why": str, "mapped_fact_ids": []}]
class PodcastCostBreakdownItem(BaseModel):
phase: Literal["Analyze", "Gather", "Write", "Produce"]
cost: float
class PodcastCostEst(BaseModel):
total: float
breakdown: List[PodcastCostBreakdownItem]
currency: Literal["USD"] = "USD"
last_updated: datetime
class PodcastExaResearchResponse(BaseModel): class PodcastExaResearchResponse(BaseModel):
sources: List[PodcastExaSource] sources: List[PodcastExaSource]
search_queries: List[str] = [] search_queries: List[str] = []
summary: str = "" summary: str = ""
key_insights: List[PodcastResearchInsight] = [] key_insights: List[PodcastResearchInsight] = []
cost: Optional[Dict[str, Any]] = None cost_est: PodcastCostEst
search_type: Optional[str] = None search_type: Optional[str] = None
provider: str = "exa" provider: str = "exa"
content: Optional[str] = None # Raw aggregated content (deprecated) content: Optional[str] = None # Raw aggregated content (deprecated)
@@ -450,4 +462,3 @@ class VoiceCloneResult(BaseModel):
file_size: int file_size: int
task_id: str task_id: str
status: str = "completed" status: str = "completed"

View File

@@ -130,10 +130,10 @@ export const ResearchSummary: React.FC<ResearchSummaryProps> = ({
}} }}
/> />
)} )}
{research.cost !== undefined && ( {research.costEst?.total !== undefined && (
<Chip <Chip
icon={<AttachMoneyIcon sx={{ fontSize: "0.875rem !important" }} />} icon={<AttachMoneyIcon sx={{ fontSize: "0.875rem !important" }} />}
label={`$${research.cost.toFixed(3)}`} label={`$${research.costEst.total.toFixed(3)}`}
size="small" size="small"
sx={{ sx={{
background: alpha("#f59e0b", 0.1), background: alpha("#f59e0b", 0.1),
@@ -356,4 +356,3 @@ export const ResearchSummary: React.FC<ResearchSummaryProps> = ({
</GlassyCard> </GlassyCard>
); );
}; };

View File

@@ -33,6 +33,16 @@ export type ResearchInsight = {
source_indices: number[]; source_indices: number[];
}; };
export type PodcastCostEst = {
total: number;
breakdown: {
phase: "Analyze" | "Gather" | "Write" | "Produce";
cost: number;
}[];
currency: "USD";
last_updated: string;
};
export type Research = { export type Research = {
summary: string; summary: string;
keyInsights: ResearchInsight[]; keyInsights: ResearchInsight[];
@@ -45,7 +55,7 @@ export type Research = {
searchQueries?: string[]; searchQueries?: string[];
searchType?: string; searchType?: string;
provider?: string; provider?: string;
cost?: number; costEst?: PodcastCostEst;
sourceCount?: number; sourceCount?: number;
expertQuotes?: { quote: string; source_index: number }[]; expertQuotes?: { quote: string; source_index: number }[];
listenerCta?: string[]; listenerCta?: string[];
@@ -222,4 +232,3 @@ export type TaskStatus = {
created_at?: string; created_at?: string;
updated_at?: string; updated_at?: string;
}; };

View File

@@ -95,6 +95,30 @@ const DEFAULT_STATE: PodcastProjectState = {
const STORAGE_KEY = 'podcast_project_state'; const STORAGE_KEY = 'podcast_project_state';
const normalizeResearchCostEst = (research: any): Research | null => {
if (!research) return research;
const fromSnakeCase = research.cost_est;
const fromCamelCase = research.costEst;
const legacyCost = typeof research.cost === "number" ? research.cost : undefined;
const normalizedCostEst = fromCamelCase || (fromSnakeCase ? {
total: Number(fromSnakeCase.total || 0),
breakdown: Array.isArray(fromSnakeCase.breakdown) ? fromSnakeCase.breakdown : [],
currency: fromSnakeCase.currency || "USD",
last_updated: fromSnakeCase.last_updated || new Date().toISOString(),
} : undefined);
return {
...research,
costEst: normalizedCostEst || (legacyCost !== undefined ? {
total: legacyCost,
breakdown: [],
currency: "USD",
last_updated: new Date().toISOString(),
} : undefined),
};
};
export const usePodcastProjectState = () => { export const usePodcastProjectState = () => {
const [state, setState] = useState<PodcastProjectState>(() => { const [state, setState] = useState<PodcastProjectState>(() => {
// Initialize from localStorage if available // Initialize from localStorage if available
@@ -107,6 +131,7 @@ export const usePodcastProjectState = () => {
const restoredState: PodcastProjectState = { const restoredState: PodcastProjectState = {
...DEFAULT_STATE, ...DEFAULT_STATE,
...parsed, ...parsed,
research: normalizeResearchCostEst(parsed.research),
selectedQueries: parsed.selectedQueries ? new Set(parsed.selectedQueries) : new Set(), selectedQueries: parsed.selectedQueries ? new Set(parsed.selectedQueries) : new Set(),
renderJobs: parsed.renderJobs || [], renderJobs: parsed.renderJobs || [],
}; };
@@ -401,7 +426,7 @@ export const usePodcastProjectState = () => {
analysis: dbProject.analysis, analysis: dbProject.analysis,
queries: dbProject.queries || [], queries: dbProject.queries || [],
selectedQueries: new Set(dbProject.selected_queries || []), selectedQueries: new Set(dbProject.selected_queries || []),
research: dbProject.research, research: normalizeResearchCostEst(dbProject.research),
rawResearch: dbProject.raw_research, rawResearch: dbProject.raw_research,
estimate: dbProject.estimate, estimate: dbProject.estimate,
scriptData: dbProject.script_data, scriptData: dbProject.script_data,
@@ -454,4 +479,3 @@ export const usePodcastProjectState = () => {
loadProjectFromDb, loadProjectFromDb,
}; };
}; };

View File

@@ -173,7 +173,12 @@ const mapSourcesToFacts = (sources: ExaSource[]): Fact[] => {
type ExaResearchResult = { type ExaResearchResult = {
sources: ExaSource[]; sources: ExaSource[];
search_queries?: string[]; search_queries?: string[];
cost?: { total?: number }; cost_est?: {
total?: number;
breakdown?: { phase: "Analyze" | "Gather" | "Write" | "Produce"; cost: number }[];
currency?: "USD";
last_updated?: string;
};
search_type?: string; search_type?: string;
provider?: string; provider?: string;
content?: string; content?: string;
@@ -212,7 +217,14 @@ const mapExaResearchResponse = (response: any): Research => {
searchQueries: response.search_queries, searchQueries: response.search_queries,
searchType: response.search_type, searchType: response.search_type,
provider: response.provider || "exa", provider: response.provider || "exa",
cost: response.cost?.total, costEst: response.cost_est
? {
total: Number(response.cost_est.total || 0),
breakdown: Array.isArray(response.cost_est.breakdown) ? response.cost_est.breakdown : [],
currency: response.cost_est.currency || "USD",
last_updated: response.cost_est.last_updated || new Date().toISOString(),
}
: undefined,
sourceCount: response.sources?.length || 0, sourceCount: response.sources?.length || 0,
}; };
}; };
@@ -953,4 +965,3 @@ export const podcastApi = {
}; };
export type PodcastApi = typeof podcastApi; export type PodcastApi = typeof podcastApi;