"Extract_text_generation_utilities_into_modular_structure"

This commit is contained in:
ajaysi
2026-03-12 16:59:45 +05:30
parent 54396b8268
commit 1829f47893
6 changed files with 716 additions and 0 deletions

View File

@@ -0,0 +1,53 @@
"""Main Text Generation Service for ALwrity Backend.
This service provides the main LLM text generation functionality,
migrated from the legacy lib/gpt_providers/text_generation/main_text_generation.py
This is a clean version that imports from modular components to avoid merge conflicts.
"""
import os
import json
from typing import Optional, Dict, Any, List
from datetime import datetime
from loguru import logger
from fastapi import HTTPException
# Import all functionality from our modular textgen_utils package
from .textgen_utils import (
llm_text_gen,
check_gpt_provider,
get_api_key,
_normalize_provider,
_parse_csv_env,
_resolve_provider_sequence,
_map_logical_model_to_provider_model,
_resolve_model_sequence,
)
# Re-export all the main functions for backward compatibility
__all__ = [
"llm_text_gen",
"check_gpt_provider",
"get_api_key",
"_normalize_provider",
"_parse_csv_env",
"_resolve_provider_sequence",
"_map_logical_model_to_provider_model",
"_resolve_model_sequence",
]
# Maintain any additional constants or configurations that might be needed
PREMIUM_HF_MINIMAL_FALLBACK_MODELS = [
"openai/gpt-oss-120b:groq",
]
# Legacy compatibility - any imports that other modules might expect
from .gemini_provider import gemini_text_response, gemini_structured_json_response
from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response
from .tenant_provider_config import tenant_provider_config_resolver
from .routing_policy import (
PREMIUM_DEFAULT_MODEL,
SIF_LOW_COST_MODEL_DEFAULTS,
resolve_text_provider_alias,
)

View File

@@ -0,0 +1,22 @@
"""
Text Generation Utilities Package
This package contains modular components extracted from main_text_generation.py
to resolve merge conflicts and improve maintainability.
"""
from .llm_text_generator import llm_text_gen
from .provider_utils import check_gpt_provider, _normalize_provider, _parse_csv_env, _resolve_provider_sequence
from .model_utils import _map_logical_model_to_provider_model, _resolve_model_sequence
from .api_key_utils import get_api_key
__all__ = [
"llm_text_gen",
"check_gpt_provider",
"get_api_key",
"_normalize_provider",
"_parse_csv_env",
"_resolve_provider_sequence",
"_map_logical_model_to_provider_model",
"_resolve_model_sequence",
]

View File

@@ -0,0 +1,26 @@
"""
API Key Utilities Module
This module contains API key-related utility functions extracted from main_text_generation.py
to resolve merge conflicts and improve maintainability.
"""
from typing import Optional
from loguru import logger
from ..tenant_provider_config import tenant_provider_config_resolver
def get_api_key(gpt_provider: str, user_id: Optional[str] = None) -> Optional[str]:
"""Get API key for the specified provider."""
try:
provider_mapping = {
"google": "gemini",
"huggingface": "huggingface"
}
mapped_provider = provider_mapping.get(gpt_provider, gpt_provider)
key, _source = tenant_provider_config_resolver.resolve_provider_key(mapped_provider, user_id=user_id)
return key
except Exception as e:
logger.error(f"[get_api_key] Error getting API key for {gpt_provider}: {str(e)}")
return None

View File

@@ -0,0 +1,464 @@
"""LLM Text Generator Module
This module contains the main text generation logic extracted from main_text_generation.py
to resolve merge conflicts and improve maintainability.
"""
import os
import json
from typing import Optional, Dict, Any, List
from datetime import datetime
from loguru import logger
from fastapi import HTTPException
from ..gemini_provider import gemini_text_response, gemini_structured_json_response
from ..huggingface_provider import huggingface_text_response, huggingface_structured_json_response
from ..tenant_provider_config import tenant_provider_config_resolver
from ..routing_policy import (
PREMIUM_DEFAULT_MODEL,
SIF_LOW_COST_MODEL_DEFAULTS,
resolve_text_provider_alias,
)
PREMIUM_HF_MINIMAL_FALLBACK_MODELS = [
"openai/gpt-oss-120b:groq",
]
def llm_text_gen(
prompt: str,
system_prompt: Optional[str] = None,
json_struct: Optional[Dict[str, Any]] = None,
user_id: str = None,
preferred_hf_models: Optional[List[str]] = None,
preferred_provider: Optional[str] = None,
flow_type: str = "default",
) -> str:
"""
Generate text using Language Model (LLM) based on the provided prompt.
Args:
prompt (str): The prompt to generate text from.
system_prompt (str, optional): Custom system prompt to use instead of the default one.
json_struct (dict, optional): JSON schema structure for structured responses.
user_id (str): Clerk user ID for subscription checking (required).
preferred_hf_models (list, optional): Preferred HuggingFace models to use.
preferred_provider (str, optional): Preferred provider to use.
flow_type (str): Type of flow for logging and routing.
Returns:
str: Generated text based on the prompt.
Raises:
RuntimeError: If subscription limits are exceeded or user_id is missing.
HTTPException: For subscription limit errors (429 status).
"""
try:
resolved_flow_type = flow_type or ("sif_agent" if preferred_hf_models else "premium_tool")
flow_tag = f"flow_type={resolved_flow_type}"
subscription_preflight_completed = False
logger.info(f"[llm_text_gen][{flow_tag}] Starting text generation")
logger.debug(f"[llm_text_gen] Prompt length: {len(prompt)} characters")
# Set default values for LLM parameters
gpt_provider = "google"
model = "gemini-2.0-flash-001"
hf_low_cost_default_model = SIF_LOW_COST_MODEL_DEFAULTS[0]
temperature = 0.7
max_tokens = 4000
top_p = 0.9
n = 1
# Resolve provider configuration using tenant-aware resolver
try:
provider_cfg = tenant_provider_config_resolver.resolve(
modality="text",
user_id=user_id,
explicit_provider=preferred_provider
)
if provider_cfg.selected_providers:
gpt_provider = provider_cfg.selected_providers[0]
if provider_cfg.model_policy.get("default_model"):
model = provider_cfg.model_policy["default_model"]
logger.info(f"[llm_text_gen] Resolved provider: {gpt_provider}, model: {model}")
except Exception as config_error:
logger.warning(f"[llm_text_gen] Provider config resolution failed: {config_error}")
# Continue with defaults
# Handle preferred HF models for SIF flows
hf_fallback_models: Optional[List[str]] = None
hf_allow_model_variant_fallback = True
if gpt_provider == "huggingface":
if preferred_hf_models is not None:
if preferred_hf_models:
model = preferred_hf_models[0]
hf_fallback_models = preferred_hf_models[1:]
logger.info(f"[llm_text_gen] Using caller-provided HF policy starting model: {model}")
else:
# Explicit empty policy: only requested model (plus optional variant handling).
hf_fallback_models = []
logger.info("[llm_text_gen] Using caller-provided HF policy with no fallback models")
else:
# Premium/default path: minimal safe fallback chain to avoid excessive model hopping.
hf_fallback_models = PREMIUM_HF_MINIMAL_FALLBACK_MODELS
# Default blog characteristics
blog_tone = "Professional"
blog_demographic = "Professional"
blog_type = "Informational"
blog_language = "English"
blog_output_format = "markdown"
blog_length = 2000
# Check available providers
available_providers = []
for provider in ("google", "huggingface"):
if get_api_key(provider, user_id=user_id):
available_providers.append(provider)
if gpt_provider not in available_providers:
logger.warning(f"[llm_text_gen] Provider {gpt_provider} unavailable for user {user_id}, falling back.")
if available_providers:
gpt_provider = available_providers[0]
else:
logger.error("[llm_text_gen] No API keys found for supported providers.")
raise RuntimeError("No LLM API keys configured for tenant or environment defaults.")
# Ensure downstream provider clients receive resolved key
resolved_key = get_api_key(gpt_provider, user_id=user_id)
if gpt_provider == "google" and resolved_key:
os.environ["GEMINI_API_KEY"] = resolved_key
os.environ.setdefault("GOOGLE_API_KEY", resolved_key)
elif gpt_provider == "huggingface" and resolved_key:
os.environ["HF_TOKEN"] = resolved_key
logger.debug(f"[llm_text_gen] Using provider: {gpt_provider}, model: {model}")
# Map provider name to APIProvider enum (define at function scope for usage tracking)
from models.subscription_models import APIProvider
provider_enum = None
actual_provider_name = None
if gpt_provider == "google":
provider_enum = APIProvider.GEMINI
actual_provider_name = "gemini"
elif gpt_provider == "huggingface":
provider_enum = APIProvider.MISTRAL
actual_provider_name = "huggingface"
if not provider_enum:
raise RuntimeError(f"Unknown provider {gpt_provider} for subscription checking")
# SUBSCRIPTION CHECK - Required and strict enforcement
if not user_id:
raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.")
try:
from services.database import get_session_for_user
from services.subscription import UsageTrackingService, PricingService
from models.subscription_models import UsageSummary
logger.info(
f"[llm_text_gen][{flow_tag}] Starting subscription preflight for user={user_id}, "
f"provider={actual_provider_name}, model={model}"
)
db = get_session_for_user(user_id)
if not db:
logger.error(f"[llm_text_gen] Could not get database session for user {user_id}")
raise RuntimeError("Database connection failed")
try:
usage_service = UsageTrackingService(db)
pricing_service = PricingService(db)
# Estimate tokens from prompt (input tokens)
input_tokens = int(len(prompt.split()) * 1.3)
# Worst-case estimate: assume maximum possible output tokens
if max_tokens:
estimated_output_tokens = max_tokens
else:
# If max_tokens not specified, use conservative estimate (input * 1.5)
estimated_output_tokens = int(input_tokens * 1.5)
estimated_total_tokens = input_tokens + estimated_output_tokens
logger.info(
"[llm_text_gen][subscription_preflight] start | user_id={} | provider={} | tokens_requested={}",
user_id,
actual_provider_name or provider_enum.value,
estimated_total_tokens,
)
# Check limits using sync method from pricing service (strict enforcement)
can_proceed, message, usage_info = pricing_service.check_usage_limits(
user_id=user_id,
provider=provider_enum,
tokens_requested=estimated_total_tokens,
actual_provider_name=actual_provider_name
)
subscription_preflight_completed = True
logger.info(
f"[llm_text_gen][{flow_tag}] Subscription preflight complete: can_proceed={can_proceed}, "
f"estimated_tokens={estimated_total_tokens}, provider={actual_provider_name}"
)
if not can_proceed:
logger.warning(f"[llm_text_gen] Subscription limit exceeded for user {user_id}: {message}")
# Raise HTTPException(429) with usage info so frontend can display subscription modal
error_detail = {
'error': message,
'message': message,
'provider': actual_provider_name or provider_enum.value,
'usage_info': usage_info if usage_info else {}
}
raise HTTPException(status_code=429, detail=error_detail)
logger.info(
"[llm_text_gen][subscription_preflight] pass | user_id={} | provider={} | tokens_requested={}",
user_id,
actual_provider_name or provider_enum.value,
estimated_total_tokens,
)
# Get current usage for limit checking only
current_period = pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
usage = db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
finally:
db.close()
except HTTPException:
# Re-raise HTTPExceptions (e.g., 429 subscription limit) - preserve error details
raise
except RuntimeError:
# Re-raise subscription limit errors
raise
except Exception as sub_error:
# STRICT: Fail on subscription check errors
logger.error(f"[llm_text_gen] Subscription check failed for user {user_id}: {sub_error}")
raise RuntimeError(f"Subscription check failed: {str(sub_error)}")
# Construct the system prompt if not provided
if system_prompt is None:
system_instructions = f"""You are a highly skilled content writer with a knack for creating engaging and informative content.
Your expertise spans various writing styles and formats.
Writing Style Guidelines:
- Tone: {blog_tone}
- Target Audience: {blog_demographic}
- Content Type: {blog_type}
- Language: {blog_language}
- Output Format: {blog_output_format}
- Target Length: {blog_length} words
Please provide responses that are:
- Well-structured and easy to read
- Engaging and informative
- Tailored to the specified tone and audience
- Professional yet accessible
- Optimized for the target content type
"""
else:
system_instructions = system_prompt
# Generate response based on provider
response_text = None
actual_provider_used = gpt_provider
try:
if gpt_provider == "google":
if json_struct:
response_text = gemini_structured_json_response(
prompt=prompt,
schema=json_struct,
temperature=temperature,
top_p=top_p,
top_k=n,
max_tokens=max_tokens,
system_prompt=system_instructions
)
else:
response_text = gemini_text_response(
prompt=prompt,
temperature=temperature,
top_p=top_p,
n=n,
max_tokens=max_tokens,
system_prompt=system_instructions
)
elif gpt_provider == "huggingface":
if json_struct:
response_text = huggingface_structured_json_response(
prompt=prompt,
schema=json_struct,
model=model,
fallback_models=hf_fallback_models,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_instructions,
allow_model_variant_fallback=hf_allow_model_variant_fallback,
)
else:
response_text = huggingface_text_response(
prompt=prompt,
model=model,
fallback_models=hf_fallback_models,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
system_prompt=system_instructions
)
else:
logger.error(f"[llm_text_gen] Unknown provider: {gpt_provider}")
raise RuntimeError("Unknown LLM provider. Supported providers: google, huggingface")
# TRACK USAGE after successful API call
if response_text:
logger.info(
f"[llm_text_gen][{flow_tag}] ✅ API call successful, tracking usage for user {user_id}, provider {provider_enum.value}"
)
try:
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
# Estimate tokens
tokens_input = int(len(prompt.split()) * 1.3)
# Calculate duration (mocking it since we didn't track start time explicitly in this function)
duration = 0.5
track_agent_usage_sync(
user_id=user_id,
model_name=model,
prompt=prompt,
response_text=response_text,
duration=duration
)
except Exception as usage_error:
# Non-blocking: log error but don't fail the request
logger.error(f"[llm_text_gen] ❌ Failed to track usage: {usage_error}", exc_info=True)
return response_text
except Exception as provider_error:
logger.error(
f"[llm_text_gen][{flow_tag}] Provider {gpt_provider} failed: {str(provider_error)} | "
f"subscription_preflight_completed={subscription_preflight_completed} | model={model}"
)
# CIRCUIT BREAKER: Only try ONE fallback to prevent expensive API calls
fallback_providers = ["google", "huggingface"]
fallback_providers = [p for p in fallback_providers if p in available_providers and p != gpt_provider]
if fallback_providers:
fallback_provider = fallback_providers[0] # Only try the first available
try:
logger.info(f"[llm_text_gen][{flow_tag}] Trying SINGLE fallback provider: {fallback_provider}")
actual_provider_used = fallback_provider
# Update provider enum for fallback
if fallback_provider == "google":
provider_enum = APIProvider.GEMINI
actual_provider_name = "gemini"
fallback_model = "gemini-2.0-flash-lite"
elif fallback_provider == "huggingface":
provider_enum = APIProvider.MISTRAL
actual_provider_name = "huggingface"
fallback_model = preferred_hf_models[0] if preferred_hf_models else PREMIUM_DEFAULT_MODEL
if fallback_provider == "google":
if json_struct:
response_text = gemini_structured_json_response(
prompt=prompt,
schema=json_struct,
temperature=temperature,
top_p=top_p,
top_k=n,
max_tokens=max_tokens,
system_prompt=system_instructions
)
else:
response_text = gemini_text_response(
prompt=prompt,
temperature=temperature,
top_p=top_p,
n=n,
max_tokens=max_tokens,
system_prompt=system_instructions
)
elif fallback_provider == "huggingface":
if json_struct:
response_text = huggingface_structured_json_response(
prompt=prompt,
schema=json_struct,
model=fallback_model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_instructions,
fallback_models=PREMIUM_HF_MINIMAL_FALLBACK_MODELS,
allow_model_variant_fallback=True,
)
else:
response_text = huggingface_text_response(
prompt=prompt,
model=fallback_model,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
system_prompt=system_instructions,
fallback_models=PREMIUM_HF_MINIMAL_FALLBACK_MODELS,
allow_model_variant_fallback=True,
)
# TRACK USAGE after successful fallback call
if response_text:
logger.info(
f"[llm_text_gen][{flow_tag}] ✅ Fallback API call successful, tracking usage for user {user_id}, provider {provider_enum.value}"
)
try:
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
# Estimate tokens
tokens_input = int(len(prompt.split()) * 1.3)
track_agent_usage_sync(
user_id=user_id,
model_name=fallback_model,
prompt=prompt,
response_text=response_text,
duration=0.5 # Approximate duration
)
except Exception as usage_error:
logger.error(f"[llm_text_gen] ❌ Failed to track fallback usage: {usage_error}", exc_info=True)
return response_text
except Exception as fallback_error:
logger.error(f"[llm_text_gen][{flow_tag}] Fallback provider {fallback_provider} also failed: {str(fallback_error)}")
# CIRCUIT BREAKER: Stop immediately to prevent expensive API calls
logger.error(f"[llm_text_gen][{flow_tag}] CIRCUIT BREAKER: Stopping to prevent expensive API calls.")
raise RuntimeError("All LLM providers failed to generate a response.")
except Exception as e:
logger.error(f"[llm_text_gen][{flow_tag}] Error during text generation: {str(e)}")
raise
def get_api_key(gpt_provider: str, user_id: Optional[str] = None) -> Optional[str]:
"""Get API key for the specified provider."""
try:
provider_mapping = {
"google": "gemini",
"huggingface": "huggingface"
}
mapped_provider = provider_mapping.get(gpt_provider, gpt_provider)
key, _source = tenant_provider_config_resolver.resolve_provider_key(mapped_provider, user_id=user_id)
return key
except Exception as e:
logger.error(f"[get_api_key] Error getting API key for {gpt_provider}: {str(e)}")
return None

View File

@@ -0,0 +1,77 @@
"""
Model Utilities Module
This module contains model-related utility functions extracted from main_text_generation.py
to resolve merge conflicts and improve maintainability.
"""
import os
from typing import Optional, List
from ..routing_policy import PREMIUM_DEFAULT_MODEL
def _map_logical_model_to_provider_model(provider: str, model_name: str) -> str:
"""Map logical model aliases/full names to provider-specific model IDs."""
raw = (model_name or "").strip()
if not raw:
return raw
# Full provider path supplied explicitly; use as-is.
if "/" in raw:
return raw
key = raw.lower()
hf_map = {
"gpt-oss": "openai/gpt-oss-120b:cerebras",
"gpt-oss-120b": "openai/gpt-oss-120b:cerebras",
"gpt-oss-20b": "openai/gpt-oss-20b:cerebras",
"mistral": "mistralai/Mistral-7B-Instruct-v0.3:cerebras",
"mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3:cerebras",
"llama": "meta-llama/Llama-3.1-8B-Instruct:groq",
"llama-8b": "meta-llama/Llama-3.1-8B-Instruct:groq",
"llama-70b": "meta-llama/Llama-3.1-70B-Instruct:groq",
}
wavespeed_map = {
"gpt-oss": "openai/gpt-oss-120b",
"gpt-oss-120b": "openai/gpt-oss-120b",
"gpt-oss-20b": "openai/gpt-oss-20b",
"mistral": "mistralai/Mistral-7B-Instruct-v0.3",
"mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3",
"llama": "meta-llama/Llama-3.1-8B-Instruct",
"llama-8b": "meta-llama/Llama-3.1-8B-Instruct",
"llama-70b": "meta-llama/Llama-3.1-70B-Instruct",
}
if provider in {"huggingface", "hf", "hf_response_api"}:
return hf_map.get(key, raw)
if provider == "wavespeed":
return wavespeed_map.get(key, raw)
return raw
def _resolve_model_sequence(provider: str, preferred_hf_models: Optional[List[str]] = None) -> List[str]:
"""Resolve model sequence for a given provider."""
models_env = _parse_csv_env(os.getenv("TEXTGEN_AI_MODELS", ""))
if provider == "google":
return ["gemini-2.0-flash-001"]
if preferred_hf_models:
return [_map_logical_model_to_provider_model(provider, m) for m in preferred_hf_models if m]
if not models_env:
return [PREMIUM_DEFAULT_MODEL]
resolved = [_map_logical_model_to_provider_model(provider, m) for m in models_env if m.strip()]
return resolved or [PREMIUM_DEFAULT_MODEL]
def _parse_csv_env(value: Optional[str]) -> List[str]:
"""Parse CSV environment variable into list of values."""
if not value:
return []
return [v.strip() for v in str(value).split(",") if v.strip()]

View File

@@ -0,0 +1,74 @@
"""
Provider Utilities Module
This module contains provider-related utility functions extracted from main_text_generation.py
to resolve merge conflicts and improve maintainability.
"""
import os
from typing import Optional, List
from ..routing_policy import resolve_text_provider_alias
def _normalize_provider(provider: Optional[str]) -> Optional[str]:
"""Normalize provider name to canonical form."""
if not provider:
return None
provider_aliases = {
"gemini": "google",
"google": "google",
"hf": "huggingface",
"hf_response_api": "huggingface",
"huggingface": "huggingface",
"wavespeed": "huggingface",
}
value = str(provider).strip().lower()
return provider_aliases.get(value, value)
def _parse_csv_env(value: Optional[str]) -> List[str]:
"""Parse CSV environment variable into list of values."""
if not value:
return []
return [v.strip() for v in str(value).split(",") if v.strip()]
def _resolve_provider_sequence(
preferred_provider: Optional[str],
env_provider_raw: str,
available_providers: List[str],
) -> List[str]:
"""Resolve provider sequence based on preferences and availability."""
configured = _parse_csv_env(preferred_provider) if preferred_provider else _parse_csv_env(env_provider_raw)
normalized = [_normalize_provider(p) for p in configured if _normalize_provider(p)]
if not normalized:
if "google" in available_providers:
return ["google"]
if "huggingface" in available_providers:
return ["huggingface"]
return []
# preserve order and keep only available providers
sequence = []
for provider in normalized:
if provider in available_providers:
sequence.append(provider)
# strict mode for single configured provider: no silent remap
if len(normalized) == 1:
return sequence
# multi-provider mode: append any other available providers as tail only if none configured are available
if not sequence:
return [p for p in ["huggingface", "google"] if p in available_providers]
return sequence
def check_gpt_provider(gpt_provider: str) -> bool:
"""Check if the specified GPT provider is supported."""
supported_providers = ["google", "huggingface"]
resolved_provider = resolve_text_provider_alias(gpt_provider) or gpt_provider
return resolved_provider in supported_providers