Recovered state: integrated TrendSurferAgent, restored frontend/backend files, and cleaned up recovery scripts
This commit is contained in:
@@ -39,6 +39,22 @@ class AudioGenerationResult:
|
||||
self.file_size = file_size
|
||||
|
||||
|
||||
class VoiceCloneResult:
|
||||
def __init__(
|
||||
self,
|
||||
preview_audio_bytes: bytes,
|
||||
provider: str,
|
||||
model: str,
|
||||
custom_voice_id: str,
|
||||
file_size: int,
|
||||
):
|
||||
self.preview_audio_bytes = preview_audio_bytes
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.custom_voice_id = custom_voice_id
|
||||
self.file_size = file_size
|
||||
|
||||
|
||||
def generate_audio(
|
||||
text: str,
|
||||
voice_id: str = "Wise_Woman",
|
||||
@@ -331,3 +347,380 @@ def generate_audio(
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def clone_voice(
|
||||
audio_bytes: bytes,
|
||||
custom_voice_id: str,
|
||||
model: str = "speech-02-hd",
|
||||
*,
|
||||
audio_mime_type: Optional[str] = None,
|
||||
text: Optional[str] = None,
|
||||
need_noise_reduction: bool = False,
|
||||
need_volume_normalization: bool = False,
|
||||
accuracy: float = 0.7,
|
||||
language_boost: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> VoiceCloneResult:
|
||||
try:
|
||||
if not user_id:
|
||||
raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
|
||||
if not audio_bytes or not isinstance(audio_bytes, (bytes, bytearray)) or len(audio_bytes) == 0:
|
||||
raise ValueError("Audio is required and cannot be empty")
|
||||
|
||||
if len(audio_bytes) > 15 * 1024 * 1024:
|
||||
raise ValueError("Audio file too large. Maximum is 15MB.")
|
||||
|
||||
if not custom_voice_id or not isinstance(custom_voice_id, str):
|
||||
raise ValueError("custom_voice_id is required")
|
||||
custom_voice_id = custom_voice_id.strip()
|
||||
if len(custom_voice_id) < 8:
|
||||
raise ValueError("custom_voice_id must be at least 8 characters long")
|
||||
if not custom_voice_id[0].isalpha():
|
||||
raise ValueError("custom_voice_id must start with a letter")
|
||||
if not any(c.isalpha() for c in custom_voice_id) or not any(c.isdigit() for c in custom_voice_id):
|
||||
raise ValueError("custom_voice_id must include both letters and numbers")
|
||||
|
||||
voice_clone_cost = 0.5
|
||||
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
|
||||
try:
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.AUDIO,
|
||||
tokens_requested=1,
|
||||
actual_provider_name="wavespeed",
|
||||
)
|
||||
if not can_proceed:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
"error": message,
|
||||
"message": message,
|
||||
"provider": "wavespeed",
|
||||
"usage_info": usage_info if usage_info else {},
|
||||
},
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as sub_error:
|
||||
raise RuntimeError(f"Subscription check failed: {str(sub_error)}")
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
client = WaveSpeedClient()
|
||||
preview_audio_bytes = client.voice_clone(
|
||||
audio_bytes=bytes(audio_bytes),
|
||||
custom_voice_id=custom_voice_id,
|
||||
model=model,
|
||||
audio_mime_type=audio_mime_type or "audio/wav",
|
||||
text=text,
|
||||
need_noise_reduction=need_noise_reduction,
|
||||
need_volume_normalization=need_volume_normalization,
|
||||
accuracy=accuracy,
|
||||
language_boost=language_boost,
|
||||
)
|
||||
response_time = time.time() - start_time
|
||||
|
||||
if preview_audio_bytes:
|
||||
try:
|
||||
db_track = next(get_db())
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription import PricingService
|
||||
from sqlalchemy import text as sql_text
|
||||
from services.subscription.provider_detection import detect_actual_provider
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(user_id=user_id, billing_period=current_period)
|
||||
db_track.add(summary)
|
||||
db_track.flush()
|
||||
|
||||
current_calls_before = getattr(summary, "audio_calls", 0) or 0
|
||||
current_cost_before = getattr(summary, "audio_cost", 0.0) or 0.0
|
||||
new_calls = current_calls_before + 1
|
||||
new_cost = current_cost_before + voice_clone_cost
|
||||
|
||||
update_query = sql_text("""
|
||||
UPDATE usage_summaries
|
||||
SET audio_calls = :new_calls,
|
||||
audio_cost = :new_cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_query, {
|
||||
"new_calls": new_calls,
|
||||
"new_cost": new_cost,
|
||||
"user_id": user_id,
|
||||
"period": current_period
|
||||
})
|
||||
|
||||
summary.total_cost = (summary.total_cost or 0.0) + voice_clone_cost
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
actual_provider = detect_actual_provider(
|
||||
provider_enum=APIProvider.AUDIO,
|
||||
model_name="minimax/voice-clone",
|
||||
endpoint="/audio-generation/wavespeed/voice-clone",
|
||||
)
|
||||
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.AUDIO,
|
||||
endpoint="/audio-generation/wavespeed/voice-clone",
|
||||
method="POST",
|
||||
model_used="minimax/voice-clone",
|
||||
actual_provider_name=actual_provider,
|
||||
tokens_input=0,
|
||||
tokens_output=0,
|
||||
tokens_total=0,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=voice_clone_cost,
|
||||
response_time=response_time,
|
||||
status_code=200,
|
||||
request_size=len(audio_bytes),
|
||||
response_size=len(preview_audio_bytes),
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
db_track.commit()
|
||||
|
||||
print(f"""
|
||||
[SUBSCRIPTION] Voice Clone
|
||||
├─ User: {user_id}
|
||||
├─ Provider: wavespeed
|
||||
├─ Model: minimax/voice-clone
|
||||
├─ Voice ID: {custom_voice_id}
|
||||
├─ Calls: {current_calls_before} → {new_calls}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""", flush=True)
|
||||
sys.stdout.flush()
|
||||
except Exception as track_error:
|
||||
logger.error(f"[voice_clone] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"[voice_clone] ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
|
||||
return VoiceCloneResult(
|
||||
preview_audio_bytes=preview_audio_bytes,
|
||||
provider="wavespeed",
|
||||
model=f"minimax/voice-clone:{model}",
|
||||
custom_voice_id=custom_voice_id,
|
||||
file_size=len(preview_audio_bytes),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[voice_clone] Error cloning voice: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Voice cloning failed",
|
||||
"message": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def qwen3_voice_clone(
|
||||
audio_bytes: bytes,
|
||||
text: str,
|
||||
*,
|
||||
reference_text: Optional[str] = None,
|
||||
language: str = "auto",
|
||||
audio_mime_type: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> VoiceCloneResult:
|
||||
try:
|
||||
if not user_id:
|
||||
raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
|
||||
if not audio_bytes or not isinstance(audio_bytes, (bytes, bytearray)) or len(audio_bytes) == 0:
|
||||
raise ValueError("Audio is required and cannot be empty")
|
||||
|
||||
if len(audio_bytes) > 15 * 1024 * 1024:
|
||||
raise ValueError("Audio file too large. Maximum is 15MB.")
|
||||
|
||||
if not text or not isinstance(text, str) or len(text.strip()) == 0:
|
||||
raise ValueError("Text is required and cannot be empty")
|
||||
text = text.strip()
|
||||
if len(text) > 4000:
|
||||
raise ValueError("Text too long. Please keep it under 4000 characters.")
|
||||
|
||||
char_count = len(text)
|
||||
estimated_cost = max(0.005, 0.005 * (char_count / 100.0))
|
||||
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
|
||||
try:
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.AUDIO,
|
||||
tokens_requested=char_count,
|
||||
actual_provider_name="wavespeed",
|
||||
)
|
||||
if not can_proceed:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
"error": message,
|
||||
"message": message,
|
||||
"provider": "wavespeed",
|
||||
"usage_info": usage_info if usage_info else {},
|
||||
},
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as sub_error:
|
||||
raise RuntimeError(f"Subscription check failed: {str(sub_error)}")
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
client = WaveSpeedClient()
|
||||
preview_audio_bytes = client.qwen3_voice_clone(
|
||||
audio_bytes=bytes(audio_bytes),
|
||||
text=text,
|
||||
audio_mime_type=audio_mime_type or "audio/wav",
|
||||
language=language or "auto",
|
||||
reference_text=reference_text,
|
||||
)
|
||||
response_time = time.time() - start_time
|
||||
|
||||
if preview_audio_bytes:
|
||||
try:
|
||||
db_track = next(get_db())
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription import PricingService
|
||||
from sqlalchemy import text as sql_text
|
||||
from services.subscription.provider_detection import detect_actual_provider
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(user_id=user_id, billing_period=current_period)
|
||||
db_track.add(summary)
|
||||
db_track.flush()
|
||||
|
||||
current_calls_before = getattr(summary, "audio_calls", 0) or 0
|
||||
current_cost_before = getattr(summary, "audio_cost", 0.0) or 0.0
|
||||
new_calls = current_calls_before + 1
|
||||
new_cost = current_cost_before + float(estimated_cost)
|
||||
|
||||
update_query = sql_text("""
|
||||
UPDATE usage_summaries
|
||||
SET audio_calls = :new_calls,
|
||||
audio_cost = :new_cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_query, {
|
||||
"new_calls": new_calls,
|
||||
"new_cost": new_cost,
|
||||
"user_id": user_id,
|
||||
"period": current_period
|
||||
})
|
||||
|
||||
summary.total_cost = (summary.total_cost or 0.0) + float(estimated_cost)
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
actual_provider = detect_actual_provider(
|
||||
provider_enum=APIProvider.AUDIO,
|
||||
model_name="wavespeed-ai/qwen3-tts/voice-clone",
|
||||
endpoint="/audio-generation/wavespeed/qwen3-tts/voice-clone",
|
||||
)
|
||||
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.AUDIO,
|
||||
endpoint="/audio-generation/wavespeed/qwen3-tts/voice-clone",
|
||||
method="POST",
|
||||
model_used="wavespeed-ai/qwen3-tts/voice-clone",
|
||||
actual_provider_name=actual_provider,
|
||||
tokens_input=char_count,
|
||||
tokens_output=0,
|
||||
tokens_total=char_count,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=float(estimated_cost),
|
||||
response_time=response_time,
|
||||
status_code=200,
|
||||
request_size=len(audio_bytes) + len(text.encode("utf-8")),
|
||||
response_size=len(preview_audio_bytes),
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
db_track.commit()
|
||||
|
||||
print(f"""
|
||||
[SUBSCRIPTION] Qwen3 Voice Clone
|
||||
├─ User: {user_id}
|
||||
├─ Provider: wavespeed
|
||||
├─ Model: wavespeed-ai/qwen3-tts/voice-clone
|
||||
├─ Calls: {current_calls_before} → {new_calls}
|
||||
├─ Text chars: {char_count}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""", flush=True)
|
||||
sys.stdout.flush()
|
||||
except Exception as track_error:
|
||||
logger.error(f"[qwen3_voice_clone] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"[qwen3_voice_clone] ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
|
||||
return VoiceCloneResult(
|
||||
preview_audio_bytes=preview_audio_bytes,
|
||||
provider="wavespeed",
|
||||
model="wavespeed-ai/qwen3-tts/voice-clone",
|
||||
custom_voice_id="",
|
||||
file_size=len(preview_audio_bytes),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[qwen3_voice_clone] Error cloning voice: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Qwen3 voice cloning failed",
|
||||
"message": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -2,8 +2,10 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import base64
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
|
||||
from .image_generation import (
|
||||
ImageGenerationOptions,
|
||||
@@ -104,13 +106,13 @@ def _validate_image_operation(
|
||||
logger.warning(f"{log_prefix} ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
|
||||
return
|
||||
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_generation_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
logger.info(f"{log_prefix} 🔍 Starting pre-flight validation for user_id={user_id}")
|
||||
db = next(get_db())
|
||||
db = get_session_for_user(user_id)
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
# Raises HTTPException immediately if validation fails - frontend gets immediate response
|
||||
@@ -162,8 +164,8 @@ def _track_image_operation_usage(
|
||||
Dictionary with tracking information (current_calls, cost, etc.)
|
||||
"""
|
||||
try:
|
||||
from services.database import get_db as get_db_track
|
||||
db_track = next(get_db_track())
|
||||
from services.database import get_session_for_user
|
||||
db_track = get_session_for_user(user_id)
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription.provider_detection import detect_actual_provider
|
||||
@@ -706,3 +708,65 @@ def generate_face_swap(
|
||||
return result
|
||||
|
||||
|
||||
async def generate_image_with_provider(
|
||||
prompt: str,
|
||||
user_id: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Async wrapper for generate_image to support step4_asset_routes.
|
||||
"""
|
||||
# Construct options from kwargs
|
||||
options = kwargs.copy()
|
||||
|
||||
try:
|
||||
# Run in threadpool since generate_image is blocking
|
||||
result = await run_in_threadpool(
|
||||
generate_image,
|
||||
prompt=prompt,
|
||||
options=options,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
image_base64 = base64.b64encode(result.image_bytes).decode('utf-8')
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"image_base64": image_base64,
|
||||
"image_url": None,
|
||||
"error": None,
|
||||
"metadata": result.metadata
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in generate_image_with_provider: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
async def enhance_image_prompt(prompt: str, user_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Enhance image prompt using LLM.
|
||||
Placeholder implementation.
|
||||
"""
|
||||
return prompt
|
||||
|
||||
|
||||
async def generate_image_variation(
|
||||
image: Any,
|
||||
prompt: str,
|
||||
user_id: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate variation of an existing image.
|
||||
Placeholder implementation.
|
||||
"""
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Not implemented yet"
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -119,11 +119,14 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
|
||||
try:
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import UsageTrackingService, PricingService
|
||||
from models.subscription_models import UsageSummary
|
||||
|
||||
db = next(get_db())
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
logger.error(f"[llm_text_gen] Could not get database session for user {user_id}")
|
||||
raise RuntimeError("Database connection failed")
|
||||
try:
|
||||
|
||||
usage_service = UsageTrackingService(db)
|
||||
@@ -257,7 +260,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
if response_text:
|
||||
logger.info(f"[llm_text_gen] ✅ API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
|
||||
try:
|
||||
db_track = next(get_db())
|
||||
db_track = get_session_for_user(user_id)
|
||||
try:
|
||||
# Estimate tokens from prompt and response
|
||||
# Recalculate input tokens from prompt (consistent with pre-flight estimation)
|
||||
@@ -658,7 +661,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
if response_text:
|
||||
logger.info(f"[llm_text_gen] ✅ Fallback API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
|
||||
try:
|
||||
db_track = next(get_db())
|
||||
db_track = get_session_for_user(user_id)
|
||||
try:
|
||||
# Estimate tokens from prompt and response
|
||||
# Recalculate input tokens from prompt (consistent with pre-flight estimation)
|
||||
|
||||
Reference in New Issue
Block a user