Save local changes (GSC/Bing integrations) before merging PR #354

This commit is contained in:
ajaysi
2026-02-13 13:11:27 +05:30
parent 43e66835ac
commit 08a1f4a1d8
144 changed files with 8310 additions and 2748 deletions

View File

@@ -297,7 +297,7 @@ def _dict_to_types_schema(schema: Dict[str, Any]) -> types.Schema:
return _convert(schema)
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9, top_k=40, max_tokens=8192, system_prompt=None):
def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9, top_k=40, max_tokens=8192, system_prompt=None, user_id: str = None):
"""
Generate structured JSON response using Google's Gemini Pro model.
@@ -312,6 +312,7 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
top_k (int): Top-k sampling parameter
max_tokens (int): Maximum tokens in response. Use 8192 for complex outputs
system_prompt (str, optional): System instruction for the model
user_id (str, optional): User ID for usage tracking.
Returns:
dict: Parsed JSON response matching the provided schema
@@ -468,6 +469,25 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
logger.info(f"Response has parsed attribute: {response.parsed is not None}")
if response.parsed is not None:
logger.info("Using response.parsed for structured output")
# Track usage if user_id is provided
if user_id:
try:
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
import json
response_str = json.dumps(response.parsed)
track_agent_usage_sync(
user_id=user_id,
model_name="gemini-2.5-flash",
prompt=prompt,
response_text=response_str,
duration=0.5
)
except Exception as e:
logger.error(f"Failed to track usage: {e}")
return response.parsed
else:
logger.warning("Response.parsed is None, falling back to text parsing")
@@ -500,6 +520,22 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
parsed_text = json.loads(cleaned_text)
logger.info("Successfully parsed text as JSON")
# Track usage if user_id is provided
if user_id:
try:
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
track_agent_usage_sync(
user_id=user_id,
model_name="gemini-2.5-flash",
prompt=prompt,
response_text=cleaned_text,
duration=0.5
)
except Exception as e:
logger.error(f"Failed to track usage: {e}")
return parsed_text
except json.JSONDecodeError as e:
logger.error(f"Failed to parse text as JSON: {e}")
@@ -521,6 +557,26 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
fixed_json = re.sub(r',\s*]', ']', fixed_json)
parsed_text = json.loads(fixed_json)
# Track usage if user_id is provided
if user_id:
try:
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
import json
response_str = json.dumps(parsed_text) if parsed_text else ""
track_agent_usage_sync(
user_id=user_id,
model_name="gemini-2.5-flash",
prompt=prompt,
response_text=response_str,
duration=0.5 # Approximation
)
logger.info(f"✅ Tracked structured JSON usage for user {user_id}")
except Exception as e:
logger.error(f"Failed to track usage: {e}")
logger.info("Successfully parsed cleaned JSON")
return parsed_text
except Exception as fix_error:
@@ -537,6 +593,22 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
import json
parsed_text = json.loads(part.text)
logger.info("Successfully parsed candidate text as JSON")
# Track usage if user_id is provided
if user_id:
try:
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
track_agent_usage_sync(
user_id=user_id,
model_name="gemini-2.5-flash",
prompt=prompt,
response_text=part.text,
duration=0.5
)
except Exception as e:
logger.error(f"Failed to track usage: {e}")
return parsed_text
except json.JSONDecodeError as e:
logger.error(f"Failed to parse candidate text as JSON: {e}")

View File

@@ -4,6 +4,7 @@ import io
import os
from typing import Optional
from PIL import Image
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from .base import ImageGenerationProvider, ImageGenerationOptions, ImageGenerationResult
from services.wavespeed.client import WaveSpeedClient
@@ -14,7 +15,10 @@ logger = get_service_logger("wavespeed.image_provider")
class WaveSpeedImageProvider(ImageGenerationProvider):
"""WaveSpeed AI image generation provider supporting Ideogram V3 and Qwen."""
"""WaveSpeed AI image generation provider supporting Ideogram V3 and Qwen.
Implements robust error handling and retries for production stability.
"""
SUPPORTED_MODELS = {
"ideogram-v3-turbo": {
@@ -54,6 +58,28 @@ class WaveSpeedImageProvider(ImageGenerationProvider):
logger.info("[WaveSpeed Image Provider] Initialized with available models: %s",
list(self.SUPPORTED_MODELS.keys()))
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=2, max=10),
retry=retry_if_exception_type((RuntimeError, IOError)),
reraise=True
)
def _call_api_with_retry(self, method, **kwargs):
"""Execute API call with retry logic.
Args:
method: Callable API method
**kwargs: Arguments for the method
Returns:
API response
"""
try:
return method(**kwargs)
except Exception as e:
logger.warning(f"WaveSpeed API call failed (retrying): {str(e)}")
raise
def _validate_options(self, options: ImageGenerationOptions) -> None:
"""Validate generation options.
@@ -117,7 +143,7 @@ class WaveSpeedImageProvider(ImageGenerationProvider):
# Call WaveSpeed API (using generic image generation method)
# This will need to be adjusted based on actual WaveSpeed client implementation
result = self.client.generate_image(**params)
result = self._call_api_with_retry(self.client.generate_image, **params)
# Extract image bytes from result
# Adjust based on actual WaveSpeed API response format
@@ -167,7 +193,7 @@ class WaveSpeedImageProvider(ImageGenerationProvider):
params["seed"] = options.seed
# Call WaveSpeed API
result = self.client.generate_image(**params)
result = self._call_api_with_retry(self.client.generate_image, **params)
# Extract image bytes from result
if isinstance(result, bytes):
@@ -216,7 +242,7 @@ class WaveSpeedImageProvider(ImageGenerationProvider):
params["seed"] = options.seed
# Call WaveSpeed API
result = self.client.generate_image(**params)
result = self._call_api_with_retry(self.client.generate_image, **params)
# Extract image bytes from result
if isinstance(result, bytes):

View File

@@ -107,11 +107,13 @@ def generate_audio(
estimated_cost = (character_count / 1000.0) * cost_per_1000_chars
try:
from services.database import get_db
from services.database import get_session_for_user
from services.subscription import PricingService
from models.subscription_models import UsageSummary, APIProvider
db = next(get_db())
db = get_session_for_user(user_id)
if not db:
raise RuntimeError("Failed to get database session")
try:
pricing_service = PricingService(db)
@@ -194,7 +196,11 @@ def generate_audio(
if audio_bytes:
logger.info(f"[audio_gen] ✅ API call successful, tracking usage for user {user_id}")
try:
db_track = next(get_db())
db_track = get_session_for_user(user_id)
if not db_track:
logger.error(f"[audio_gen] ❌ Failed to get database session for tracking")
raise RuntimeError("Failed to get database session")
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
@@ -383,12 +389,14 @@ def clone_voice(
voice_clone_cost = 0.5
from services.database import get_db
from services.database import get_session_for_user
from services.subscription import PricingService
from models.subscription_models import APIProvider
try:
db = next(get_db())
db = get_session_for_user(user_id)
if not db:
raise RuntimeError("Failed to get database session")
try:
pricing_service = PricingService(db)
can_proceed, message, usage_info = pricing_service.check_usage_limits(
@@ -432,7 +440,11 @@ def clone_voice(
if preview_audio_bytes:
try:
db_track = next(get_db())
db_track = get_session_for_user(user_id)
if not db_track:
logger.error(f"[clone_voice] ❌ Failed to get database session for tracking")
raise RuntimeError("Failed to get database session")
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
@@ -570,12 +582,14 @@ def qwen3_voice_clone(
char_count = len(text)
estimated_cost = max(0.005, 0.005 * (char_count / 100.0))
from services.database import get_db
from services.database import get_session_for_user
from services.subscription import PricingService
from models.subscription_models import APIProvider
try:
db = next(get_db())
db = get_session_for_user(user_id)
if not db:
raise RuntimeError("Failed to get database session")
try:
pricing_service = PricingService(db)
can_proceed, message, usage_info = pricing_service.check_usage_limits(
@@ -615,7 +629,11 @@ def qwen3_voice_clone(
if preview_audio_bytes:
try:
db_track = next(get_db())
db_track = get_session_for_user(user_id)
if not db_track:
logger.error(f"[qwen3_voice_clone] ❌ Failed to get database session for tracking")
raise RuntimeError("Failed to get database session")
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
@@ -691,6 +709,7 @@ def qwen3_voice_clone(
├─ Provider: wavespeed
├─ Model: wavespeed-ai/qwen3-tts/voice-clone
├─ Calls: {current_calls_before}{new_calls}
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
├─ Text chars: {char_count}
└─ Status: ✅ Allowed & Tracked
""", flush=True)
@@ -724,3 +743,373 @@ def qwen3_voice_clone(
},
)
def qwen3_voice_design(
text: str,
voice_description: str,
*,
language: str = "auto",
user_id: Optional[str] = None,
) -> VoiceCloneResult:
try:
if not user_id:
raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.")
if not text or not isinstance(text, str) or len(text.strip()) == 0:
raise ValueError("Text is required and cannot be empty")
text = text.strip()
if not voice_description or not isinstance(voice_description, str) or len(voice_description.strip()) == 0:
raise ValueError("Voice description is required")
voice_description = voice_description.strip()
char_count = len(text)
# Pricing logic similar to TTS/Clone
estimated_cost = max(0.005, 0.005 * (char_count / 100.0))
from services.database import get_session_for_user
from services.subscription import PricingService
from models.subscription_models import APIProvider
try:
db = get_session_for_user(user_id)
if not db:
raise RuntimeError("Failed to get database session")
try:
pricing_service = PricingService(db)
can_proceed, message, usage_info = pricing_service.check_usage_limits(
user_id=user_id,
provider=APIProvider.AUDIO,
tokens_requested=char_count,
actual_provider_name="wavespeed",
)
if not can_proceed:
raise HTTPException(
status_code=429,
detail={
"error": message,
"message": message,
"provider": "wavespeed",
"usage_info": usage_info if usage_info else {},
},
)
finally:
db.close()
except HTTPException:
raise
except Exception as sub_error:
raise RuntimeError(f"Subscription check failed: {str(sub_error)}")
import time
start_time = time.time()
client = WaveSpeedClient()
preview_audio_bytes = client.voice_design(
text=text,
voice_description=voice_description,
language=language
)
response_time = time.time() - start_time
# Track usage
try:
db_track = get_session_for_user(user_id)
if not db_track:
logger.error(f"[qwen3_voice_design] ❌ Failed to get database session for tracking")
raise RuntimeError("Failed to get database session")
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
from sqlalchemy import text as sql_text
from services.subscription.provider_detection import detect_actual_provider
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
summary = UsageSummary(user_id=user_id, billing_period=current_period)
db_track.add(summary)
db_track.flush()
current_calls_before = getattr(summary, "audio_calls", 0) or 0
current_cost_before = getattr(summary, "audio_cost", 0.0) or 0.0
new_calls = current_calls_before + 1
new_cost = current_cost_before + float(estimated_cost)
update_query = sql_text("""
UPDATE usage_summaries
SET audio_calls = :new_calls,
audio_cost = :new_cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_query, {
"new_calls": new_calls,
"new_cost": new_cost,
"user_id": user_id,
"period": current_period
})
summary.total_cost = (summary.total_cost or 0.0) + float(estimated_cost)
summary.total_calls = (summary.total_calls or 0) + 1
summary.updated_at = datetime.utcnow()
actual_provider = detect_actual_provider(
provider_enum=APIProvider.AUDIO,
model_name="wavespeed-ai/qwen3-tts/voice-design",
endpoint="/audio-generation/wavespeed/qwen3-tts/voice-design",
)
usage_log = APIUsageLog(
user_id=user_id,
provider=APIProvider.AUDIO,
endpoint="/audio-generation/wavespeed/qwen3-tts/voice-design",
method="POST",
model_used="wavespeed-ai/qwen3-tts/voice-design",
actual_provider_name=actual_provider,
tokens_input=char_count,
tokens_output=0,
tokens_total=char_count,
cost_input=0.0,
cost_output=0.0,
cost_total=float(estimated_cost),
response_time=response_time,
status_code=200,
request_size=len(text) + len(voice_description),
response_size=len(preview_audio_bytes),
billing_period=current_period,
)
db_track.add(usage_log)
db_track.commit()
print(f"""
[SUBSCRIPTION] Qwen3 Voice Design
├─ User: {user_id}
├─ Provider: wavespeed
├─ Model: wavespeed-ai/qwen3-tts/voice-design
├─ Calls: {current_calls_before}{new_calls}
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
├─ Text chars: {char_count}
└─ Status: ✅ Allowed & Tracked
""", flush=True)
sys.stdout.flush()
except Exception as track_error:
logger.error(f"[qwen3_voice_design] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
logger.error(f"[qwen3_voice_design] ❌ Failed to track usage: {usage_error}", exc_info=True)
return VoiceCloneResult(
preview_audio_bytes=preview_audio_bytes,
provider="wavespeed",
model="wavespeed-ai/qwen3-tts/voice-design",
custom_voice_id="", # No persistent ID for design usually, unless we save it
file_size=len(preview_audio_bytes),
)
except HTTPException:
raise
except RuntimeError:
raise
except Exception as e:
logger.error(f"[qwen3_voice_design] Error designing voice: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
"error": "Qwen3 voice design failed",
"message": str(e),
},
)
def cosyvoice_voice_clone(
audio_bytes: bytes,
text: str,
*,
reference_text: Optional[str] = None,
audio_mime_type: Optional[str] = None,
user_id: Optional[str] = None,
) -> VoiceCloneResult:
try:
if not user_id:
raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.")
if not audio_bytes or not isinstance(audio_bytes, (bytes, bytearray)) or len(audio_bytes) == 0:
raise ValueError("Audio is required and cannot be empty")
if len(audio_bytes) > 15 * 1024 * 1024:
raise ValueError("Audio file too large. Maximum is 15MB.")
if not text or not isinstance(text, str) or len(text.strip()) == 0:
raise ValueError("Text is required and cannot be empty")
text = text.strip()
if len(text) > 4000:
raise ValueError("Text too long. Please keep it under 4000 characters.")
char_count = len(text)
estimated_cost = max(0.005, 0.005 * (char_count / 100.0))
from services.database import get_session_for_user
from services.subscription import PricingService
from models.subscription_models import APIProvider
try:
db = get_session_for_user(user_id)
if not db:
raise RuntimeError("Failed to get database session")
try:
pricing_service = PricingService(db)
can_proceed, message, usage_info = pricing_service.check_usage_limits(
user_id=user_id,
provider=APIProvider.AUDIO,
tokens_requested=char_count,
actual_provider_name="wavespeed",
)
if not can_proceed:
raise HTTPException(
status_code=429,
detail={
"error": message,
"message": message,
"provider": "wavespeed",
"usage_info": usage_info if usage_info else {},
},
)
finally:
db.close()
except HTTPException:
raise
except Exception as sub_error:
raise RuntimeError(f"Subscription check failed: {str(sub_error)}")
import time
start_time = time.time()
client = WaveSpeedClient()
preview_audio_bytes = client.cosyvoice_voice_clone(
audio_bytes=bytes(audio_bytes),
text=text,
audio_mime_type=audio_mime_type or "audio/wav",
reference_text=reference_text,
)
response_time = time.time() - start_time
if preview_audio_bytes:
try:
db_track = get_session_for_user(user_id)
if not db_track:
logger.error(f"[cosyvoice_voice_clone] ❌ Failed to get database session for tracking")
raise RuntimeError("Failed to get database session")
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
from sqlalchemy import text as sql_text
from services.subscription.provider_detection import detect_actual_provider
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
summary = UsageSummary(user_id=user_id, billing_period=current_period)
db_track.add(summary)
db_track.flush()
current_calls_before = getattr(summary, "audio_calls", 0) or 0
current_cost_before = getattr(summary, "audio_cost", 0.0) or 0.0
new_calls = current_calls_before + 1
new_cost = current_cost_before + float(estimated_cost)
update_query = sql_text("""
UPDATE usage_summaries
SET audio_calls = :new_calls,
audio_cost = :new_cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_query, {
"new_calls": new_calls,
"new_cost": new_cost,
"user_id": user_id,
"period": current_period
})
summary.total_cost = (summary.total_cost or 0.0) + float(estimated_cost)
summary.total_calls = (summary.total_calls or 0) + 1
summary.updated_at = datetime.utcnow()
actual_provider = detect_actual_provider(
provider_enum=APIProvider.AUDIO,
model_name="wavespeed-ai/cosyvoice-tts/voice-clone",
endpoint="/audio-generation/wavespeed/cosyvoice-tts/voice-clone",
)
usage_log = APIUsageLog(
user_id=user_id,
provider=APIProvider.AUDIO,
endpoint="/audio-generation/wavespeed/cosyvoice-tts/voice-clone",
method="POST",
model_used="wavespeed-ai/cosyvoice-tts/voice-clone",
actual_provider_name=actual_provider,
tokens_input=char_count,
tokens_output=0,
tokens_total=char_count,
cost_input=0.0,
cost_output=0.0,
cost_total=float(estimated_cost),
response_time=response_time,
status_code=200,
request_size=len(audio_bytes) + len(text.encode("utf-8")),
response_size=len(preview_audio_bytes),
billing_period=current_period,
)
db_track.add(usage_log)
db_track.commit()
print(f"""
[SUBSCRIPTION] CosyVoice Voice Clone
├─ User: {user_id}
├─ Provider: wavespeed
├─ Model: wavespeed-ai/cosyvoice-tts/voice-clone
├─ Calls: {current_calls_before}{new_calls}
├─ Text chars: {char_count}
└─ Status: ✅ Allowed & Tracked
""", flush=True)
sys.stdout.flush()
except Exception as track_error:
logger.error(f"[cosyvoice_voice_clone] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
logger.error(f"[cosyvoice_voice_clone] ❌ Failed to track usage: {usage_error}", exc_info=True)
return VoiceCloneResult(
preview_audio_bytes=preview_audio_bytes,
provider="wavespeed",
model="wavespeed-ai/cosyvoice-tts/voice-clone",
custom_voice_id="",
file_size=len(preview_audio_bytes),
)
except HTTPException:
raise
except RuntimeError:
raise
except Exception as e:
logger.error(f"[cosyvoice_voice_clone] Error cloning voice: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
"error": "CosyVoice voice cloning failed",
"message": str(e),
},
)

View File

@@ -2,6 +2,8 @@ from __future__ import annotations
import os
import io
import base64
import logging
from typing import Optional, Dict, Any
from PIL import Image
@@ -9,6 +11,9 @@ from .image_generation import (
ImageGenerationOptions,
ImageGenerationResult,
)
from .image_generation.base import ImageEditOptions
from .image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
from utils.logger_utils import get_service_logger
try:
@@ -22,21 +27,36 @@ logger = get_service_logger("image_editing.facade")
DEFAULT_IMAGE_EDIT_MODEL = os.getenv(
"HF_IMAGE_EDIT_MODEL",
"Qwen/Qwen-Image-Edit",
"WAVESPEED_IMAGE_EDIT_MODEL",
"qwen-edit-plus",
)
def _select_provider(explicit: Optional[str]) -> str:
"""Select provider for image editing. Defaults to huggingface with fal-ai."""
"""
Select the appropriate image editing provider.
Priority:
1. Explicitly requested provider
2. WaveSpeed (if API key available) - Preferred for quality/speed
3. Hugging Face (fallback)
"""
if explicit:
return explicit
# Default to huggingface for image editing (best support for image-to-image)
return explicit.lower()
# Check for WaveSpeed API key first (Preferred provider)
if os.getenv("WAVESPEED_API_KEY"):
return "wavespeed"
# Default to huggingface if WaveSpeed not available
return "huggingface"
def _get_provider_client(provider_name: str, api_key: Optional[str] = None):
"""Get InferenceClient for the specified provider."""
"""Get the client for the specified provider."""
if provider_name == "wavespeed":
return WaveSpeedEditProvider(api_key=api_key)
if not HF_HUB_AVAILABLE:
raise RuntimeError("huggingface_hub is not installed. Install with: pip install huggingface_hub")
@@ -44,7 +64,7 @@ def _get_provider_client(provider_name: str, api_key: Optional[str] = None):
api_key = api_key or os.getenv("HF_TOKEN")
if not api_key:
raise RuntimeError("HF_TOKEN is required for Hugging Face image editing")
# Use fal-ai provider for fast inference
# Use fal-ai provider for fast inference via HF Inference API
return InferenceClient(provider="fal-ai", api_key=api_key)
raise ValueError(f"Unknown image editing provider: {provider_name}")
@@ -86,6 +106,8 @@ def edit_image(
from fastapi import HTTPException
logger.info(f"[Image Editing] 🔍 Starting pre-flight validation for user_id={user_id}")
# Note: get_db() is a generator, so we need to use next() to get the session
# and ensure we close it in the finally block
db = next(get_db())
try:
pricing_service = PricingService(db)
@@ -99,6 +121,9 @@ def edit_image(
# Re-raise immediately - don't proceed with API call
logger.error(f"[Image Editing] ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
raise
except Exception as e:
logger.error(f"[Image Editing] ❌ Unexpected error during pre-flight validation: {e}")
raise HTTPException(status_code=500, detail=f"Image editing validation failed: {str(e)}")
finally:
db.close()
else:
@@ -119,6 +144,69 @@ def edit_image(
# Get provider client
client = _get_provider_client(provider_name, opts.get("api_key"))
if provider_name == "wavespeed":
# Handle WaveSpeed provider
try:
# Convert inputs to base64 for WaveSpeed
image_b64 = base64.b64encode(input_image_bytes).decode('utf-8')
mask_b64 = None
if mask_bytes:
mask_b64 = base64.b64encode(mask_bytes).decode('utf-8')
# Determine operation type based on prompt/mask
operation = "general_edit" # Default
if not prompt and mask_b64:
operation = "remove_bg" # Heuristic: mask but no prompt implies removal/in-painting
elif prompt and not mask_b64:
operation = "style_transfer" # Heuristic: prompt but no mask implies style transfer
elif opts.get("operation"):
operation = opts.get("operation")
edit_options = ImageEditOptions(
image_base64=image_b64,
prompt=prompt.strip(),
operation=operation,
mask_base64=mask_b64,
model=model,
guidance_scale=opts.get("guidance_scale"),
steps=opts.get("steps"),
seed=opts.get("seed"),
extra=opts
)
logger.info(f"[Image Editing] Calling WaveSpeed edit with model={model}")
result = client.edit(edit_options)
# TRACK USAGE after successful WaveSpeed call
if user_id:
try:
from services.llm_providers.main_image_generation import _track_image_operation_usage
# Estimate cost (WaveSpeed default: $0.02)
estimated_cost = result.metadata.get("estimated_cost", 0.02) if result.metadata else 0.02
_track_image_operation_usage(
user_id=user_id,
provider="wavespeed",
model=result.model or model,
operation_type="image-editing",
result_bytes=result.image_bytes,
cost=estimated_cost,
prompt=prompt,
endpoint="/image-editing",
metadata=result.metadata,
log_prefix="[Image Editing]"
)
except Exception as track_error:
logger.warning(f"[Image Editing] ⚠️ Failed to track usage: {track_error}")
return result
except Exception as e:
logger.error(f"[Image Editing] ❌ WaveSpeed editing failed: {e}", exc_info=True)
raise RuntimeError(f"WaveSpeed editing failed: {str(e)}")
# Hugging Face (Fallback)
# Prepare parameters for image-to-image
params: Dict[str, Any] = {}
if opts.get("guidance_scale") is not None:
@@ -170,6 +258,29 @@ def edit_image(
logger.info(f"[Image Editing] ✅ Successfully edited image: {len(edited_image_bytes)} bytes")
# TRACK USAGE after successful HF call
if user_id:
try:
from services.llm_providers.main_image_generation import _track_image_operation_usage
# Estimate cost (HF/Fal-ai default: $0.05)
estimated_cost = 0.05
_track_image_operation_usage(
user_id=user_id,
provider="huggingface",
model=model,
operation_type="image-editing",
result_bytes=edited_image_bytes,
cost=estimated_cost,
prompt=prompt,
endpoint="/image-editing",
metadata={"provider": "fal-ai"},
log_prefix="[Image Editing]"
)
except Exception as track_error:
logger.warning(f"[Image Editing] ⚠️ Failed to track usage: {track_error}")
return ImageGenerationResult(
image_bytes=edited_image_bytes,
width=edited_image.width,

View File

@@ -5,6 +5,7 @@ import sys
import base64
from datetime import datetime
from typing import Optional, Dict, Any
from fastapi import HTTPException
from fastapi.concurrency import run_in_threadpool
from .image_generation import (
@@ -29,6 +30,11 @@ logger = get_service_logger("image_generation.facade")
def _select_provider(explicit: Optional[str]) -> str:
if explicit:
return explicit
# User requested WaveSpeed as default provider
if os.getenv("WAVESPEED_API_KEY"):
return "wavespeed"
gpt_provider = (os.getenv("GPT_PROVIDER") or "").lower()
if gpt_provider.startswith("gemini"):
return "gemini"
@@ -36,8 +42,7 @@ def _select_provider(explicit: Optional[str]) -> str:
return "huggingface"
if os.getenv("STABILITY_API_KEY"):
return "stability"
if os.getenv("WAVESPEED_API_KEY"):
return "wavespeed"
# Fallback to huggingface to enable a path if configured
return "huggingface"
@@ -739,18 +744,139 @@ async def generate_image_with_provider(
}
except Exception as e:
logger.error(f"Error in generate_image_with_provider: {e}")
# Propagate specific error message if available
error_detail = str(e)
if "402" in error_detail or "Payment Required" in error_detail:
raise HTTPException(status_code=402, detail=f"Payment Required: {error_detail}")
return {
"success": False,
"error": str(e)
"error": error_detail
}
import time
from services.database import get_session_for_user
from models.onboarding import WebsiteAnalysis, OnboardingSession, CompetitorAnalysis
async def enhance_image_prompt(prompt: str, user_id: Optional[str] = None) -> str:
"""
Enhance image prompt using LLM.
Placeholder implementation.
Enhance image prompt using WaveSpeed's specialized prompt optimizer.
Restructures and enriches prompts for visual clarity and cinematic detail.
Uses Step 2 (Website Analysis) and Step 3 (Competitor Analysis) context if available.
"""
return prompt
start_time = time.time()
try:
from services.wavespeed.client import WaveSpeedClient
# 1. Pre-flight Validation
if user_id:
_validate_image_operation(
user_id=user_id,
operation_type="prompt-enhancement",
num_operations=1,
log_prefix="[Prompt Enhancement]"
)
# 2. Fetch Context from Step 2 & 3
context_instruction = ""
if user_id:
try:
db_session = get_session_for_user(user_id)
try:
# Get Onboarding Session
session = db_session.query(OnboardingSession).filter(
OnboardingSession.user_id == user_id
).first()
if session:
# Step 2: Website Analysis
website_analysis = db_session.query(WebsiteAnalysis).filter(
WebsiteAnalysis.session_id == session.id
).first()
if website_analysis:
# Handle potential JSON or dict types
brand_voice = website_analysis.brand_analysis
style = website_analysis.style_guidelines
target_audience = website_analysis.target_audience
context_instruction += "\n\nCONTEXT FROM WEBSITE ANALYSIS:\n"
if target_audience:
context_instruction += f"Target Audience: {target_audience}\n"
if brand_voice and isinstance(brand_voice, dict):
context_instruction += f"Brand Voice: {brand_voice.get('voice_characteristics', '')} - {brand_voice.get('tone', '')}\n"
if style and isinstance(style, dict):
context_instruction += f"Visual Style: {style.get('visual_style', '')} - {style.get('color_palette', '')}\n"
# Step 3: Competitor Analysis (Limit to top 3)
competitors = db_session.query(CompetitorAnalysis).filter(
CompetitorAnalysis.session_id == session.id
).limit(3).all()
if competitors:
context_instruction += "\nCOMPETITOR VISUAL INSIGHTS:\n"
for comp in competitors:
if comp.analysis_data and isinstance(comp.analysis_data, dict):
comp_title = comp.analysis_data.get('title', 'Competitor')
# Try to extract visual/content insights if available
highlights = comp.analysis_data.get('highlights', [])
if highlights:
context_instruction += f"- {comp_title}: {', '.join(highlights[:2])}\n"
finally:
db_session.close()
except Exception as db_ex:
logger.warning(f"Failed to fetch context for prompt enhancement: {db_ex}")
# Combine prompt with context
full_input_text = prompt
if context_instruction:
logger.info(f"Enhancing prompt for user {user_id} with Step 2/3 context")
# We append context as instruction for the optimizer
full_input_text = f"Original Request: {prompt}\n\n{context_instruction}\n\nTask: Generate a hyper-personalized, detailed image generation prompt based on the Original Request and the provided Context. Ensure the visual style aligns with the Brand Voice and Visual Style."
else:
logger.info(f"Enhancing prompt for user {user_id} (no context found)")
# 3. Call WaveSpeed
client = WaveSpeedClient()
# Use 'image' mode for avatar/image generation workflows
# Use 'photographic' style as requested for avatars
optimized_prompt = client.optimize_prompt(
text=full_input_text,
mode="image",
style="photographic",
enable_sync_mode=True,
timeout=30
)
# 4. Track Usage
if user_id:
duration = time.time() - start_time
# Track as 0 cost for now unless we have specific pricing for prompt opt
# But we track it as an operation
_track_image_operation_usage(
user_id=user_id,
provider="wavespeed",
model="wavespeed-prompt-opt",
operation_type="prompt-enhancement",
result_bytes=b"", # No image
cost=0.0,
prompt=prompt,
endpoint="/enhance-prompt",
metadata={"duration": duration, "context_added": bool(context_instruction)},
log_prefix="[Prompt Enhancement]",
response_time=duration
)
return optimized_prompt
except Exception as e:
logger.error(f"Failed to enhance prompt via WaveSpeed: {e}")
# Fallback to original prompt on failure
return prompt
async def generate_image_variation(
@@ -760,13 +886,123 @@ async def generate_image_variation(
**kwargs
) -> Dict[str, Any]:
"""
Generate variation of an existing image.
Placeholder implementation.
Generate variation of an existing image using image-to-image editing.
Wrapper for step4_asset_routes.
"""
return {
"success": False,
"error": "Not implemented yet"
}
try:
# Handle image input (bytes, file, or base64)
image_bytes = None
if isinstance(image, bytes):
image_bytes = image
elif hasattr(image, "read"):
image_bytes = await image.read()
elif isinstance(image, str):
# Assume base64 or path
if os.path.exists(image):
with open(image, "rb") as f:
image_bytes = f.read()
else:
# Try base64 decode
try:
if "base64," in image:
image = image.split("base64,")[1]
image_bytes = base64.b64decode(image)
except:
pass
if not image_bytes:
return {"success": False, "error": "Invalid image input"}
# Convert to base64 for internal function
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
# Use generate_image_edit with "variation" intent
# For variation, we typically use general_edit with specific prompt
result = await run_in_threadpool(
generate_image_edit,
image_base64=image_base64,
prompt=prompt,
operation="general_edit",
model=kwargs.get("model", "qwen-edit-plus"), # Default to capable model
options=kwargs,
user_id=user_id
)
result_base64 = base64.b64encode(result.image_bytes).decode('utf-8')
return {
"success": True,
"image_base64": result_base64,
"metadata": result.metadata
}
except Exception as e:
logger.error(f"Error in generate_image_variation: {e}")
return {
"success": False,
"error": str(e)
}
async def generate_image_enhance(
image: Any,
user_id: Optional[str] = None,
**kwargs
) -> Dict[str, Any]:
"""
Enhance/Upscale an existing image.
Wrapper for step4_asset_routes.
"""
try:
# Handle image input
image_bytes = None
if isinstance(image, bytes):
image_bytes = image
elif hasattr(image, "read"):
image_bytes = await image.read()
elif isinstance(image, str):
if os.path.exists(image):
with open(image, "rb") as f:
image_bytes = f.read()
else:
try:
if "base64," in image:
image = image.split("base64,")[1]
image_bytes = base64.b64decode(image)
except:
pass
if not image_bytes:
return {"success": False, "error": "Invalid image input"}
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
# Use generate_image_edit with "enhance" intent
# Use high-res model like nano-banana-pro-edit-ultra
result = await run_in_threadpool(
generate_image_edit,
image_base64=image_base64,
prompt="enhance details, high resolution, professional quality, 4k, sharp focus",
operation="general_edit",
model="nano-banana-pro-edit-ultra",
options={**kwargs, "resolution": "4k"},
user_id=user_id
)
result_base64 = base64.b64encode(result.image_bytes).decode('utf-8')
return {
"success": True,
"image_base64": result_base64,
"metadata": result.metadata
}
except Exception as e:
logger.error(f"Error in generate_image_enhance: {e}")
return {
"success": False,
"error": str(e)
}

View File

@@ -260,335 +260,23 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
if response_text:
logger.info(f"[llm_text_gen] ✅ API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
try:
db_track = get_session_for_user(user_id)
try:
# Estimate tokens from prompt and response
# Recalculate input tokens from prompt (consistent with pre-flight estimation)
tokens_input = int(len(prompt.split()) * 1.3)
tokens_output = int(len(str(response_text).split()) * 1.3) # Estimate output tokens
tokens_total = tokens_input + tokens_output
logger.debug(f"[llm_text_gen] Token estimates: input={tokens_input}, output={tokens_output}, total={tokens_total}")
# Get or create usage summary
from models.subscription_models import UsageSummary
from services.subscription import PricingService
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
logger.debug(f"[llm_text_gen] Looking for usage summary: user_id={user_id}, period={current_period}")
# Get limits once for safety check (to prevent exceeding limits even if actual usage > estimate)
provider_name = provider_enum.value
limits = pricing.get_user_limits(user_id)
token_limit = 0
if limits and limits.get('limits'):
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
# CRITICAL: Use raw SQL to read current values directly from DB, bypassing SQLAlchemy cache
# This ensures we always get the absolute latest committed values, even across different sessions
from sqlalchemy import text
current_calls_before = 0
current_tokens_before = 0
record_count = 0 # Initialize to ensure it's always defined
# CRITICAL: First check if record exists using COUNT query
try:
check_query = text("SELECT COUNT(*) FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period")
record_count = db_track.execute(check_query, {'user_id': user_id, 'period': current_period}).scalar()
logger.debug(f"[llm_text_gen] 🔍 DEBUG: Record count check - found {record_count} record(s) for user={user_id}, period={current_period}")
except Exception as count_error:
logger.error(f"[llm_text_gen] ❌ COUNT query failed: {count_error}", exc_info=True)
record_count = 0
if record_count and record_count > 0:
# Record exists - read current values with raw SQL
try:
# Validate provider_name to prevent SQL injection (whitelist approach)
valid_providers = ['gemini', 'openai', 'anthropic', 'mistral']
if provider_name not in valid_providers:
raise ValueError(f"Invalid provider_name for SQL query: {provider_name}")
# Read current values directly from database using raw SQL
# CRITICAL: This bypasses SQLAlchemy's session cache and gets absolute latest values
sql_query = text(f"""
SELECT {provider_name}_calls, {provider_name}_tokens
FROM usage_summaries
WHERE user_id = :user_id AND billing_period = :period
LIMIT 1
""")
logger.debug(f"[llm_text_gen] 🔍 Executing raw SQL for EXISTING record: SELECT {provider_name}_calls, {provider_name}_tokens WHERE user_id={user_id}, period={current_period}")
result = db_track.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
if result:
raw_calls = result[0] if result[0] is not None else 0
raw_tokens = result[1] if result[1] is not None else 0
current_calls_before = raw_calls
current_tokens_before = raw_tokens
logger.debug(f"[llm_text_gen] ✅ Raw SQL SUCCESS: Found EXISTING record - calls={current_calls_before}, tokens={current_tokens_before} (provider={provider_name}, column={provider_name}_calls/{provider_name}_tokens)")
logger.debug(f"[llm_text_gen] 🔍 Raw SQL returned row: {result}, extracted calls={raw_calls}, tokens={raw_tokens}")
else:
logger.error(f"[llm_text_gen] ❌ CRITICAL BUG: Record EXISTS (count={record_count}) but SELECT query returned None! Query: {sql_query}")
# Fallback: Use ORM to get values
summary_fallback = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if summary_fallback:
db_track.refresh(summary_fallback)
current_calls_before = getattr(summary_fallback, f"{provider_name}_calls", 0) or 0
current_tokens_before = getattr(summary_fallback, f"{provider_name}_tokens", 0) or 0
logger.warning(f"[llm_text_gen] ⚠️ Using ORM fallback: calls={current_calls_before}, tokens={current_tokens_before}")
except Exception as sql_error:
logger.error(f"[llm_text_gen] ❌ Raw SQL query failed: {sql_error}", exc_info=True)
# Fallback: Use ORM to get values
summary_fallback = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if summary_fallback:
db_track.refresh(summary_fallback)
current_calls_before = getattr(summary_fallback, f"{provider_name}_calls", 0) or 0
current_tokens_before = getattr(summary_fallback, f"{provider_name}_tokens", 0) or 0
else:
logger.debug(f"[llm_text_gen] No record exists yet (will create new) - user={user_id}, period={current_period}")
# Get or create usage summary object (needed for ORM update)
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
logger.debug(f"[llm_text_gen] Creating NEW usage summary for user {user_id}, period {current_period}")
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush() # Ensure summary is persisted before updating
# New record - values are already 0, no need to set
logger.debug(f"[llm_text_gen] ✅ New summary created - starting from 0")
else:
# CRITICAL: Update the ORM object with values from raw SQL query
# This ensures the ORM object reflects the actual database state before we increment
logger.debug(f"[llm_text_gen] 🔄 Existing summary found - syncing with raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}")
setattr(summary, f"{provider_name}_calls", current_calls_before)
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
setattr(summary, f"{provider_name}_tokens", current_tokens_before)
logger.debug(f"[llm_text_gen] ✅ Synchronized ORM object: {provider_name}_calls={current_calls_before}, {provider_name}_tokens={current_tokens_before}")
logger.debug(f"[llm_text_gen] Current {provider_name}_calls from DB (raw SQL): {current_calls_before}")
# Update provider-specific counters (sync operation)
new_calls = current_calls_before + 1
# CRITICAL: Use direct SQL UPDATE instead of ORM setattr for dynamic attributes
# SQLAlchemy doesn't detect changes when using setattr() on dynamic attributes
# Using raw SQL UPDATE ensures the change is persisted
from sqlalchemy import text
update_calls_query = text(f"""
UPDATE usage_summaries
SET {provider_name}_calls = :new_calls
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_calls_query, {
'new_calls': new_calls,
'user_id': user_id,
'period': current_period
})
logger.debug(f"[llm_text_gen] Updated {provider_name}_calls via SQL: {current_calls_before} -> {new_calls}")
# Update token usage for LLM providers with safety check
# CRITICAL: Use current_tokens_before from raw SQL query (NOT from ORM object)
# The ORM object may have stale values, but raw SQL always has the latest committed values
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
logger.debug(f"[llm_text_gen] Current {provider_name}_tokens from DB (raw SQL): {current_tokens_before}")
# SAFETY CHECK: Prevent exceeding token limit even if actual usage exceeds estimate
# This prevents abuse where actual response tokens exceed pre-flight validation estimate
projected_new_tokens = current_tokens_before + tokens_total
# If limit is set (> 0) and would be exceeded, cap at limit
if token_limit > 0 and projected_new_tokens > token_limit:
logger.warning(
f"[llm_text_gen] ⚠️ ACTUAL token usage ({tokens_total}) exceeded estimate. "
f"Would exceed limit: {projected_new_tokens} > {token_limit}. "
f"Capping tracked tokens at limit to prevent abuse."
)
# Cap at limit to prevent abuse
new_tokens = token_limit
# Adjust tokens_total for accurate total tracking
tokens_total = token_limit - current_tokens_before
if tokens_total < 0:
tokens_total = 0
else:
new_tokens = projected_new_tokens
# CRITICAL: Use direct SQL UPDATE instead of ORM setattr for dynamic attributes
update_tokens_query = text(f"""
UPDATE usage_summaries
SET {provider_name}_tokens = :new_tokens
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_tokens_query, {
'new_tokens': new_tokens,
'user_id': user_id,
'period': current_period
})
logger.debug(f"[llm_text_gen] Updated {provider_name}_tokens via SQL: {current_tokens_before} -> {new_tokens}")
else:
current_tokens_before = 0
new_tokens = 0
# Determine tracked tokens (after any safety capping)
tracked_tokens_input = min(tokens_input, tokens_total)
tracked_tokens_output = max(tokens_total - tracked_tokens_input, 0)
# Calculate and persist cost for this call
try:
cost_info = pricing.calculate_api_cost(
provider=provider_enum,
model_name=model,
tokens_input=tracked_tokens_input,
tokens_output=tracked_tokens_output,
request_count=1
)
cost_total = cost_info.get('cost_total', 0.0) or 0.0
except Exception as cost_error:
cost_total = 0.0
logger.error(f"[llm_text_gen] ❌ Failed to calculate API cost: {cost_error}", exc_info=True)
if cost_total > 0:
logger.debug(f"[llm_text_gen] 💰 Calculated cost for {provider_name}: ${cost_total:.6f}")
update_costs_query = text(f"""
UPDATE usage_summaries
SET {provider_name}_cost = COALESCE({provider_name}_cost, 0) + :cost,
total_cost = COALESCE(total_cost, 0) + :cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_costs_query, {
'cost': cost_total,
'user_id': user_id,
'period': current_period
})
# Keep ORM object in sync for logging/debugging
current_provider_cost = getattr(summary, f"{provider_name}_cost", 0.0) or 0.0
setattr(summary, f"{provider_name}_cost", current_provider_cost + cost_total)
summary.total_cost = (summary.total_cost or 0.0) + cost_total
else:
logger.debug(f"[llm_text_gen] 💰 Cost calculation returned $0 for {provider_name} (tokens_input={tracked_tokens_input}, tokens_output={tracked_tokens_output})")
# Update totals using SQL UPDATE
old_total_calls = summary.total_calls or 0
old_total_tokens = summary.total_tokens or 0
new_total_calls = old_total_calls + 1
new_total_tokens = old_total_tokens + tokens_total
update_totals_query = text("""
UPDATE usage_summaries
SET total_calls = :total_calls, total_tokens = :total_tokens
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_totals_query, {
'total_calls': new_total_calls,
'total_tokens': new_total_tokens,
'user_id': user_id,
'period': current_period
})
logger.debug(f"[llm_text_gen] Updated totals via SQL: calls {old_total_calls} -> {new_total_calls}, tokens {old_total_tokens} -> {new_total_tokens}")
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) if limits else 0
# Get image stats for unified log
current_images_before = getattr(summary, "stability_calls", 0) or 0
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
# Get image editing stats for unified log
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
# Get video stats for unified log
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
# Get audio stats for unified log
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else ''
# CRITICAL DEBUG: Print diagnostic info BEFORE commit (always visible, flushed immediately)
import sys
debug_msg = f"[DEBUG] BEFORE COMMIT - Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}, Period: {current_period}, New calls will be: {new_calls}, New tokens will be: {new_tokens}"
print(debug_msg, flush=True)
sys.stdout.flush()
logger.debug(f"[llm_text_gen] {debug_msg}")
# CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions
db_track.flush() # Flush to ensure changes are in DB (not just in transaction)
db_track.commit() # Commit transaction to make changes visible to other sessions
logger.debug(f"[llm_text_gen] ✅ Successfully tracked usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens (COMMITTED to DB)")
logger.debug(f"[llm_text_gen] Database state after commit: {provider_name}_calls={new_calls}, {provider_name}_tokens={new_tokens} (should be visible to next session)")
# CRITICAL: Verify commit worked by reading back from DB immediately after commit
try:
verify_query = text(f"SELECT {provider_name}_calls, {provider_name}_tokens FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1")
verify_result = db_track.execute(verify_query, {'user_id': user_id, 'period': current_period}).first()
if verify_result:
verified_calls = verify_result[0] if verify_result[0] is not None else 0
verified_tokens = verify_result[1] if verify_result[1] is not None else 0
logger.debug(f"[llm_text_gen] ✅ VERIFICATION AFTER COMMIT: Read back calls={verified_calls}, tokens={verified_tokens} (expected: calls={new_calls}, tokens={new_tokens})")
if verified_calls != new_calls or verified_tokens != new_tokens:
logger.error(f"[llm_text_gen] ❌ CRITICAL: COMMIT VERIFICATION FAILED! Expected calls={new_calls}, tokens={new_tokens}, but DB has calls={verified_calls}, tokens={verified_tokens}")
# Force another commit attempt
db_track.commit()
verify_result2 = db_track.execute(verify_query, {'user_id': user_id, 'period': current_period}).first()
if verify_result2:
verified_calls2 = verify_result2[0] if verify_result2[0] is not None else 0
verified_tokens2 = verify_result2[1] if verify_result2[1] is not None else 0
logger.debug(f"[llm_text_gen] 🔄 After second commit attempt: calls={verified_calls2}, tokens={verified_tokens2}")
else:
logger.debug(f"[llm_text_gen] ✅ COMMIT VERIFICATION PASSED: Values match expected values")
else:
logger.error(f"[llm_text_gen] ❌ CRITICAL: COMMIT VERIFICATION FAILED! Record not found after commit!")
except Exception as verify_error:
logger.error(f"[llm_text_gen] ❌ Error verifying commit: {verify_error}", exc_info=True)
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")
# Include image stats in the log
# DEBUG: Log the actual values being used
logger.debug(f"[llm_text_gen] 📊 FINAL VALUES FOR LOG: calls_before={current_calls_before}, calls_after={new_calls}, tokens_before={current_tokens_before}, tokens_after={new_tokens}, provider={provider_name}, enum={provider_enum}")
# CRITICAL DEBUG: Print diagnostic info to stdout (always visible)
print(f"[DEBUG] Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}")
print(f"""
[SUBSCRIPTION] LLM Text Generation
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: {actual_provider_name}
├─ Model: {model}
├─ Calls: {current_calls_before}{new_calls} / {call_limit if call_limit > 0 else ''}
├─ Tokens: {current_tokens_before}{new_tokens} / {token_limit if token_limit > 0 else ''}
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
├─ Audio: {current_audio_calls} / {audio_limit_display}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:
logger.error(f"[llm_text_gen] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
# Estimate tokens
tokens_input = int(len(prompt.split()) * 1.3)
# Calculate duration (mocking it since we didn't track start time explicitly in this function)
# Ideally we should track start_time at beginning of function
duration = 0.5
track_agent_usage_sync(
user_id=user_id,
model_name=model,
prompt=prompt,
response_text=response_text,
duration=duration
)
except Exception as usage_error:
# Non-blocking: log error but don't fail the request
logger.error(f"[llm_text_gen] ❌ Failed to track usage: {usage_error}", exc_info=True)
@@ -661,208 +349,18 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
if response_text:
logger.info(f"[llm_text_gen] ✅ Fallback API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
try:
db_track = get_session_for_user(user_id)
try:
# Estimate tokens from prompt and response
# Recalculate input tokens from prompt (consistent with pre-flight estimation)
tokens_input = int(len(prompt.split()) * 1.3)
tokens_output = int(len(str(response_text).split()) * 1.3)
tokens_total = tokens_input + tokens_output
# Get or create usage summary
from models.subscription_models import UsageSummary
from services.subscription import PricingService
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Get limits once for safety check (to prevent exceeding limits even if actual usage > estimate)
provider_name = provider_enum.value
limits = pricing.get_user_limits(user_id)
token_limit = 0
if limits and limits.get('limits'):
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
# CRITICAL: Use raw SQL to read current values directly from DB, bypassing SQLAlchemy cache
from sqlalchemy import text
current_calls_before = 0
current_tokens_before = 0
try:
# Validate provider_name to prevent SQL injection
valid_providers = ['gemini', 'openai', 'anthropic', 'mistral']
if provider_name not in valid_providers:
raise ValueError(f"Invalid provider_name for SQL query: {provider_name}")
# Read current values directly from database using raw SQL
sql_query = text(f"""
SELECT {provider_name}_calls, {provider_name}_tokens
FROM usage_summaries
WHERE user_id = :user_id AND billing_period = :period
LIMIT 1
""")
result = db_track.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
if result:
current_calls_before = result[0] if result[0] is not None else 0
current_tokens_before = result[1] if result[1] is not None else 0
logger.debug(f"[llm_text_gen] Raw SQL read current values (fallback): calls={current_calls_before}, tokens={current_tokens_before}")
except Exception as sql_error:
logger.warning(f"[llm_text_gen] Raw SQL query failed (fallback), falling back to ORM: {sql_error}")
# Fallback to ORM query if raw SQL fails
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if summary:
db_track.refresh(summary)
current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0
current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0
# Get or create usage summary object (needed for ORM update)
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush() # Ensure summary is persisted before updating
else:
# CRITICAL: Update the ORM object with values from raw SQL query
# This ensures the ORM object reflects the actual database state before we increment
setattr(summary, f"{provider_name}_calls", current_calls_before)
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
setattr(summary, f"{provider_name}_tokens", current_tokens_before)
logger.debug(f"[llm_text_gen] Synchronized summary object with raw SQL values (fallback): calls={current_calls_before}, tokens={current_tokens_before}")
# Get "before" state for unified log (from raw SQL query)
logger.debug(f"[llm_text_gen] Current {provider_name}_calls from DB (fallback, raw SQL): {current_calls_before}")
# Update provider-specific counters (sync operation)
new_calls = current_calls_before + 1
setattr(summary, f"{provider_name}_calls", new_calls)
# Update token usage for LLM providers with safety check
# Use current_tokens_before from raw SQL query (most reliable)
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
logger.debug(f"[llm_text_gen] Current {provider_name}_tokens from DB (fallback, raw SQL): {current_tokens_before}")
# SAFETY CHECK: Prevent exceeding token limit even if actual usage exceeds estimate
# This prevents abuse where actual response tokens exceed pre-flight validation estimate
projected_new_tokens = current_tokens_before + tokens_total
# If limit is set (> 0) and would be exceeded, cap at limit
if token_limit > 0 and projected_new_tokens > token_limit:
logger.warning(
f"[llm_text_gen] ⚠️ ACTUAL token usage ({tokens_total}) exceeded estimate in fallback provider. "
f"Would exceed limit: {projected_new_tokens} > {token_limit}. "
f"Capping tracked tokens at limit to prevent abuse."
)
# Cap at limit to prevent abuse
new_tokens = token_limit
# Adjust tokens_total for accurate total tracking
tokens_total = token_limit - current_tokens_before
if tokens_total < 0:
tokens_total = 0
else:
new_tokens = projected_new_tokens
setattr(summary, f"{provider_name}_tokens", new_tokens)
else:
current_tokens_before = 0
new_tokens = 0
# Determine tracked tokens after any safety capping
tracked_tokens_input = min(tokens_input, tokens_total)
tracked_tokens_output = max(tokens_total - tracked_tokens_input, 0)
# Calculate and persist cost for this fallback call
cost_total = 0.0
try:
cost_info = pricing.calculate_api_cost(
provider=provider_enum,
model_name=fallback_model,
tokens_input=tracked_tokens_input,
tokens_output=tracked_tokens_output,
request_count=1
)
cost_total = cost_info.get('cost_total', 0.0) or 0.0
except Exception as cost_error:
logger.error(f"[llm_text_gen] ❌ Failed to calculate fallback cost: {cost_error}", exc_info=True)
if cost_total > 0:
update_costs_query = text(f"""
UPDATE usage_summaries
SET {provider_name}_cost = COALESCE({provider_name}_cost, 0) + :cost,
total_cost = COALESCE(total_cost, 0) + :cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_costs_query, {
'cost': cost_total,
'user_id': user_id,
'period': current_period
})
setattr(summary, f"{provider_name}_cost", (getattr(summary, f"{provider_name}_cost", 0.0) or 0.0) + cost_total)
summary.total_cost = (summary.total_cost or 0.0) + cost_total
# Update totals (using potentially capped tokens_total from safety check)
summary.total_calls = (summary.total_calls or 0) + 1
summary.total_tokens = (summary.total_tokens or 0) + tokens_total
# Get plan details for unified log (limits already retrieved above)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0
# Get image stats for unified log
current_images_before = getattr(summary, "stability_calls", 0) or 0
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
# Get image editing stats for unified log
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
# Get video stats for unified log
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
# Get audio stats for unified log
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else ''
# CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions
db_track.flush() # Flush to ensure changes are in DB (not just in transaction)
db_track.commit() # Commit transaction to make changes visible to other sessions
logger.info(f"[llm_text_gen] ✅ Successfully tracked fallback usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens (committed)")
# UNIFIED SUBSCRIPTION LOG for fallback
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")
# Include image stats in the log
print(f"""
[SUBSCRIPTION] LLM Text Generation (Fallback)
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: {actual_provider_name}
├─ Model: {fallback_model}
├─ Calls: {current_calls_before}{new_calls} / {call_limit if call_limit > 0 else ''}
├─ Tokens: {current_tokens_before}{new_tokens} / {token_limit if token_limit > 0 else ''}
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
├─ Audio: {current_audio_calls} / {audio_limit_display}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:
logger.error(f"[llm_text_gen] ❌ Error tracking fallback usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
# Estimate tokens
tokens_input = int(len(prompt.split()) * 1.3)
track_agent_usage_sync(
user_id=user_id,
model_name=fallback_model,
prompt=prompt,
response_text=response_text,
duration=0.5 # Approximate duration
)
except Exception as usage_error:
logger.error(f"[llm_text_gen] ❌ Failed to track fallback usage: {usage_error}", exc_info=True)

View File

@@ -36,6 +36,172 @@ class VideoProviderNotImplemented(Exception):
pass
def _track_video_operation_usage(
user_id: str,
provider: str,
model: str,
operation_type: str,
result_bytes: bytes,
cost: float,
prompt: Optional[str] = None,
endpoint: str = "/video-generation",
metadata: Optional[Dict[str, Any]] = None,
log_prefix: str = "[Video Generation]",
response_time: float = 0.0
) -> Dict[str, Any]:
"""
Reusable usage tracking helper for all video operations.
Args:
user_id: User ID for tracking
provider: Provider name
model: Model name used
operation_type: Type of operation (for logging)
result_bytes: Generated video bytes
cost: Cost of the operation
prompt: Optional prompt text
endpoint: API endpoint path
metadata: Optional additional metadata
log_prefix: Logging prefix
response_time: API response time
Returns:
Dictionary with tracking information
"""
try:
from services.database import get_session_for_user
db_track = get_session_for_user(user_id)
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Get or create usage summary
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush()
# Get current values before update
current_calls_before = getattr(summary, "video_calls", 0) or 0
current_cost_before = getattr(summary, "video_cost", 0.0) or 0.0
# Update video calls and cost
new_calls = current_calls_before + 1
new_cost = current_cost_before + cost
# Use direct SQL UPDATE for dynamic attributes
from sqlalchemy import text as sql_text
update_query = sql_text("""
UPDATE usage_summaries
SET video_calls = :new_calls,
video_cost = :new_cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_query, {
'new_calls': new_calls,
'new_cost': new_cost,
'user_id': user_id,
'period': current_period
})
# Update total cost
summary.total_cost = (summary.total_cost or 0.0) + cost
summary.total_calls = (summary.total_calls or 0) + 1
summary.updated_at = datetime.utcnow()
# Create usage log
request_size = len(prompt.encode("utf-8")) if prompt else 0
usage_log = APIUsageLog(
user_id=user_id,
provider=APIProvider.WAVESPEED, # Default for video
endpoint=endpoint,
method="POST",
model_used=model or "unknown",
actual_provider_name=provider,
tokens_input=0,
tokens_output=0,
tokens_total=0,
cost_input=0.0,
cost_output=0.0,
cost_total=cost,
response_time=response_time,
status_code=200,
request_size=request_size,
response_size=len(result_bytes) if result_bytes else 0,
billing_period=current_period,
)
db_track.add(usage_log)
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
# Get limits for display
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
video_limit_display = video_limit if (video_limit > 0 or tier != 'enterprise') else ''
# Get related stats for unified log
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else ''
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
image_edit_limit_display = image_edit_limit if (image_edit_limit > 0 or tier != 'enterprise') else ''
db_track.commit()
logger.info(f"{log_prefix} ✅ Successfully tracked usage: user {user_id} -> {operation_type} -> {new_calls} calls, ${cost:.4f}")
# UNIFIED SUBSCRIPTION LOG
operation_name = operation_type.replace("-", " ").title()
print(f"""
[SUBSCRIPTION] {operation_name}
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: {provider}
├─ Actual Provider: {provider}
├─ Model: {model or 'unknown'}
├─ Calls: {current_calls_before}{new_calls} / {video_limit_display}
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
├─ Audio: {current_audio_calls} / {audio_limit_display}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit_display}
└─ Status: ✅ Allowed & Tracked
""", flush=True)
sys.stdout.flush()
return {
"current_calls": new_calls,
"cost": cost,
"total_cost": new_cost,
}
except Exception as track_error:
logger.error(f"{log_prefix} ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
import traceback
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
db_track.rollback()
return {}
finally:
db_track.close()
except Exception as usage_error:
logger.error(f"{log_prefix} ❌ Failed to track usage: {usage_error}", exc_info=True)
import traceback
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
return {}
def _get_api_key(provider: str) -> Optional[str]:
try:
manager = APIKeyManager()
@@ -500,156 +666,74 @@ async def ai_video_generate(
raise
finally:
db.close()
logger.info(f"[Video Generation] ✅ Pre-flight validation passed - proceeding with {operation_type}")
# Progress callback: Initial submission
if progress_callback:
progress_callback(10.0, f"Submitting {operation_type} request to {provider}...")
# Generate video based on operation type
model_name = kwargs.get("model", _get_default_model(operation_type, provider))
# Track response time for video generation
# Track response time
import time
from datetime import datetime
start_time = time.time()
# Execute operation based on type
result = {}
try:
if operation_type == "text-to-video":
if provider == "huggingface":
video_bytes = _generate_with_huggingface(
prompt=prompt,
**kwargs,
)
# For text-to-video, create metadata dict (HuggingFace doesn't return metadata)
result_dict = {
video_bytes = _generate_with_huggingface(prompt=prompt, **kwargs)
result = {
"video_bytes": video_bytes,
"prompt": prompt,
"duration": kwargs.get("duration", 5.0),
"model_name": model_name,
"cost": 0.10, # Default cost, will be calculated in track_video_usage
"provider": provider,
"resolution": kwargs.get("resolution", "720p"),
"width": 1280, # Default, actual may vary
"height": 720, # Default, actual may vary
"metadata": {},
"model_name": kwargs.get("model", "tencent/HunyuanVideo"),
"provider": "huggingface",
"cost": 0.0, # HuggingFace inference is free/low cost
}
elif provider == "wavespeed":
# WaveSpeed text-to-video - use unified service
result_dict = await _generate_text_to_video_wavespeed(
result = await _generate_text_to_video_wavespeed(
prompt=prompt,
progress_callback=progress_callback,
**kwargs,
**kwargs
)
elif provider == "gemini":
video_bytes = _generate_with_gemini(prompt=prompt, **kwargs)
result_dict = {
"video_bytes": video_bytes,
"prompt": prompt,
"duration": kwargs.get("duration", 5.0),
"model_name": model_name,
"cost": 0.10,
"provider": provider,
"resolution": kwargs.get("resolution", "720p"),
"width": 1280,
"height": 720,
"metadata": {},
}
result = {"video_bytes": _generate_with_gemini(prompt, **kwargs)}
elif provider == "openai":
video_bytes = _generate_with_openai(prompt=prompt, **kwargs)
result_dict = {
"video_bytes": video_bytes,
"prompt": prompt,
"duration": kwargs.get("duration", 5.0),
"model_name": model_name,
"cost": 0.10,
"provider": provider,
"resolution": kwargs.get("resolution", "720p"),
"width": 1280,
"height": 720,
"metadata": {},
}
result = {"video_bytes": _generate_with_openai(prompt, **kwargs)}
else:
raise RuntimeError(f"Unknown provider for text-to-video: {provider}")
raise ValueError(f"Unknown provider for text-to-video: {provider}")
elif operation_type == "image-to-video":
if provider == "wavespeed":
# Progress callback: Starting generation
if progress_callback:
progress_callback(20.0, "Video generation in progress...")
# Handle async call from sync context
# Since ai_video_generate is sync, we need to run async function
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# We're in an async context - use ThreadPoolExecutor to run in new event loop
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(
asyncio.run,
_generate_image_to_video_wavespeed(
image_data=image_data,
image_base64=image_base64,
prompt=prompt or kwargs.get("prompt", ""),
progress_callback=progress_callback,
**kwargs
)
)
result_dict = future.result()
else:
# Event loop exists but not running - use it
result_dict = loop.run_until_complete(_generate_image_to_video_wavespeed(
image_data=image_data,
image_base64=image_base64,
prompt=prompt or kwargs.get("prompt", ""),
progress_callback=progress_callback,
**kwargs
))
except RuntimeError:
# No event loop exists, create a new one
result_dict = asyncio.run(_generate_image_to_video_wavespeed(
image_data=image_data,
image_base64=image_base64,
prompt=prompt or kwargs.get("prompt", ""),
progress_callback=progress_callback,
**kwargs
))
video_bytes = result_dict["video_bytes"]
model_name = result_dict.get("model_name", model_name)
# Progress callback: Processing result
if progress_callback:
progress_callback(90.0, "Processing video result...")
result = await _generate_image_to_video_wavespeed(
image_data=image_data,
image_base64=image_base64,
prompt=prompt or "",
progress_callback=progress_callback,
**kwargs
)
else:
raise RuntimeError(f"Unknown provider for image-to-video: {provider}. Only 'wavespeed' is supported.")
raise ValueError(f"Unknown provider for image-to-video: {provider}")
# Track usage (same pattern as text generation)
# Use cost from result_dict if available, otherwise calculate
response_time = time.time() - start_time
cost_override = result_dict.get("cost") if operation_type == "image-to-video" else kwargs.get("cost_override")
track_video_usage(
user_id=user_id,
provider=provider,
model_name=model_name,
prompt=result_dict.get("prompt", prompt or ""),
video_bytes=video_bytes,
cost_override=cost_override,
response_time=response_time,
)
# Progress callback: Complete
if progress_callback:
progress_callback(100.0, "Video generation complete!")
return result_dict
except HTTPException:
# Re-raise HTTPExceptions (e.g., from validation or API errors)
raise
# TRACK USAGE after successful API call
video_bytes = result.get("video_bytes")
if user_id and video_bytes:
_track_video_operation_usage(
user_id=user_id,
provider=result.get("provider", provider),
model=result.get("model_name", kwargs.get("model", "unknown")),
operation_type=operation_type,
result_bytes=video_bytes,
cost=result.get("cost", 0.0),
prompt=prompt,
endpoint="/video-generation",
metadata=result.get("metadata"),
log_prefix=f"[{operation_type.replace('-', ' ').title()}]",
response_time=response_time
)
return result
except Exception as e:
logger.error(f"[video_gen] Error during video generation: {e}", exc_info=True)
raise HTTPException(status_code=500, detail={"error": str(e)})
# Log failure but don't track usage (no cost incurred)
logger.error(f"[video_gen] Generation failed: {str(e)}")
raise
def _get_default_model(operation_type: str, provider: str) -> str: