Save local changes (GSC/Bing integrations) before merging PR #354
This commit is contained in:
@@ -297,7 +297,7 @@ def _dict_to_types_schema(schema: Dict[str, Any]) -> types.Schema:
|
||||
return _convert(schema)
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9, top_k=40, max_tokens=8192, system_prompt=None):
|
||||
def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9, top_k=40, max_tokens=8192, system_prompt=None, user_id: str = None):
|
||||
"""
|
||||
Generate structured JSON response using Google's Gemini Pro model.
|
||||
|
||||
@@ -312,6 +312,7 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
|
||||
top_k (int): Top-k sampling parameter
|
||||
max_tokens (int): Maximum tokens in response. Use 8192 for complex outputs
|
||||
system_prompt (str, optional): System instruction for the model
|
||||
user_id (str, optional): User ID for usage tracking.
|
||||
|
||||
Returns:
|
||||
dict: Parsed JSON response matching the provided schema
|
||||
@@ -468,6 +469,25 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
|
||||
logger.info(f"Response has parsed attribute: {response.parsed is not None}")
|
||||
if response.parsed is not None:
|
||||
logger.info("Using response.parsed for structured output")
|
||||
|
||||
# Track usage if user_id is provided
|
||||
if user_id:
|
||||
try:
|
||||
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
|
||||
import json
|
||||
|
||||
response_str = json.dumps(response.parsed)
|
||||
|
||||
track_agent_usage_sync(
|
||||
user_id=user_id,
|
||||
model_name="gemini-2.5-flash",
|
||||
prompt=prompt,
|
||||
response_text=response_str,
|
||||
duration=0.5
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to track usage: {e}")
|
||||
|
||||
return response.parsed
|
||||
else:
|
||||
logger.warning("Response.parsed is None, falling back to text parsing")
|
||||
@@ -500,6 +520,22 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
|
||||
|
||||
parsed_text = json.loads(cleaned_text)
|
||||
logger.info("Successfully parsed text as JSON")
|
||||
|
||||
# Track usage if user_id is provided
|
||||
if user_id:
|
||||
try:
|
||||
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
|
||||
|
||||
track_agent_usage_sync(
|
||||
user_id=user_id,
|
||||
model_name="gemini-2.5-flash",
|
||||
prompt=prompt,
|
||||
response_text=cleaned_text,
|
||||
duration=0.5
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to track usage: {e}")
|
||||
|
||||
return parsed_text
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse text as JSON: {e}")
|
||||
@@ -521,6 +557,26 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
|
||||
fixed_json = re.sub(r',\s*]', ']', fixed_json)
|
||||
|
||||
parsed_text = json.loads(fixed_json)
|
||||
|
||||
# Track usage if user_id is provided
|
||||
if user_id:
|
||||
try:
|
||||
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
|
||||
import json
|
||||
|
||||
response_str = json.dumps(parsed_text) if parsed_text else ""
|
||||
|
||||
track_agent_usage_sync(
|
||||
user_id=user_id,
|
||||
model_name="gemini-2.5-flash",
|
||||
prompt=prompt,
|
||||
response_text=response_str,
|
||||
duration=0.5 # Approximation
|
||||
)
|
||||
logger.info(f"✅ Tracked structured JSON usage for user {user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to track usage: {e}")
|
||||
|
||||
logger.info("Successfully parsed cleaned JSON")
|
||||
return parsed_text
|
||||
except Exception as fix_error:
|
||||
@@ -537,6 +593,22 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
|
||||
import json
|
||||
parsed_text = json.loads(part.text)
|
||||
logger.info("Successfully parsed candidate text as JSON")
|
||||
|
||||
# Track usage if user_id is provided
|
||||
if user_id:
|
||||
try:
|
||||
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
|
||||
|
||||
track_agent_usage_sync(
|
||||
user_id=user_id,
|
||||
model_name="gemini-2.5-flash",
|
||||
prompt=prompt,
|
||||
response_text=part.text,
|
||||
duration=0.5
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to track usage: {e}")
|
||||
|
||||
return parsed_text
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse candidate text as JSON: {e}")
|
||||
|
||||
@@ -4,6 +4,7 @@ import io
|
||||
import os
|
||||
from typing import Optional
|
||||
from PIL import Image
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||
|
||||
from .base import ImageGenerationProvider, ImageGenerationOptions, ImageGenerationResult
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
@@ -14,7 +15,10 @@ logger = get_service_logger("wavespeed.image_provider")
|
||||
|
||||
|
||||
class WaveSpeedImageProvider(ImageGenerationProvider):
|
||||
"""WaveSpeed AI image generation provider supporting Ideogram V3 and Qwen."""
|
||||
"""WaveSpeed AI image generation provider supporting Ideogram V3 and Qwen.
|
||||
|
||||
Implements robust error handling and retries for production stability.
|
||||
"""
|
||||
|
||||
SUPPORTED_MODELS = {
|
||||
"ideogram-v3-turbo": {
|
||||
@@ -54,6 +58,28 @@ class WaveSpeedImageProvider(ImageGenerationProvider):
|
||||
logger.info("[WaveSpeed Image Provider] Initialized with available models: %s",
|
||||
list(self.SUPPORTED_MODELS.keys()))
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=2, max=10),
|
||||
retry=retry_if_exception_type((RuntimeError, IOError)),
|
||||
reraise=True
|
||||
)
|
||||
def _call_api_with_retry(self, method, **kwargs):
|
||||
"""Execute API call with retry logic.
|
||||
|
||||
Args:
|
||||
method: Callable API method
|
||||
**kwargs: Arguments for the method
|
||||
|
||||
Returns:
|
||||
API response
|
||||
"""
|
||||
try:
|
||||
return method(**kwargs)
|
||||
except Exception as e:
|
||||
logger.warning(f"WaveSpeed API call failed (retrying): {str(e)}")
|
||||
raise
|
||||
|
||||
def _validate_options(self, options: ImageGenerationOptions) -> None:
|
||||
"""Validate generation options.
|
||||
|
||||
@@ -117,7 +143,7 @@ class WaveSpeedImageProvider(ImageGenerationProvider):
|
||||
|
||||
# Call WaveSpeed API (using generic image generation method)
|
||||
# This will need to be adjusted based on actual WaveSpeed client implementation
|
||||
result = self.client.generate_image(**params)
|
||||
result = self._call_api_with_retry(self.client.generate_image, **params)
|
||||
|
||||
# Extract image bytes from result
|
||||
# Adjust based on actual WaveSpeed API response format
|
||||
@@ -167,7 +193,7 @@ class WaveSpeedImageProvider(ImageGenerationProvider):
|
||||
params["seed"] = options.seed
|
||||
|
||||
# Call WaveSpeed API
|
||||
result = self.client.generate_image(**params)
|
||||
result = self._call_api_with_retry(self.client.generate_image, **params)
|
||||
|
||||
# Extract image bytes from result
|
||||
if isinstance(result, bytes):
|
||||
@@ -216,7 +242,7 @@ class WaveSpeedImageProvider(ImageGenerationProvider):
|
||||
params["seed"] = options.seed
|
||||
|
||||
# Call WaveSpeed API
|
||||
result = self.client.generate_image(**params)
|
||||
result = self._call_api_with_retry(self.client.generate_image, **params)
|
||||
|
||||
# Extract image bytes from result
|
||||
if isinstance(result, bytes):
|
||||
|
||||
@@ -107,11 +107,13 @@ def generate_audio(
|
||||
estimated_cost = (character_count / 1000.0) * cost_per_1000_chars
|
||||
|
||||
try:
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import UsageSummary, APIProvider
|
||||
|
||||
db = next(get_db())
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
raise RuntimeError("Failed to get database session")
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
@@ -194,7 +196,11 @@ def generate_audio(
|
||||
if audio_bytes:
|
||||
logger.info(f"[audio_gen] ✅ API call successful, tracking usage for user {user_id}")
|
||||
try:
|
||||
db_track = next(get_db())
|
||||
db_track = get_session_for_user(user_id)
|
||||
if not db_track:
|
||||
logger.error(f"[audio_gen] ❌ Failed to get database session for tracking")
|
||||
raise RuntimeError("Failed to get database session")
|
||||
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription import PricingService
|
||||
@@ -383,12 +389,14 @@ def clone_voice(
|
||||
|
||||
voice_clone_cost = 0.5
|
||||
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
|
||||
try:
|
||||
db = next(get_db())
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
raise RuntimeError("Failed to get database session")
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
@@ -432,7 +440,11 @@ def clone_voice(
|
||||
|
||||
if preview_audio_bytes:
|
||||
try:
|
||||
db_track = next(get_db())
|
||||
db_track = get_session_for_user(user_id)
|
||||
if not db_track:
|
||||
logger.error(f"[clone_voice] ❌ Failed to get database session for tracking")
|
||||
raise RuntimeError("Failed to get database session")
|
||||
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription import PricingService
|
||||
@@ -570,12 +582,14 @@ def qwen3_voice_clone(
|
||||
char_count = len(text)
|
||||
estimated_cost = max(0.005, 0.005 * (char_count / 100.0))
|
||||
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
|
||||
try:
|
||||
db = next(get_db())
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
raise RuntimeError("Failed to get database session")
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
@@ -615,7 +629,11 @@ def qwen3_voice_clone(
|
||||
|
||||
if preview_audio_bytes:
|
||||
try:
|
||||
db_track = next(get_db())
|
||||
db_track = get_session_for_user(user_id)
|
||||
if not db_track:
|
||||
logger.error(f"[qwen3_voice_clone] ❌ Failed to get database session for tracking")
|
||||
raise RuntimeError("Failed to get database session")
|
||||
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription import PricingService
|
||||
@@ -691,6 +709,7 @@ def qwen3_voice_clone(
|
||||
├─ Provider: wavespeed
|
||||
├─ Model: wavespeed-ai/qwen3-tts/voice-clone
|
||||
├─ Calls: {current_calls_before} → {new_calls}
|
||||
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
|
||||
├─ Text chars: {char_count}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""", flush=True)
|
||||
@@ -724,3 +743,373 @@ def qwen3_voice_clone(
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def qwen3_voice_design(
|
||||
text: str,
|
||||
voice_description: str,
|
||||
*,
|
||||
language: str = "auto",
|
||||
user_id: Optional[str] = None,
|
||||
) -> VoiceCloneResult:
|
||||
try:
|
||||
if not user_id:
|
||||
raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
|
||||
if not text or not isinstance(text, str) or len(text.strip()) == 0:
|
||||
raise ValueError("Text is required and cannot be empty")
|
||||
text = text.strip()
|
||||
|
||||
if not voice_description or not isinstance(voice_description, str) or len(voice_description.strip()) == 0:
|
||||
raise ValueError("Voice description is required")
|
||||
voice_description = voice_description.strip()
|
||||
|
||||
char_count = len(text)
|
||||
# Pricing logic similar to TTS/Clone
|
||||
estimated_cost = max(0.005, 0.005 * (char_count / 100.0))
|
||||
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
|
||||
try:
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
raise RuntimeError("Failed to get database session")
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.AUDIO,
|
||||
tokens_requested=char_count,
|
||||
actual_provider_name="wavespeed",
|
||||
)
|
||||
if not can_proceed:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
"error": message,
|
||||
"message": message,
|
||||
"provider": "wavespeed",
|
||||
"usage_info": usage_info if usage_info else {},
|
||||
},
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as sub_error:
|
||||
raise RuntimeError(f"Subscription check failed: {str(sub_error)}")
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
client = WaveSpeedClient()
|
||||
preview_audio_bytes = client.voice_design(
|
||||
text=text,
|
||||
voice_description=voice_description,
|
||||
language=language
|
||||
)
|
||||
response_time = time.time() - start_time
|
||||
|
||||
# Track usage
|
||||
try:
|
||||
db_track = get_session_for_user(user_id)
|
||||
if not db_track:
|
||||
logger.error(f"[qwen3_voice_design] ❌ Failed to get database session for tracking")
|
||||
raise RuntimeError("Failed to get database session")
|
||||
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription import PricingService
|
||||
from sqlalchemy import text as sql_text
|
||||
from services.subscription.provider_detection import detect_actual_provider
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(user_id=user_id, billing_period=current_period)
|
||||
db_track.add(summary)
|
||||
db_track.flush()
|
||||
|
||||
current_calls_before = getattr(summary, "audio_calls", 0) or 0
|
||||
current_cost_before = getattr(summary, "audio_cost", 0.0) or 0.0
|
||||
new_calls = current_calls_before + 1
|
||||
new_cost = current_cost_before + float(estimated_cost)
|
||||
|
||||
update_query = sql_text("""
|
||||
UPDATE usage_summaries
|
||||
SET audio_calls = :new_calls,
|
||||
audio_cost = :new_cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_query, {
|
||||
"new_calls": new_calls,
|
||||
"new_cost": new_cost,
|
||||
"user_id": user_id,
|
||||
"period": current_period
|
||||
})
|
||||
|
||||
summary.total_cost = (summary.total_cost or 0.0) + float(estimated_cost)
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
actual_provider = detect_actual_provider(
|
||||
provider_enum=APIProvider.AUDIO,
|
||||
model_name="wavespeed-ai/qwen3-tts/voice-design",
|
||||
endpoint="/audio-generation/wavespeed/qwen3-tts/voice-design",
|
||||
)
|
||||
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.AUDIO,
|
||||
endpoint="/audio-generation/wavespeed/qwen3-tts/voice-design",
|
||||
method="POST",
|
||||
model_used="wavespeed-ai/qwen3-tts/voice-design",
|
||||
actual_provider_name=actual_provider,
|
||||
tokens_input=char_count,
|
||||
tokens_output=0,
|
||||
tokens_total=char_count,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=float(estimated_cost),
|
||||
response_time=response_time,
|
||||
status_code=200,
|
||||
request_size=len(text) + len(voice_description),
|
||||
response_size=len(preview_audio_bytes),
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
db_track.commit()
|
||||
|
||||
print(f"""
|
||||
[SUBSCRIPTION] Qwen3 Voice Design
|
||||
├─ User: {user_id}
|
||||
├─ Provider: wavespeed
|
||||
├─ Model: wavespeed-ai/qwen3-tts/voice-design
|
||||
├─ Calls: {current_calls_before} → {new_calls}
|
||||
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
|
||||
├─ Text chars: {char_count}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""", flush=True)
|
||||
sys.stdout.flush()
|
||||
except Exception as track_error:
|
||||
logger.error(f"[qwen3_voice_design] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"[qwen3_voice_design] ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
|
||||
return VoiceCloneResult(
|
||||
preview_audio_bytes=preview_audio_bytes,
|
||||
provider="wavespeed",
|
||||
model="wavespeed-ai/qwen3-tts/voice-design",
|
||||
custom_voice_id="", # No persistent ID for design usually, unless we save it
|
||||
file_size=len(preview_audio_bytes),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[qwen3_voice_design] Error designing voice: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Qwen3 voice design failed",
|
||||
"message": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def cosyvoice_voice_clone(
|
||||
audio_bytes: bytes,
|
||||
text: str,
|
||||
*,
|
||||
reference_text: Optional[str] = None,
|
||||
audio_mime_type: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> VoiceCloneResult:
|
||||
try:
|
||||
if not user_id:
|
||||
raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
|
||||
if not audio_bytes or not isinstance(audio_bytes, (bytes, bytearray)) or len(audio_bytes) == 0:
|
||||
raise ValueError("Audio is required and cannot be empty")
|
||||
|
||||
if len(audio_bytes) > 15 * 1024 * 1024:
|
||||
raise ValueError("Audio file too large. Maximum is 15MB.")
|
||||
|
||||
if not text or not isinstance(text, str) or len(text.strip()) == 0:
|
||||
raise ValueError("Text is required and cannot be empty")
|
||||
text = text.strip()
|
||||
if len(text) > 4000:
|
||||
raise ValueError("Text too long. Please keep it under 4000 characters.")
|
||||
|
||||
char_count = len(text)
|
||||
estimated_cost = max(0.005, 0.005 * (char_count / 100.0))
|
||||
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
|
||||
try:
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
raise RuntimeError("Failed to get database session")
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.AUDIO,
|
||||
tokens_requested=char_count,
|
||||
actual_provider_name="wavespeed",
|
||||
)
|
||||
if not can_proceed:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
"error": message,
|
||||
"message": message,
|
||||
"provider": "wavespeed",
|
||||
"usage_info": usage_info if usage_info else {},
|
||||
},
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as sub_error:
|
||||
raise RuntimeError(f"Subscription check failed: {str(sub_error)}")
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
client = WaveSpeedClient()
|
||||
preview_audio_bytes = client.cosyvoice_voice_clone(
|
||||
audio_bytes=bytes(audio_bytes),
|
||||
text=text,
|
||||
audio_mime_type=audio_mime_type or "audio/wav",
|
||||
reference_text=reference_text,
|
||||
)
|
||||
response_time = time.time() - start_time
|
||||
|
||||
if preview_audio_bytes:
|
||||
try:
|
||||
db_track = get_session_for_user(user_id)
|
||||
if not db_track:
|
||||
logger.error(f"[cosyvoice_voice_clone] ❌ Failed to get database session for tracking")
|
||||
raise RuntimeError("Failed to get database session")
|
||||
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription import PricingService
|
||||
from sqlalchemy import text as sql_text
|
||||
from services.subscription.provider_detection import detect_actual_provider
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(user_id=user_id, billing_period=current_period)
|
||||
db_track.add(summary)
|
||||
db_track.flush()
|
||||
|
||||
current_calls_before = getattr(summary, "audio_calls", 0) or 0
|
||||
current_cost_before = getattr(summary, "audio_cost", 0.0) or 0.0
|
||||
new_calls = current_calls_before + 1
|
||||
new_cost = current_cost_before + float(estimated_cost)
|
||||
|
||||
update_query = sql_text("""
|
||||
UPDATE usage_summaries
|
||||
SET audio_calls = :new_calls,
|
||||
audio_cost = :new_cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_query, {
|
||||
"new_calls": new_calls,
|
||||
"new_cost": new_cost,
|
||||
"user_id": user_id,
|
||||
"period": current_period
|
||||
})
|
||||
|
||||
summary.total_cost = (summary.total_cost or 0.0) + float(estimated_cost)
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
actual_provider = detect_actual_provider(
|
||||
provider_enum=APIProvider.AUDIO,
|
||||
model_name="wavespeed-ai/cosyvoice-tts/voice-clone",
|
||||
endpoint="/audio-generation/wavespeed/cosyvoice-tts/voice-clone",
|
||||
)
|
||||
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.AUDIO,
|
||||
endpoint="/audio-generation/wavespeed/cosyvoice-tts/voice-clone",
|
||||
method="POST",
|
||||
model_used="wavespeed-ai/cosyvoice-tts/voice-clone",
|
||||
actual_provider_name=actual_provider,
|
||||
tokens_input=char_count,
|
||||
tokens_output=0,
|
||||
tokens_total=char_count,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=float(estimated_cost),
|
||||
response_time=response_time,
|
||||
status_code=200,
|
||||
request_size=len(audio_bytes) + len(text.encode("utf-8")),
|
||||
response_size=len(preview_audio_bytes),
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
db_track.commit()
|
||||
|
||||
print(f"""
|
||||
[SUBSCRIPTION] CosyVoice Voice Clone
|
||||
├─ User: {user_id}
|
||||
├─ Provider: wavespeed
|
||||
├─ Model: wavespeed-ai/cosyvoice-tts/voice-clone
|
||||
├─ Calls: {current_calls_before} → {new_calls}
|
||||
├─ Text chars: {char_count}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""", flush=True)
|
||||
sys.stdout.flush()
|
||||
except Exception as track_error:
|
||||
logger.error(f"[cosyvoice_voice_clone] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"[cosyvoice_voice_clone] ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
|
||||
return VoiceCloneResult(
|
||||
preview_audio_bytes=preview_audio_bytes,
|
||||
provider="wavespeed",
|
||||
model="wavespeed-ai/cosyvoice-tts/voice-clone",
|
||||
custom_voice_id="",
|
||||
file_size=len(preview_audio_bytes),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[cosyvoice_voice_clone] Error cloning voice: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "CosyVoice voice cloning failed",
|
||||
"message": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import io
|
||||
import base64
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from PIL import Image
|
||||
|
||||
@@ -9,6 +11,9 @@ from .image_generation import (
|
||||
ImageGenerationOptions,
|
||||
ImageGenerationResult,
|
||||
)
|
||||
from .image_generation.base import ImageEditOptions
|
||||
from .image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
try:
|
||||
@@ -22,21 +27,36 @@ logger = get_service_logger("image_editing.facade")
|
||||
|
||||
|
||||
DEFAULT_IMAGE_EDIT_MODEL = os.getenv(
|
||||
"HF_IMAGE_EDIT_MODEL",
|
||||
"Qwen/Qwen-Image-Edit",
|
||||
"WAVESPEED_IMAGE_EDIT_MODEL",
|
||||
"qwen-edit-plus",
|
||||
)
|
||||
|
||||
|
||||
def _select_provider(explicit: Optional[str]) -> str:
|
||||
"""Select provider for image editing. Defaults to huggingface with fal-ai."""
|
||||
"""
|
||||
Select the appropriate image editing provider.
|
||||
|
||||
Priority:
|
||||
1. Explicitly requested provider
|
||||
2. WaveSpeed (if API key available) - Preferred for quality/speed
|
||||
3. Hugging Face (fallback)
|
||||
"""
|
||||
if explicit:
|
||||
return explicit
|
||||
# Default to huggingface for image editing (best support for image-to-image)
|
||||
return explicit.lower()
|
||||
|
||||
# Check for WaveSpeed API key first (Preferred provider)
|
||||
if os.getenv("WAVESPEED_API_KEY"):
|
||||
return "wavespeed"
|
||||
|
||||
# Default to huggingface if WaveSpeed not available
|
||||
return "huggingface"
|
||||
|
||||
|
||||
def _get_provider_client(provider_name: str, api_key: Optional[str] = None):
|
||||
"""Get InferenceClient for the specified provider."""
|
||||
"""Get the client for the specified provider."""
|
||||
if provider_name == "wavespeed":
|
||||
return WaveSpeedEditProvider(api_key=api_key)
|
||||
|
||||
if not HF_HUB_AVAILABLE:
|
||||
raise RuntimeError("huggingface_hub is not installed. Install with: pip install huggingface_hub")
|
||||
|
||||
@@ -44,7 +64,7 @@ def _get_provider_client(provider_name: str, api_key: Optional[str] = None):
|
||||
api_key = api_key or os.getenv("HF_TOKEN")
|
||||
if not api_key:
|
||||
raise RuntimeError("HF_TOKEN is required for Hugging Face image editing")
|
||||
# Use fal-ai provider for fast inference
|
||||
# Use fal-ai provider for fast inference via HF Inference API
|
||||
return InferenceClient(provider="fal-ai", api_key=api_key)
|
||||
|
||||
raise ValueError(f"Unknown image editing provider: {provider_name}")
|
||||
@@ -86,6 +106,8 @@ def edit_image(
|
||||
from fastapi import HTTPException
|
||||
|
||||
logger.info(f"[Image Editing] 🔍 Starting pre-flight validation for user_id={user_id}")
|
||||
# Note: get_db() is a generator, so we need to use next() to get the session
|
||||
# and ensure we close it in the finally block
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
@@ -99,6 +121,9 @@ def edit_image(
|
||||
# Re-raise immediately - don't proceed with API call
|
||||
logger.error(f"[Image Editing] ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Image Editing] ❌ Unexpected error during pre-flight validation: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Image editing validation failed: {str(e)}")
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
@@ -119,6 +144,69 @@ def edit_image(
|
||||
# Get provider client
|
||||
client = _get_provider_client(provider_name, opts.get("api_key"))
|
||||
|
||||
if provider_name == "wavespeed":
|
||||
# Handle WaveSpeed provider
|
||||
try:
|
||||
# Convert inputs to base64 for WaveSpeed
|
||||
image_b64 = base64.b64encode(input_image_bytes).decode('utf-8')
|
||||
mask_b64 = None
|
||||
if mask_bytes:
|
||||
mask_b64 = base64.b64encode(mask_bytes).decode('utf-8')
|
||||
|
||||
# Determine operation type based on prompt/mask
|
||||
operation = "general_edit" # Default
|
||||
if not prompt and mask_b64:
|
||||
operation = "remove_bg" # Heuristic: mask but no prompt implies removal/in-painting
|
||||
elif prompt and not mask_b64:
|
||||
operation = "style_transfer" # Heuristic: prompt but no mask implies style transfer
|
||||
elif opts.get("operation"):
|
||||
operation = opts.get("operation")
|
||||
|
||||
edit_options = ImageEditOptions(
|
||||
image_base64=image_b64,
|
||||
prompt=prompt.strip(),
|
||||
operation=operation,
|
||||
mask_base64=mask_b64,
|
||||
model=model,
|
||||
guidance_scale=opts.get("guidance_scale"),
|
||||
steps=opts.get("steps"),
|
||||
seed=opts.get("seed"),
|
||||
extra=opts
|
||||
)
|
||||
|
||||
logger.info(f"[Image Editing] Calling WaveSpeed edit with model={model}")
|
||||
result = client.edit(edit_options)
|
||||
|
||||
# TRACK USAGE after successful WaveSpeed call
|
||||
if user_id:
|
||||
try:
|
||||
from services.llm_providers.main_image_generation import _track_image_operation_usage
|
||||
|
||||
# Estimate cost (WaveSpeed default: $0.02)
|
||||
estimated_cost = result.metadata.get("estimated_cost", 0.02) if result.metadata else 0.02
|
||||
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider="wavespeed",
|
||||
model=result.model or model,
|
||||
operation_type="image-editing",
|
||||
result_bytes=result.image_bytes,
|
||||
cost=estimated_cost,
|
||||
prompt=prompt,
|
||||
endpoint="/image-editing",
|
||||
metadata=result.metadata,
|
||||
log_prefix="[Image Editing]"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"[Image Editing] ⚠️ Failed to track usage: {track_error}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Image Editing] ❌ WaveSpeed editing failed: {e}", exc_info=True)
|
||||
raise RuntimeError(f"WaveSpeed editing failed: {str(e)}")
|
||||
|
||||
# Hugging Face (Fallback)
|
||||
# Prepare parameters for image-to-image
|
||||
params: Dict[str, Any] = {}
|
||||
if opts.get("guidance_scale") is not None:
|
||||
@@ -170,6 +258,29 @@ def edit_image(
|
||||
|
||||
logger.info(f"[Image Editing] ✅ Successfully edited image: {len(edited_image_bytes)} bytes")
|
||||
|
||||
# TRACK USAGE after successful HF call
|
||||
if user_id:
|
||||
try:
|
||||
from services.llm_providers.main_image_generation import _track_image_operation_usage
|
||||
|
||||
# Estimate cost (HF/Fal-ai default: $0.05)
|
||||
estimated_cost = 0.05
|
||||
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider="huggingface",
|
||||
model=model,
|
||||
operation_type="image-editing",
|
||||
result_bytes=edited_image_bytes,
|
||||
cost=estimated_cost,
|
||||
prompt=prompt,
|
||||
endpoint="/image-editing",
|
||||
metadata={"provider": "fal-ai"},
|
||||
log_prefix="[Image Editing]"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"[Image Editing] ⚠️ Failed to track usage: {track_error}")
|
||||
|
||||
return ImageGenerationResult(
|
||||
image_bytes=edited_image_bytes,
|
||||
width=edited_image.width,
|
||||
|
||||
@@ -5,6 +5,7 @@ import sys
|
||||
import base64
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import HTTPException
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
|
||||
from .image_generation import (
|
||||
@@ -29,6 +30,11 @@ logger = get_service_logger("image_generation.facade")
|
||||
def _select_provider(explicit: Optional[str]) -> str:
|
||||
if explicit:
|
||||
return explicit
|
||||
|
||||
# User requested WaveSpeed as default provider
|
||||
if os.getenv("WAVESPEED_API_KEY"):
|
||||
return "wavespeed"
|
||||
|
||||
gpt_provider = (os.getenv("GPT_PROVIDER") or "").lower()
|
||||
if gpt_provider.startswith("gemini"):
|
||||
return "gemini"
|
||||
@@ -36,8 +42,7 @@ def _select_provider(explicit: Optional[str]) -> str:
|
||||
return "huggingface"
|
||||
if os.getenv("STABILITY_API_KEY"):
|
||||
return "stability"
|
||||
if os.getenv("WAVESPEED_API_KEY"):
|
||||
return "wavespeed"
|
||||
|
||||
# Fallback to huggingface to enable a path if configured
|
||||
return "huggingface"
|
||||
|
||||
@@ -739,18 +744,139 @@ async def generate_image_with_provider(
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in generate_image_with_provider: {e}")
|
||||
# Propagate specific error message if available
|
||||
error_detail = str(e)
|
||||
if "402" in error_detail or "Payment Required" in error_detail:
|
||||
raise HTTPException(status_code=402, detail=f"Payment Required: {error_detail}")
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
"error": error_detail
|
||||
}
|
||||
|
||||
|
||||
import time
|
||||
from services.database import get_session_for_user
|
||||
from models.onboarding import WebsiteAnalysis, OnboardingSession, CompetitorAnalysis
|
||||
|
||||
async def enhance_image_prompt(prompt: str, user_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Enhance image prompt using LLM.
|
||||
Placeholder implementation.
|
||||
Enhance image prompt using WaveSpeed's specialized prompt optimizer.
|
||||
Restructures and enriches prompts for visual clarity and cinematic detail.
|
||||
Uses Step 2 (Website Analysis) and Step 3 (Competitor Analysis) context if available.
|
||||
"""
|
||||
return prompt
|
||||
start_time = time.time()
|
||||
try:
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
|
||||
# 1. Pre-flight Validation
|
||||
if user_id:
|
||||
_validate_image_operation(
|
||||
user_id=user_id,
|
||||
operation_type="prompt-enhancement",
|
||||
num_operations=1,
|
||||
log_prefix="[Prompt Enhancement]"
|
||||
)
|
||||
|
||||
# 2. Fetch Context from Step 2 & 3
|
||||
context_instruction = ""
|
||||
if user_id:
|
||||
try:
|
||||
db_session = get_session_for_user(user_id)
|
||||
try:
|
||||
# Get Onboarding Session
|
||||
session = db_session.query(OnboardingSession).filter(
|
||||
OnboardingSession.user_id == user_id
|
||||
).first()
|
||||
|
||||
if session:
|
||||
# Step 2: Website Analysis
|
||||
website_analysis = db_session.query(WebsiteAnalysis).filter(
|
||||
WebsiteAnalysis.session_id == session.id
|
||||
).first()
|
||||
|
||||
if website_analysis:
|
||||
# Handle potential JSON or dict types
|
||||
brand_voice = website_analysis.brand_analysis
|
||||
style = website_analysis.style_guidelines
|
||||
target_audience = website_analysis.target_audience
|
||||
|
||||
context_instruction += "\n\nCONTEXT FROM WEBSITE ANALYSIS:\n"
|
||||
if target_audience:
|
||||
context_instruction += f"Target Audience: {target_audience}\n"
|
||||
|
||||
if brand_voice and isinstance(brand_voice, dict):
|
||||
context_instruction += f"Brand Voice: {brand_voice.get('voice_characteristics', '')} - {brand_voice.get('tone', '')}\n"
|
||||
|
||||
if style and isinstance(style, dict):
|
||||
context_instruction += f"Visual Style: {style.get('visual_style', '')} - {style.get('color_palette', '')}\n"
|
||||
|
||||
# Step 3: Competitor Analysis (Limit to top 3)
|
||||
competitors = db_session.query(CompetitorAnalysis).filter(
|
||||
CompetitorAnalysis.session_id == session.id
|
||||
).limit(3).all()
|
||||
|
||||
if competitors:
|
||||
context_instruction += "\nCOMPETITOR VISUAL INSIGHTS:\n"
|
||||
for comp in competitors:
|
||||
if comp.analysis_data and isinstance(comp.analysis_data, dict):
|
||||
comp_title = comp.analysis_data.get('title', 'Competitor')
|
||||
# Try to extract visual/content insights if available
|
||||
highlights = comp.analysis_data.get('highlights', [])
|
||||
if highlights:
|
||||
context_instruction += f"- {comp_title}: {', '.join(highlights[:2])}\n"
|
||||
|
||||
finally:
|
||||
db_session.close()
|
||||
except Exception as db_ex:
|
||||
logger.warning(f"Failed to fetch context for prompt enhancement: {db_ex}")
|
||||
|
||||
# Combine prompt with context
|
||||
full_input_text = prompt
|
||||
if context_instruction:
|
||||
logger.info(f"Enhancing prompt for user {user_id} with Step 2/3 context")
|
||||
# We append context as instruction for the optimizer
|
||||
full_input_text = f"Original Request: {prompt}\n\n{context_instruction}\n\nTask: Generate a hyper-personalized, detailed image generation prompt based on the Original Request and the provided Context. Ensure the visual style aligns with the Brand Voice and Visual Style."
|
||||
else:
|
||||
logger.info(f"Enhancing prompt for user {user_id} (no context found)")
|
||||
|
||||
# 3. Call WaveSpeed
|
||||
client = WaveSpeedClient()
|
||||
# Use 'image' mode for avatar/image generation workflows
|
||||
# Use 'photographic' style as requested for avatars
|
||||
optimized_prompt = client.optimize_prompt(
|
||||
text=full_input_text,
|
||||
mode="image",
|
||||
style="photographic",
|
||||
enable_sync_mode=True,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
# 4. Track Usage
|
||||
if user_id:
|
||||
duration = time.time() - start_time
|
||||
# Track as 0 cost for now unless we have specific pricing for prompt opt
|
||||
# But we track it as an operation
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider="wavespeed",
|
||||
model="wavespeed-prompt-opt",
|
||||
operation_type="prompt-enhancement",
|
||||
result_bytes=b"", # No image
|
||||
cost=0.0,
|
||||
prompt=prompt,
|
||||
endpoint="/enhance-prompt",
|
||||
metadata={"duration": duration, "context_added": bool(context_instruction)},
|
||||
log_prefix="[Prompt Enhancement]",
|
||||
response_time=duration
|
||||
)
|
||||
|
||||
return optimized_prompt
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to enhance prompt via WaveSpeed: {e}")
|
||||
# Fallback to original prompt on failure
|
||||
return prompt
|
||||
|
||||
|
||||
async def generate_image_variation(
|
||||
@@ -760,13 +886,123 @@ async def generate_image_variation(
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate variation of an existing image.
|
||||
Placeholder implementation.
|
||||
Generate variation of an existing image using image-to-image editing.
|
||||
Wrapper for step4_asset_routes.
|
||||
"""
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Not implemented yet"
|
||||
}
|
||||
try:
|
||||
# Handle image input (bytes, file, or base64)
|
||||
image_bytes = None
|
||||
if isinstance(image, bytes):
|
||||
image_bytes = image
|
||||
elif hasattr(image, "read"):
|
||||
image_bytes = await image.read()
|
||||
elif isinstance(image, str):
|
||||
# Assume base64 or path
|
||||
if os.path.exists(image):
|
||||
with open(image, "rb") as f:
|
||||
image_bytes = f.read()
|
||||
else:
|
||||
# Try base64 decode
|
||||
try:
|
||||
if "base64," in image:
|
||||
image = image.split("base64,")[1]
|
||||
image_bytes = base64.b64decode(image)
|
||||
except:
|
||||
pass
|
||||
|
||||
if not image_bytes:
|
||||
return {"success": False, "error": "Invalid image input"}
|
||||
|
||||
# Convert to base64 for internal function
|
||||
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
|
||||
|
||||
# Use generate_image_edit with "variation" intent
|
||||
# For variation, we typically use general_edit with specific prompt
|
||||
result = await run_in_threadpool(
|
||||
generate_image_edit,
|
||||
image_base64=image_base64,
|
||||
prompt=prompt,
|
||||
operation="general_edit",
|
||||
model=kwargs.get("model", "qwen-edit-plus"), # Default to capable model
|
||||
options=kwargs,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
result_base64 = base64.b64encode(result.image_bytes).decode('utf-8')
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"image_base64": result_base64,
|
||||
"metadata": result.metadata
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in generate_image_variation: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
async def generate_image_enhance(
|
||||
image: Any,
|
||||
user_id: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Enhance/Upscale an existing image.
|
||||
Wrapper for step4_asset_routes.
|
||||
"""
|
||||
try:
|
||||
# Handle image input
|
||||
image_bytes = None
|
||||
if isinstance(image, bytes):
|
||||
image_bytes = image
|
||||
elif hasattr(image, "read"):
|
||||
image_bytes = await image.read()
|
||||
elif isinstance(image, str):
|
||||
if os.path.exists(image):
|
||||
with open(image, "rb") as f:
|
||||
image_bytes = f.read()
|
||||
else:
|
||||
try:
|
||||
if "base64," in image:
|
||||
image = image.split("base64,")[1]
|
||||
image_bytes = base64.b64decode(image)
|
||||
except:
|
||||
pass
|
||||
|
||||
if not image_bytes:
|
||||
return {"success": False, "error": "Invalid image input"}
|
||||
|
||||
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
|
||||
|
||||
# Use generate_image_edit with "enhance" intent
|
||||
# Use high-res model like nano-banana-pro-edit-ultra
|
||||
result = await run_in_threadpool(
|
||||
generate_image_edit,
|
||||
image_base64=image_base64,
|
||||
prompt="enhance details, high resolution, professional quality, 4k, sharp focus",
|
||||
operation="general_edit",
|
||||
model="nano-banana-pro-edit-ultra",
|
||||
options={**kwargs, "resolution": "4k"},
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
result_base64 = base64.b64encode(result.image_bytes).decode('utf-8')
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"image_base64": result_base64,
|
||||
"metadata": result.metadata
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in generate_image_enhance: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -260,335 +260,23 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
if response_text:
|
||||
logger.info(f"[llm_text_gen] ✅ API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
|
||||
try:
|
||||
db_track = get_session_for_user(user_id)
|
||||
try:
|
||||
# Estimate tokens from prompt and response
|
||||
# Recalculate input tokens from prompt (consistent with pre-flight estimation)
|
||||
tokens_input = int(len(prompt.split()) * 1.3)
|
||||
tokens_output = int(len(str(response_text).split()) * 1.3) # Estimate output tokens
|
||||
tokens_total = tokens_input + tokens_output
|
||||
|
||||
logger.debug(f"[llm_text_gen] Token estimates: input={tokens_input}, output={tokens_output}, total={tokens_total}")
|
||||
|
||||
# Get or create usage summary
|
||||
from models.subscription_models import UsageSummary
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
logger.debug(f"[llm_text_gen] Looking for usage summary: user_id={user_id}, period={current_period}")
|
||||
|
||||
# Get limits once for safety check (to prevent exceeding limits even if actual usage > estimate)
|
||||
provider_name = provider_enum.value
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
token_limit = 0
|
||||
if limits and limits.get('limits'):
|
||||
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
|
||||
|
||||
# CRITICAL: Use raw SQL to read current values directly from DB, bypassing SQLAlchemy cache
|
||||
# This ensures we always get the absolute latest committed values, even across different sessions
|
||||
from sqlalchemy import text
|
||||
current_calls_before = 0
|
||||
current_tokens_before = 0
|
||||
record_count = 0 # Initialize to ensure it's always defined
|
||||
|
||||
# CRITICAL: First check if record exists using COUNT query
|
||||
try:
|
||||
check_query = text("SELECT COUNT(*) FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period")
|
||||
record_count = db_track.execute(check_query, {'user_id': user_id, 'period': current_period}).scalar()
|
||||
logger.debug(f"[llm_text_gen] 🔍 DEBUG: Record count check - found {record_count} record(s) for user={user_id}, period={current_period}")
|
||||
except Exception as count_error:
|
||||
logger.error(f"[llm_text_gen] ❌ COUNT query failed: {count_error}", exc_info=True)
|
||||
record_count = 0
|
||||
|
||||
if record_count and record_count > 0:
|
||||
# Record exists - read current values with raw SQL
|
||||
try:
|
||||
# Validate provider_name to prevent SQL injection (whitelist approach)
|
||||
valid_providers = ['gemini', 'openai', 'anthropic', 'mistral']
|
||||
if provider_name not in valid_providers:
|
||||
raise ValueError(f"Invalid provider_name for SQL query: {provider_name}")
|
||||
|
||||
# Read current values directly from database using raw SQL
|
||||
# CRITICAL: This bypasses SQLAlchemy's session cache and gets absolute latest values
|
||||
sql_query = text(f"""
|
||||
SELECT {provider_name}_calls, {provider_name}_tokens
|
||||
FROM usage_summaries
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
LIMIT 1
|
||||
""")
|
||||
logger.debug(f"[llm_text_gen] 🔍 Executing raw SQL for EXISTING record: SELECT {provider_name}_calls, {provider_name}_tokens WHERE user_id={user_id}, period={current_period}")
|
||||
result = db_track.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
|
||||
if result:
|
||||
raw_calls = result[0] if result[0] is not None else 0
|
||||
raw_tokens = result[1] if result[1] is not None else 0
|
||||
current_calls_before = raw_calls
|
||||
current_tokens_before = raw_tokens
|
||||
logger.debug(f"[llm_text_gen] ✅ Raw SQL SUCCESS: Found EXISTING record - calls={current_calls_before}, tokens={current_tokens_before} (provider={provider_name}, column={provider_name}_calls/{provider_name}_tokens)")
|
||||
logger.debug(f"[llm_text_gen] 🔍 Raw SQL returned row: {result}, extracted calls={raw_calls}, tokens={raw_tokens}")
|
||||
else:
|
||||
logger.error(f"[llm_text_gen] ❌ CRITICAL BUG: Record EXISTS (count={record_count}) but SELECT query returned None! Query: {sql_query}")
|
||||
# Fallback: Use ORM to get values
|
||||
summary_fallback = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if summary_fallback:
|
||||
db_track.refresh(summary_fallback)
|
||||
current_calls_before = getattr(summary_fallback, f"{provider_name}_calls", 0) or 0
|
||||
current_tokens_before = getattr(summary_fallback, f"{provider_name}_tokens", 0) or 0
|
||||
logger.warning(f"[llm_text_gen] ⚠️ Using ORM fallback: calls={current_calls_before}, tokens={current_tokens_before}")
|
||||
except Exception as sql_error:
|
||||
logger.error(f"[llm_text_gen] ❌ Raw SQL query failed: {sql_error}", exc_info=True)
|
||||
# Fallback: Use ORM to get values
|
||||
summary_fallback = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if summary_fallback:
|
||||
db_track.refresh(summary_fallback)
|
||||
current_calls_before = getattr(summary_fallback, f"{provider_name}_calls", 0) or 0
|
||||
current_tokens_before = getattr(summary_fallback, f"{provider_name}_tokens", 0) or 0
|
||||
else:
|
||||
logger.debug(f"[llm_text_gen] ℹ️ No record exists yet (will create new) - user={user_id}, period={current_period}")
|
||||
|
||||
# Get or create usage summary object (needed for ORM update)
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
logger.debug(f"[llm_text_gen] Creating NEW usage summary for user {user_id}, period {current_period}")
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush() # Ensure summary is persisted before updating
|
||||
# New record - values are already 0, no need to set
|
||||
logger.debug(f"[llm_text_gen] ✅ New summary created - starting from 0")
|
||||
else:
|
||||
# CRITICAL: Update the ORM object with values from raw SQL query
|
||||
# This ensures the ORM object reflects the actual database state before we increment
|
||||
logger.debug(f"[llm_text_gen] 🔄 Existing summary found - syncing with raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}")
|
||||
setattr(summary, f"{provider_name}_calls", current_calls_before)
|
||||
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
setattr(summary, f"{provider_name}_tokens", current_tokens_before)
|
||||
logger.debug(f"[llm_text_gen] ✅ Synchronized ORM object: {provider_name}_calls={current_calls_before}, {provider_name}_tokens={current_tokens_before}")
|
||||
|
||||
logger.debug(f"[llm_text_gen] Current {provider_name}_calls from DB (raw SQL): {current_calls_before}")
|
||||
|
||||
# Update provider-specific counters (sync operation)
|
||||
new_calls = current_calls_before + 1
|
||||
|
||||
# CRITICAL: Use direct SQL UPDATE instead of ORM setattr for dynamic attributes
|
||||
# SQLAlchemy doesn't detect changes when using setattr() on dynamic attributes
|
||||
# Using raw SQL UPDATE ensures the change is persisted
|
||||
from sqlalchemy import text
|
||||
update_calls_query = text(f"""
|
||||
UPDATE usage_summaries
|
||||
SET {provider_name}_calls = :new_calls
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_calls_query, {
|
||||
'new_calls': new_calls,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
logger.debug(f"[llm_text_gen] Updated {provider_name}_calls via SQL: {current_calls_before} -> {new_calls}")
|
||||
|
||||
# Update token usage for LLM providers with safety check
|
||||
# CRITICAL: Use current_tokens_before from raw SQL query (NOT from ORM object)
|
||||
# The ORM object may have stale values, but raw SQL always has the latest committed values
|
||||
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
logger.debug(f"[llm_text_gen] Current {provider_name}_tokens from DB (raw SQL): {current_tokens_before}")
|
||||
|
||||
# SAFETY CHECK: Prevent exceeding token limit even if actual usage exceeds estimate
|
||||
# This prevents abuse where actual response tokens exceed pre-flight validation estimate
|
||||
projected_new_tokens = current_tokens_before + tokens_total
|
||||
|
||||
# If limit is set (> 0) and would be exceeded, cap at limit
|
||||
if token_limit > 0 and projected_new_tokens > token_limit:
|
||||
logger.warning(
|
||||
f"[llm_text_gen] ⚠️ ACTUAL token usage ({tokens_total}) exceeded estimate. "
|
||||
f"Would exceed limit: {projected_new_tokens} > {token_limit}. "
|
||||
f"Capping tracked tokens at limit to prevent abuse."
|
||||
)
|
||||
# Cap at limit to prevent abuse
|
||||
new_tokens = token_limit
|
||||
# Adjust tokens_total for accurate total tracking
|
||||
tokens_total = token_limit - current_tokens_before
|
||||
if tokens_total < 0:
|
||||
tokens_total = 0
|
||||
else:
|
||||
new_tokens = projected_new_tokens
|
||||
|
||||
# CRITICAL: Use direct SQL UPDATE instead of ORM setattr for dynamic attributes
|
||||
update_tokens_query = text(f"""
|
||||
UPDATE usage_summaries
|
||||
SET {provider_name}_tokens = :new_tokens
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_tokens_query, {
|
||||
'new_tokens': new_tokens,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
logger.debug(f"[llm_text_gen] Updated {provider_name}_tokens via SQL: {current_tokens_before} -> {new_tokens}")
|
||||
else:
|
||||
current_tokens_before = 0
|
||||
new_tokens = 0
|
||||
|
||||
# Determine tracked tokens (after any safety capping)
|
||||
tracked_tokens_input = min(tokens_input, tokens_total)
|
||||
tracked_tokens_output = max(tokens_total - tracked_tokens_input, 0)
|
||||
|
||||
# Calculate and persist cost for this call
|
||||
try:
|
||||
cost_info = pricing.calculate_api_cost(
|
||||
provider=provider_enum,
|
||||
model_name=model,
|
||||
tokens_input=tracked_tokens_input,
|
||||
tokens_output=tracked_tokens_output,
|
||||
request_count=1
|
||||
)
|
||||
cost_total = cost_info.get('cost_total', 0.0) or 0.0
|
||||
except Exception as cost_error:
|
||||
cost_total = 0.0
|
||||
logger.error(f"[llm_text_gen] ❌ Failed to calculate API cost: {cost_error}", exc_info=True)
|
||||
|
||||
if cost_total > 0:
|
||||
logger.debug(f"[llm_text_gen] 💰 Calculated cost for {provider_name}: ${cost_total:.6f}")
|
||||
update_costs_query = text(f"""
|
||||
UPDATE usage_summaries
|
||||
SET {provider_name}_cost = COALESCE({provider_name}_cost, 0) + :cost,
|
||||
total_cost = COALESCE(total_cost, 0) + :cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_costs_query, {
|
||||
'cost': cost_total,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
|
||||
# Keep ORM object in sync for logging/debugging
|
||||
current_provider_cost = getattr(summary, f"{provider_name}_cost", 0.0) or 0.0
|
||||
setattr(summary, f"{provider_name}_cost", current_provider_cost + cost_total)
|
||||
summary.total_cost = (summary.total_cost or 0.0) + cost_total
|
||||
else:
|
||||
logger.debug(f"[llm_text_gen] 💰 Cost calculation returned $0 for {provider_name} (tokens_input={tracked_tokens_input}, tokens_output={tracked_tokens_output})")
|
||||
|
||||
# Update totals using SQL UPDATE
|
||||
old_total_calls = summary.total_calls or 0
|
||||
old_total_tokens = summary.total_tokens or 0
|
||||
new_total_calls = old_total_calls + 1
|
||||
new_total_tokens = old_total_tokens + tokens_total
|
||||
|
||||
update_totals_query = text("""
|
||||
UPDATE usage_summaries
|
||||
SET total_calls = :total_calls, total_tokens = :total_tokens
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_totals_query, {
|
||||
'total_calls': new_total_calls,
|
||||
'total_tokens': new_total_tokens,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
logger.debug(f"[llm_text_gen] Updated totals via SQL: calls {old_total_calls} -> {new_total_calls}, tokens {old_total_tokens} -> {new_total_tokens}")
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0
|
||||
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) if limits else 0
|
||||
|
||||
# Get image stats for unified log
|
||||
current_images_before = getattr(summary, "stability_calls", 0) or 0
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
|
||||
# Get image editing stats for unified log
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
|
||||
# Get video stats for unified log
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
# Get audio stats for unified log
|
||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
|
||||
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
# CRITICAL DEBUG: Print diagnostic info BEFORE commit (always visible, flushed immediately)
|
||||
import sys
|
||||
debug_msg = f"[DEBUG] BEFORE COMMIT - Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}, Period: {current_period}, New calls will be: {new_calls}, New tokens will be: {new_tokens}"
|
||||
print(debug_msg, flush=True)
|
||||
sys.stdout.flush()
|
||||
logger.debug(f"[llm_text_gen] {debug_msg}")
|
||||
|
||||
# CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions
|
||||
db_track.flush() # Flush to ensure changes are in DB (not just in transaction)
|
||||
db_track.commit() # Commit transaction to make changes visible to other sessions
|
||||
logger.debug(f"[llm_text_gen] ✅ Successfully tracked usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens (COMMITTED to DB)")
|
||||
logger.debug(f"[llm_text_gen] Database state after commit: {provider_name}_calls={new_calls}, {provider_name}_tokens={new_tokens} (should be visible to next session)")
|
||||
|
||||
# CRITICAL: Verify commit worked by reading back from DB immediately after commit
|
||||
try:
|
||||
verify_query = text(f"SELECT {provider_name}_calls, {provider_name}_tokens FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1")
|
||||
verify_result = db_track.execute(verify_query, {'user_id': user_id, 'period': current_period}).first()
|
||||
if verify_result:
|
||||
verified_calls = verify_result[0] if verify_result[0] is not None else 0
|
||||
verified_tokens = verify_result[1] if verify_result[1] is not None else 0
|
||||
logger.debug(f"[llm_text_gen] ✅ VERIFICATION AFTER COMMIT: Read back calls={verified_calls}, tokens={verified_tokens} (expected: calls={new_calls}, tokens={new_tokens})")
|
||||
if verified_calls != new_calls or verified_tokens != new_tokens:
|
||||
logger.error(f"[llm_text_gen] ❌ CRITICAL: COMMIT VERIFICATION FAILED! Expected calls={new_calls}, tokens={new_tokens}, but DB has calls={verified_calls}, tokens={verified_tokens}")
|
||||
# Force another commit attempt
|
||||
db_track.commit()
|
||||
verify_result2 = db_track.execute(verify_query, {'user_id': user_id, 'period': current_period}).first()
|
||||
if verify_result2:
|
||||
verified_calls2 = verify_result2[0] if verify_result2[0] is not None else 0
|
||||
verified_tokens2 = verify_result2[1] if verify_result2[1] is not None else 0
|
||||
logger.debug(f"[llm_text_gen] 🔄 After second commit attempt: calls={verified_calls2}, tokens={verified_tokens2}")
|
||||
else:
|
||||
logger.debug(f"[llm_text_gen] ✅ COMMIT VERIFICATION PASSED: Values match expected values")
|
||||
else:
|
||||
logger.error(f"[llm_text_gen] ❌ CRITICAL: COMMIT VERIFICATION FAILED! Record not found after commit!")
|
||||
except Exception as verify_error:
|
||||
logger.error(f"[llm_text_gen] ❌ Error verifying commit: {verify_error}", exc_info=True)
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
|
||||
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")
|
||||
# Include image stats in the log
|
||||
# DEBUG: Log the actual values being used
|
||||
logger.debug(f"[llm_text_gen] 📊 FINAL VALUES FOR LOG: calls_before={current_calls_before}, calls_after={new_calls}, tokens_before={current_tokens_before}, tokens_after={new_tokens}, provider={provider_name}, enum={provider_enum}")
|
||||
|
||||
# CRITICAL DEBUG: Print diagnostic info to stdout (always visible)
|
||||
print(f"[DEBUG] Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}")
|
||||
|
||||
print(f"""
|
||||
[SUBSCRIPTION] LLM Text Generation
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: {actual_provider_name}
|
||||
├─ Model: {model}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {call_limit if call_limit > 0 else '∞'}
|
||||
├─ Tokens: {current_tokens_before} → {new_tokens} / {token_limit if token_limit > 0 else '∞'}
|
||||
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit_display}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""")
|
||||
except Exception as track_error:
|
||||
logger.error(f"[llm_text_gen] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
|
||||
|
||||
# Estimate tokens
|
||||
tokens_input = int(len(prompt.split()) * 1.3)
|
||||
|
||||
# Calculate duration (mocking it since we didn't track start time explicitly in this function)
|
||||
# Ideally we should track start_time at beginning of function
|
||||
duration = 0.5
|
||||
|
||||
track_agent_usage_sync(
|
||||
user_id=user_id,
|
||||
model_name=model,
|
||||
prompt=prompt,
|
||||
response_text=response_text,
|
||||
duration=duration
|
||||
)
|
||||
|
||||
except Exception as usage_error:
|
||||
# Non-blocking: log error but don't fail the request
|
||||
logger.error(f"[llm_text_gen] ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
@@ -661,208 +349,18 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
if response_text:
|
||||
logger.info(f"[llm_text_gen] ✅ Fallback API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
|
||||
try:
|
||||
db_track = get_session_for_user(user_id)
|
||||
try:
|
||||
# Estimate tokens from prompt and response
|
||||
# Recalculate input tokens from prompt (consistent with pre-flight estimation)
|
||||
tokens_input = int(len(prompt.split()) * 1.3)
|
||||
tokens_output = int(len(str(response_text).split()) * 1.3)
|
||||
tokens_total = tokens_input + tokens_output
|
||||
|
||||
# Get or create usage summary
|
||||
from models.subscription_models import UsageSummary
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get limits once for safety check (to prevent exceeding limits even if actual usage > estimate)
|
||||
provider_name = provider_enum.value
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
token_limit = 0
|
||||
if limits and limits.get('limits'):
|
||||
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
|
||||
|
||||
# CRITICAL: Use raw SQL to read current values directly from DB, bypassing SQLAlchemy cache
|
||||
from sqlalchemy import text
|
||||
current_calls_before = 0
|
||||
current_tokens_before = 0
|
||||
|
||||
try:
|
||||
# Validate provider_name to prevent SQL injection
|
||||
valid_providers = ['gemini', 'openai', 'anthropic', 'mistral']
|
||||
if provider_name not in valid_providers:
|
||||
raise ValueError(f"Invalid provider_name for SQL query: {provider_name}")
|
||||
|
||||
# Read current values directly from database using raw SQL
|
||||
sql_query = text(f"""
|
||||
SELECT {provider_name}_calls, {provider_name}_tokens
|
||||
FROM usage_summaries
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
LIMIT 1
|
||||
""")
|
||||
result = db_track.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
|
||||
if result:
|
||||
current_calls_before = result[0] if result[0] is not None else 0
|
||||
current_tokens_before = result[1] if result[1] is not None else 0
|
||||
logger.debug(f"[llm_text_gen] Raw SQL read current values (fallback): calls={current_calls_before}, tokens={current_tokens_before}")
|
||||
except Exception as sql_error:
|
||||
logger.warning(f"[llm_text_gen] Raw SQL query failed (fallback), falling back to ORM: {sql_error}")
|
||||
# Fallback to ORM query if raw SQL fails
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if summary:
|
||||
db_track.refresh(summary)
|
||||
current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0
|
||||
current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0
|
||||
|
||||
# Get or create usage summary object (needed for ORM update)
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush() # Ensure summary is persisted before updating
|
||||
else:
|
||||
# CRITICAL: Update the ORM object with values from raw SQL query
|
||||
# This ensures the ORM object reflects the actual database state before we increment
|
||||
setattr(summary, f"{provider_name}_calls", current_calls_before)
|
||||
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
setattr(summary, f"{provider_name}_tokens", current_tokens_before)
|
||||
logger.debug(f"[llm_text_gen] Synchronized summary object with raw SQL values (fallback): calls={current_calls_before}, tokens={current_tokens_before}")
|
||||
|
||||
# Get "before" state for unified log (from raw SQL query)
|
||||
logger.debug(f"[llm_text_gen] Current {provider_name}_calls from DB (fallback, raw SQL): {current_calls_before}")
|
||||
|
||||
# Update provider-specific counters (sync operation)
|
||||
new_calls = current_calls_before + 1
|
||||
setattr(summary, f"{provider_name}_calls", new_calls)
|
||||
|
||||
# Update token usage for LLM providers with safety check
|
||||
# Use current_tokens_before from raw SQL query (most reliable)
|
||||
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
logger.debug(f"[llm_text_gen] Current {provider_name}_tokens from DB (fallback, raw SQL): {current_tokens_before}")
|
||||
|
||||
# SAFETY CHECK: Prevent exceeding token limit even if actual usage exceeds estimate
|
||||
# This prevents abuse where actual response tokens exceed pre-flight validation estimate
|
||||
projected_new_tokens = current_tokens_before + tokens_total
|
||||
|
||||
# If limit is set (> 0) and would be exceeded, cap at limit
|
||||
if token_limit > 0 and projected_new_tokens > token_limit:
|
||||
logger.warning(
|
||||
f"[llm_text_gen] ⚠️ ACTUAL token usage ({tokens_total}) exceeded estimate in fallback provider. "
|
||||
f"Would exceed limit: {projected_new_tokens} > {token_limit}. "
|
||||
f"Capping tracked tokens at limit to prevent abuse."
|
||||
)
|
||||
# Cap at limit to prevent abuse
|
||||
new_tokens = token_limit
|
||||
# Adjust tokens_total for accurate total tracking
|
||||
tokens_total = token_limit - current_tokens_before
|
||||
if tokens_total < 0:
|
||||
tokens_total = 0
|
||||
else:
|
||||
new_tokens = projected_new_tokens
|
||||
|
||||
setattr(summary, f"{provider_name}_tokens", new_tokens)
|
||||
else:
|
||||
current_tokens_before = 0
|
||||
new_tokens = 0
|
||||
|
||||
# Determine tracked tokens after any safety capping
|
||||
tracked_tokens_input = min(tokens_input, tokens_total)
|
||||
tracked_tokens_output = max(tokens_total - tracked_tokens_input, 0)
|
||||
|
||||
# Calculate and persist cost for this fallback call
|
||||
cost_total = 0.0
|
||||
try:
|
||||
cost_info = pricing.calculate_api_cost(
|
||||
provider=provider_enum,
|
||||
model_name=fallback_model,
|
||||
tokens_input=tracked_tokens_input,
|
||||
tokens_output=tracked_tokens_output,
|
||||
request_count=1
|
||||
)
|
||||
cost_total = cost_info.get('cost_total', 0.0) or 0.0
|
||||
except Exception as cost_error:
|
||||
logger.error(f"[llm_text_gen] ❌ Failed to calculate fallback cost: {cost_error}", exc_info=True)
|
||||
|
||||
if cost_total > 0:
|
||||
update_costs_query = text(f"""
|
||||
UPDATE usage_summaries
|
||||
SET {provider_name}_cost = COALESCE({provider_name}_cost, 0) + :cost,
|
||||
total_cost = COALESCE(total_cost, 0) + :cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_costs_query, {
|
||||
'cost': cost_total,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
setattr(summary, f"{provider_name}_cost", (getattr(summary, f"{provider_name}_cost", 0.0) or 0.0) + cost_total)
|
||||
summary.total_cost = (summary.total_cost or 0.0) + cost_total
|
||||
|
||||
# Update totals (using potentially capped tokens_total from safety check)
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.total_tokens = (summary.total_tokens or 0) + tokens_total
|
||||
|
||||
# Get plan details for unified log (limits already retrieved above)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0
|
||||
|
||||
# Get image stats for unified log
|
||||
current_images_before = getattr(summary, "stability_calls", 0) or 0
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
|
||||
# Get image editing stats for unified log
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
|
||||
# Get video stats for unified log
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
# Get audio stats for unified log
|
||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
|
||||
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
# CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions
|
||||
db_track.flush() # Flush to ensure changes are in DB (not just in transaction)
|
||||
db_track.commit() # Commit transaction to make changes visible to other sessions
|
||||
logger.info(f"[llm_text_gen] ✅ Successfully tracked fallback usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens (committed)")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG for fallback
|
||||
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")
|
||||
# Include image stats in the log
|
||||
print(f"""
|
||||
[SUBSCRIPTION] LLM Text Generation (Fallback)
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: {actual_provider_name}
|
||||
├─ Model: {fallback_model}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {call_limit if call_limit > 0 else '∞'}
|
||||
├─ Tokens: {current_tokens_before} → {new_tokens} / {token_limit if token_limit > 0 else '∞'}
|
||||
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit_display}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""")
|
||||
except Exception as track_error:
|
||||
logger.error(f"[llm_text_gen] ❌ Error tracking fallback usage (non-blocking): {track_error}", exc_info=True)
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
|
||||
|
||||
# Estimate tokens
|
||||
tokens_input = int(len(prompt.split()) * 1.3)
|
||||
|
||||
track_agent_usage_sync(
|
||||
user_id=user_id,
|
||||
model_name=fallback_model,
|
||||
prompt=prompt,
|
||||
response_text=response_text,
|
||||
duration=0.5 # Approximate duration
|
||||
)
|
||||
except Exception as usage_error:
|
||||
logger.error(f"[llm_text_gen] ❌ Failed to track fallback usage: {usage_error}", exc_info=True)
|
||||
|
||||
|
||||
@@ -36,6 +36,172 @@ class VideoProviderNotImplemented(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _track_video_operation_usage(
|
||||
user_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
operation_type: str,
|
||||
result_bytes: bytes,
|
||||
cost: float,
|
||||
prompt: Optional[str] = None,
|
||||
endpoint: str = "/video-generation",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
log_prefix: str = "[Video Generation]",
|
||||
response_time: float = 0.0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Reusable usage tracking helper for all video operations.
|
||||
|
||||
Args:
|
||||
user_id: User ID for tracking
|
||||
provider: Provider name
|
||||
model: Model name used
|
||||
operation_type: Type of operation (for logging)
|
||||
result_bytes: Generated video bytes
|
||||
cost: Cost of the operation
|
||||
prompt: Optional prompt text
|
||||
endpoint: API endpoint path
|
||||
metadata: Optional additional metadata
|
||||
log_prefix: Logging prefix
|
||||
response_time: API response time
|
||||
|
||||
Returns:
|
||||
Dictionary with tracking information
|
||||
"""
|
||||
try:
|
||||
from services.database import get_session_for_user
|
||||
db_track = get_session_for_user(user_id)
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get or create usage summary
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush()
|
||||
|
||||
# Get current values before update
|
||||
current_calls_before = getattr(summary, "video_calls", 0) or 0
|
||||
current_cost_before = getattr(summary, "video_cost", 0.0) or 0.0
|
||||
|
||||
# Update video calls and cost
|
||||
new_calls = current_calls_before + 1
|
||||
new_cost = current_cost_before + cost
|
||||
|
||||
# Use direct SQL UPDATE for dynamic attributes
|
||||
from sqlalchemy import text as sql_text
|
||||
update_query = sql_text("""
|
||||
UPDATE usage_summaries
|
||||
SET video_calls = :new_calls,
|
||||
video_cost = :new_cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_query, {
|
||||
'new_calls': new_calls,
|
||||
'new_cost': new_cost,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
|
||||
# Update total cost
|
||||
summary.total_cost = (summary.total_cost or 0.0) + cost
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
# Create usage log
|
||||
request_size = len(prompt.encode("utf-8")) if prompt else 0
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.WAVESPEED, # Default for video
|
||||
endpoint=endpoint,
|
||||
method="POST",
|
||||
model_used=model or "unknown",
|
||||
actual_provider_name=provider,
|
||||
tokens_input=0,
|
||||
tokens_output=0,
|
||||
tokens_total=0,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=cost,
|
||||
response_time=response_time,
|
||||
status_code=200,
|
||||
request_size=request_size,
|
||||
response_size=len(result_bytes) if result_bytes else 0,
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
|
||||
# Get limits for display
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
video_limit_display = video_limit if (video_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
# Get related stats for unified log
|
||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
image_edit_limit_display = image_edit_limit if (image_edit_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"{log_prefix} ✅ Successfully tracked usage: user {user_id} -> {operation_type} -> {new_calls} calls, ${cost:.4f}")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG
|
||||
operation_name = operation_type.replace("-", " ").title()
|
||||
print(f"""
|
||||
[SUBSCRIPTION] {operation_name}
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: {provider}
|
||||
├─ Actual Provider: {provider}
|
||||
├─ Model: {model or 'unknown'}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {video_limit_display}
|
||||
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit_display}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit_display}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
return {
|
||||
"current_calls": new_calls,
|
||||
"cost": cost,
|
||||
"total_cost": new_cost,
|
||||
}
|
||||
|
||||
except Exception as track_error:
|
||||
logger.error(f"{log_prefix} ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
|
||||
db_track.rollback()
|
||||
return {}
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"{log_prefix} ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
|
||||
return {}
|
||||
|
||||
|
||||
def _get_api_key(provider: str) -> Optional[str]:
|
||||
try:
|
||||
manager = APIKeyManager()
|
||||
@@ -500,156 +666,74 @@ async def ai_video_generate(
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
logger.info(f"[Video Generation] ✅ Pre-flight validation passed - proceeding with {operation_type}")
|
||||
|
||||
# Progress callback: Initial submission
|
||||
if progress_callback:
|
||||
progress_callback(10.0, f"Submitting {operation_type} request to {provider}...")
|
||||
|
||||
# Generate video based on operation type
|
||||
model_name = kwargs.get("model", _get_default_model(operation_type, provider))
|
||||
|
||||
# Track response time for video generation
|
||||
|
||||
# Track response time
|
||||
import time
|
||||
from datetime import datetime
|
||||
start_time = time.time()
|
||||
|
||||
# Execute operation based on type
|
||||
result = {}
|
||||
try:
|
||||
if operation_type == "text-to-video":
|
||||
if provider == "huggingface":
|
||||
video_bytes = _generate_with_huggingface(
|
||||
prompt=prompt,
|
||||
**kwargs,
|
||||
)
|
||||
# For text-to-video, create metadata dict (HuggingFace doesn't return metadata)
|
||||
result_dict = {
|
||||
video_bytes = _generate_with_huggingface(prompt=prompt, **kwargs)
|
||||
result = {
|
||||
"video_bytes": video_bytes,
|
||||
"prompt": prompt,
|
||||
"duration": kwargs.get("duration", 5.0),
|
||||
"model_name": model_name,
|
||||
"cost": 0.10, # Default cost, will be calculated in track_video_usage
|
||||
"provider": provider,
|
||||
"resolution": kwargs.get("resolution", "720p"),
|
||||
"width": 1280, # Default, actual may vary
|
||||
"height": 720, # Default, actual may vary
|
||||
"metadata": {},
|
||||
"model_name": kwargs.get("model", "tencent/HunyuanVideo"),
|
||||
"provider": "huggingface",
|
||||
"cost": 0.0, # HuggingFace inference is free/low cost
|
||||
}
|
||||
elif provider == "wavespeed":
|
||||
# WaveSpeed text-to-video - use unified service
|
||||
result_dict = await _generate_text_to_video_wavespeed(
|
||||
result = await _generate_text_to_video_wavespeed(
|
||||
prompt=prompt,
|
||||
progress_callback=progress_callback,
|
||||
**kwargs,
|
||||
**kwargs
|
||||
)
|
||||
elif provider == "gemini":
|
||||
video_bytes = _generate_with_gemini(prompt=prompt, **kwargs)
|
||||
result_dict = {
|
||||
"video_bytes": video_bytes,
|
||||
"prompt": prompt,
|
||||
"duration": kwargs.get("duration", 5.0),
|
||||
"model_name": model_name,
|
||||
"cost": 0.10,
|
||||
"provider": provider,
|
||||
"resolution": kwargs.get("resolution", "720p"),
|
||||
"width": 1280,
|
||||
"height": 720,
|
||||
"metadata": {},
|
||||
}
|
||||
result = {"video_bytes": _generate_with_gemini(prompt, **kwargs)}
|
||||
elif provider == "openai":
|
||||
video_bytes = _generate_with_openai(prompt=prompt, **kwargs)
|
||||
result_dict = {
|
||||
"video_bytes": video_bytes,
|
||||
"prompt": prompt,
|
||||
"duration": kwargs.get("duration", 5.0),
|
||||
"model_name": model_name,
|
||||
"cost": 0.10,
|
||||
"provider": provider,
|
||||
"resolution": kwargs.get("resolution", "720p"),
|
||||
"width": 1280,
|
||||
"height": 720,
|
||||
"metadata": {},
|
||||
}
|
||||
result = {"video_bytes": _generate_with_openai(prompt, **kwargs)}
|
||||
else:
|
||||
raise RuntimeError(f"Unknown provider for text-to-video: {provider}")
|
||||
|
||||
raise ValueError(f"Unknown provider for text-to-video: {provider}")
|
||||
|
||||
elif operation_type == "image-to-video":
|
||||
if provider == "wavespeed":
|
||||
# Progress callback: Starting generation
|
||||
if progress_callback:
|
||||
progress_callback(20.0, "Video generation in progress...")
|
||||
|
||||
# Handle async call from sync context
|
||||
# Since ai_video_generate is sync, we need to run async function
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# We're in an async context - use ThreadPoolExecutor to run in new event loop
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
asyncio.run,
|
||||
_generate_image_to_video_wavespeed(
|
||||
image_data=image_data,
|
||||
image_base64=image_base64,
|
||||
prompt=prompt or kwargs.get("prompt", ""),
|
||||
progress_callback=progress_callback,
|
||||
**kwargs
|
||||
)
|
||||
)
|
||||
result_dict = future.result()
|
||||
else:
|
||||
# Event loop exists but not running - use it
|
||||
result_dict = loop.run_until_complete(_generate_image_to_video_wavespeed(
|
||||
image_data=image_data,
|
||||
image_base64=image_base64,
|
||||
prompt=prompt or kwargs.get("prompt", ""),
|
||||
progress_callback=progress_callback,
|
||||
**kwargs
|
||||
))
|
||||
except RuntimeError:
|
||||
# No event loop exists, create a new one
|
||||
result_dict = asyncio.run(_generate_image_to_video_wavespeed(
|
||||
image_data=image_data,
|
||||
image_base64=image_base64,
|
||||
prompt=prompt or kwargs.get("prompt", ""),
|
||||
progress_callback=progress_callback,
|
||||
**kwargs
|
||||
))
|
||||
video_bytes = result_dict["video_bytes"]
|
||||
model_name = result_dict.get("model_name", model_name)
|
||||
|
||||
# Progress callback: Processing result
|
||||
if progress_callback:
|
||||
progress_callback(90.0, "Processing video result...")
|
||||
result = await _generate_image_to_video_wavespeed(
|
||||
image_data=image_data,
|
||||
image_base64=image_base64,
|
||||
prompt=prompt or "",
|
||||
progress_callback=progress_callback,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Unknown provider for image-to-video: {provider}. Only 'wavespeed' is supported.")
|
||||
raise ValueError(f"Unknown provider for image-to-video: {provider}")
|
||||
|
||||
# Track usage (same pattern as text generation)
|
||||
# Use cost from result_dict if available, otherwise calculate
|
||||
response_time = time.time() - start_time
|
||||
cost_override = result_dict.get("cost") if operation_type == "image-to-video" else kwargs.get("cost_override")
|
||||
track_video_usage(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
prompt=result_dict.get("prompt", prompt or ""),
|
||||
video_bytes=video_bytes,
|
||||
cost_override=cost_override,
|
||||
response_time=response_time,
|
||||
)
|
||||
|
||||
# Progress callback: Complete
|
||||
if progress_callback:
|
||||
progress_callback(100.0, "Video generation complete!")
|
||||
|
||||
return result_dict
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions (e.g., from validation or API errors)
|
||||
raise
|
||||
# TRACK USAGE after successful API call
|
||||
video_bytes = result.get("video_bytes")
|
||||
if user_id and video_bytes:
|
||||
_track_video_operation_usage(
|
||||
user_id=user_id,
|
||||
provider=result.get("provider", provider),
|
||||
model=result.get("model_name", kwargs.get("model", "unknown")),
|
||||
operation_type=operation_type,
|
||||
result_bytes=video_bytes,
|
||||
cost=result.get("cost", 0.0),
|
||||
prompt=prompt,
|
||||
endpoint="/video-generation",
|
||||
metadata=result.get("metadata"),
|
||||
log_prefix=f"[{operation_type.replace('-', ' ').title()}]",
|
||||
response_time=response_time
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[video_gen] Error during video generation: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail={"error": str(e)})
|
||||
# Log failure but don't track usage (no cost incurred)
|
||||
logger.error(f"[video_gen] Generation failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def _get_default_model(operation_type: str, provider: str) -> str:
|
||||
|
||||
Reference in New Issue
Block a user