AI Researcher and Video Studio implementation complete

This commit is contained in:
ajaysi
2026-01-05 15:49:51 +05:30
parent b134e9dc7e
commit 0b63ae7fc1
200 changed files with 39535 additions and 1375 deletions

View File

@@ -12,7 +12,7 @@ from datetime import datetime
from services.database import get_db
from middleware.auth_middleware import get_current_user
from services.content_asset_service import ContentAssetService
from models.content_asset_models import AssetType, AssetSource
from models.content_asset_models import AssetType, AssetSource, AssetCollection
router = APIRouter(prefix="/api/content-assets", tags=["Content Assets"])
@@ -62,6 +62,11 @@ async def get_assets(
search: Optional[str] = Query(None, description="Search query"),
tags: Optional[str] = Query(None, description="Comma-separated tags"),
favorites_only: bool = Query(False, description="Only favorites"),
collection_id: Optional[int] = Query(None, description="Filter by collection ID"),
date_from: Optional[str] = Query(None, description="Filter from date (ISO format)"),
date_to: Optional[str] = Query(None, description="Filter to date (ISO format)"),
sort_by: str = Query("created_at", description="Sort by: created_at, updated_at, cost, file_size, title"),
sort_order: str = Query("desc", description="Sort order: asc or desc"),
limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0),
db: Session = Depends(get_db),
@@ -95,6 +100,29 @@ async def get_assets(
if tags:
tags_list = [tag.strip() for tag in tags.split(",")]
# Parse date filters
date_from_obj = None
if date_from:
try:
date_from_obj = datetime.fromisoformat(date_from.replace('Z', '+00:00'))
except ValueError:
raise HTTPException(status_code=400, detail="Invalid date_from format. Use ISO format.")
date_to_obj = None
if date_to:
try:
date_to_obj = datetime.fromisoformat(date_to.replace('Z', '+00:00'))
except ValueError:
raise HTTPException(status_code=400, detail="Invalid date_to format. Use ISO format.")
# Validate sort parameters
valid_sort_by = ["created_at", "updated_at", "cost", "file_size", "title"]
if sort_by not in valid_sort_by:
raise HTTPException(status_code=400, detail=f"Invalid sort_by. Must be one of: {', '.join(valid_sort_by)}")
if sort_order not in ["asc", "desc"]:
raise HTTPException(status_code=400, detail="Invalid sort_order. Must be 'asc' or 'desc'")
assets, total = service.get_user_assets(
user_id=user_id,
asset_type=asset_type_enum,
@@ -102,6 +130,11 @@ async def get_assets(
search_query=search,
tags=tags_list,
favorites_only=favorites_only,
collection_id=collection_id,
date_from=date_from_obj,
date_to=date_to_obj,
sort_by=sort_by,
sort_order=sort_order,
limit=limit,
offset=offset,
)
@@ -330,3 +363,303 @@ async def get_statistics(
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error fetching statistics: {str(e)}")
# ==================== Collection Endpoints ====================
class CollectionResponse(BaseModel):
"""Response model for collection data."""
id: int
user_id: str
name: str
description: Optional[str] = None
is_public: bool = False
cover_asset_id: Optional[int] = None
asset_count: int = 0
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class CollectionListResponse(BaseModel):
"""Response model for collection list."""
collections: List[CollectionResponse]
total: int
limit: int
offset: int
class CollectionCreateRequest(BaseModel):
"""Request model for creating a collection."""
name: str = Field(..., description="Collection name")
description: Optional[str] = Field(None, description="Collection description")
is_public: bool = Field(False, description="Whether collection is public")
class CollectionUpdateRequest(BaseModel):
"""Request model for updating a collection."""
name: Optional[str] = None
description: Optional[str] = None
is_public: Optional[bool] = None
cover_asset_id: Optional[int] = None
@router.post("/collections", response_model=CollectionResponse)
async def create_collection(
collection_data: CollectionCreateRequest,
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Create a new asset collection."""
try:
user_id = current_user.get("user_id") or current_user.get("id")
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found")
service = ContentAssetService(db)
collection = service.create_collection(
user_id=user_id,
name=collection_data.name,
description=collection_data.description,
is_public=collection_data.is_public,
)
# Get asset count
assets, _ = service.get_collection_assets(collection.id, user_id, limit=1, offset=0)
asset_count = len(assets)
response = CollectionResponse.model_validate(collection)
response.asset_count = asset_count
return response
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error creating collection: {str(e)}")
@router.get("/collections", response_model=CollectionListResponse)
async def get_collections(
limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0),
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Get user's collections."""
try:
user_id = current_user.get("user_id") or current_user.get("id")
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found")
service = ContentAssetService(db)
collections, total = service.get_user_collections(user_id, limit=limit, offset=offset)
# Get asset counts for each collection
collection_responses = []
for collection in collections:
assets, _ = service.get_collection_assets(collection.id, user_id, limit=1, offset=0)
response = CollectionResponse.model_validate(collection)
response.asset_count = len(assets)
collection_responses.append(response)
return CollectionListResponse(
collections=collection_responses,
total=total,
limit=limit,
offset=offset,
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error fetching collections: {str(e)}")
@router.get("/collections/{collection_id}", response_model=CollectionResponse)
async def get_collection(
collection_id: int,
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Get a specific collection."""
try:
user_id = current_user.get("user_id") or current_user.get("id")
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found")
service = ContentAssetService(db)
collection = service.get_collection_by_id(collection_id, user_id)
if not collection:
raise HTTPException(status_code=404, detail="Collection not found")
assets, _ = service.get_collection_assets(collection.id, user_id, limit=1, offset=0)
response = CollectionResponse.model_validate(collection)
response.asset_count = len(assets)
return response
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error fetching collection: {str(e)}")
@router.put("/collections/{collection_id}", response_model=CollectionResponse)
async def update_collection(
collection_id: int,
update_data: CollectionUpdateRequest,
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Update collection metadata."""
try:
user_id = current_user.get("user_id") or current_user.get("id")
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found")
service = ContentAssetService(db)
collection = service.update_collection(
collection_id=collection_id,
user_id=user_id,
name=update_data.name,
description=update_data.description,
is_public=update_data.is_public,
cover_asset_id=update_data.cover_asset_id,
)
if not collection:
raise HTTPException(status_code=404, detail="Collection not found")
assets, _ = service.get_collection_assets(collection.id, user_id, limit=1, offset=0)
response = CollectionResponse.model_validate(collection)
response.asset_count = len(assets)
return response
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error updating collection: {str(e)}")
@router.delete("/collections/{collection_id}", response_model=Dict[str, Any])
async def delete_collection(
collection_id: int,
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Delete a collection."""
try:
user_id = current_user.get("user_id") or current_user.get("id")
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found")
service = ContentAssetService(db)
success = service.delete_collection(collection_id, user_id)
if not success:
raise HTTPException(status_code=404, detail="Collection not found")
return {"collection_id": collection_id, "deleted": True}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error deleting collection: {str(e)}")
@router.get("/collections/{collection_id}/assets", response_model=AssetListResponse)
async def get_collection_assets(
collection_id: int,
limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0),
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Get all assets in a collection."""
try:
user_id = current_user.get("user_id") or current_user.get("id")
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found")
service = ContentAssetService(db)
collection = service.get_collection_by_id(collection_id, user_id)
if not collection:
raise HTTPException(status_code=404, detail="Collection not found")
assets, total = service.get_collection_assets(collection_id, user_id, limit=limit, offset=offset)
return AssetListResponse(
assets=[AssetResponse.model_validate(asset) for asset in assets],
total=total,
limit=limit,
offset=offset,
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error fetching collection assets: {str(e)}")
class CollectionAssetsRequest(BaseModel):
"""Request model for adding/removing assets from collection."""
asset_ids: List[int] = Field(..., description="List of asset IDs")
@router.post("/collections/{collection_id}/assets", response_model=Dict[str, Any])
async def add_assets_to_collection(
collection_id: int,
request: CollectionAssetsRequest,
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Add assets to a collection."""
try:
user_id = current_user.get("user_id") or current_user.get("id")
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found")
service = ContentAssetService(db)
count = service.add_assets_to_collection(collection_id, user_id, request.asset_ids)
return {
"collection_id": collection_id,
"assets_added": count,
"asset_ids": request.asset_ids,
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error adding assets to collection: {str(e)}")
@router.delete("/collections/{collection_id}/assets", response_model=Dict[str, Any])
async def remove_assets_from_collection(
collection_id: int,
request: CollectionAssetsRequest,
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Remove assets from a collection."""
try:
user_id = current_user.get("user_id") or current_user.get("id")
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found")
service = ContentAssetService(db)
count = service.remove_assets_from_collection(collection_id, user_id, request.asset_ids)
return {
"collection_id": collection_id,
"assets_removed": count,
"asset_ids": request.asset_ids,
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error removing assets from collection: {str(e)}")

View File

@@ -19,6 +19,7 @@ from typing import Optional, List, Dict, Any
from loguru import logger
import uuid
import asyncio
from models.research_intent_models import TrendAnalysis
from services.database import get_db
from services.research.core import (
@@ -379,7 +380,7 @@ class AnalyzeIntentRequest(BaseModel):
class AnalyzeIntentResponse(BaseModel):
"""Response from intent analysis."""
"""Response from intent analysis with optimized provider parameters."""
success: bool
intent: Dict[str, Any]
analysis_summary: str
@@ -387,7 +388,16 @@ class AnalyzeIntentResponse(BaseModel):
suggested_keywords: List[str]
suggested_angles: List[str]
quick_options: List[Dict[str, Any]]
confidence_reason: Optional[str] = None
great_example: Optional[str] = None
error_message: Optional[str] = None
# Unified: Optimized provider parameters based on intent
optimized_config: Optional[Dict[str, Any]] = None # Provider settings auto-configured from intent
recommended_provider: Optional[str] = None # Best provider for this intent (exa, tavily, google)
# Google Trends configuration (if trends in deliverables)
trends_config: Optional[Dict[str, Any]] = None # Trends keywords and settings with justifications
class IntentDrivenResearchRequest(BaseModel):
@@ -406,6 +416,9 @@ class IntentDrivenResearchRequest(BaseModel):
include_domains: List[str] = Field(default_factory=list)
exclude_domains: List[str] = Field(default_factory=list)
# Google Trends configuration (from intent analysis)
trends_config: Optional[Dict[str, Any]] = None # Trends keywords and settings
# Skip intent inference (for re-runs with same intent)
skip_inference: bool = False
@@ -445,6 +458,9 @@ class IntentDrivenResearchResponse(BaseModel):
# The inferred/confirmed intent
intent: Optional[Dict[str, Any]] = None
# Google Trends data (if trends were analyzed)
google_trends_data: Optional[Dict[str, Any]] = None
# Error handling
error_message: Optional[str] = None
@@ -480,14 +496,14 @@ async def analyze_research_intent(
if request.use_persona or request.use_competitor_data:
from services.research.research_persona_service import ResearchPersonaService
from services.onboarding_service import OnboardingService
from services.onboarding.database_service import OnboardingDatabaseService
from sqlalchemy.orm import Session
# Get database session
db = next(get_db())
try:
persona_service = ResearchPersonaService(db)
onboarding_service = OnboardingService()
onboarding_service = OnboardingDatabaseService(db=db)
if request.use_persona:
research_persona = persona_service.get_or_generate(user_id)
@@ -497,37 +513,91 @@ async def analyze_research_intent(
finally:
db.close()
# Infer intent
intent_service = ResearchIntentInference()
response = await intent_service.infer_intent(
# Use Unified Research Analyzer (single AI call for intent + queries + params)
from services.research.intent.unified_research_analyzer import UnifiedResearchAnalyzer
analyzer = UnifiedResearchAnalyzer()
unified_result = await analyzer.analyze(
user_input=request.user_input,
keywords=request.keywords,
research_persona=research_persona,
competitor_data=competitor_data,
industry=research_persona.default_industry if research_persona else None,
target_audience=research_persona.default_target_audience if research_persona else None,
user_id=user_id,
)
# Generate targeted queries
query_generator = IntentQueryGenerator()
query_result = await query_generator.generate_queries(
intent=response.intent,
research_persona=research_persona,
)
if not unified_result.get("success", False):
logger.warning("Unified analysis failed, using fallback")
# Update response with queries
response.suggested_queries = [q.dict() for q in query_result.get("queries", [])]
response.suggested_keywords = query_result.get("enhanced_keywords", [])
response.suggested_angles = query_result.get("research_angles", [])
# Extract results
intent = unified_result.get("intent")
queries = unified_result.get("queries", [])
exa_config = unified_result.get("exa_config", {})
tavily_config = unified_result.get("tavily_config", {})
trends_config = unified_result.get("trends_config", {}) # NEW: Google Trends config
# Build optimized config with AI-driven justifications
optimized_config = {
"provider": unified_result.get("recommended_provider", "exa"),
"provider_justification": unified_result.get("provider_justification", ""),
# Exa settings with justifications
"exa_type": exa_config.get("type", "auto"),
"exa_type_justification": exa_config.get("type_justification", ""),
"exa_category": exa_config.get("category"),
"exa_category_justification": exa_config.get("category_justification", ""),
"exa_include_domains": exa_config.get("includeDomains", []),
"exa_include_domains_justification": exa_config.get("includeDomains_justification", ""),
"exa_num_results": exa_config.get("numResults", 10),
"exa_num_results_justification": exa_config.get("numResults_justification", ""),
"exa_date_filter": exa_config.get("startPublishedDate"),
"exa_date_justification": exa_config.get("date_justification", ""),
"exa_highlights": exa_config.get("highlights", True),
"exa_highlights_justification": exa_config.get("highlights_justification", ""),
"exa_context": exa_config.get("context", True),
"exa_context_justification": exa_config.get("context_justification", ""),
# Tavily settings with justifications
"tavily_topic": tavily_config.get("topic", "general"),
"tavily_topic_justification": tavily_config.get("topic_justification", ""),
"tavily_search_depth": tavily_config.get("search_depth", "advanced"),
"tavily_search_depth_justification": tavily_config.get("search_depth_justification", ""),
"tavily_include_answer": tavily_config.get("include_answer", True),
"tavily_include_answer_justification": tavily_config.get("include_answer_justification", ""),
"tavily_time_range": tavily_config.get("time_range"),
"tavily_time_range_justification": tavily_config.get("time_range_justification", ""),
"tavily_max_results": tavily_config.get("max_results", 10),
"tavily_max_results_justification": tavily_config.get("max_results_justification", ""),
"tavily_raw_content": tavily_config.get("include_raw_content", "markdown"),
"tavily_raw_content_justification": tavily_config.get("include_raw_content_justification", ""),
}
# Build trends config response (if enabled)
trends_config_response = None
if trends_config.get("enabled", False):
trends_config_response = {
"enabled": True,
"keywords": trends_config.get("keywords", []),
"keywords_justification": trends_config.get("keywords_justification", ""),
"timeframe": trends_config.get("timeframe", "today 12-m"),
"timeframe_justification": trends_config.get("timeframe_justification", ""),
"geo": trends_config.get("geo", "US"),
"geo_justification": trends_config.get("geo_justification", ""),
"expected_insights": trends_config.get("expected_insights", []),
}
return AnalyzeIntentResponse(
success=True,
intent=response.intent.dict(),
analysis_summary=response.analysis_summary,
suggested_queries=response.suggested_queries,
suggested_keywords=response.suggested_keywords,
suggested_angles=response.suggested_angles,
quick_options=response.quick_options,
intent=intent.dict() if hasattr(intent, 'dict') else intent,
analysis_summary=unified_result.get("analysis_summary", ""),
suggested_queries=[q.dict() if hasattr(q, 'dict') else q for q in queries],
suggested_keywords=unified_result.get("enhanced_keywords", []),
suggested_angles=unified_result.get("research_angles", []),
quick_options=[], # Deprecated in unified approach
confidence_reason=intent.confidence_reason if hasattr(intent, 'confidence_reason') else "",
great_example=intent.great_example if hasattr(intent, 'great_example') else "",
optimized_config=optimized_config,
recommended_provider=unified_result.get("recommended_provider", "exa"),
trends_config=trends_config_response, # NEW: Google Trends configuration
)
except Exception as e:
@@ -540,6 +610,8 @@ async def analyze_research_intent(
suggested_keywords=[],
suggested_angles=[],
quick_options=[],
confidence_reason=None,
great_example=None,
error_message=str(e),
)
@@ -591,6 +663,7 @@ async def execute_intent_driven_research(
intent_response = await intent_service.infer_intent(
user_input=request.user_input,
research_persona=research_persona,
user_id=user_id,
)
intent = intent_response.intent
else:
@@ -613,6 +686,7 @@ async def execute_intent_driven_research(
query_result = await query_generator.generate_queries(
intent=intent,
research_persona=research_persona,
user_id=user_id,
)
queries = query_result.get("queries", [])
@@ -648,8 +722,35 @@ async def execute_intent_driven_research(
exclude_domains=request.exclude_domains,
)
# Execute research
raw_result = await engine.research(context)
# Execute research and trends in parallel
research_task = asyncio.create_task(engine.research(context))
# Execute Google Trends analysis in parallel (if enabled)
trends_task = None
trends_data = None
if request.trends_config and request.trends_config.get("enabled"):
from services.research.trends.google_trends_service import GoogleTrendsService
trends_service = GoogleTrendsService()
trends_task = asyncio.create_task(
trends_service.analyze_trends(
keywords=request.trends_config.get("keywords", []),
timeframe=request.trends_config.get("timeframe", "today 12-m"),
geo=request.trends_config.get("geo", "US"),
user_id=user_id
)
)
# Wait for research to complete
raw_result = await research_task
# Wait for trends if it was started
if trends_task:
try:
trends_data = await trends_task
logger.info(f"Google Trends data fetched: {len(trends_data.get('interest_over_time', []))} time points")
except Exception as e:
logger.error(f"Google Trends analysis failed: {e}")
trends_data = None
# Analyze results using intent-aware analyzer
analyzer = IntentAwareAnalyzer()
@@ -661,8 +762,13 @@ async def execute_intent_driven_research(
},
intent=intent,
research_persona=research_persona,
user_id=user_id, # Required for subscription checking
)
# Merge Google Trends data into trends analysis
if trends_data and analyzed_result.trends:
analyzed_result = _merge_trends_data(analyzed_result, trends_data)
# Build response
return IntentDrivenResearchResponse(
success=True,
@@ -687,6 +793,7 @@ async def execute_intent_driven_research(
gaps_identified=analyzed_result.gaps_identified,
follow_up_queries=analyzed_result.follow_up_queries,
intent=intent.dict(),
google_trends_data=trends_data, # Include Google Trends data in response
)
finally:
@@ -737,3 +844,67 @@ def _map_provider_to_preference(provider: str) -> ProviderPreference:
}
return mapping.get(provider, ProviderPreference.AUTO)
def _merge_trends_data(
analyzed_result: Any,
trends_data: Dict[str, Any]
) -> Any:
"""
Merge Google Trends data into analyzed result trends.
Enhances AI-extracted trends with Google Trends data.
"""
from services.research.intent.intent_aware_analyzer import IntentDrivenResearchResult
from models.research_intent_models import TrendAnalysis
if not analyzed_result.trends:
return analyzed_result
# Enhance each trend with Google Trends data
enhanced_trends = []
for trend in analyzed_result.trends:
# Create enhanced trend with Google Trends data
trend_dict = trend.dict() if hasattr(trend, 'dict') else trend
trend_dict["google_trends_data"] = trends_data
# Add interest score if available
if trends_data.get("interest_over_time"):
# Calculate average interest score
interest_values = []
for point in trends_data["interest_over_time"]:
for key, value in point.items():
if key not in ["date", "isPartial"] and isinstance(value, (int, float)):
interest_values.append(value)
if interest_values:
trend_dict["interest_score"] = sum(interest_values) / len(interest_values)
# Add related topics/queries
if trends_data.get("related_topics"):
top_topics = [t.get("topic_title", "") for t in trends_data["related_topics"].get("top", [])[:5]]
rising_topics = [t.get("topic_title", "") for t in trends_data["related_topics"].get("rising", [])[:5]]
trend_dict["related_topics"] = {"top": top_topics, "rising": rising_topics}
if trends_data.get("related_queries"):
top_queries = [q.get("query", "") for q in trends_data["related_queries"].get("top", [])[:5]]
rising_queries = [q.get("query", "") for q in trends_data["related_queries"].get("rising", [])[:5]]
trend_dict["related_queries"] = {"top": top_queries, "rising": rising_queries}
# Add regional interest
if trends_data.get("interest_by_region"):
regional_interest = {}
for region in trends_data["interest_by_region"][:10]: # Top 10 regions
region_name = region.get("geoName", "")
if region_name:
# Get interest value (first numeric column)
for key, value in region.items():
if key != "geoName" and isinstance(value, (int, float)):
regional_interest[region_name] = value
break
trend_dict["regional_interest"] = regional_interest
enhanced_trends.append(TrendAnalysis(**trend_dict))
# Update analyzed result with enhanced trends
analyzed_result.trends = enhanced_trends
return analyzed_result

View File

@@ -236,6 +236,56 @@ async def get_persona_defaults(
)
@router.get("/research-persona/verify")
async def verify_research_persona_exists(
current_user: Dict = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
Verify if research persona exists in database (for debugging).
Returns detailed information about persona status.
"""
try:
user_id = str(current_user.get('id'))
if not user_id:
raise HTTPException(status_code=401, detail="User not authenticated")
if not db:
raise HTTPException(status_code=500, detail="Database not available")
persona_service = ResearchPersonaService(db_session=db)
persona_data = persona_service._get_persona_data_record(user_id)
if not persona_data:
return {
"exists": False,
"reason": "No persona_data record found",
"user_id": user_id
}
has_persona = (
persona_data.research_persona is not None
and persona_data.research_persona != {}
and persona_data.research_persona != ""
and isinstance(persona_data.research_persona, dict)
and len(persona_data.research_persona) > 0
)
cache_valid = persona_service.is_cache_valid(persona_data) if has_persona else False
return {
"exists": has_persona,
"cache_valid": cache_valid,
"generated_at": persona_data.research_persona_generated_at.isoformat() if persona_data.research_persona_generated_at else None,
"persona_type": type(persona_data.research_persona).__name__ if persona_data.research_persona else None,
"persona_keys": list(persona_data.research_persona.keys()) if has_persona and isinstance(persona_data.research_persona, dict) else None,
"user_id": user_id
}
except Exception as e:
logger.error(f"[ResearchConfig] Error verifying research persona: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to verify research persona: {str(e)}")
@router.get("/research-persona")
async def get_research_persona(
current_user: Dict = Depends(get_current_user),
@@ -261,6 +311,14 @@ async def get_research_persona(
raise HTTPException(status_code=500, detail="Database not available")
persona_service = ResearchPersonaService(db_session=db)
# First check if persona exists (without generating)
existing_persona = persona_service.get_cached_only(user_id)
if existing_persona and not force_refresh:
logger.info(f"[ResearchConfig] Returning existing research persona for user {user_id} (force_refresh={force_refresh})")
return existing_persona.dict()
# Only generate if persona doesn't exist or force_refresh is True
research_persona = persona_service.get_or_generate(user_id, force_refresh=force_refresh)
if not research_persona:
@@ -286,8 +344,9 @@ async def get_research_config(
):
"""
Get complete research configuration including provider availability and persona defaults.
Requires authentication - user must be logged in.
"""
user_id = None
try:
user_id = str(current_user.get('id'))
logger.info(f"[ResearchConfig] Starting get_research_config for user {user_id}")
@@ -378,19 +437,30 @@ async def get_research_config(
# Get research persona (optional, may not exist for all users)
# CRITICAL: Use get_cached_only() to avoid triggering rate limit checks
# Only return persona if it's already cached - don't generate on config load
# Returns persona if it exists in database, regardless of cache validity
research_persona = None
persona_scheduled = False
try:
logger.debug(f"[ResearchConfig] Getting cached research persona for user {user_id}")
logger.info(f"[ResearchConfig] 🔍 Getting research persona for user {user_id}")
persona_service = ResearchPersonaService(db_session=db)
research_persona = persona_service.get_cached_only(user_id)
logger.info(
f"[ResearchConfig] Research persona check for user {user_id}: "
f"persona_exists={research_persona is not None}, "
f"onboarding_completed={onboarding_completed}"
)
if research_persona:
# Check cache validity for logging
persona_data_record = persona_service._get_persona_data_record(user_id)
cache_valid = persona_data_record and persona_service.is_cache_valid(persona_data_record) if persona_data_record else False
logger.info(
f"[ResearchConfig] ✅ Research persona FOUND for user {user_id}: "
f"exists=True, cache_valid={cache_valid}, "
f"industry={research_persona.default_industry}, "
f"onboarding_completed={onboarding_completed}"
)
else:
logger.warning(
f"[ResearchConfig] ⚠️ Research persona NOT FOUND for user {user_id}: "
f"persona_exists=False, "
f"onboarding_completed={onboarding_completed}"
)
# If onboarding is completed but persona doesn't exist, schedule generation
if onboarding_completed and not research_persona:
@@ -412,13 +482,24 @@ async def get_research_config(
# get_cached_only() never raises HTTPException, but catch any unexpected errors
logger.warning(f"[ResearchConfig] Could not load cached research persona for user {user_id}: {e}", exc_info=True)
# FastAPI will automatically serialize the ResearchPersona Pydantic model
# Convert ResearchPersona to dict for proper serialization
# FastAPI should handle Pydantic models, but explicit conversion ensures compatibility
research_persona_dict = None
if research_persona:
try:
research_persona_dict = research_persona.dict() if hasattr(research_persona, 'dict') else research_persona
logger.debug(f"[ResearchConfig] Converted research persona to dict for user {user_id}")
except Exception as e:
logger.warning(f"[ResearchConfig] Failed to convert research persona to dict: {e}")
research_persona_dict = None
# FastAPI will automatically serialize the ResearchPersona dict
# If there's a serialization issue, we catch it and log it
try:
response = ResearchConfigResponse(
provider_availability=provider_availability,
persona_defaults=persona_defaults,
research_persona=research_persona,
research_persona=research_persona_dict,
onboarding_completed=onboarding_completed,
persona_scheduled=persona_scheduled
)
@@ -434,9 +515,10 @@ async def get_research_config(
)
logger.info(
f"[ResearchConfig] Response for user {user_id}: "
f"[ResearchConfig] 📤 Response for user {user_id}: "
f"onboarding_completed={onboarding_completed}, "
f"persona_exists={research_persona is not None}, "
f"persona_dict_exists={research_persona_dict is not None}, "
f"persona_scheduled={persona_scheduled}"
)

View File

@@ -230,6 +230,14 @@ class ResearchIntent(BaseModel):
le=1.0,
description="Confidence in the intent inference"
)
confidence_reason: Optional[str] = Field(
None,
description="Reason for the confidence level"
)
great_example: Optional[str] = Field(
None,
description="Example of what a great input would look like (if confidence is low)"
)
needs_clarification: bool = Field(
False,
description="True if AI is uncertain and needs user clarification"
@@ -281,6 +289,8 @@ class IntentInferenceResponse(BaseModel):
default_factory=list,
description="Quick options for user to confirm/modify intent"
)
confidence_reason: Optional[str] = Field(None, description="Reason for confidence level")
great_example: Optional[str] = Field(None, description="Example of great input (if confidence is low)")
# ============================================================================

View File

@@ -0,0 +1,51 @@
"""
Google Trends Data Models
Pydantic models for Google Trends API responses.
"""
from typing import List, Dict, Any, Optional
from pydantic import BaseModel, Field
from datetime import datetime
class GoogleTrendsData(BaseModel):
"""Structured Google Trends data."""
interest_over_time: List[Dict[str, Any]] = Field(default_factory=list, description="Time series interest data")
interest_by_region: List[Dict[str, Any]] = Field(default_factory=list, description="Geographic interest data")
related_topics: Dict[str, List[Dict[str, Any]]] = Field(
default_factory=dict,
description="Related topics: {top: [...], rising: [...]}"
)
related_queries: Dict[str, List[Dict[str, Any]]] = Field(
default_factory=dict,
description="Related queries: {top: [...], rising: [...]}"
)
trending_searches: Optional[List[str]] = Field(None, description="Current trending searches")
timeframe: str = Field(..., description="Timeframe used (e.g., 'today 12-m')")
geo: str = Field(..., description="Geographic region (e.g., 'US', 'GB')")
keywords: List[str] = Field(..., description="Keywords analyzed")
timestamp: datetime = Field(default_factory=datetime.utcnow, description="When data was fetched")
class TrendsConfig(BaseModel):
"""Google Trends configuration with AI-driven justifications."""
enabled: bool = Field(True, description="Whether trends analysis is enabled")
keywords: List[str] = Field(..., description="AI-optimized keywords for trends analysis")
keywords_justification: str = Field(..., description="Why these keywords were chosen")
timeframe: str = Field(default="today 12-m", description="Timeframe: 'today 1-y', 'today 12-m', 'all', etc.")
timeframe_justification: str = Field(..., description="Why this timeframe was chosen")
geo: str = Field(default="US", description="Country code (e.g., 'US', 'GB', 'IN')")
geo_justification: str = Field(..., description="Why this geographic region was chosen")
expected_insights: List[str] = Field(
default_factory=list,
description="What insights trends will uncover for content generation"
)
class TrendsAnalysisResponse(BaseModel):
"""Response from trends analysis endpoint."""
success: bool
data: Optional[GoogleTrendsData] = None
error_message: Optional[str] = None
cached: bool = Field(False, description="Whether data was served from cache")

View File

@@ -3,7 +3,7 @@
import base64
from pathlib import Path
from typing import Optional, List, Dict, Any, Literal
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, status, Query
from fastapi.responses import FileResponse
from pydantic import BaseModel, Field
@@ -16,6 +16,7 @@ from services.image_studio import (
TransformImageToVideoRequest,
TalkingAvatarRequest,
)
from services.image_studio.face_swap_service import FaceSwapStudioRequest
from services.image_studio.upscale_service import UpscaleStudioRequest
from services.image_studio.templates import Platform, TemplateCategory
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
@@ -97,6 +98,27 @@ class EditImageRequest(BaseModel):
)
class EditModelsResponse(BaseModel):
"""Response model for available editing models."""
models: List[Dict[str, Any]]
total: int
class EditModelRecommendationRequest(BaseModel):
"""Request model for model recommendations."""
operation: str
image_resolution: Optional[Dict[str, int]] = None
user_tier: Optional[str] = None
preferences: Optional[Dict[str, Any]] = None
class EditModelRecommendationResponse(BaseModel):
"""Response model for model recommendations."""
recommended_model: str
reason: str
alternatives: List[Dict[str, Any]]
class EditImageResponse(BaseModel):
success: bool
operation: str
@@ -512,6 +534,173 @@ async def get_edit_operations(
raise HTTPException(status_code=500, detail="Failed to load edit operations")
@router.get("/edit/models", response_model=EditModelsResponse, summary="List available editing models")
async def get_edit_models(
operation: Optional[str] = None,
tier: Optional[str] = None,
current_user: Dict[str, Any] = Depends(get_current_user),
studio_manager: ImageStudioManager = Depends(get_studio_manager),
):
"""Get available WaveSpeed editing models with metadata.
Query Parameters:
- operation: Filter by operation type (e.g., "general_edit")
- tier: Filter by tier ("budget", "mid", "premium")
"""
try:
result = studio_manager.get_edit_models(operation=operation, tier=tier)
return EditModelsResponse(**result)
except Exception as e:
logger.error(f"[Edit Models] ❌ Error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to load editing models")
@router.post("/edit/recommend", response_model=EditModelRecommendationResponse, summary="Get model recommendation")
async def recommend_edit_model(
request: EditModelRecommendationRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
studio_manager: ImageStudioManager = Depends(get_studio_manager),
):
"""Get recommended editing model based on operation, image resolution, and user preferences.
Auto-detects best model when user doesn't specify one.
"""
try:
# Get user tier from current_user if available
user_tier = request.user_tier
if not user_tier and current_user:
# Try to extract from user data (adjust based on your user model)
user_tier = current_user.get("tier") or current_user.get("subscription_tier")
result = studio_manager.recommend_edit_model(
operation=request.operation,
image_resolution=request.image_resolution,
user_tier=user_tier,
preferences=request.preferences,
)
return EditModelRecommendationResponse(**result)
except Exception as e:
logger.error(f"[Edit Recommend] ❌ Error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to get recommendation: {e}")
# ====================
# FACE SWAP STUDIO ENDPOINTS
# ====================
class FaceSwapRequest(BaseModel):
base_image_base64: str
face_image_base64: str
model: Optional[str] = None
target_face_index: Optional[int] = None
target_gender: Optional[str] = None
options: Optional[Dict[str, Any]] = None
class FaceSwapResponse(BaseModel):
success: bool
image_base64: str
width: int
height: int
provider: str
model: str
metadata: Dict[str, Any]
class FaceSwapModelsResponse(BaseModel):
"""Response model for available face swap models."""
models: List[Dict[str, Any]]
total: int
class FaceSwapModelRecommendationRequest(BaseModel):
"""Request model for face swap model recommendations."""
base_image_resolution: Optional[Dict[str, int]] = None
face_image_resolution: Optional[Dict[str, int]] = None
user_tier: Optional[str] = None
preferences: Optional[Dict[str, Any]] = None
class FaceSwapModelRecommendationResponse(BaseModel):
"""Response model for face swap model recommendations."""
recommended_model: str
reason: str
alternatives: List[Dict[str, Any]]
@router.post("/face-swap/process", response_model=FaceSwapResponse, summary="Process Face Swap")
async def process_face_swap(
request: FaceSwapRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
studio_manager: ImageStudioManager = Depends(get_studio_manager),
):
"""Process face swap request with auto-detection and model selection."""
try:
user_id = _require_user_id(current_user, "face swap")
face_swap_request = FaceSwapStudioRequest(
base_image_base64=request.base_image_base64,
face_image_base64=request.face_image_base64,
model=request.model,
target_face_index=request.target_face_index,
target_gender=request.target_gender,
options=request.options,
)
result = await studio_manager.face_swap(face_swap_request, user_id=user_id)
return FaceSwapResponse(**result)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Face Swap] ❌ Error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Face swap failed: {e}")
@router.get("/face-swap/models", response_model=FaceSwapModelsResponse, summary="List available face swap models")
async def get_face_swap_models(
tier: Optional[str] = None,
current_user: Dict[str, Any] = Depends(get_current_user),
studio_manager: ImageStudioManager = Depends(get_studio_manager),
):
"""Get available WaveSpeed face swap models with metadata.
Query Parameters:
- tier: Filter by tier ("budget", "mid", "premium")
"""
try:
result = studio_manager.get_face_swap_models(tier=tier)
return FaceSwapModelsResponse(**result)
except Exception as e:
logger.error(f"[Face Swap Models] ❌ Error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to load face swap models")
@router.post("/face-swap/recommend", response_model=FaceSwapModelRecommendationResponse, summary="Get face swap model recommendation")
async def recommend_face_swap_model(
request: FaceSwapModelRecommendationRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
studio_manager: ImageStudioManager = Depends(get_studio_manager),
):
"""Get recommended face swap model based on image resolutions and user preferences.
Auto-detects best model when user doesn't specify one.
"""
try:
# Get user tier from current_user if available
user_tier = request.user_tier
if not user_tier and current_user:
user_tier = current_user.get("tier") or current_user.get("subscription_tier")
result = studio_manager.recommend_face_swap_model(
base_image_resolution=request.base_image_resolution,
face_image_resolution=request.face_image_resolution,
user_tier=user_tier,
preferences=request.preferences,
)
return FaceSwapModelRecommendationResponse(**result)
except Exception as e:
logger.error(f"[Face Swap Recommend] ❌ Error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to get recommendation: {e}")
# ====================
# UPSCALE STUDIO ENDPOINTS
# ====================
@@ -1009,6 +1198,403 @@ async def serve_transform_video(
raise HTTPException(status_code=500, detail=str(e))
# ====================
# COMPRESSION STUDIO ENDPOINTS
# ====================
class CompressImageRequest(BaseModel):
"""Request payload for image compression."""
image_base64: str = Field(..., description="Image in base64 or data URL format")
quality: int = Field(85, ge=1, le=100, description="Compression quality (1-100)")
format: str = Field("jpeg", description="Output format: jpeg, png, webp")
target_size_kb: Optional[int] = Field(None, ge=10, description="Target file size in KB")
strip_metadata: bool = Field(True, description="Remove EXIF metadata")
progressive: bool = Field(True, description="Progressive JPEG encoding")
optimize: bool = Field(True, description="Optimize encoding")
class CompressImageResponse(BaseModel):
success: bool
image_base64: str
original_size_kb: float
compressed_size_kb: float
compression_ratio: float
format: str
width: int
height: int
quality_used: int
metadata_stripped: bool
class CompressBatchRequest(BaseModel):
"""Request payload for batch compression."""
images: List[CompressImageRequest] = Field(..., description="List of images to compress")
class CompressBatchResponse(BaseModel):
success: bool
results: List[CompressImageResponse]
total_images: int
successful: int
failed: int
class CompressionEstimateRequest(BaseModel):
"""Request for compression estimation."""
image_base64: str = Field(..., description="Image in base64 or data URL format")
format: str = Field("jpeg", description="Output format")
quality: int = Field(85, ge=1, le=100, description="Quality level")
class CompressionEstimateResponse(BaseModel):
original_size_kb: float
estimated_size_kb: float
estimated_reduction_percent: float
width: int
height: int
format: str
class CompressionFormatsResponse(BaseModel):
formats: List[Dict[str, Any]]
class CompressionPresetsResponse(BaseModel):
presets: List[Dict[str, Any]]
@router.post("/compress", response_model=CompressImageResponse, summary="Compress an image")
async def compress_image(
request: CompressImageRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
studio_manager: ImageStudioManager = Depends(get_studio_manager),
):
"""Compress an image with specified quality and format settings.
Features:
- Quality control (1-100)
- Format conversion (JPEG, PNG, WebP)
- Target size compression
- Metadata stripping
- Progressive JPEG support
"""
try:
user_id = _require_user_id(current_user, "image compression")
logger.info(f"[Compression] Request from user {user_id}: format={request.format}, quality={request.quality}")
from services.image_studio.compression_service import CompressionRequest as ServiceRequest
compression_request = ServiceRequest(
image_base64=request.image_base64,
quality=request.quality,
format=request.format,
target_size_kb=request.target_size_kb,
strip_metadata=request.strip_metadata,
progressive=request.progressive,
optimize=request.optimize,
)
result = await studio_manager.compress_image(compression_request, user_id=user_id)
return CompressImageResponse(
success=result.success,
image_base64=result.image_base64,
original_size_kb=result.original_size_kb,
compressed_size_kb=result.compressed_size_kb,
compression_ratio=result.compression_ratio,
format=result.format,
width=result.width,
height=result.height,
quality_used=result.quality_used,
metadata_stripped=result.metadata_stripped,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Compression] ❌ Error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Image compression failed: {e}")
@router.post("/compress/batch", response_model=CompressBatchResponse, summary="Compress multiple images")
async def compress_batch(
request: CompressBatchRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
studio_manager: ImageStudioManager = Depends(get_studio_manager),
):
"""Compress multiple images with the same or individual settings."""
try:
user_id = _require_user_id(current_user, "batch compression")
logger.info(f"[Compression] Batch request from user {user_id}: {len(request.images)} images")
from services.image_studio.compression_service import CompressionRequest as ServiceRequest
compression_requests = [
ServiceRequest(
image_base64=img.image_base64,
quality=img.quality,
format=img.format,
target_size_kb=img.target_size_kb,
strip_metadata=img.strip_metadata,
progressive=img.progressive,
optimize=img.optimize,
)
for img in request.images
]
results = await studio_manager.compress_batch(compression_requests, user_id=user_id)
successful = sum(1 for r in results if r.success)
failed = len(results) - successful
return CompressBatchResponse(
success=failed == 0,
results=[
CompressImageResponse(
success=r.success,
image_base64=r.image_base64,
original_size_kb=r.original_size_kb,
compressed_size_kb=r.compressed_size_kb,
compression_ratio=r.compression_ratio,
format=r.format,
width=r.width,
height=r.height,
quality_used=r.quality_used,
metadata_stripped=r.metadata_stripped,
)
for r in results
],
total_images=len(results),
successful=successful,
failed=failed,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Compression] ❌ Batch error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Batch compression failed: {e}")
@router.post("/compress/estimate", response_model=CompressionEstimateResponse, summary="Estimate compression results")
async def estimate_compression(
request: CompressionEstimateRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
studio_manager: ImageStudioManager = Depends(get_studio_manager),
):
"""Estimate compression results without actually compressing the image."""
try:
result = await studio_manager.estimate_compression(
request.image_base64,
request.format,
request.quality,
)
return CompressionEstimateResponse(**result)
except Exception as e:
logger.error(f"[Compression] ❌ Estimate error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Compression estimation failed: {e}")
@router.get("/compress/formats", response_model=CompressionFormatsResponse, summary="Get supported compression formats")
async def get_compression_formats(
studio_manager: ImageStudioManager = Depends(get_studio_manager),
):
"""Get list of supported compression formats with their capabilities."""
formats = studio_manager.get_compression_formats()
return CompressionFormatsResponse(formats=formats)
@router.get("/compress/presets", response_model=CompressionPresetsResponse, summary="Get compression presets")
async def get_compression_presets(
studio_manager: ImageStudioManager = Depends(get_studio_manager),
):
"""Get predefined compression presets for common use cases."""
presets = studio_manager.get_compression_presets()
return CompressionPresetsResponse(presets=presets)
# ====================
# FORMAT CONVERTER ENDPOINTS
# ====================
class ConvertFormatRequest(BaseModel):
"""Request payload for format conversion."""
image_base64: str = Field(..., description="Image in base64 or data URL format")
target_format: str = Field(..., description="Target format: png, jpeg, jpg, webp, gif, bmp, tiff")
preserve_transparency: bool = Field(True, description="Preserve transparency when possible")
quality: Optional[int] = Field(None, ge=1, le=100, description="Quality for lossy formats (1-100)")
color_space: Optional[str] = Field(None, description="Color space: sRGB, Adobe RGB")
strip_metadata: bool = Field(False, description="Remove EXIF metadata")
optimize: bool = Field(True, description="Optimize encoding")
progressive: bool = Field(True, description="Progressive JPEG encoding")
class ConvertFormatResponse(BaseModel):
success: bool
image_base64: str
original_format: str
target_format: str
original_size_kb: float
converted_size_kb: float
width: int
height: int
transparency_preserved: bool
metadata_preserved: bool
color_space: Optional[str] = None
class ConvertFormatBatchRequest(BaseModel):
"""Request payload for batch format conversion."""
images: List[ConvertFormatRequest] = Field(..., description="List of images to convert")
class ConvertFormatBatchResponse(BaseModel):
success: bool
results: List[ConvertFormatResponse]
total_images: int
successful: int
failed: int
class SupportedFormatsResponse(BaseModel):
formats: List[Dict[str, Any]]
class FormatRecommendationsResponse(BaseModel):
recommendations: List[Dict[str, Any]]
@router.post("/convert-format", response_model=ConvertFormatResponse, summary="Convert image format")
async def convert_format(
request: ConvertFormatRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
studio_manager: ImageStudioManager = Depends(get_studio_manager),
):
"""Convert an image to a different format.
Features:
- Multi-format support (PNG, JPEG, WebP, GIF, BMP, TIFF)
- Transparency preservation
- Color space conversion
- Metadata handling
"""
try:
user_id = _require_user_id(current_user, "format conversion")
logger.info(f"[Format Converter] Request from user {user_id}: {request.target_format}")
from services.image_studio.format_converter_service import FormatConversionRequest as ServiceRequest
conversion_request = ServiceRequest(
image_base64=request.image_base64,
target_format=request.target_format,
preserve_transparency=request.preserve_transparency,
quality=request.quality,
color_space=request.color_space,
strip_metadata=request.strip_metadata,
optimize=request.optimize,
progressive=request.progressive,
)
result = await studio_manager.convert_format(conversion_request, user_id=user_id)
return ConvertFormatResponse(
success=result.success,
image_base64=result.image_base64,
original_format=result.original_format,
target_format=result.target_format,
original_size_kb=result.original_size_kb,
converted_size_kb=result.converted_size_kb,
width=result.width,
height=result.height,
transparency_preserved=result.transparency_preserved,
metadata_preserved=result.metadata_preserved,
color_space=result.color_space,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Format Converter] ❌ Error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Format conversion failed: {e}")
@router.post("/convert-format/batch", response_model=ConvertFormatBatchResponse, summary="Convert multiple images")
async def convert_format_batch(
request: ConvertFormatBatchRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
studio_manager: ImageStudioManager = Depends(get_studio_manager),
):
"""Convert multiple images to different formats."""
try:
user_id = _require_user_id(current_user, "batch format conversion")
logger.info(f"[Format Converter] Batch request from user {user_id}: {len(request.images)} images")
from services.image_studio.format_converter_service import FormatConversionRequest as ServiceRequest
conversion_requests = [
ServiceRequest(
image_base64=img.image_base64,
target_format=img.target_format,
preserve_transparency=img.preserve_transparency,
quality=img.quality,
color_space=img.color_space,
strip_metadata=img.strip_metadata,
optimize=img.optimize,
progressive=img.progressive,
)
for img in request.images
]
results = await studio_manager.convert_format_batch(conversion_requests, user_id=user_id)
successful = sum(1 for r in results if r.success)
failed = len(results) - successful
return ConvertFormatBatchResponse(
success=failed == 0,
results=[
ConvertFormatResponse(
success=r.success,
image_base64=r.image_base64,
original_format=r.original_format,
target_format=r.target_format,
original_size_kb=r.original_size_kb,
converted_size_kb=r.converted_size_kb,
width=r.width,
height=r.height,
transparency_preserved=r.transparency_preserved,
metadata_preserved=r.metadata_preserved,
color_space=r.color_space,
)
for r in results
],
total_images=len(results),
successful=successful,
failed=failed,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Format Converter] ❌ Batch error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Batch format conversion failed: {e}")
@router.get("/convert-format/supported", response_model=SupportedFormatsResponse, summary="Get supported formats")
async def get_supported_formats(
studio_manager: ImageStudioManager = Depends(get_studio_manager),
):
"""Get list of supported conversion formats with their capabilities."""
formats = studio_manager.get_supported_formats()
return SupportedFormatsResponse(formats=formats)
@router.get("/convert-format/recommendations", response_model=FormatRecommendationsResponse, summary="Get format recommendations")
async def get_format_recommendations(
source_format: str = Query(..., description="Source format"),
studio_manager: ImageStudioManager = Depends(get_studio_manager),
):
"""Get format recommendations based on source format."""
recommendations = studio_manager.get_format_recommendations(source_format)
return FormatRecommendationsResponse(recommendations=recommendations)
# ====================
# HEALTH CHECK
# ====================
@@ -1028,6 +1614,7 @@ async def health_check():
"create_studio": "available",
"templates": "available",
"providers": "available",
"compression": "available",
}
}

View File

@@ -9,6 +9,12 @@ from services.product_marketing import (
BrandDNASyncService,
AssetAuditService,
ChannelPackService,
ProductAnimationService,
ProductAnimationRequest,
ProductVideoService,
ProductVideoRequest,
ProductAvatarService,
ProductAvatarRequest,
)
from services.product_marketing.campaign_storage import CampaignStorageService
from services.product_marketing.product_image_service import ProductImageService, ProductImageRequest
@@ -268,6 +274,7 @@ async def generate_asset(
- Applies specialized marketing prompts
- Automatically tracks assets in Asset Library
- Validates subscription limits
- Updates campaign status after generation
"""
try:
user_id = _require_user_id(current_user, "asset generation")
@@ -279,6 +286,51 @@ async def generate_asset(
product_context=request.product_context,
)
# Update campaign status if asset was generated successfully
if result.get('success'):
campaign_id = request.asset_proposal.get('campaign_id')
if not campaign_id:
# Try to extract from asset_id
asset_id = request.asset_proposal.get('asset_id', '')
if asset_id and '_' in asset_id:
parts = asset_id.split('_')
phase_indicators = ['teaser', 'launch', 'nurture', 'prelaunch', 'postlaunch']
for i, part in enumerate(parts):
if part.lower() in phase_indicators and i > 0:
campaign_id = '_'.join(parts[:i])
break
if campaign_id:
try:
campaign_storage = get_campaign_storage()
campaign = campaign_storage.get_campaign(user_id, campaign_id)
if campaign:
# Update proposal status to 'generating' or 'ready'
asset_node_id = request.asset_proposal.get('asset_id', '')
if asset_node_id:
from models.product_marketing_models import CampaignProposal
from services.database import SessionLocal
db = SessionLocal()
try:
proposal = db.query(CampaignProposal).filter(
CampaignProposal.campaign_id == campaign_id,
CampaignProposal.asset_node_id == asset_node_id,
CampaignProposal.user_id == user_id
).first()
if proposal:
proposal.status = 'ready'
db.commit()
logger.info(f"[Product Marketing] ✅ Updated proposal status for {asset_node_id}")
finally:
db.close()
# Check if all assets are ready and update campaign status
# (This could be enhanced to check all proposals)
logger.info(f"[Product Marketing] ✅ Asset generated for campaign {campaign_id}")
except Exception as update_error:
logger.warning(f"[Product Marketing] ⚠️ Could not update campaign status: {str(update_error)}")
# Don't fail the request if status update fails
logger.info(f"[Product Marketing] ✅ Asset generated successfully")
return result
@@ -617,6 +669,474 @@ async def serve_product_image(
raise HTTPException(status_code=500, detail=str(e))
# ====================
# PRODUCT ANIMATION ENDPOINTS
# ====================
class ProductAnimationRequestModel(BaseModel):
"""Request for product animation."""
product_image_base64: str = Field(..., description="Base64 encoded product image")
animation_type: str = Field(..., description="Animation type: reveal, rotation, demo, lifestyle")
product_name: str = Field(..., description="Product name")
product_description: Optional[str] = Field(None, description="Product description")
resolution: str = Field(default="720p", description="Video resolution: 480p, 720p, 1080p")
duration: int = Field(default=5, description="Video duration: 5 or 10 seconds")
audio_base64: Optional[str] = Field(None, description="Optional audio for synchronization")
additional_context: Optional[str] = Field(None, description="Additional context for animation")
def get_product_animation_service() -> ProductAnimationService:
"""Get Product Animation Service instance."""
return ProductAnimationService()
@router.post("/products/animate", summary="Animate Product Image")
async def animate_product(
request: ProductAnimationRequestModel,
current_user: Dict[str, Any] = Depends(get_current_user),
animation_service: ProductAnimationService = Depends(get_product_animation_service),
brand_dna_sync: BrandDNASyncService = Depends(lambda: BrandDNASyncService())
):
"""Animate a product image into a video.
This endpoint:
- Uses WAN 2.5 Image-to-Video via Transform Studio
- Supports multiple animation types (reveal, rotation, demo, lifestyle)
- Applies brand DNA for consistent styling
- Returns video URL and metadata
"""
try:
user_id = _require_user_id(current_user, "product animation")
logger.info(f"[Product Marketing] Animating product '{request.product_name}' with type '{request.animation_type}'")
# Get brand DNA for personalization
brand_context = None
try:
brand_dna = brand_dna_sync.get_brand_dna_tokens(user_id)
brand_context = {
"visual_identity": brand_dna.get("visual_identity", {}),
"persona": brand_dna.get("persona", {}),
}
except Exception as brand_error:
logger.warning(f"[Product Marketing] Could not load brand DNA: {str(brand_error)}")
# Create animation request
animation_request = ProductAnimationRequest(
product_image_base64=request.product_image_base64,
animation_type=request.animation_type,
product_name=request.product_name,
product_description=request.product_description,
resolution=request.resolution,
duration=request.duration,
audio_base64=request.audio_base64,
brand_context=brand_context,
additional_context=request.additional_context,
)
# Generate animation
result = await animation_service.animate_product(animation_request, user_id)
logger.info(f"[Product Marketing] ✅ Product animation completed: cost=${result.get('cost', 0):.2f}")
return {
"success": True,
"product_name": result.get("product_name"),
"animation_type": result.get("animation_type"),
"video_url": result.get("video_url"),
"video_filename": result.get("filename"),
"cost": result.get("cost", 0.0),
"resolution": request.resolution,
"duration": request.duration,
}
except HTTPException:
raise
except Exception as e:
logger.error(f"[Product Marketing] ❌ Error animating product: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Product animation failed: {str(e)}")
@router.post("/products/animate/reveal", summary="Create Product Reveal Animation")
async def create_product_reveal(
request: ProductAnimationRequestModel,
current_user: Dict[str, Any] = Depends(get_current_user),
animation_service: ProductAnimationService = Depends(get_product_animation_service),
brand_dna_sync: BrandDNASyncService = Depends(lambda: BrandDNASyncService())
):
"""Create product reveal animation (elegant product unveiling)."""
try:
user_id = _require_user_id(current_user, "product reveal animation")
# Get brand DNA
brand_context = None
try:
brand_dna = brand_dna_sync.get_brand_dna_tokens(user_id)
brand_context = {
"visual_identity": brand_dna.get("visual_identity", {}),
"persona": brand_dna.get("persona", {}),
}
except Exception:
pass
result = await animation_service.create_product_reveal(
product_image_base64=request.product_image_base64,
product_name=request.product_name,
product_description=request.product_description,
user_id=user_id,
resolution=request.resolution,
duration=request.duration,
brand_context=brand_context
)
return {
"success": True,
"animation_type": "reveal",
"video_url": result.get("video_url"),
"cost": result.get("cost", 0.0),
}
except Exception as e:
logger.error(f"[Product Marketing] ❌ Error creating reveal: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/products/animate/rotation", summary="Create Product Rotation Animation")
async def create_product_rotation(
request: ProductAnimationRequestModel,
current_user: Dict[str, Any] = Depends(get_current_user),
animation_service: ProductAnimationService = Depends(get_product_animation_service),
brand_dna_sync: BrandDNASyncService = Depends(lambda: BrandDNASyncService())
):
"""Create 360° product rotation animation."""
try:
user_id = _require_user_id(current_user, "product rotation animation")
# Get brand DNA
brand_context = None
try:
brand_dna = brand_dna_sync.get_brand_dna_tokens(user_id)
brand_context = {
"visual_identity": brand_dna.get("visual_identity", {}),
"persona": brand_dna.get("persona", {}),
}
except Exception:
pass
result = await animation_service.create_product_rotation(
product_image_base64=request.product_image_base64,
product_name=request.product_name,
product_description=request.product_description,
user_id=user_id,
resolution=request.resolution,
duration=request.duration or 10, # Default 10s for rotation
brand_context=brand_context
)
return {
"success": True,
"animation_type": "rotation",
"video_url": result.get("video_url"),
"cost": result.get("cost", 0.0),
}
except Exception as e:
logger.error(f"[Product Marketing] ❌ Error creating rotation: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/products/animate/demo", summary="Create Product Demo Animation")
async def create_product_demo_animation(
request: ProductAnimationRequestModel,
current_user: Dict[str, Any] = Depends(get_current_user),
animation_service: ProductAnimationService = Depends(get_product_animation_service),
brand_dna_sync: BrandDNASyncService = Depends(lambda: BrandDNASyncService())
):
"""Create product demo animation (image-to-video: product in use, demonstrating features)."""
try:
user_id = _require_user_id(current_user, "product demo animation")
# Get brand DNA
brand_context = None
try:
brand_dna = brand_dna_sync.get_brand_dna_tokens(user_id)
brand_context = {
"visual_identity": brand_dna.get("visual_identity", {}),
"persona": brand_dna.get("persona", {}),
}
except Exception:
pass
result = await animation_service.create_product_demo(
product_image_base64=request.product_image_base64,
product_name=request.product_name,
product_description=request.product_description,
user_id=user_id,
resolution=request.resolution,
duration=request.duration or 10, # Default 10s for demo
audio_base64=request.audio_base64,
brand_context=brand_context
)
return {
"success": True,
"animation_type": "demo",
"video_subtype": "animation", # Image-to-video
"video_url": result.get("video_url"),
"cost": result.get("cost", 0.0),
}
except Exception as e:
logger.error(f"[Product Marketing] ❌ Error creating demo animation: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# ====================
# PRODUCT VIDEO ENDPOINTS (Text-to-Video)
# ====================
class ProductVideoRequestModel(BaseModel):
"""Request for product demo video (text-to-video)."""
product_name: str = Field(..., description="Product name")
product_description: str = Field(..., description="Product description")
video_type: str = Field(default="demo", description="Video type: demo, storytelling, feature_highlight, launch")
resolution: str = Field(default="720p", description="Video resolution: 480p, 720p, 1080p")
duration: int = Field(default=10, description="Video duration: 5 or 10 seconds")
audio_base64: Optional[str] = Field(None, description="Optional audio for synchronization")
additional_context: Optional[str] = Field(None, description="Additional context for video")
def get_product_video_service() -> ProductVideoService:
"""Get Product Video Service instance."""
return ProductVideoService()
@router.post("/products/video/demo", summary="Create Product Demo Video (Text-to-Video)")
async def create_product_demo_video(
request: ProductVideoRequestModel,
current_user: Dict[str, Any] = Depends(get_current_user),
video_service: ProductVideoService = Depends(get_product_video_service),
brand_dna_sync: BrandDNASyncService = Depends(lambda: BrandDNASyncService())
):
"""Create product demo video using WAN 2.5 Text-to-Video.
This endpoint:
- Uses WAN 2.5 Text-to-Video via main_video_generation
- Generates video from product description (no image required)
- Applies brand DNA for consistent styling
- Returns video URL and metadata
"""
try:
user_id = _require_user_id(current_user, "product demo video")
logger.info(f"[Product Marketing] Creating {request.video_type} video for product '{request.product_name}'")
# Get brand DNA for personalization
brand_context = None
try:
brand_dna = brand_dna_sync.get_brand_dna_tokens(user_id)
brand_context = {
"visual_identity": brand_dna.get("visual_identity", {}),
"persona": brand_dna.get("persona", {}),
}
except Exception as brand_error:
logger.warning(f"[Product Marketing] Could not load brand DNA: {str(brand_error)}")
# Create video request
video_request = ProductVideoRequest(
product_name=request.product_name,
product_description=request.product_description,
video_type=request.video_type,
resolution=request.resolution,
duration=request.duration,
audio_base64=request.audio_base64,
brand_context=brand_context,
additional_context=request.additional_context,
)
# Generate video using unified ai_video_generate()
result = await video_service.generate_product_video(video_request, user_id)
logger.info(f"[Product Marketing] ✅ Product demo video completed: cost=${result.get('cost', 0):.2f}")
return {
"success": True,
"product_name": result.get("product_name"),
"video_type": result.get("video_type"),
"video_url": result.get("file_url"),
"video_filename": result.get("filename"),
"cost": result.get("cost", 0.0),
"resolution": request.resolution,
"duration": request.duration,
}
except HTTPException:
raise
except Exception as e:
logger.error(f"[Product Marketing] ❌ Error creating product demo video: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Product demo video generation failed: {str(e)}")
@router.post("/products/video/storytelling", summary="Create Product Storytelling Video")
async def create_product_storytelling(
request: ProductVideoRequestModel,
current_user: Dict[str, Any] = Depends(get_current_user),
video_service: ProductVideoService = Depends(get_product_video_service),
brand_dna_sync: BrandDNASyncService = Depends(lambda: BrandDNASyncService())
):
"""Create product storytelling video (narrative-driven product showcase)."""
try:
user_id = _require_user_id(current_user, "product storytelling video")
# Get brand DNA
brand_context = None
try:
brand_dna = brand_dna_sync.get_brand_dna_tokens(user_id)
brand_context = {
"visual_identity": brand_dna.get("visual_identity", {}),
"persona": brand_dna.get("persona", {}),
}
except Exception:
pass
result = await video_service.create_product_storytelling(
product_name=request.product_name,
product_description=request.product_description,
user_id=user_id,
resolution=request.resolution,
duration=request.duration,
audio_base64=request.audio_base64,
brand_context=brand_context
)
return {
"success": True,
"video_type": "storytelling",
"video_url": result.get("file_url"),
"cost": result.get("cost", 0.0),
}
except Exception as e:
logger.error(f"[Product Marketing] ❌ Error creating storytelling video: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/products/video/feature-highlight", summary="Create Product Feature Highlight Video")
async def create_product_feature_highlight(
request: ProductVideoRequestModel,
current_user: Dict[str, Any] = Depends(get_current_user),
video_service: ProductVideoService = Depends(get_product_video_service),
brand_dna_sync: BrandDNASyncService = Depends(lambda: BrandDNASyncService())
):
"""Create product feature highlight video (close-up shots of key features)."""
try:
user_id = _require_user_id(current_user, "product feature highlight video")
# Get brand DNA
brand_context = None
try:
brand_dna = brand_dna_sync.get_brand_dna_tokens(user_id)
brand_context = {
"visual_identity": brand_dna.get("visual_identity", {}),
"persona": brand_dna.get("persona", {}),
}
except Exception:
pass
result = await video_service.create_product_feature_highlight(
product_name=request.product_name,
product_description=request.product_description,
user_id=user_id,
resolution=request.resolution,
duration=request.duration,
audio_base64=request.audio_base64,
brand_context=brand_context
)
return {
"success": True,
"video_type": "feature_highlight",
"video_url": result.get("file_url"),
"cost": result.get("cost", 0.0),
}
except Exception as e:
logger.error(f"[Product Marketing] ❌ Error creating feature highlight video: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/products/video/launch", summary="Create Product Launch Video")
async def create_product_launch(
request: ProductVideoRequestModel,
current_user: Dict[str, Any] = Depends(get_current_user),
video_service: ProductVideoService = Depends(get_product_video_service),
brand_dna_sync: BrandDNASyncService = Depends(lambda: BrandDNASyncService())
):
"""Create product launch video (exciting unveiling, launch event aesthetic)."""
try:
user_id = _require_user_id(current_user, "product launch video")
# Get brand DNA
brand_context = None
try:
brand_dna = brand_dna_sync.get_brand_dna_tokens(user_id)
brand_context = {
"visual_identity": brand_dna.get("visual_identity", {}),
"persona": brand_dna.get("persona", {}),
}
except Exception:
pass
result = await video_service.create_product_launch(
product_name=request.product_name,
product_description=request.product_description,
user_id=user_id,
resolution=request.resolution or "1080p", # Higher quality for launch
duration=request.duration,
audio_base64=request.audio_base64,
brand_context=brand_context
)
return {
"success": True,
"video_type": "launch",
"video_url": result.get("file_url"),
"cost": result.get("cost", 0.0),
}
except Exception as e:
logger.error(f"[Product Marketing] ❌ Error creating launch video: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/products/videos/{user_id}/{filename}", summary="Serve Product Video")
async def serve_product_video(
user_id: str,
filename: str,
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Serve generated product videos."""
try:
from fastapi.responses import FileResponse
from pathlib import Path
# Verify user owns the video
current_user_id = _require_user_id(current_user, "serving product video")
if current_user_id != user_id:
raise HTTPException(status_code=403, detail="Access denied")
# Locate video file
base_dir = Path(__file__).parent.parent.parent
video_path = base_dir / "product_videos" / user_id / filename
if not video_path.exists():
raise HTTPException(status_code=404, detail="Video not found")
return FileResponse(
path=str(video_path),
media_type="video/mp4",
filename=filename
)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Product Marketing] ❌ Error serving product video: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# ====================
# HEALTH CHECK
# ====================
@@ -635,6 +1155,8 @@ async def health_check():
"asset_audit": "available",
"channel_pack": "available",
"product_image_service": "available",
"product_animation_service": "available",
"product_video_service": "available",
}
}

View File

@@ -12,7 +12,7 @@ Uses WaveSpeed AI models for high-quality video generation.
from fastapi import APIRouter
from .endpoints import create, avatar, enhance, extend, transform, models, serve, tasks, prompt, social, face_swap, video_translate, video_background_remover, add_audio_to_video
from .endpoints import create, avatar, enhance, extend, transform, models, serve, tasks, prompt, social, face_swap, video_translate, video_background_remover, add_audio_to_video, edit
# Create main router
router = APIRouter(
@@ -32,6 +32,7 @@ router.include_router(face_swap.router)
router.include_router(video_translate.router)
router.include_router(video_background_remover.router)
router.include_router(add_audio_to_video.router)
router.include_router(edit.router)
router.include_router(models.router)
router.include_router(serve.router)
router.include_router(tasks.router)

View File

@@ -0,0 +1,418 @@
"""
Edit Studio API endpoints.
Phase 1: Basic FFmpeg operations (Trim/Cut, Speed Control, Stabilization)
"""
from typing import Dict, Any, Optional
from fastapi import APIRouter, File, UploadFile, Form, Depends, HTTPException
from sqlalchemy.orm import Session
from backend.middleware.auth import get_current_user, require_authenticated_user
from backend.database.database import get_db
from backend.services.video_studio.edit_service import EditService
router = APIRouter()
@router.post("/edit/trim")
async def trim_video(
file: UploadFile = File(..., description="Video file to trim"),
start_time: float = Form(0.0, description="Start time in seconds"),
end_time: Optional[float] = Form(None, description="End time in seconds (optional)"),
max_duration: Optional[float] = Form(None, description="Maximum duration in seconds (trims if video is longer)"),
trim_mode: str = Form("beginning", description="How to trim if max_duration is set: beginning, middle, end"),
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db),
) -> Dict[str, Any]:
"""
Trim video to specified duration or time range.
Supports:
- Trim by start/end time
- Trim to maximum duration
- Trim modes: beginning, middle, end
"""
try:
user_id = require_authenticated_user(current_user)
if not file.content_type.startswith('video/'):
raise HTTPException(status_code=400, detail="File must be a video")
# Validate trim_mode
valid_modes = ["beginning", "middle", "end"]
if trim_mode not in valid_modes:
raise HTTPException(
status_code=400,
detail=f"Invalid trim_mode. Must be one of: {', '.join(valid_modes)}"
)
# Initialize service
edit_service = EditService()
# Read video file
video_data = await file.read()
# Trim video
result = await edit_service.trim_video(
video_data=video_data,
start_time=start_time,
end_time=end_time,
max_duration=max_duration,
trim_mode=trim_mode,
user_id=user_id,
)
return result
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Video trimming failed: {str(e)}"
)
@router.post("/edit/speed")
async def adjust_video_speed(
file: UploadFile = File(..., description="Video file to adjust speed"),
speed_factor: float = Form(..., description="Speed multiplier (0.25, 0.5, 1.0, 1.5, 2.0, 4.0)"),
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db),
) -> Dict[str, Any]:
"""
Adjust video playback speed.
Supports:
- Slow motion: 0.25x, 0.5x
- Normal: 1.0x
- Fast forward: 1.5x, 2.0x, 4.0x
"""
try:
user_id = require_authenticated_user(current_user)
if not file.content_type.startswith('video/'):
raise HTTPException(status_code=400, detail="File must be a video")
# Validate speed factor
if speed_factor <= 0 or speed_factor > 4.0:
raise HTTPException(
status_code=400,
detail="Speed factor must be between 0.25 and 4.0"
)
# Initialize service
edit_service = EditService()
# Read video file
video_data = await file.read()
# Adjust speed
result = await edit_service.adjust_speed(
video_data=video_data,
speed_factor=speed_factor,
user_id=user_id,
)
return result
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Video speed adjustment failed: {str(e)}"
)
@router.post("/edit/stabilize")
async def stabilize_video(
file: UploadFile = File(..., description="Video file to stabilize"),
smoothing: int = Form(10, description="Smoothing window size (1-100, default: 10)"),
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db),
) -> Dict[str, Any]:
"""
Stabilize shaky video using FFmpeg's vidstab filters.
Uses two-pass stabilization:
1. Detect camera shake (vidstabdetect)
2. Apply stabilization (vidstabtransform)
Note: Requires FFmpeg with vidstab filters enabled.
"""
try:
user_id = require_authenticated_user(current_user)
if not file.content_type.startswith('video/'):
raise HTTPException(status_code=400, detail="File must be a video")
# Validate smoothing
if smoothing < 1 or smoothing > 100:
raise HTTPException(
status_code=400,
detail="Smoothing must be between 1 and 100"
)
# Initialize service
edit_service = EditService()
# Read video file
video_data = await file.read()
# Stabilize video
result = await edit_service.stabilize_video(
video_data=video_data,
smoothing=smoothing,
user_id=user_id,
)
return result
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Video stabilization failed: {str(e)}"
)
@router.post("/edit/estimate-cost")
async def estimate_edit_cost(
edit_type: str = Form(..., description="Type of edit: trim, speed, stabilize, text, volume, normalize, denoise"),
duration: float = Form(10.0, description="Estimated video duration in seconds"),
current_user: Dict[str, Any] = Depends(get_current_user),
) -> Dict[str, Any]:
"""
Estimate cost for video editing operation.
Note: FFmpeg-based operations are free.
AI-based operations will have costs (Phase 3).
"""
try:
require_authenticated_user(current_user)
edit_service = EditService()
estimated_cost = edit_service.calculate_cost(edit_type, duration)
return {
"estimated_cost": estimated_cost,
"edit_type": edit_type,
"estimated_duration": duration,
"pricing_model": "free", # FFmpeg operations are free
"note": "FFmpeg-based editing operations are free. AI-based operations may have costs.",
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Cost estimation failed: {str(e)}"
)
# ==================== Phase 2: Text & Audio Endpoints ====================
@router.post("/edit/text")
async def add_text_overlay(
file: UploadFile = File(..., description="Video file to add text overlay"),
text: str = Form(..., description="Text to overlay on video"),
position: str = Form("center", description="Text position: top, center, bottom, top-left, top-right, bottom-left, bottom-right"),
font_size: int = Form(48, description="Font size in pixels"),
font_color: str = Form("white", description="Font color (e.g., white, #FFFFFF)"),
background_color: str = Form("black@0.5", description="Background color with opacity (e.g., black@0.5)"),
start_time: float = Form(0.0, description="When to start showing text (seconds)"),
end_time: Optional[float] = Form(None, description="When to stop showing text (None = end of video)"),
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db),
) -> Dict[str, Any]:
"""
Add text overlay to video using FFmpeg drawtext filter.
Supports:
- Multiple positions (center, top, bottom, corners)
- Custom font size and colors
- Background box with opacity
- Time-limited display (show text only during specific time range)
"""
try:
user_id = require_authenticated_user(current_user)
if not file.content_type.startswith('video/'):
raise HTTPException(status_code=400, detail="File must be a video")
valid_positions = ["top", "center", "bottom", "top-left", "top-right", "bottom-left", "bottom-right"]
if position not in valid_positions:
raise HTTPException(
status_code=400,
detail=f"Invalid position. Must be one of: {', '.join(valid_positions)}"
)
edit_service = EditService()
video_data = await file.read()
result = await edit_service.add_text_overlay(
video_data=video_data,
text=text,
position=position,
font_size=font_size,
font_color=font_color,
background_color=background_color,
start_time=start_time,
end_time=end_time,
user_id=user_id,
)
return result
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Text overlay failed: {str(e)}"
)
@router.post("/edit/volume")
async def adjust_volume(
file: UploadFile = File(..., description="Video file to adjust volume"),
volume_factor: float = Form(..., description="Volume multiplier (0.0 = mute, 1.0 = original, 2.0 = double)"),
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db),
) -> Dict[str, Any]:
"""
Adjust video audio volume.
Supports:
- Mute (0.0)
- Reduce volume (0.0 - 1.0)
- Original (1.0)
- Increase volume (1.0 - 3.0+)
"""
try:
user_id = require_authenticated_user(current_user)
if not file.content_type.startswith('video/'):
raise HTTPException(status_code=400, detail="File must be a video")
if volume_factor < 0:
raise HTTPException(status_code=400, detail="Volume factor must be non-negative")
if volume_factor > 5.0:
raise HTTPException(status_code=400, detail="Volume factor cannot exceed 5.0 to prevent distortion")
edit_service = EditService()
video_data = await file.read()
result = await edit_service.adjust_volume(
video_data=video_data,
volume_factor=volume_factor,
user_id=user_id,
)
return result
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Volume adjustment failed: {str(e)}"
)
@router.post("/edit/normalize")
async def normalize_audio(
file: UploadFile = File(..., description="Video file to normalize audio"),
target_level: float = Form(-14.0, description="Target integrated loudness in LUFS (default: -14 for streaming)"),
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db),
) -> Dict[str, Any]:
"""
Normalize audio levels using EBU R128 standard (loudnorm filter).
Common target levels:
- -14 LUFS: YouTube, Spotify, general streaming
- -16 LUFS: Podcast standard
- -23 LUFS: Broadcast TV (EBU R128)
- -24 LUFS: US Broadcast (ATSC A/85)
"""
try:
user_id = require_authenticated_user(current_user)
if not file.content_type.startswith('video/'):
raise HTTPException(status_code=400, detail="File must be a video")
if target_level > 0 or target_level < -50:
raise HTTPException(
status_code=400,
detail="Target level must be between -50 and 0 LUFS"
)
edit_service = EditService()
video_data = await file.read()
result = await edit_service.normalize_audio(
video_data=video_data,
target_level=target_level,
user_id=user_id,
)
return result
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Audio normalization failed: {str(e)}"
)
@router.post("/edit/denoise")
async def reduce_noise(
file: UploadFile = File(..., description="Video file to reduce audio noise"),
strength: float = Form(0.5, description="Noise reduction strength (0.0 - 1.0)"),
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db),
) -> Dict[str, Any]:
"""
Reduce audio noise using FFmpeg's noise reduction filters.
Supports:
- Light noise reduction (0.0 - 0.3): Subtle cleanup
- Moderate reduction (0.3 - 0.6): Good for background noise
- Strong reduction (0.6 - 1.0): Heavy noise, may affect audio quality
"""
try:
user_id = require_authenticated_user(current_user)
if not file.content_type.startswith('video/'):
raise HTTPException(status_code=400, detail="File must be a video")
if strength < 0 or strength > 1:
raise HTTPException(
status_code=400,
detail="Strength must be between 0.0 and 1.0"
)
edit_service = EditService()
video_data = await file.read()
result = await edit_service.reduce_noise(
video_data=video_data,
noise_reduction_strength=strength,
user_id=user_id,
)
return result
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Noise reduction failed: {str(e)}"
)

View File

@@ -110,6 +110,11 @@ class ContentAssetService:
search_query: Optional[str] = None,
tags: Optional[List[str]] = None,
favorites_only: bool = False,
collection_id: Optional[int] = None,
date_from: Optional[datetime] = None,
date_to: Optional[datetime] = None,
sort_by: str = "created_at",
sort_order: str = "desc",
limit: int = 100,
offset: int = 0,
) -> Tuple[List[ContentAsset], int]:
@@ -157,11 +162,37 @@ class ContentAssetService:
tag_filters = [ContentAsset.tags.contains([tag]) for tag in tags]
query = query.filter(or_(*tag_filters))
if collection_id:
query = query.filter(ContentAsset.collection_id == collection_id)
if date_from:
query = query.filter(ContentAsset.created_at >= date_from)
if date_to:
query = query.filter(ContentAsset.created_at <= date_to)
# Get total count before pagination
total_count = query.count()
# Apply ordering and pagination
query = query.order_by(desc(ContentAsset.created_at))
# Apply ordering
order_column = ContentAsset.created_at
if sort_by == "created_at":
order_column = ContentAsset.created_at
elif sort_by == "updated_at":
order_column = ContentAsset.updated_at
elif sort_by == "cost":
order_column = ContentAsset.cost
elif sort_by == "file_size":
order_column = ContentAsset.file_size
elif sort_by == "title":
order_column = ContentAsset.title
if sort_order.lower() == "asc":
query = query.order_by(order_column)
else:
query = query.order_by(desc(order_column))
# Apply pagination
query = query.limit(limit).offset(offset)
return query.all(), total_count
@@ -319,4 +350,231 @@ class ContentAssetService:
"total_cost": 0.0,
"favorites_count": 0,
}
# ==================== Collection Management ====================
def create_collection(
self,
user_id: str,
name: str,
description: Optional[str] = None,
is_public: bool = False,
) -> AssetCollection:
"""Create a new asset collection."""
try:
collection = AssetCollection(
user_id=user_id,
name=name,
description=description,
is_public=is_public,
)
self.db.add(collection)
self.db.commit()
self.db.refresh(collection)
logger.info(f"Created collection {collection.id} '{name}' for user {user_id}")
return collection
except Exception as e:
self.db.rollback()
logger.error(f"Error creating collection: {str(e)}", exc_info=True)
raise
def get_user_collections(
self,
user_id: str,
limit: int = 100,
offset: int = 0,
) -> Tuple[List[AssetCollection], int]:
"""Get all collections for a user."""
try:
query = self.db.query(AssetCollection).filter(
AssetCollection.user_id == user_id
)
total_count = query.count()
query = query.order_by(desc(AssetCollection.created_at))
query = query.limit(limit).offset(offset)
return query.all(), total_count
except Exception as e:
logger.error(f"Error fetching collections: {str(e)}", exc_info=True)
raise
def get_collection_by_id(self, collection_id: int, user_id: str) -> Optional[AssetCollection]:
"""Get a specific collection by ID."""
try:
return self.db.query(AssetCollection).filter(
and_(
AssetCollection.id == collection_id,
AssetCollection.user_id == user_id
)
).first()
except Exception as e:
logger.error(f"Error fetching collection {collection_id}: {str(e)}", exc_info=True)
return None
def update_collection(
self,
collection_id: int,
user_id: str,
name: Optional[str] = None,
description: Optional[str] = None,
is_public: Optional[bool] = None,
cover_asset_id: Optional[int] = None,
) -> Optional[AssetCollection]:
"""Update collection metadata."""
try:
collection = self.get_collection_by_id(collection_id, user_id)
if not collection:
return None
if name is not None:
collection.name = name
if description is not None:
collection.description = description
if is_public is not None:
collection.is_public = is_public
if cover_asset_id is not None:
# Verify asset belongs to user
asset = self.get_asset_by_id(cover_asset_id, user_id)
if asset:
collection.cover_asset_id = cover_asset_id
else:
collection.cover_asset_id = None
collection.updated_at = datetime.utcnow()
self.db.commit()
self.db.refresh(collection)
logger.info(f"Updated collection {collection_id} for user {user_id}")
return collection
except Exception as e:
self.db.rollback()
logger.error(f"Error updating collection: {str(e)}", exc_info=True)
return None
def delete_collection(self, collection_id: int, user_id: str) -> bool:
"""Delete a collection (assets are not deleted, just removed from collection)."""
try:
collection = self.get_collection_by_id(collection_id, user_id)
if not collection:
return False
# Remove assets from collection before deleting
self.db.query(ContentAsset).filter(
ContentAsset.collection_id == collection_id
).update({ContentAsset.collection_id: None})
self.db.delete(collection)
self.db.commit()
logger.info(f"Deleted collection {collection_id} for user {user_id}")
return True
except Exception as e:
self.db.rollback()
logger.error(f"Error deleting collection: {str(e)}", exc_info=True)
return False
def add_assets_to_collection(
self,
collection_id: int,
user_id: str,
asset_ids: List[int],
) -> int:
"""Add assets to a collection. Returns number of assets added."""
try:
collection = self.get_collection_by_id(collection_id, user_id)
if not collection:
return 0
# Verify all assets belong to user
assets = self.db.query(ContentAsset).filter(
and_(
ContentAsset.id.in_(asset_ids),
ContentAsset.user_id == user_id
)
).all()
count = 0
for asset in assets:
asset.collection_id = collection_id
count += 1
collection.updated_at = datetime.utcnow()
self.db.commit()
logger.info(f"Added {count} assets to collection {collection_id}")
return count
except Exception as e:
self.db.rollback()
logger.error(f"Error adding assets to collection: {str(e)}", exc_info=True)
return 0
def remove_assets_from_collection(
self,
collection_id: int,
user_id: str,
asset_ids: List[int],
) -> int:
"""Remove assets from a collection. Returns number of assets removed."""
try:
collection = self.get_collection_by_id(collection_id, user_id)
if not collection:
return 0
# Remove assets from collection
count = self.db.query(ContentAsset).filter(
and_(
ContentAsset.id.in_(asset_ids),
ContentAsset.collection_id == collection_id,
ContentAsset.user_id == user_id
)
).update({ContentAsset.collection_id: None})
collection.updated_at = datetime.utcnow()
self.db.commit()
logger.info(f"Removed {count} assets from collection {collection_id}")
return count
except Exception as e:
self.db.rollback()
logger.error(f"Error removing assets from collection: {str(e)}", exc_info=True)
return 0
def get_collection_assets(
self,
collection_id: int,
user_id: str,
limit: int = 100,
offset: int = 0,
) -> Tuple[List[ContentAsset], int]:
"""Get all assets in a collection."""
try:
collection = self.get_collection_by_id(collection_id, user_id)
if not collection:
return [], 0
query = self.db.query(ContentAsset).filter(
and_(
ContentAsset.collection_id == collection_id,
ContentAsset.user_id == user_id
)
)
total_count = query.count()
query = query.order_by(desc(ContentAsset.created_at))
query = query.limit(limit).offset(offset)
return query.all(), total_count
except Exception as e:
logger.error(f"Error fetching collection assets: {str(e)}", exc_info=True)
raise

View File

@@ -6,6 +6,8 @@ from .edit_service import EditStudioService, EditStudioRequest
from .upscale_service import UpscaleStudioService, UpscaleStudioRequest
from .control_service import ControlStudioService, ControlStudioRequest
from .social_optimizer_service import SocialOptimizerService, SocialOptimizerRequest
from .compression_service import ImageCompressionService, CompressionRequest, CompressionResult
from .format_converter_service import ImageFormatConverterService, FormatConversionRequest, FormatConversionResult
from .transform_service import (
TransformStudioService,
TransformImageToVideoRequest,
@@ -25,6 +27,12 @@ __all__ = [
"ControlStudioRequest",
"SocialOptimizerService",
"SocialOptimizerRequest",
"ImageCompressionService",
"CompressionRequest",
"CompressionResult",
"ImageFormatConverterService",
"FormatConversionRequest",
"FormatConversionResult",
"TransformStudioService",
"TransformImageToVideoRequest",
"TalkingAvatarRequest",

View File

@@ -0,0 +1,367 @@
"""Image Compression Service for optimizing image file sizes."""
from __future__ import annotations
import base64
import io
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Literal
from PIL import Image, ExifTags
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.compression")
@dataclass
class CompressionRequest:
"""Request model for image compression."""
image_base64: str
quality: int = 85 # 1-100, where 100 is best quality
format: str = "jpeg" # jpeg, png, webp, avif
target_size_kb: Optional[int] = None # Target file size in KB
strip_metadata: bool = True
progressive: bool = True # Progressive JPEG
optimize: bool = True # Optimize encoding
@dataclass
class CompressionResult:
"""Result of compression operation."""
success: bool
image_base64: str
original_size_kb: float
compressed_size_kb: float
compression_ratio: float
format: str
width: int
height: int
quality_used: int
metadata_stripped: bool
class ImageCompressionService:
"""Service for image compression and optimization."""
SUPPORTED_FORMATS = ["jpeg", "jpg", "png", "webp"]
# Format-specific options
FORMAT_OPTIONS = {
"jpeg": {"quality": (1, 100), "progressive": True, "optimize": True},
"jpg": {"quality": (1, 100), "progressive": True, "optimize": True},
"png": {"compress_level": (0, 9), "optimize": True},
"webp": {"quality": (1, 100), "lossless": False},
}
def __init__(self):
logger.info("[Compression] ImageCompressionService initialized")
def _decode_image(self, image_base64: str) -> tuple[Image.Image, int]:
"""Decode base64 image and return PIL Image and original size."""
# Handle data URL format
if "," in image_base64:
image_base64 = image_base64.split(",", 1)[1]
image_bytes = base64.b64decode(image_base64)
original_size = len(image_bytes)
image = Image.open(io.BytesIO(image_bytes))
return image, original_size
def _strip_exif(self, image: Image.Image) -> Image.Image:
"""Remove EXIF metadata from image."""
# Create a new image without EXIF data
data = list(image.getdata())
image_without_exif = Image.new(image.mode, image.size)
image_without_exif.putdata(data)
return image_without_exif
def _compress_to_target_size(
self,
image: Image.Image,
target_size_kb: int,
format: str,
min_quality: int = 10,
max_quality: int = 95,
) -> tuple[bytes, int]:
"""Compress image to target file size using binary search."""
target_bytes = target_size_kb * 1024
low, high = min_quality, max_quality
best_result = None
best_quality = max_quality
while low <= high:
mid = (low + high) // 2
compressed = self._compress_image(image, format, mid, True, True)
if len(compressed) <= target_bytes:
best_result = compressed
best_quality = mid
low = mid + 1 # Try higher quality
else:
high = mid - 1 # Try lower quality
if best_result is None:
# Even minimum quality exceeds target, return min quality result
best_result = self._compress_image(image, format, min_quality, True, True)
best_quality = min_quality
return best_result, best_quality
def _compress_image(
self,
image: Image.Image,
format: str,
quality: int,
progressive: bool,
optimize: bool,
) -> bytes:
"""Compress image with given settings."""
buffer = io.BytesIO()
# Handle format-specific options
save_kwargs: Dict[str, Any] = {}
format_lower = format.lower()
if format_lower in ["jpeg", "jpg"]:
# Convert to RGB if necessary (JPEG doesn't support alpha)
if image.mode in ("RGBA", "P"):
image = image.convert("RGB")
save_kwargs["format"] = "JPEG"
save_kwargs["quality"] = quality
save_kwargs["optimize"] = optimize
if progressive:
save_kwargs["progressive"] = True
elif format_lower == "png":
save_kwargs["format"] = "PNG"
save_kwargs["optimize"] = optimize
# PNG uses compress_level (0-9) instead of quality
compress_level = max(0, min(9, (100 - quality) // 11))
save_kwargs["compress_level"] = compress_level
elif format_lower == "webp":
save_kwargs["format"] = "WEBP"
save_kwargs["quality"] = quality
save_kwargs["method"] = 6 # Best compression
else:
raise ValueError(f"Unsupported format: {format}")
image.save(buffer, **save_kwargs)
return buffer.getvalue()
async def compress(
self,
request: CompressionRequest,
user_id: Optional[str] = None,
) -> CompressionResult:
"""Compress an image with specified settings."""
logger.info(f"[Compression] Processing compression request for user: {user_id}")
try:
# Decode image
image, original_size = self._decode_image(request.image_base64)
original_size_kb = original_size / 1024
logger.info(f"[Compression] Original size: {original_size_kb:.2f} KB, dimensions: {image.size}")
# Strip metadata if requested
if request.strip_metadata:
image = self._strip_exif(image)
# Validate format
format_lower = request.format.lower()
if format_lower not in self.SUPPORTED_FORMATS:
raise ValueError(f"Unsupported format: {request.format}. Supported: {self.SUPPORTED_FORMATS}")
# Compress to target size or with quality setting
if request.target_size_kb:
compressed_bytes, quality_used = self._compress_to_target_size(
image,
request.target_size_kb,
format_lower,
)
else:
compressed_bytes = self._compress_image(
image,
format_lower,
request.quality,
request.progressive,
request.optimize,
)
quality_used = request.quality
compressed_size_kb = len(compressed_bytes) / 1024
compression_ratio = (1 - compressed_size_kb / original_size_kb) * 100 if original_size_kb > 0 else 0
# Encode result
mime_type = "image/jpeg" if format_lower in ["jpeg", "jpg"] else f"image/{format_lower}"
result_base64 = f"data:{mime_type};base64,{base64.b64encode(compressed_bytes).decode()}"
logger.info(f"[Compression] Compressed: {original_size_kb:.2f}KB → {compressed_size_kb:.2f}KB ({compression_ratio:.1f}% reduction)")
return CompressionResult(
success=True,
image_base64=result_base64,
original_size_kb=round(original_size_kb, 2),
compressed_size_kb=round(compressed_size_kb, 2),
compression_ratio=round(compression_ratio, 2),
format=format_lower,
width=image.width,
height=image.height,
quality_used=quality_used,
metadata_stripped=request.strip_metadata,
)
except Exception as e:
logger.error(f"[Compression] Failed to compress image: {e}")
raise
async def compress_batch(
self,
requests: List[CompressionRequest],
user_id: Optional[str] = None,
) -> List[CompressionResult]:
"""Compress multiple images with same or individual settings."""
logger.info(f"[Compression] Processing batch of {len(requests)} images for user: {user_id}")
results = []
for i, request in enumerate(requests):
try:
result = await self.compress(request, user_id)
results.append(result)
logger.info(f"[Compression] Batch item {i+1}/{len(requests)} complete")
except Exception as e:
logger.error(f"[Compression] Batch item {i+1} failed: {e}")
# Return partial success
results.append(CompressionResult(
success=False,
image_base64="",
original_size_kb=0,
compressed_size_kb=0,
compression_ratio=0,
format="",
width=0,
height=0,
quality_used=0,
metadata_stripped=False,
))
return results
async def estimate_compression(
self,
image_base64: str,
format: str = "jpeg",
quality: int = 85,
) -> Dict[str, Any]:
"""Estimate compression results without actually compressing."""
try:
image, original_size = self._decode_image(image_base64)
original_size_kb = original_size / 1024
# Quick estimation based on format and quality
if format.lower() in ["jpeg", "jpg"]:
# JPEG compression ratio estimate
estimated_ratio = 0.1 + (quality / 100) * 0.4 # 10-50% of original
elif format.lower() == "webp":
# WebP is typically 25-34% smaller than JPEG
estimated_ratio = 0.08 + (quality / 100) * 0.35
else: # PNG
estimated_ratio = 0.7 + (quality / 100) * 0.2 # PNG is less compressible
estimated_size_kb = original_size_kb * estimated_ratio
return {
"original_size_kb": round(original_size_kb, 2),
"estimated_size_kb": round(estimated_size_kb, 2),
"estimated_reduction_percent": round((1 - estimated_ratio) * 100, 1),
"width": image.width,
"height": image.height,
"format": format.lower(),
}
except Exception as e:
logger.error(f"[Compression] Estimation failed: {e}")
raise
def get_supported_formats(self) -> List[Dict[str, Any]]:
"""Get list of supported compression formats with details."""
return [
{
"id": "jpeg",
"name": "JPEG",
"extension": ".jpg",
"description": "Best for photos. Lossy compression with excellent size reduction.",
"supports_transparency": False,
"quality_range": [1, 100],
"recommended_quality": 85,
"use_cases": ["Photos", "Blog images", "Email", "Social media"],
},
{
"id": "png",
"name": "PNG",
"extension": ".png",
"description": "Best for graphics with transparency. Lossless compression.",
"supports_transparency": True,
"quality_range": [1, 100],
"recommended_quality": 90,
"use_cases": ["Logos", "Icons", "Graphics", "Screenshots"],
},
{
"id": "webp",
"name": "WebP",
"extension": ".webp",
"description": "Modern format with excellent compression. 25-34% smaller than JPEG.",
"supports_transparency": True,
"quality_range": [1, 100],
"recommended_quality": 80,
"use_cases": ["Web images", "Fast loading", "Modern browsers"],
},
]
def get_presets(self) -> List[Dict[str, Any]]:
"""Get compression presets for common use cases."""
return [
{
"id": "web",
"name": "Web Optimized",
"description": "Balanced quality and size for web pages",
"format": "webp",
"quality": 80,
"strip_metadata": True,
},
{
"id": "email",
"name": "Email Friendly",
"description": "Small file size for email attachments (<200KB target)",
"format": "jpeg",
"quality": 70,
"target_size_kb": 200,
"strip_metadata": True,
},
{
"id": "social",
"name": "Social Media",
"description": "Optimized for social platforms",
"format": "jpeg",
"quality": 85,
"strip_metadata": True,
},
{
"id": "high_quality",
"name": "High Quality",
"description": "Minimal compression for quality-critical images",
"format": "png",
"quality": 95,
"strip_metadata": False,
},
{
"id": "maximum",
"name": "Maximum Compression",
"description": "Smallest possible file size",
"format": "webp",
"quality": 60,
"strip_metadata": True,
},
]

View File

@@ -1,17 +1,10 @@
"""Create Studio service for AI-powered image generation."""
import os
from typing import Optional, Dict, Any, List, Literal
from dataclasses import dataclass
from services.llm_providers.image_generation import (
ImageGenerationOptions,
ImageGenerationResult,
HuggingFaceImageProvider,
GeminiImageProvider,
StabilityImageProvider,
WaveSpeedImageProvider,
)
from services.llm_providers.main_image_generation import generate_image
from services.llm_providers.image_generation import ImageGenerationResult
from .templates import TemplateManager, ImageTemplate, Platform, TemplateCategory
from utils.logger_utils import get_service_logger
@@ -75,29 +68,8 @@ class CreateStudioService:
self.template_manager = TemplateManager()
logger.info("[Create Studio] Initialized with template manager")
def _get_provider_instance(self, provider_name: str, api_key: Optional[str] = None):
"""Get provider instance by name.
Args:
provider_name: Name of the provider
api_key: Optional API key (uses env vars if not provided)
Returns:
Provider instance
Raises:
ValueError: If provider is not supported
"""
if provider_name == "stability":
return StabilityImageProvider(api_key=api_key or os.getenv("STABILITY_API_KEY"))
elif provider_name == "wavespeed":
return WaveSpeedImageProvider(api_key=api_key or os.getenv("WAVESPEED_API_KEY"))
elif provider_name == "huggingface":
return HuggingFaceImageProvider(api_token=api_key or os.getenv("HF_API_KEY"))
elif provider_name == "gemini":
return GeminiImageProvider(api_key=api_key or os.getenv("GEMINI_API_KEY"))
else:
raise ValueError(f"Unsupported provider: {provider_name}")
# Removed _get_provider_instance() - now using unified entry point
# Provider selection is handled by main_image_generation.generate_image()
def _select_provider_and_model(
self,
@@ -289,30 +261,17 @@ class CreateStudioService:
logger.info("[Create Studio] Starting generation: prompt=%s, template=%s",
request.prompt[:100], request.template_id)
# Pre-flight validation: Check subscription and usage limits
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_generation_operations
from fastapi import HTTPException
db = next(get_db())
try:
pricing_service = PricingService(db)
logger.info(f"[Create Studio] 🛂 Running pre-flight validation for user {user_id}")
validate_image_generation_operations(
pricing_service=pricing_service,
user_id=user_id,
num_images=request.num_variations
)
logger.info(f"[Create Studio] ✅ Pre-flight validation passed - proceeding with generation")
except HTTPException as http_ex:
logger.error(f"[Create Studio] ❌ Pre-flight validation failed - blocking generation")
raise
finally:
db.close()
else:
logger.warning("[Create Studio] ⚠️ No user_id provided - skipping pre-flight validation")
# Pre-flight validation: Reuse unified helper
# Note: Validation for num_variations will be done per-image in generate_image()
# We validate once upfront to fail fast if user has no credits
if user_id and request.num_variations > 0:
from services.llm_providers.main_image_generation import _validate_image_operation
_validate_image_operation(
user_id=user_id,
operation_type="create-studio-generation",
num_operations=request.num_variations,
log_prefix="[Create Studio]"
)
# Load template if specified
template = None
@@ -337,36 +296,37 @@ class CreateStudioService:
# Select provider and model
provider_name, model = self._select_provider_and_model(request, template)
# Get provider instance
try:
provider = self._get_provider_instance(provider_name)
except Exception as e:
logger.error("[Create Studio] ❌ Failed to initialize provider %s: %s",
provider_name, str(e))
raise RuntimeError(f"Provider initialization failed: {str(e)}")
# Generate images
# Generate images using unified entry point
# This ensures consistent validation, tracking, and error handling
results = []
for i in range(request.num_variations):
logger.info("[Create Studio] Generating variation %d/%d",
i + 1, request.num_variations)
try:
# Prepare options
options = ImageGenerationOptions(
prompt=prompt,
negative_prompt=request.negative_prompt,
width=width,
height=height,
guidance_scale=request.guidance_scale,
steps=request.steps,
seed=request.seed + i if request.seed else None,
model=model,
extra={"style_preset": request.style_preset} if request.style_preset else {}
)
# Prepare options for unified entry point
options = {
"provider": provider_name,
"model": model,
"width": width,
"height": height,
"negative_prompt": request.negative_prompt,
"guidance_scale": request.guidance_scale,
"steps": request.steps,
"seed": request.seed + i if request.seed else None,
}
# Generate image
result: ImageGenerationResult = provider.generate(options)
# Add style preset to extra if specified
if request.style_preset:
options["extra"] = {"style_preset": request.style_preset}
# Generate image using unified entry point
# This handles validation, provider selection, generation, and tracking automatically
result: ImageGenerationResult = generate_image(
prompt=prompt,
options=options,
user_id=user_id
)
results.append({
"image_bytes": result.image_bytes,

View File

@@ -11,6 +11,7 @@ from typing import Any, Dict, Literal, Optional
from PIL import Image
from services.llm_providers.main_image_editing import edit_image as huggingface_edit_image
from services.llm_providers.main_image_generation import generate_image_edit
from services.stability_service import StabilityAIService
from utils.logger_utils import get_service_logger
@@ -213,6 +214,249 @@ class EditStudioService:
def list_operations(self) -> Dict[str, Dict[str, Any]]:
"""Expose supported operations for UI rendering."""
return self.SUPPORTED_OPERATIONS
def get_available_models(
self,
operation: Optional[str] = None,
tier: Optional[str] = None,
) -> Dict[str, Any]:
"""Get available WaveSpeed editing models.
Args:
operation: Filter by operation type (e.g., "general_edit")
tier: Filter by tier ("budget", "mid", "premium")
Returns:
Dictionary with models and metadata
"""
from services.llm_providers.image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
provider = WaveSpeedEditProvider()
all_models = provider.get_available_models()
# Filter by operation if specified
if operation:
filtered = provider.get_models_by_operation(operation)
all_models = {k: v for k, v in all_models.items() if k in filtered}
# Filter by tier if specified
if tier:
filtered = provider.get_models_by_tier(tier)
all_models = {k: v for k, v in all_models.items() if k in filtered}
# Format for API response
models_list = []
for model_id, model_info in all_models.items():
models_list.append({
"id": model_id,
"name": model_info.get("name", model_id),
"description": model_info.get("description", ""),
"cost": model_info.get("cost", 0.02),
"cost_8k": model_info.get("cost_8k"), # Optional
"tier": model_info.get("tier", "mid"),
"max_resolution": model_info.get("max_resolution", [2048, 2048]),
"capabilities": model_info.get("capabilities", []),
"use_cases": self._get_use_cases_for_model(model_id, model_info),
"features": self._get_features_for_model(model_info),
"supports_multi_image": model_info.get("supports_multi_image", False),
"supports_controlnet": model_info.get("supports_controlnet", False),
"languages": model_info.get("languages", ["en"]),
})
return {
"models": models_list,
"total": len(models_list),
}
def recommend_model(
self,
operation: str,
image_resolution: Optional[Dict[str, int]] = None,
user_tier: Optional[str] = None,
preferences: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Recommend best model for given operation and context.
Args:
operation: Operation type (e.g., "general_edit")
image_resolution: Dict with "width" and "height"
user_tier: User subscription tier ("free", "pro", "enterprise")
preferences: Dict with "prioritize_cost" or "prioritize_quality"
Returns:
Dictionary with recommended model and alternatives
"""
from services.llm_providers.image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
provider = WaveSpeedEditProvider()
available_models = provider.get_models_by_operation(operation)
if not available_models:
# Fallback to all models if operation doesn't match
available_models = provider.get_available_models()
# Filter by resolution if provided
if image_resolution:
width = image_resolution.get("width", 0)
height = image_resolution.get("height", 0)
max_dimension = max(width, height)
# Filter models that support this resolution
filtered = {}
for model_id, model_info in available_models.items():
max_res = model_info.get("max_resolution", (2048, 2048))
max_supported = max(max_res[0], max_res[1])
if max_dimension <= max_supported:
filtered[model_id] = model_info
available_models = filtered
if not available_models:
# No models match, return first available
all_models = provider.get_available_models()
if all_models:
first_model_id = list(all_models.keys())[0]
return {
"recommended_model": first_model_id,
"reason": "No specific match found, using default model",
"alternatives": [],
}
else:
raise ValueError("No models available")
# Apply preferences
prioritize_cost = preferences and preferences.get("prioritize_cost", False)
prioritize_quality = preferences and preferences.get("prioritize_quality", False)
# Score models
scored_models = []
for model_id, model_info in available_models.items():
score = 0
cost = model_info.get("cost", 0.02)
tier = model_info.get("tier", "mid")
max_res = model_info.get("max_resolution", (2048, 2048))
max_resolution = max(max_res[0], max_res[1])
# Cost scoring (lower is better)
if prioritize_cost:
score += (1.0 / cost) * 100 # Invert cost for scoring
else:
score += (1.0 / cost) * 50 # Less weight if not prioritizing
# Quality scoring (higher resolution = better)
if prioritize_quality:
score += max_resolution / 10 # Higher weight for quality
else:
score += max_resolution / 20 # Lower weight
# Tier preference based on user tier
if user_tier == "free":
if tier == "budget":
score += 50
elif tier == "mid":
score += 20
elif user_tier in ["pro", "enterprise"]:
if tier == "premium":
score += 50
elif tier == "mid":
score += 30
scored_models.append((model_id, model_info, score))
# Sort by score (highest first)
scored_models.sort(key=lambda x: x[2], reverse=True)
# Get recommended model
recommended_id, recommended_info, recommended_score = scored_models[0]
# Build reason
reasons = []
if prioritize_cost:
reasons.append("Lowest cost option")
if prioritize_quality:
reasons.append("Best quality")
if image_resolution:
reasons.append(f"Supports {image_resolution.get('width')}×{image_resolution.get('height')} resolution")
if user_tier == "free" and recommended_info.get("tier") == "budget":
reasons.append("Budget-friendly for free tier")
reason = ", ".join(reasons) if reasons else "Best match for your requirements"
# Get alternatives (top 2-3)
alternatives = []
for model_id, model_info, score in scored_models[1:4]:
alt_reason = f"Alternative: {model_info.get('tier', 'mid').title()} tier"
if model_info.get("cost", 0) < recommended_info.get("cost", 0):
alt_reason += ", lower cost"
elif model_info.get("cost", 0) > recommended_info.get("cost", 0):
alt_reason += ", higher quality"
alternatives.append({
"model_id": model_id,
"name": model_info.get("name", model_id),
"cost": model_info.get("cost", 0.02),
"reason": alt_reason,
})
return {
"recommended_model": recommended_id,
"reason": reason,
"alternatives": alternatives,
}
def _get_use_cases_for_model(self, model_id: str, model_info: Dict[str, Any]) -> list:
"""Get use cases for a model based on its capabilities."""
use_cases_map = {
"general_edit": ["Quick edits", "Style changes", "Background replacement"],
"style_transfer": ["Apply artistic styles", "Style transformations"],
"text_edit": ["Add text to images", "Edit text in images"],
"multi_image": ["Batch editing", "Consistent character work"],
"high_res": ["Professional work", "Print materials", "4K/8K editing"],
"professional": ["Marketing campaigns", "Brand assets"],
"typography": ["Text-heavy edits", "Typography generation"],
"portrait_retouching": ["Portrait edits", "Beauty retouching"],
"fashion_edit": ["Fashion photography", "Outfit changes"],
"product_edit": ["E-commerce", "Product photography"],
}
capabilities = model_info.get("capabilities", [])
use_cases = []
for cap in capabilities:
if cap in use_cases_map:
use_cases.extend(use_cases_map[cap])
# Remove duplicates
return list(set(use_cases)) if use_cases else ["General image editing"]
def _get_features_for_model(self, model_info: Dict[str, Any]) -> list:
"""Get feature list for a model."""
features = []
if model_info.get("supports_multi_image"):
max_images = model_info.get("api_params", {}).get("max_images", 0)
if max_images:
features.append(f"Multi-image ({max_images} images)")
else:
features.append("Multi-image support")
if model_info.get("supports_controlnet"):
features.append("ControlNet support")
languages = model_info.get("languages", [])
if len(languages) > 1:
features.append(f"Multilingual ({', '.join(languages)})")
elif "multilingual" in languages:
features.append("Multilingual support")
max_res = model_info.get("max_resolution", (2048, 2048))
if max(max_res) >= 4096:
features.append("4K/8K support")
elif max(max_res) >= 2048:
features.append("2K support")
api_params = model_info.get("api_params", {})
if api_params.get("supports_guidance_scale"):
features.append("Guidance scale control")
return features if features else ["Standard editing"]
async def process_edit(
self,
@@ -221,6 +465,9 @@ class EditStudioService:
) -> Dict[str, Any]:
"""Process edit request and return normalized response."""
# Pre-flight validation: Use specific validator for editing operations
# Note: Editing uses validate_image_editing_operations (different from generation)
# This is intentional as editing may have different subscription limits
if user_id:
from services.database import get_db
from services.subscription import PricingService
@@ -386,29 +633,109 @@ class EditStudioService:
mask_bytes: Optional[bytes],
user_id: Optional[str],
) -> bytes:
"""Execute Hugging Face powered general editing (synchronous API)."""
"""Execute general editing - routes to WaveSpeed (unified entry) or HuggingFace (legacy).
If model is a WaveSpeed model (qwen-edit-plus, nano-banana-pro-edit-ultra, seedream-v4.5-edit),
uses unified entry point. Otherwise falls back to HuggingFace for backward compatibility.
"""
if not request.prompt:
raise ValueError("Prompt is required for general edits")
options = {
"provider": request.provider or "huggingface",
"model": request.model,
"guidance_scale": request.guidance_scale,
"steps": request.steps,
"seed": request.seed,
}
# huggingface edit is synchronous - run in thread
result = await asyncio.to_thread(
huggingface_edit_image,
image_bytes,
request.prompt,
options,
user_id,
mask_bytes, # Optional mask for selective editing
# Check if model is a WaveSpeed editing model
from services.llm_providers.image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
provider = WaveSpeedEditProvider()
wavespeed_models = set(provider.get_available_models().keys())
# Also check if provider is explicitly set to "wavespeed"
is_wavespeed = (
request.provider == "wavespeed" or
(request.model and request.model in wavespeed_models)
)
# Auto-detect: If no model specified and operation is general_edit, recommend one
if not request.model and not is_wavespeed and request.operation == "general_edit":
# Auto-select recommended model
try:
# Get image dimensions for recommendation
with Image.open(io.BytesIO(image_bytes)) as img:
image_resolution = {"width": img.width, "height": img.height}
recommendation = self.recommend_model(
operation=request.operation,
image_resolution=image_resolution,
preferences={"prioritize_cost": True}, # Default to cost-optimized
)
recommended_model = recommendation.get("recommended_model")
if recommended_model and recommended_model in wavespeed_models:
logger.info(f"[Edit Studio] Auto-selected model: {recommended_model} (reason: {recommendation.get('reason')})")
request.model = recommended_model
is_wavespeed = True
except Exception as e:
logger.warning(f"[Edit Studio] Auto-detection failed: {e}, falling back to HuggingFace")
if is_wavespeed:
# Use unified entry point for WaveSpeed models
logger.info(f"[Edit Studio] Using WaveSpeed unified entry for model={request.model}")
# Convert image bytes to base64
import base64
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
# Prepare options for unified entry point
edit_options = {
"mask_base64": None,
"negative_prompt": request.negative_prompt,
"width": None, # Will be determined from image if needed
"height": None,
"guidance_scale": request.guidance_scale,
"steps": request.steps,
"seed": request.seed,
}
# Add mask if provided
if mask_bytes:
edit_options["mask_base64"] = base64.b64encode(mask_bytes).decode("utf-8")
# Extract dimensions from image if needed
with Image.open(io.BytesIO(image_bytes)) as img:
edit_options["width"] = img.width
edit_options["height"] = img.height
# Call unified entry point (synchronous, so run in thread)
result = await asyncio.to_thread(
generate_image_edit,
image_base64=image_base64,
prompt=request.prompt,
operation=request.operation or "general_edit",
model=request.model, # Will auto-select if None
options=edit_options,
user_id=user_id,
)
return result.image_bytes
else:
# Fall back to HuggingFace for backward compatibility
logger.info("[Edit Studio] Using HuggingFace (legacy) for general edit")
options = {
"provider": request.provider or "huggingface",
"model": request.model,
"guidance_scale": request.guidance_scale,
"steps": request.steps,
"seed": request.seed,
}
return result.image_bytes
# huggingface edit is synchronous - run in thread
result = await asyncio.to_thread(
huggingface_edit_image,
image_bytes,
request.prompt,
options,
user_id,
mask_bytes, # Optional mask for selective editing
)
return result.image_bytes
@staticmethod
def _extract_image_bytes(result: Any) -> bytes:

View File

@@ -0,0 +1,266 @@
"""Face Swap Studio service for AI-powered face swapping."""
from __future__ import annotations
import base64
import io
from dataclasses import dataclass
from typing import Any, Dict, Optional
from PIL import Image
from services.llm_providers.main_image_generation import generate_face_swap
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.face_swap")
@dataclass
class FaceSwapStudioRequest:
"""Request model for face swap operations."""
base_image_base64: str
face_image_base64: str
model: Optional[str] = None
target_face_index: Optional[int] = None
target_gender: Optional[str] = None
options: Optional[Dict[str, Any]] = None
class FaceSwapService:
"""Service for face swap operations."""
def __init__(self):
pass
def get_available_models(
self,
tier: Optional[str] = None,
) -> Dict[str, Any]:
"""Get available WaveSpeed face swap models.
Args:
tier: Filter by tier ("budget", "mid", "premium")
Returns:
Dictionary with models and metadata
"""
from services.llm_providers.image_generation.wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
provider = WaveSpeedFaceSwapProvider()
all_models = provider.get_available_models()
# Filter by tier if specified
if tier:
filtered = provider.get_models_by_tier(tier)
all_models = {k: v for k, v in all_models.items() if k in filtered}
# Format for API response
models_list = []
for model_id, model_info in all_models.items():
models_list.append({
"id": model_id,
"name": model_info.get("name", model_id),
"description": model_info.get("description", ""),
"cost": model_info.get("cost", 0.025),
"tier": model_info.get("tier", "mid"),
"capabilities": model_info.get("capabilities", []),
"use_cases": self._get_use_cases_for_model(model_id, model_info),
"features": model_info.get("features", []),
"max_faces": model_info.get("max_faces", 1),
})
return {
"models": models_list,
"total": len(models_list),
}
def recommend_model(
self,
base_image_resolution: Optional[Dict[str, int]] = None,
face_image_resolution: Optional[Dict[str, int]] = None,
user_tier: Optional[str] = None,
preferences: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Recommend best model for face swap.
Args:
base_image_resolution: Dict with "width" and "height" of base image
face_image_resolution: Dict with "width" and "height" of face image
user_tier: User subscription tier ("free", "pro", "enterprise")
preferences: Dict with "prioritize_cost" or "prioritize_quality"
Returns:
Dictionary with recommended model and alternatives
"""
from services.llm_providers.image_generation.wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
provider = WaveSpeedFaceSwapProvider()
available_models = provider.get_available_models()
if not available_models:
raise ValueError("No models available")
# Apply preferences
prioritize_cost = preferences and preferences.get("prioritize_cost", False)
prioritize_quality = preferences and preferences.get("prioritize_quality", False)
# Score models
scored_models = []
for model_id, model_info in available_models.items():
score = 0
cost = model_info.get("cost", 0.025)
tier = model_info.get("tier", "mid")
# Cost scoring (lower is better)
if prioritize_cost:
score += (1.0 / cost) * 100
else:
score += (1.0 / cost) * 50
# Quality scoring (higher cost = better quality for face swap)
if prioritize_quality:
score += cost * 20
else:
score += cost * 10
# Tier preference based on user tier
if user_tier == "free":
if tier == "budget":
score += 50
elif tier == "mid":
score += 20
elif user_tier in ["pro", "enterprise"]:
if tier == "premium":
score += 50
elif tier == "mid":
score += 30
scored_models.append((model_id, model_info, score))
# Sort by score (highest first)
scored_models.sort(key=lambda x: x[2], reverse=True)
# Get recommended model
recommended_id, recommended_info, recommended_score = scored_models[0]
# Build reason
reasons = []
if prioritize_cost:
reasons.append("Lowest cost option")
if prioritize_quality:
reasons.append("Best quality")
if user_tier == "free" and recommended_info.get("tier") == "budget":
reasons.append("Budget-friendly for free tier")
reason = ", ".join(reasons) if reasons else "Best match for your requirements"
# Get alternatives (top 2-3)
alternatives = []
for model_id, model_info, score in scored_models[1:4]:
alt_reason = f"Alternative: {model_info.get('tier', 'mid').title()} tier"
if model_info.get("cost", 0) < recommended_info.get("cost", 0):
alt_reason += ", lower cost"
elif model_info.get("cost", 0) > recommended_info.get("cost", 0):
alt_reason += ", higher quality"
alternatives.append({
"model_id": model_id,
"name": model_info.get("name", model_id),
"cost": model_info.get("cost", 0.025),
"reason": alt_reason,
})
return {
"recommended_model": recommended_id,
"reason": reason,
"alternatives": alternatives,
}
def _get_use_cases_for_model(self, model_id: str, model_info: Dict[str, Any]) -> list:
"""Get use cases for a model based on its capabilities."""
use_cases_map = {
"face_swap": ["Portrait editing", "Fun swaps", "Social media"],
"head_swap": ["Casting and concept design", "Privacy and anonymization", "Photo exploration"],
"full_head_replacement": ["Full head replacement", "Hair included", "Casting mockups"],
"realistic_blending": ["Professional work", "Marketing", "Entertainment"],
"multi_face": ["Group photos", "Family photos", "Team photos", "Creative projects", "Content creation"],
"face_enhancement": ["High-quality results", "Professional work", "Marketing campaigns"],
"identity_preservation": ["Character consistency", "Brand identity"],
}
capabilities = model_info.get("capabilities", [])
use_cases = []
for cap in capabilities:
if cap in use_cases_map:
use_cases.extend(use_cases_map[cap])
return list(set(use_cases)) if use_cases else ["General face swapping"]
async def process_face_swap(
self,
request: FaceSwapStudioRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Process face swap request.
Args:
request: Face swap request
user_id: User ID for tracking
Returns:
Dictionary with result image and metadata
"""
# Auto-detect model if not specified
selected_model = request.model
if not selected_model:
try:
# Get image dimensions for recommendation
base_img = Image.open(io.BytesIO(base64.b64decode(request.base_image_base64.split(",", 1)[1] if "," in request.base_image_base64 else request.base_image_base64)))
face_img = Image.open(io.BytesIO(base64.b64decode(request.face_image_base64.split(",", 1)[1] if "," in request.face_image_base64 else request.face_image_base64)))
base_resolution = {"width": base_img.width, "height": base_img.height}
face_resolution = {"width": face_img.width, "height": face_img.height}
recommendation = self.recommend_model(
base_image_resolution=base_resolution,
face_image_resolution=face_resolution,
preferences={"prioritize_cost": True},
)
selected_model = recommendation.get("recommended_model")
logger.info(f"[Face Swap] Auto-selected model: {selected_model} (reason: {recommendation.get('reason')})")
except Exception as e:
logger.warning(f"[Face Swap] Auto-detection failed: {e}, using default model")
# Use first available model as fallback
from services.llm_providers.image_generation.wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
provider = WaveSpeedFaceSwapProvider()
all_models = provider.get_available_models()
if all_models:
selected_model = list(all_models.keys())[0]
# Prepare options
options = request.options or {}
if request.target_face_index is not None:
options["target_face_index"] = request.target_face_index
if request.target_gender:
options["target_gender"] = request.target_gender
# Call unified entry point
result = generate_face_swap(
base_image_base64=request.base_image_base64,
face_image_base64=request.face_image_base64,
model=selected_model,
options=options,
user_id=user_id,
)
# Convert result to base64
result_base64 = base64.b64encode(result.image_bytes).decode("utf-8")
result_data_url = f"data:image/png;base64,{result_base64}"
return {
"success": True,
"image_base64": result_data_url,
"width": result.width,
"height": result.height,
"provider": result.provider,
"model": result.model,
"metadata": result.metadata or {},
}

View File

@@ -0,0 +1,403 @@
"""Image Format Converter Service for converting between image formats."""
from __future__ import annotations
import base64
import io
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from PIL import Image, ImageCms
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.format_converter")
@dataclass
class FormatConversionRequest:
"""Request model for format conversion."""
image_base64: str
target_format: str # png, jpeg, jpg, webp, gif, bmp, tiff
preserve_transparency: bool = True
quality: Optional[int] = None # For lossy formats (1-100)
color_space: Optional[str] = None # sRGB, Adobe RGB, etc.
strip_metadata: bool = False # Keep metadata by default for conversion
optimize: bool = True
progressive: bool = True # For JPEG
@dataclass
class FormatConversionResult:
"""Result of format conversion."""
success: bool
image_base64: str
original_format: str
target_format: str
original_size_kb: float
converted_size_kb: float
width: int
height: int
transparency_preserved: bool
metadata_preserved: bool
color_space: Optional[str] = None
class ImageFormatConverterService:
"""Service for converting images between formats."""
SUPPORTED_FORMATS = {
"png": {
"name": "PNG",
"description": "Lossless format with transparency support",
"supports_transparency": True,
"supports_lossy": False,
"mime_type": "image/png",
},
"jpeg": {
"name": "JPEG",
"description": "Lossy format, best for photos",
"supports_transparency": False,
"supports_lossy": True,
"mime_type": "image/jpeg",
},
"jpg": {
"name": "JPEG",
"description": "Lossy format, best for photos",
"supports_transparency": False,
"supports_lossy": True,
"mime_type": "image/jpeg",
},
"webp": {
"name": "WebP",
"description": "Modern format with excellent compression",
"supports_transparency": True,
"supports_lossy": True,
"mime_type": "image/webp",
},
"gif": {
"name": "GIF",
"description": "Supports animation and transparency",
"supports_transparency": True,
"supports_lossy": False,
"mime_type": "image/gif",
},
"bmp": {
"name": "BMP",
"description": "Uncompressed bitmap format",
"supports_transparency": False,
"supports_lossy": False,
"mime_type": "image/bmp",
},
"tiff": {
"name": "TIFF",
"description": "High-quality format for print",
"supports_transparency": True,
"supports_lossy": False,
"mime_type": "image/tiff",
},
}
def __init__(self):
logger.info("[Format Converter] ImageFormatConverterService initialized")
def _decode_image(self, image_base64: str) -> tuple[Image.Image, int, str]:
"""Decode base64 image and return PIL Image, size, and format."""
# Handle data URL format
if "," in image_base64:
image_base64 = image_base64.split(",", 1)[1]
image_bytes = base64.b64decode(image_base64)
original_size = len(image_bytes)
image = Image.open(io.BytesIO(image_bytes))
original_format = image.format.lower() if image.format else "unknown"
return image, original_size, original_format
def _strip_exif(self, image: Image.Image) -> Image.Image:
"""Remove EXIF metadata from image."""
data = list(image.getdata())
image_without_exif = Image.new(image.mode, image.size)
image_without_exif.putdata(data)
return image_without_exif
def _convert_color_space(
self,
image: Image.Image,
target_color_space: str,
) -> Image.Image:
"""Convert image color space."""
try:
# Get current color space
if hasattr(image, 'info') and 'icc_profile' in image.info:
# Image has ICC profile
try:
src_profile = ImageCms.ImageCmsProfile(io.BytesIO(image.info['icc_profile']))
if target_color_space.lower() == "srgb":
dst_profile = ImageCms.createProfile("sRGB")
elif target_color_space.lower() == "adobe rgb":
dst_profile = ImageCms.createProfile("Adobe RGB")
else:
return image # Unknown color space
transform = ImageCms.ImageCmsTransform(src_profile, dst_profile, image.mode, image.mode)
image = ImageCms.applyTransform(image, transform)
except Exception as e:
logger.warning(f"[Format Converter] Color space conversion failed: {e}")
else:
# No ICC profile, assume sRGB
logger.info("[Format Converter] No ICC profile found, assuming sRGB")
except Exception as e:
logger.warning(f"[Format Converter] Color space conversion error: {e}")
return image
def _convert_image(
self,
image: Image.Image,
target_format: str,
quality: Optional[int],
preserve_transparency: bool,
optimize: bool,
progressive: bool,
) -> bytes:
"""Convert image to target format."""
buffer = io.BytesIO()
format_lower = target_format.lower()
# Handle format-specific conversions
save_kwargs: Dict[str, Any] = {}
# Check if source has transparency and target doesn't support it
has_transparency = image.mode in ("RGBA", "LA", "P") and (
"transparency" in image.info or image.mode == "RGBA"
)
if format_lower in ["jpeg", "jpg"]:
# JPEG doesn't support transparency
if has_transparency and preserve_transparency:
# Convert to RGB, losing transparency
if image.mode in ("RGBA", "LA"):
# Create white background
rgb_image = Image.new("RGB", image.size, (255, 255, 255))
if image.mode == "RGBA":
rgb_image.paste(image, mask=image.split()[3]) # Use alpha channel as mask
else:
rgb_image.paste(image)
image = rgb_image
elif image.mode == "P":
image = image.convert("RGB")
else:
image = image.convert("RGB")
save_kwargs["format"] = "JPEG"
if quality:
save_kwargs["quality"] = quality
else:
save_kwargs["quality"] = 95 # Default high quality
save_kwargs["optimize"] = optimize
if progressive:
save_kwargs["progressive"] = True
elif format_lower == "png":
save_kwargs["format"] = "PNG"
save_kwargs["optimize"] = optimize
# PNG compression level (0-9)
if quality:
compress_level = max(0, min(9, (100 - quality) // 11))
save_kwargs["compress_level"] = compress_level
else:
save_kwargs["compress_level"] = 6 # Default
elif format_lower == "webp":
save_kwargs["format"] = "WEBP"
if quality:
save_kwargs["quality"] = quality
else:
save_kwargs["quality"] = 80 # Default
save_kwargs["method"] = 6 # Best compression
if preserve_transparency and has_transparency:
# WebP supports transparency
if image.mode not in ("RGBA", "LA"):
image = image.convert("RGBA")
elif format_lower == "gif":
save_kwargs["format"] = "GIF"
# GIF conversion
if image.mode != "P":
# Convert to palette mode for GIF
image = image.convert("P", palette=Image.ADAPTIVE)
save_kwargs["optimize"] = optimize
if preserve_transparency and has_transparency:
save_kwargs["transparency"] = 255 # Preserve transparency
elif format_lower == "bmp":
save_kwargs["format"] = "BMP"
if image.mode in ("RGBA", "LA", "P") and has_transparency:
# BMP doesn't support transparency, convert to RGB
if image.mode == "RGBA":
rgb_image = Image.new("RGB", image.size, (255, 255, 255))
rgb_image.paste(image, mask=image.split()[3])
image = rgb_image
else:
image = image.convert("RGB")
elif format_lower == "tiff":
save_kwargs["format"] = "TIFF"
save_kwargs["compression"] = "tiff_lzw" # Lossless compression
if preserve_transparency and has_transparency:
# TIFF supports transparency
if image.mode not in ("RGBA", "LA"):
image = image.convert("RGBA")
else:
raise ValueError(f"Unsupported target format: {target_format}")
image.save(buffer, **save_kwargs)
return buffer.getvalue()
async def convert(
self,
request: FormatConversionRequest,
user_id: Optional[str] = None,
) -> FormatConversionResult:
"""Convert an image to target format."""
logger.info(f"[Format Converter] Processing conversion request for user: {user_id}")
try:
# Decode image
image, original_size, original_format = self._decode_image(request.image_base64)
original_size_kb = original_size / 1024
logger.info(f"[Format Converter] Original: {original_format}, Target: {request.target_format}, Size: {original_size_kb:.2f} KB")
# Validate target format
format_lower = request.target_format.lower()
if format_lower not in self.SUPPORTED_FORMATS:
raise ValueError(f"Unsupported format: {request.target_format}. Supported: {list(self.SUPPORTED_FORMATS.keys())}")
# Check transparency preservation
has_transparency = image.mode in ("RGBA", "LA", "P") and (
"transparency" in image.info or image.mode == "RGBA"
)
target_supports_transparency = self.SUPPORTED_FORMATS[format_lower]["supports_transparency"]
transparency_preserved = (
has_transparency and
target_supports_transparency and
request.preserve_transparency
)
# Color space conversion
if request.color_space:
image = self._convert_color_space(image, request.color_space)
# Strip metadata if requested
metadata_preserved = not request.strip_metadata
if request.strip_metadata:
image = self._strip_exif(image)
# Convert format
converted_bytes = self._convert_image(
image,
format_lower,
request.quality,
request.preserve_transparency,
request.optimize,
request.progressive,
)
converted_size_kb = len(converted_bytes) / 1024
# Encode result
mime_type = self.SUPPORTED_FORMATS[format_lower]["mime_type"]
result_base64 = f"data:{mime_type};base64,{base64.b64encode(converted_bytes).decode()}"
logger.info(f"[Format Converter] Converted: {original_size_kb:.2f}KB → {converted_size_kb:.2f}KB")
return FormatConversionResult(
success=True,
image_base64=result_base64,
original_format=original_format,
target_format=format_lower,
original_size_kb=round(original_size_kb, 2),
converted_size_kb=round(converted_size_kb, 2),
width=image.width,
height=image.height,
transparency_preserved=transparency_preserved,
metadata_preserved=metadata_preserved,
color_space=request.color_space,
)
except Exception as e:
logger.error(f"[Format Converter] Failed to convert image: {e}")
raise
async def convert_batch(
self,
requests: List[FormatConversionRequest],
user_id: Optional[str] = None,
) -> List[FormatConversionResult]:
"""Convert multiple images."""
logger.info(f"[Format Converter] Processing batch of {len(requests)} images for user: {user_id}")
results = []
for i, request in enumerate(requests):
try:
result = await self.convert(request, user_id)
results.append(result)
logger.info(f"[Format Converter] Batch item {i+1}/{len(requests)} complete")
except Exception as e:
logger.error(f"[Format Converter] Batch item {i+1} failed: {e}")
results.append(FormatConversionResult(
success=False,
image_base64="",
original_format="",
target_format="",
original_size_kb=0,
converted_size_kb=0,
width=0,
height=0,
transparency_preserved=False,
metadata_preserved=False,
))
return results
def get_supported_formats(self) -> List[Dict[str, Any]]:
"""Get list of supported formats with details."""
return [
{
"id": fmt_id,
"name": fmt_info["name"],
"description": fmt_info["description"],
"supports_transparency": fmt_info["supports_transparency"],
"supports_lossy": fmt_info["supports_lossy"],
"mime_type": fmt_info["mime_type"],
}
for fmt_id, fmt_info in self.SUPPORTED_FORMATS.items()
]
def get_format_recommendations(self, source_format: str) -> List[Dict[str, Any]]:
"""Get format recommendations based on source format."""
recommendations = {
"png": [
{"format": "webp", "reason": "60% smaller file size, maintains transparency"},
{"format": "jpeg", "reason": "Best for photos, smaller file size"},
],
"jpeg": [
{"format": "webp", "reason": "25-34% smaller with similar quality"},
{"format": "png", "reason": "Lossless, supports transparency"},
],
"jpg": [
{"format": "webp", "reason": "25-34% smaller with similar quality"},
{"format": "png", "reason": "Lossless, supports transparency"},
],
"webp": [
{"format": "png", "reason": "Better compatibility, lossless"},
{"format": "jpeg", "reason": "Universal compatibility"},
],
}
source_lower = source_format.lower()
return recommendations.get(source_lower, [])

View File

@@ -7,6 +7,9 @@ from .edit_service import EditStudioService, EditStudioRequest
from .upscale_service import UpscaleStudioService, UpscaleStudioRequest
from .control_service import ControlStudioService, ControlStudioRequest
from .social_optimizer_service import SocialOptimizerService, SocialOptimizerRequest
from .face_swap_service import FaceSwapService, FaceSwapStudioRequest
from .compression_service import ImageCompressionService, CompressionRequest, CompressionResult
from .format_converter_service import ImageFormatConverterService, FormatConversionRequest, FormatConversionResult
from .transform_service import (
TransformStudioService,
TransformImageToVideoRequest,
@@ -29,6 +32,9 @@ class ImageStudioManager:
self.upscale_service = UpscaleStudioService()
self.control_service = ControlStudioService()
self.social_optimizer_service = SocialOptimizerService()
self.face_swap_service = FaceSwapService()
self.compression_service = ImageCompressionService()
self.format_converter_service = ImageFormatConverterService()
self.transform_service = TransformStudioService()
logger.info("[Image Studio Manager] Initialized successfully")
@@ -69,6 +75,99 @@ class ImageStudioManager:
def get_edit_operations(self) -> Dict[str, Any]:
"""Expose edit operations for UI."""
return self.edit_service.list_operations()
def get_edit_models(
self,
operation: Optional[str] = None,
tier: Optional[str] = None,
) -> Dict[str, Any]:
"""Get available editing models.
Args:
operation: Filter by operation type
tier: Filter by tier (budget, mid, premium)
Returns:
Dictionary with models and metadata
"""
return self.edit_service.get_available_models(operation=operation, tier=tier)
def recommend_edit_model(
self,
operation: str,
image_resolution: Optional[Dict[str, int]] = None,
user_tier: Optional[str] = None,
preferences: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Recommend best editing model for given context.
Args:
operation: Operation type
image_resolution: Image dimensions
user_tier: User subscription tier
preferences: User preferences (prioritize_cost, prioritize_quality)
Returns:
Dictionary with recommended model and alternatives
"""
return self.edit_service.recommend_model(
operation=operation,
image_resolution=image_resolution,
user_tier=user_tier,
preferences=preferences,
)
# ====================
# FACE SWAP STUDIO
# ====================
async def face_swap(
self,
request: FaceSwapStudioRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Run Face Swap Studio operations."""
logger.info("[Image Studio] Face swap request from user: %s", user_id)
return await self.face_swap_service.process_face_swap(request, user_id=user_id)
def get_face_swap_models(
self,
tier: Optional[str] = None,
) -> Dict[str, Any]:
"""Get available face swap models.
Args:
tier: Filter by tier (budget, mid, premium)
Returns:
Dictionary with models and metadata
"""
return self.face_swap_service.get_available_models(tier=tier)
def recommend_face_swap_model(
self,
base_image_resolution: Optional[Dict[str, int]] = None,
face_image_resolution: Optional[Dict[str, int]] = None,
user_tier: Optional[str] = None,
preferences: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Recommend best face swap model for given context.
Args:
base_image_resolution: Base image dimensions
face_image_resolution: Face image dimensions
user_tier: User subscription tier
preferences: User preferences (prioritize_cost, prioritize_quality)
Returns:
Dictionary with recommended model and alternatives
"""
return self.face_swap_service.recommend_model(
base_image_resolution=base_image_resolution,
face_image_resolution=face_image_resolution,
user_tier=user_tier,
preferences=preferences,
)
# ====================
# UPSCALE STUDIO
@@ -377,3 +476,72 @@ class ImageStudioManager:
"""Estimate cost for transform operation."""
return self.transform_service.estimate_cost(operation, resolution, duration)
# ====================
# COMPRESSION STUDIO
# ====================
async def compress_image(
self,
request: CompressionRequest,
user_id: Optional[str] = None,
) -> CompressionResult:
"""Compress an image with specified settings."""
logger.info("[Image Studio] Compress image request from user: %s", user_id)
return await self.compression_service.compress(request, user_id=user_id)
async def compress_batch(
self,
requests: List[CompressionRequest],
user_id: Optional[str] = None,
) -> List[CompressionResult]:
"""Compress multiple images."""
logger.info("[Image Studio] Batch compress request (%d images) from user: %s", len(requests), user_id)
return await self.compression_service.compress_batch(requests, user_id=user_id)
async def estimate_compression(
self,
image_base64: str,
format: str = "jpeg",
quality: int = 85,
) -> Dict[str, Any]:
"""Estimate compression results without compressing."""
return await self.compression_service.estimate_compression(image_base64, format, quality)
def get_compression_formats(self) -> List[Dict[str, Any]]:
"""Get supported compression formats."""
return self.compression_service.get_supported_formats()
def get_compression_presets(self) -> List[Dict[str, Any]]:
"""Get compression presets for common use cases."""
return self.compression_service.get_presets()
# ====================
# FORMAT CONVERTER
# ====================
async def convert_format(
self,
request: FormatConversionRequest,
user_id: Optional[str] = None,
) -> FormatConversionResult:
"""Convert an image to target format."""
logger.info("[Image Studio] Convert format request from user: %s", user_id)
return await self.format_converter_service.convert(request, user_id=user_id)
async def convert_format_batch(
self,
requests: List[FormatConversionRequest],
user_id: Optional[str] = None,
) -> List[FormatConversionResult]:
"""Convert multiple images."""
logger.info("[Image Studio] Batch convert format request (%d images) from user: %s", len(requests), user_id)
return await self.format_converter_service.convert_batch(requests, user_id=user_id)
def get_supported_formats(self) -> List[Dict[str, Any]]:
"""Get supported conversion formats."""
return self.format_converter_service.get_supported_formats()
def get_format_recommendations(self, source_format: str) -> List[Dict[str, Any]]:
"""Get format recommendations based on source format."""
return self.format_converter_service.get_format_recommendations(source_format)

View File

@@ -36,18 +36,16 @@ class UpscaleStudioService:
request: UpscaleStudioRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
# Pre-flight validation: Reuse unified helper
# Note: Using image-generation validation since upscaling uses same subscription limits
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_upscale_operations
db = next(get_db())
try:
pricing_service = PricingService(db)
logger.info("[Upscale Studio] 🛂 Running pre-flight validation for user %s", user_id)
validate_image_upscale_operations(pricing_service=pricing_service, user_id=user_id)
finally:
db.close()
from services.llm_providers.main_image_generation import _validate_image_operation
_validate_image_operation(
user_id=user_id,
operation_type="image-upscale",
num_operations=1,
log_prefix="[Upscale Studio]"
)
image_bytes = self._decode_base64(request.image_base64)
if not image_bytes:

View File

@@ -1,4 +1,12 @@
from .base import ImageGenerationOptions, ImageGenerationResult, ImageGenerationProvider
from .base import (
ImageGenerationOptions,
ImageGenerationResult,
ImageGenerationProvider,
ImageEditOptions,
ImageEditProvider,
FaceSwapOptions,
FaceSwapProvider,
)
from .hf_provider import HuggingFaceImageProvider
from .gemini_provider import GeminiImageProvider
from .stability_provider import StabilityImageProvider
@@ -8,6 +16,10 @@ __all__ = [
"ImageGenerationOptions",
"ImageGenerationResult",
"ImageGenerationProvider",
"ImageEditOptions",
"ImageEditProvider",
"FaceSwapOptions",
"FaceSwapProvider",
"HuggingFaceImageProvider",
"GeminiImageProvider",
"StabilityImageProvider",

View File

@@ -28,6 +28,50 @@ class ImageGenerationResult:
metadata: Optional[Dict[str, Any]] = None
@dataclass
class ImageEditOptions:
"""Options for image editing operations."""
image_base64: str
prompt: str
operation: str # "general_edit", "inpaint", "outpaint", "remove_background", etc.
mask_base64: Optional[str] = None
negative_prompt: Optional[str] = None
model: Optional[str] = None
width: Optional[int] = None
height: Optional[int] = None
guidance_scale: Optional[float] = None
steps: Optional[int] = None
seed: Optional[int] = None
extra: Optional[Dict[str, Any]] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for API calls."""
result = {
"image_base64": self.image_base64,
"prompt": self.prompt,
"operation": self.operation,
}
if self.mask_base64:
result["mask_base64"] = self.mask_base64
if self.negative_prompt:
result["negative_prompt"] = self.negative_prompt
if self.model:
result["model"] = self.model
if self.width:
result["width"] = self.width
if self.height:
result["height"] = self.height
if self.guidance_scale is not None:
result["guidance_scale"] = self.guidance_scale
if self.steps:
result["steps"] = self.steps
if self.seed is not None:
result["seed"] = self.seed
if self.extra:
result.update(self.extra)
return result
class ImageGenerationProvider(Protocol):
"""Protocol for image generation providers."""
@@ -35,3 +79,44 @@ class ImageGenerationProvider(Protocol):
...
@dataclass
class FaceSwapOptions:
"""Options for face swap operations."""
base_image_base64: str # Image to swap face into
face_image_base64: str # Face to swap
model: Optional[str] = None
target_face_index: Optional[int] = None # For multi-face images (0 = largest)
target_gender: Optional[str] = None # "all", "female", "male" (for some models)
extra: Optional[Dict[str, Any]] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for API calls."""
result = {
"base_image_base64": self.base_image_base64,
"face_image_base64": self.face_image_base64,
}
if self.model:
result["model"] = self.model
if self.target_face_index is not None:
result["target_face_index"] = self.target_face_index
if self.target_gender:
result["target_gender"] = self.target_gender
if self.extra:
result.update(self.extra)
return result
class ImageEditProvider(Protocol):
"""Protocol for image editing providers."""
def edit(self, options: ImageEditOptions) -> ImageGenerationResult:
...
class FaceSwapProvider(Protocol):
"""Protocol for face swap providers."""
def swap_face(self, options: FaceSwapOptions) -> ImageGenerationResult:
...

View File

@@ -0,0 +1,691 @@
"""WaveSpeed AI image editing provider (14 editing models)."""
import io
import os
import requests
from typing import Optional
from PIL import Image
from fastapi import HTTPException
from .base import ImageEditProvider, ImageEditOptions, ImageGenerationResult
from services.wavespeed.client import WaveSpeedClient
from utils.logger_utils import get_service_logger
logger = get_service_logger("wavespeed.edit_provider")
class WaveSpeedEditProvider(ImageEditProvider):
"""WaveSpeed AI image editing provider supporting 14 editing models.
REUSES: WaveSpeedClient, model registry pattern, result format
"""
# Model registry - populated with WaveSpeed editing models
SUPPORTED_MODELS = {
"qwen-edit": {
"model_path": "wavespeed-ai/qwen-image/edit",
"name": "Qwen Image Edit",
"description": "20B MMDiT image-to-image model offering precise bilingual (Chinese & English) text edits while preserving style. Single-image editing with style preservation.",
"cost": 0.02, # Same as Plus version
"max_resolution": (1536, 1536), # Based on docs: similar to Plus
"capabilities": ["general_edit", "style_transfer", "text_edit"],
"tier": "budget",
"supports_multi_image": False, # Single image only (uses "image" not "images")
"supports_controlnet": False, # Not mentioned in docs
"languages": ["en", "zh"],
"api_params": {
"uses_size": True, # Uses "size" parameter (width*height)
"uses_aspect_ratio": False,
"uses_resolution": False,
"uses_image_singular": True, # Uses "image" (singular) not "images" (array)
"default_output_format": "jpeg", # Per API docs: default is "jpeg"
"supports_seed": True, # Per API docs: seed parameter supported
}
},
"qwen-edit-plus": {
"model_path": "wavespeed-ai/qwen-image/edit-plus",
"name": "Qwen Image Edit Plus",
"description": "20B MMDiT image editor with multi-image editing, single-image consistency and native ControlNet support. Bilingual (CN/EN) text editing, appearance-level and semantic-level edits.",
"cost": 0.02,
"max_resolution": (1536, 1536), # Based on docs: 256-1536 per dimension
"capabilities": ["general_edit", "style_transfer", "text_edit", "multi_image"],
"tier": "budget",
"supports_multi_image": True, # Up to 3 reference images
"supports_controlnet": True,
"languages": ["en", "zh"],
"api_params": {
"uses_size": True, # Uses "size" parameter (width*height)
"uses_aspect_ratio": False,
"uses_resolution": False,
"uses_image_singular": False, # Uses "images" (array)
"supports_seed": True, # Seed parameter supported (default for Qwen models)
}
},
"nano-banana-pro-edit-ultra": {
"model_path": "google/nano-banana-pro/edit-ultra",
"name": "Google Nano Banana Pro Edit Ultra",
"description": "High-resolution image editing with 4K/8K native output. Natural language instructions, multilingual text support. Premium quality editing for professional marketing and high-res work.",
"cost": 0.15, # 4K - from enhancement proposal
"cost_8k": 0.18, # 8K - from enhancement proposal
"max_resolution": (8192, 8192), # 8K support
"capabilities": ["general_edit", "high_res", "professional", "typography"],
"tier": "premium",
"supports_multi_image": True, # Up to 14 reference images
"supports_controlnet": False,
"languages": ["en", "multilingual"],
"api_params": {
"uses_size": False, # Uses aspect_ratio and resolution instead
"uses_aspect_ratio": True, # "1:1", "16:9", etc.
"uses_resolution": True, # "4k" or "8k"
"max_images": 14,
"default_output_format": "png", # Per API docs: default is "png"
"supports_seed": False, # Per API docs: no seed parameter
}
},
"seedream-v4.5-edit": {
"model_path": "bytedance/seedream-v4.5/edit",
"name": "Bytedance Seedream V4.5 Edit",
"description": "Preserves facial features, lighting, and color tone from reference images, delivering professional, high-fidelity edits up to 4K with strong prompt adherence. Reference-faithful editing with multi-image support.",
"cost": 0.04, # Per generated image
"max_resolution": (4096, 4096), # 4K support (1024-4096 per dimension)
"capabilities": ["general_edit", "portrait_retouching", "fashion_edit", "product_edit", "multi_image"],
"tier": "mid",
"supports_multi_image": True, # Up to 10 reference images
"supports_controlnet": False,
"languages": ["en"],
"api_params": {
"uses_size": True, # Uses "size" parameter (width*height format, 1024-4096 per dimension)
"uses_aspect_ratio": False,
"uses_resolution": False,
"max_images": 10,
"default_output_format": "png",
"supports_seed": False, # No seed parameter in API docs (Seedream V4.5)
}
},
"flux-kontext-pro": {
"model_path": "wavespeed-ai/flux-kontext-pro",
"name": "FLUX Kontext Pro",
"description": "FLUX.1 Kontext [pro] offers improved prompt adherence and accurate typography generation for consistent, high-quality edits at speed. Typography-focused editing with improved prompt adherence.",
"cost": 0.04, # From enhancement proposal
"max_resolution": (2048, 2048), # Estimated, not specified in docs
"capabilities": ["general_edit", "typography", "text_edit", "style_transfer"],
"tier": "mid",
"supports_multi_image": False, # Single image only (uses "image" not "images")
"supports_controlnet": False,
"languages": ["en"],
"api_params": {
"uses_size": False, # Uses aspect_ratio instead
"uses_aspect_ratio": True, # Aspect ratio as string (e.g., "16:9", "1:1")
"uses_resolution": False,
"uses_image_singular": True, # Uses "image" (singular) not "images" (array)
"supports_guidance_scale": True, # Has guidance_scale parameter (default 3.5, range 1-20)
"default_guidance_scale": 3.5, # Per API docs
"supports_seed": False, # No seed parameter in API docs
}
},
# TODO: Add remaining 9 models once docs are provided
}
def __init__(self, api_key: Optional[str] = None):
"""Initialize WaveSpeed edit provider.
Args:
api_key: WaveSpeed API key (falls back to env var if not provided)
"""
self.api_key = api_key or os.getenv("WAVESPEED_API_KEY")
if not self.api_key:
raise ValueError("WaveSpeed API key not found. Set WAVESPEED_API_KEY environment variable.")
# REUSE: Same client as generation provider
self.client = WaveSpeedClient(api_key=self.api_key)
logger.info("[WaveSpeed Edit Provider] Initialized with %d models",
len(self.SUPPORTED_MODELS))
def _validate_options(self, options: ImageEditOptions) -> None:
"""Validate editing options.
Args:
options: Image editing options
Raises:
ValueError: If options are invalid
"""
model = options.model or list(self.SUPPORTED_MODELS.keys())[0] if self.SUPPORTED_MODELS else None
if not model:
raise ValueError("No model specified and no default model available")
if model not in self.SUPPORTED_MODELS:
raise ValueError(
f"Unsupported model: {model}. "
f"Supported models: {list(self.SUPPORTED_MODELS.keys())}"
)
model_info = self.SUPPORTED_MODELS[model]
max_width, max_height = model_info.get("max_resolution", (4096, 4096))
if options.width and options.width > max_width:
raise ValueError(
f"Width {options.width} exceeds maximum {max_width} for model {model}"
)
if options.height and options.height > max_height:
raise ValueError(
f"Height {options.height} exceeds maximum {max_height} for model {model}"
)
if not options.prompt or len(options.prompt.strip()) == 0:
raise ValueError("Prompt cannot be empty")
if not options.image_base64:
raise ValueError("Image base64 cannot be empty")
def edit(self, options: ImageEditOptions) -> ImageGenerationResult:
"""Edit image using WaveSpeed AI models.
Args:
options: Image editing options
Returns:
ImageGenerationResult with edited image
Raises:
ValueError: If options are invalid
RuntimeError: If editing fails
"""
# Validate options
self._validate_options(options)
# Determine model
model = options.model or (list(self.SUPPORTED_MODELS.keys())[0] if self.SUPPORTED_MODELS else None)
if not model:
raise ValueError("No model available for editing")
model_info = self.SUPPORTED_MODELS[model]
model_path = model_info["model_path"]
logger.info("[WaveSpeed Edit] Starting edit: model=%s, operation=%s, prompt=%s",
model, options.operation, options.prompt[:100])
try:
# Prepare extra parameters based on model capabilities
extra_params = options.extra or {}
# Add model-specific parameters if needed
api_params = model_info.get("api_params", {})
if api_params.get("uses_resolution", False):
# For Nano Banana: determine resolution from dimensions or use default
if options.width and options.height:
if options.width >= 4096 or options.height >= 4096:
extra_params["resolution"] = "8k"
else:
extra_params["resolution"] = "4k"
elif "resolution" not in extra_params:
extra_params["resolution"] = "4k" # Default to 4K
if api_params.get("uses_aspect_ratio", False) and not extra_params.get("aspect_ratio"):
# Calculate aspect ratio if dimensions provided
if options.width and options.height:
aspect_ratio = self._calculate_aspect_ratio(options.width, options.height)
if aspect_ratio:
extra_params["aspect_ratio"] = aspect_ratio
# Call WaveSpeed API for editing
result = self._call_wavespeed_edit_api(
model_path=model_path,
image_base64=options.image_base64,
prompt=options.prompt,
operation=options.operation,
mask_base64=options.mask_base64,
negative_prompt=options.negative_prompt,
width=options.width,
height=options.height,
guidance_scale=options.guidance_scale,
steps=options.steps,
seed=options.seed,
extra=extra_params
)
# Extract image bytes from result
if isinstance(result, bytes):
image_bytes = result
elif isinstance(result, dict) and "image" in result:
image_bytes = result["image"]
elif isinstance(result, dict) and "image_bytes" in result:
image_bytes = result["image_bytes"]
else:
raise ValueError(f"Unexpected response format from WaveSpeed API: {type(result)}")
# Load image to get dimensions
image = Image.open(io.BytesIO(image_bytes))
width, height = image.size
# Calculate estimated cost - handle resolution-based pricing
estimated_cost = model_info.get("cost", 0.02)
if api_params.get("uses_resolution", False):
# Check if 8K was requested
resolution = extra_params.get("resolution", "4k")
if resolution == "8k" and "cost_8k" in model_info:
estimated_cost = model_info["cost_8k"]
logger.info("[WaveSpeed Edit] ✅ Successfully edited image: %d bytes, %dx%d",
len(image_bytes), width, height)
# REUSE: Same result format as generation
return ImageGenerationResult(
image_bytes=image_bytes,
width=width,
height=height,
provider="wavespeed",
model=model,
seed=options.seed,
metadata={
"provider": "wavespeed",
"model": model,
"model_name": model_info.get("name", model),
"operation": options.operation,
"prompt": options.prompt,
"negative_prompt": options.negative_prompt,
"estimated_cost": estimated_cost,
"tier": model_info.get("tier", "mid"),
}
)
except Exception as e:
logger.error("[WaveSpeed Edit] ❌ Error editing image: %s", str(e), exc_info=True)
raise RuntimeError(f"WaveSpeed edit failed: {str(e)}")
def _call_wavespeed_edit_api(
self,
model_path: str,
image_base64: str,
prompt: str,
operation: str,
mask_base64: Optional[str] = None,
negative_prompt: Optional[str] = None,
width: Optional[int] = None,
height: Optional[int] = None,
guidance_scale: Optional[float] = None,
steps: Optional[int] = None,
seed: Optional[int] = None,
extra: Optional[dict] = None
) -> bytes:
"""Call WaveSpeed API for image editing.
REUSES: Same pattern as ImageGenerator.generate_image()
Args:
model_path: Full model path (e.g., "wavespeed-ai/qwen-image/edit-plus")
image_base64: Base64-encoded input image
prompt: Edit instruction prompt
operation: Type of operation
mask_base64: Optional mask for inpainting
negative_prompt: Optional negative prompt
width: Optional target width
height: Optional target height
guidance_scale: Optional guidance scale (not used by all models)
steps: Optional number of steps (not used by all models)
seed: Optional seed
extra: Optional extra parameters
Returns:
Edited image bytes
Raises:
RuntimeError: If API call fails
"""
import requests
from fastapi import HTTPException
# Build URL - REUSES same pattern as ImageGenerator
url = f"{self.client.BASE_URL}/{model_path}"
# Prepare images array - WaveSpeed expects array of image strings
# Format: base64 strings or data URIs (data:image/png;base64,...)
# For Qwen Image Edit Plus: supports up to 3 reference images
images = []
# Add main image - check if it's already a data URI or just base64
if image_base64.startswith("data:image"):
# Already a data URI
images.append(image_base64)
else:
# Assume it's base64, convert to data URI
# Try to detect format from base64 or default to PNG
images.append(f"data:image/png;base64,{image_base64}")
# If mask is provided, add it as second image
# Note: Some models may need mask in different format - will adjust per model
if mask_base64:
if mask_base64.startswith("data:image"):
images.append(mask_base64)
else:
images.append(f"data:image/png;base64,{mask_base64}")
# Get model info to determine API parameter structure
model_info = self.SUPPORTED_MODELS.get(model_path.split("/")[-1] if "/" in model_path else model_path)
if not model_info:
# Fallback: try to find model by matching path
for model_id, info in self.SUPPORTED_MODELS.items():
if info["model_path"] == model_path:
model_info = info
break
if not model_info:
raise ValueError(f"Model info not found for: {model_path}")
api_params = model_info.get("api_params", {})
# Build payload - following WaveSpeed API structure
# Note: output_format default varies by model (PNG for most, but can be JPEG)
default_output_format = api_params.get("default_output_format", "png")
# Some models use "image" (singular) instead of "images" (array)
uses_image_singular = api_params.get("uses_image_singular", False)
payload = {
"prompt": prompt,
"enable_sync_mode": True, # Use sync mode for immediate results
"enable_base64_output": False, # Get URL, then download
"output_format": default_output_format,
}
# Add image(s) based on model API format
if uses_image_singular:
# Models like Qwen Edit (basic) use "image" (singular)
# Use first image only (single image editing)
if images:
payload["image"] = images[0]
else:
raise ValueError("At least one image is required")
else:
# Models like Qwen Edit Plus, Nano Banana use "images" (array)
payload["images"] = images
# Allow override of output_format from extra params
if extra and "output_format" in extra:
payload["output_format"] = extra["output_format"]
# Model-specific parameter handling
if api_params.get("uses_size", True):
# Models like Qwen Edit Plus use "size" parameter (width*height format)
if width and height:
payload["size"] = f"{width}*{height}"
elif width:
payload["size"] = f"{width}*{width}" # Square if only width provided
elif height:
payload["size"] = f"{height}*{height}" # Square if only height provided
if api_params.get("uses_aspect_ratio", False):
# Models like Nano Banana and FLUX Kontext Pro use "aspect_ratio" parameter
if width and height:
# Calculate aspect ratio from dimensions
aspect_ratio = self._calculate_aspect_ratio(width, height)
if aspect_ratio:
payload["aspect_ratio"] = aspect_ratio
elif extra and "aspect_ratio" in extra:
payload["aspect_ratio"] = extra["aspect_ratio"]
if api_params.get("uses_resolution", False):
# Models like Nano Banana use "resolution" parameter ("4k" or "8k")
if extra and "resolution" in extra:
payload["resolution"] = extra["resolution"]
else:
# Default to 4K, or 8K if dimensions suggest high-res
if width and height and (width >= 4096 or height >= 4096):
payload["resolution"] = "8k"
else:
payload["resolution"] = "4k" # Default to 4K per API docs
# Add optional parameters (model-agnostic)
# Guidance scale: Only add if model supports it (e.g., FLUX Kontext Pro)
if api_params.get("supports_guidance_scale", False):
default_guidance = api_params.get("default_guidance_scale", 3.5)
if guidance_scale is not None:
# Clamp to valid range (1-20 per FLUX Kontext Pro docs)
payload["guidance_scale"] = max(1, min(20, guidance_scale))
elif extra and "guidance_scale" in extra:
payload["guidance_scale"] = max(1, min(20, extra["guidance_scale"]))
else:
payload["guidance_scale"] = default_guidance
# Seed parameter: Only add if model supports it
if api_params.get("supports_seed", True): # Default to True for backward compatibility
if seed is not None:
payload["seed"] = seed
else:
payload["seed"] = -1 # Random seed (per API docs default)
# Add any extra parameters
if extra:
# Filter out parameters we've already handled
handled_params = {"aspect_ratio", "resolution", "size", "seed", "guidance_scale"}
for key, value in extra.items():
if key not in handled_params:
payload[key] = value
logger.info(f"[WaveSpeed Edit] Submitting edit request to {url} (model={model_path}, prompt_length={len(prompt)})")
# Make API call - REUSES same pattern as ImageGenerator
try:
response = requests.post(
url,
headers=self.client._headers(),
json=payload,
timeout=120
)
if response.status_code != 200:
logger.error(f"[WaveSpeed Edit] API call failed: {response.status_code} {response.text}")
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed image editing failed",
"status_code": response.status_code,
"response": response.text[:500],
},
)
response_json = response.json()
data = response_json.get("data") or response_json
# Check status
status = data.get("status", "").lower()
outputs = data.get("outputs") or []
prediction_id = data.get("id")
logger.debug(
f"[WaveSpeed Edit] Response: status='{status}', outputs_count={len(outputs)}, "
f"prediction_id={prediction_id}"
)
# Handle sync mode - result should be directly in outputs
if outputs and status == "completed":
logger.info(f"[WaveSpeed Edit] Got immediate results from sync mode")
image_url = self._extract_image_url(outputs)
return self._download_image(image_url, timeout=120)
# Sync mode returned "created" or "processing" - need to poll
if not prediction_id:
logger.error(f"[WaveSpeed Edit] Sync mode returned status '{status}' but no prediction ID")
raise HTTPException(
status_code=502,
detail="WaveSpeed sync mode returned async response without prediction ID",
)
logger.info(
f"[WaveSpeed Edit] Sync mode returned status '{status}' with no outputs. "
f"Polling for result (prediction_id: {prediction_id})"
)
# Poll for result - REUSES polling utility
result = self.client.poll_until_complete(
prediction_id,
timeout_seconds=180,
interval_seconds=2.0,
)
outputs = result.get("outputs") or []
if not outputs:
raise HTTPException(
status_code=502,
detail="WaveSpeed edit returned no outputs after polling"
)
# Extract image URL from outputs - REUSE helper method
image_url = self._extract_image_url(outputs)
return self._download_image(image_url, timeout=120)
except HTTPException:
raise
except Exception as e:
logger.error(f"[WaveSpeed Edit] Unexpected error: {str(e)}", exc_info=True)
raise RuntimeError(f"WaveSpeed edit API call failed: {str(e)}")
def _extract_image_url(self, outputs: list) -> str:
"""Extract image URL from outputs - REUSES same pattern as ImageGenerator.
Args:
outputs: Array of output URLs or objects
Returns:
Image URL string
Raises:
HTTPException: If output format is invalid
"""
if not isinstance(outputs, list) or len(outputs) == 0:
raise HTTPException(
status_code=502,
detail="WaveSpeed edit returned no outputs",
)
first_output = outputs[0]
if isinstance(first_output, str):
image_url = first_output
elif isinstance(first_output, dict):
image_url = first_output.get("url") or first_output.get("image_url") or first_output.get("output")
else:
raise HTTPException(
status_code=502,
detail="WaveSpeed edit output format not recognized",
)
if not image_url or not (image_url.startswith("http://") or image_url.startswith("https://")):
raise HTTPException(
status_code=502,
detail="WaveSpeed edit returned invalid image URL",
)
return image_url
def _download_image(self, image_url: str, timeout: int = 120) -> bytes:
"""Download image from URL - REUSES same pattern as ImageGenerator.
Args:
image_url: URL to download from
timeout: Request timeout in seconds
Returns:
Image bytes
Raises:
HTTPException: If download fails
"""
logger.info(f"[WaveSpeed Edit] Downloading edited image from: {image_url}")
image_response = requests.get(image_url, timeout=timeout)
if image_response.status_code != 200:
logger.error(f"[WaveSpeed Edit] Failed to download image: {image_response.status_code}")
raise HTTPException(
status_code=502,
detail=f"Failed to download edited image: {image_response.status_code}"
)
logger.info(f"[WaveSpeed Edit] Successfully downloaded image ({len(image_response.content)} bytes)")
return image_response.content
def _calculate_aspect_ratio(self, width: int, height: int) -> Optional[str]:
"""Calculate aspect ratio string from dimensions.
Args:
width: Image width
height: Image height
Returns:
Aspect ratio string (e.g., "16:9") or None if not standard
"""
# Common aspect ratios (includes FLUX Kontext Pro supported ratios)
ratios = {
(1, 1): "1:1",
(3, 2): "3:2",
(2, 3): "2:3",
(3, 4): "3:4",
(4, 3): "4:3",
(4, 5): "4:5",
(5, 4): "5:4",
(9, 16): "9:16",
(16, 9): "16:9",
(21, 9): "21:9",
(9, 21): "9:21", # FLUX Kontext Pro also supports 9:21
}
# Calculate GCD to simplify ratio
def gcd(a, b):
while b:
a, b = b, a % b
return a
divisor = gcd(width, height)
simplified = (width // divisor, height // divisor)
# Check if it matches a standard ratio (with some tolerance)
for (w, h), ratio_str in ratios.items():
# Allow small tolerance for rounding
if abs(simplified[0] / simplified[1] - w / h) < 0.01:
return ratio_str
# If no match, return None (model may not support custom aspect ratios)
return None
@classmethod
def get_available_models(cls) -> dict:
"""Get available editing models and their information.
Returns:
Dictionary of available models
"""
return cls.SUPPORTED_MODELS
@classmethod
def get_models_by_tier(cls, tier: str) -> dict:
"""Get models filtered by tier (budget, mid, premium).
Args:
tier: Tier name ("budget", "mid", "premium")
Returns:
Dictionary of models in the specified tier
"""
return {
model_id: model_info
for model_id, model_info in cls.SUPPORTED_MODELS.items()
if model_info.get("tier") == tier
}
@classmethod
def get_models_by_operation(cls, operation: str) -> dict:
"""Get models that support a specific operation.
Args:
operation: Operation type (e.g., "inpaint", "outpaint", "general_edit")
Returns:
Dictionary of models supporting the operation
"""
return {
model_id: model_info
for model_id, model_info in cls.SUPPORTED_MODELS.items()
if operation in model_info.get("capabilities", [])
}

View File

@@ -0,0 +1,367 @@
"""WaveSpeed Face Swap Provider for Image Studio."""
from __future__ import annotations
import base64
import io
from typing import Optional, Dict, Any
from PIL import Image
from services.llm_providers.image_generation.base import (
FaceSwapOptions,
FaceSwapProvider,
ImageGenerationResult,
)
from services.wavespeed.client import WaveSpeedClient
from utils.logger_utils import get_service_logger
logger = get_service_logger("llm_providers.wavespeed_face_swap")
class WaveSpeedFaceSwapProvider:
"""WaveSpeed provider for face swap operations."""
SUPPORTED_MODELS = {
"image-face-swap-pro": {
"model_path": "wavespeed-ai/image-face-swap-pro",
"name": "Image Face Swap Pro",
"description": "Instant online AI face swap for photos with no watermark, delivering realistic, shareable results in seconds.",
"cost": 0.025,
"tier": "mid",
"capabilities": ["face_swap", "realistic_blending"],
"features": ["Enhanced blending", "Realistic results", "Watermark-free"],
"max_faces": 1,
"api_params": {
"output_format": "jpeg",
"supports_base64": True,
"supports_sync": True,
},
},
"image-head-swap": {
"model_path": "wavespeed-ai/image-head-swap",
"name": "Image Head Swap",
"description": "Instant online AI head & face swap for photos with no watermark. Replaces entire head (face + hair + outline) while preserving body, pose and background.",
"cost": 0.025,
"tier": "mid",
"capabilities": ["head_swap", "full_head_replacement", "realistic_blending"],
"features": ["Full head replacement", "Hair included", "Pose preservation", "Watermark-free"],
"max_faces": 1,
"api_params": {
"output_format": "jpeg",
"supports_base64": True,
"supports_sync": True,
},
},
"akool-face-swap": {
"model_path": "akool/image-face-swap",
"name": "Akool Image Face Swap",
"description": "Powerful AI-powered face swapping with multi-face replacement for group photos. Seamlessly replaces faces with natural lighting and skin tone matching.",
"cost": 0.16,
"tier": "premium",
"capabilities": ["face_swap", "multi_face", "realistic_blending", "face_enhancement"],
"features": ["Multi-face swapping (up to 5)", "Face enhancement", "Group photos", "High-quality blending"],
"max_faces": 5, # Supports 1-5 faces
"api_params": {
"uses_source_target_arrays": True, # Uses source_image and target_image arrays
"supports_face_enhance": True,
"supports_base64": True,
"supports_sync": False, # May need polling
},
},
"infinite-you": {
"model_path": "wavespeed-ai/infinite-you",
"name": "InfiniteYou",
"description": "High-quality face swapping powered by ByteDance's zero-shot identity preservation technology. Maintains facial identity characteristics with exceptional realism.",
"cost": 0.03,
"tier": "mid",
"capabilities": ["face_swap", "identity_preservation", "realistic_blending"],
"features": ["Zero-shot learning", "Identity preservation", "High-quality results", "Fast processing"],
"max_faces": 1,
"api_params": {
"uses_source_target_names": True, # Uses source_image and target_image (not image/face_image)
"target_is_base": True, # target_image is the base image (where face will be swapped)
"source_is_face": True, # source_image is the face to swap in
"supports_seed": True, # Supports seed parameter
"supports_base64": True,
"supports_sync": True,
},
},
# Placeholder for additional models (will be added as docs are provided)
# "image-face-swap": {...}, # Basic version ($0.01)
}
def __init__(self):
self.client = WaveSpeedClient()
def _validate_options(self, options: FaceSwapOptions) -> None:
"""Validate face swap options."""
if not options.base_image_base64:
raise ValueError("base_image_base64 is required")
if not options.face_image_base64:
raise ValueError("face_image_base64 is required")
# Validate model
if options.model and options.model not in self.SUPPORTED_MODELS:
raise ValueError(
f"Unsupported model: {options.model}. "
f"Supported models: {list(self.SUPPORTED_MODELS.keys())}"
)
def _extract_image_url(self, data_url: str) -> str:
"""Extract image URL from data URL or return as-is if already a URL."""
if data_url.startswith("data:image"):
# It's a data URL, we'll need to upload it
return data_url
return data_url
def _upload_image_if_needed(self, image_data: str) -> str:
"""Upload image if it's a base64 data URL, otherwise return URL."""
if image_data.startswith("data:image"):
# Extract base64 data
header, encoded = image_data.split(",", 1)
image_bytes = base64.b64decode(encoded)
# Upload to temporary storage (or use WaveSpeed upload endpoint if available)
# For now, we'll return the data URL and let the API handle it
# In production, you might want to upload to S3/CloudFlare first
return image_data
return image_data
def _call_wavespeed_face_swap_api(
self, options: FaceSwapOptions, model_info: Dict[str, Any]
) -> ImageGenerationResult:
"""Call WaveSpeed face swap API."""
import requests
from fastapi import HTTPException
model_path = model_info["model_path"]
api_params = model_info.get("api_params", {})
uses_source_target_arrays = api_params.get("uses_source_target_arrays", False)
# Prepare images - extract base64 if data URI
base_image = options.base_image_base64
if base_image.startswith("data:image"):
# Keep as data URI - API should accept it
pass
elif not base_image.startswith("http"):
# Assume it's base64, convert to data URI
base_image = f"data:image/png;base64,{base_image}"
face_image = options.face_image_base64
if face_image.startswith("data:image"):
# Keep as data URI
pass
elif not face_image.startswith("http"):
# Assume it's base64, convert to data URI
face_image = f"data:image/png;base64,{face_image}"
# Build API payload - handle different API formats
uses_source_target_names = api_params.get("uses_source_target_names", False)
if uses_source_target_arrays:
# Akool format: uses source_image and target_image as arrays
# For single face swap: source_image is the new face, target_image is reference from main image
# Since we only have one face_image, we'll use it as source and the base_image as target reference
payload = {
"image": base_image,
"source_image": [face_image], # Array of source faces (1-5) - the new face to swap in
"target_image": [base_image], # Array of target faces (1-5) - reference from main image
"face_enhance": api_params.get("supports_face_enhance", True), # Default to True for Akool
"enable_base64_output": True,
}
# Allow override from extra params
if options.extra:
if "source_image" in options.extra:
payload["source_image"] = options.extra["source_image"]
if "target_image" in options.extra:
payload["target_image"] = options.extra["target_image"]
if "face_enhance" in options.extra:
payload["face_enhance"] = options.extra["face_enhance"]
elif uses_source_target_names:
# InfiniteYou format: uses source_image and target_image (single values, different names)
# target_image = base image (where face will be swapped)
# source_image = face image (face to swap in)
payload = {
"target_image": base_image, # Base image where face will be swapped
"source_image": face_image, # Face to swap in
"enable_base64_output": True,
}
# Add seed if supported
if api_params.get("supports_seed", False):
seed = options.extra.get("seed") if options.extra else None
payload["seed"] = seed if seed is not None else -1 # Default to -1 (random)
# Allow override from extra params
if options.extra:
if "source_image" in options.extra:
payload["source_image"] = options.extra["source_image"]
if "target_image" in options.extra:
payload["target_image"] = options.extra["target_image"]
if "seed" in options.extra and api_params.get("supports_seed", False):
payload["seed"] = options.extra["seed"]
else:
# Standard format: uses image and face_image (single values)
payload = {
"image": base_image,
"face_image": face_image,
"output_format": api_params.get("output_format", "jpeg"),
"enable_base64_output": True, # Always get base64 for our use case
"enable_sync_mode": True, # Use sync mode for immediate results
}
# Add any extra parameters (filter out already handled ones)
if options.extra:
handled_keys = {"source_image", "target_image", "face_enhance", "output_format", "enable_sync_mode", "seed"}
for key, value in options.extra.items():
if key not in handled_keys:
payload[key] = value
url = f"{self.client.BASE_URL}/{model_path}"
headers = self.client._headers()
logger.info(f"[Face Swap] Calling WaveSpeed API: {url}")
logger.debug(f"[Face Swap] Payload keys: {list(payload.keys())}")
try:
# Call API
response = requests.post(url, headers=headers, json=payload, timeout=120)
if response.status_code != 200:
logger.error(f"[Face Swap] API call failed: {response.status_code} {response.text}")
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed face swap failed",
"status_code": response.status_code,
"response": response.text,
},
)
response_json = response.json()
data = response_json.get("data") or response_json
# Check status - Akool uses different status values
status = data.get("status", "").lower()
# Akool uses "output" (singular), others use "outputs" (plural)
outputs = data.get("outputs") or data.get("output") or []
# Normalize to list if it's a single value
if not isinstance(outputs, list):
outputs = [outputs] if outputs else []
prediction_id = data.get("id")
# Handle completed status - Akool uses "succeeded", others use "completed"
is_completed = status in ["completed", "succeeded"]
# Handle sync mode - result should be directly in outputs
if outputs and is_completed:
logger.info(f"[Face Swap] Got immediate results (status: {status})")
# Extract image URL or base64
output = outputs[0]
if output.startswith("data:image") or output.startswith("http"):
if output.startswith("http"):
# Download from URL
import requests
img_response = requests.get(output, timeout=60)
img_response.raise_for_status()
image_bytes = img_response.content
else:
# Extract base64 from data URI
image_bytes = base64.b64decode(output.split(",", 1)[1])
else:
# Assume it's base64 string
image_bytes = base64.b64decode(output)
elif prediction_id:
# Need to poll
logger.info(f"[Face Swap] Polling for result (prediction_id: {prediction_id}, status: {status})")
result = self.client.poll_until_complete(prediction_id, timeout_seconds=120, interval_seconds=1.0)
# Check both outputs and output fields
outputs = result.get("outputs") or result.get("output") or []
if not isinstance(outputs, list):
outputs = [outputs] if outputs else []
if not outputs:
raise HTTPException(status_code=502, detail="WaveSpeed face swap returned no outputs")
output = outputs[0]
if output.startswith("http"):
import requests
img_response = requests.get(output, timeout=60)
img_response.raise_for_status()
image_bytes = img_response.content
elif output.startswith("data:image"):
image_bytes = base64.b64decode(output.split(",", 1)[1])
else:
image_bytes = base64.b64decode(output)
else:
raise HTTPException(status_code=502, detail="WaveSpeed face swap response missing outputs and prediction ID")
# Get image dimensions
img = Image.open(io.BytesIO(image_bytes))
width, height = img.size
logger.info(f"[Face Swap] ✅ Successfully swapped face: {len(image_bytes)} bytes, {width}x{height}")
return ImageGenerationResult(
image_bytes=image_bytes,
width=width,
height=height,
provider="wavespeed",
model=options.model or model_path,
metadata={
"model_path": model_path,
"status": status,
"created_at": data.get("created_at"),
},
)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Face Swap] API call failed: {str(e)}", exc_info=True)
raise HTTPException(
status_code=502,
detail={
"error": "Face swap failed",
"message": str(e)
}
)
def swap_face(self, options: FaceSwapOptions) -> ImageGenerationResult:
"""Swap face in image using WaveSpeed models."""
self._validate_options(options)
# Determine model
model_id = options.model
if not model_id:
# Default to first available model
model_id = list(self.SUPPORTED_MODELS.keys())[0]
logger.info(f"[Face Swap] No model specified, using default: {model_id}")
model_info = self.SUPPORTED_MODELS[model_id]
# Call API
return self._call_wavespeed_face_swap_api(options, model_info)
@classmethod
def get_available_models(cls) -> dict:
"""Get available face swap models and their information."""
return cls.SUPPORTED_MODELS
@classmethod
def get_models_by_tier(cls, tier: str) -> dict:
"""Get models filtered by tier (budget, mid, premium)."""
return {
model_id: model_info
for model_id, model_info in cls.SUPPORTED_MODELS.items()
if model_info.get("tier") == tier
}
@classmethod
def get_models_by_capability(cls, capability: str) -> dict:
"""Get models that support a specific capability."""
return {
model_id: model_info
for model_id, model_info in cls.SUPPORTED_MODELS.items()
if capability in model_info.get("capabilities", [])
}

View File

@@ -8,11 +8,16 @@ from typing import Optional, Dict, Any
from .image_generation import (
ImageGenerationOptions,
ImageGenerationResult,
ImageEditOptions,
ImageEditProvider,
HuggingFaceImageProvider,
GeminiImageProvider,
StabilityImageProvider,
WaveSpeedImageProvider,
)
from .image_generation.base import FaceSwapOptions, FaceSwapProvider
from .image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
from .image_generation.wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
from utils.logger_utils import get_service_logger
@@ -47,6 +52,249 @@ def _get_provider(provider_name: str):
raise ValueError(f"Unknown image provider: {provider_name}")
def _get_face_swap_provider(provider_name: str) -> FaceSwapProvider:
"""Get face swap provider by name."""
if provider_name == "wavespeed":
return WaveSpeedFaceSwapProvider()
raise ValueError(f"Unknown face swap provider: {provider_name}")
def _get_edit_provider(provider_name: str) -> ImageEditProvider:
"""Get editing provider instance.
Args:
provider_name: Provider name ("wavespeed", "stability", etc.)
Returns:
ImageEditProvider instance
Raises:
ValueError: If provider is not supported
"""
if provider_name == "wavespeed":
return WaveSpeedEditProvider()
# TODO: Add Stability edit provider if needed
# elif provider_name == "stability":
# return StabilityEditProvider()
else:
raise ValueError(f"Unknown edit provider: {provider_name}")
def _validate_image_operation(
user_id: Optional[str],
operation_type: str = "image-generation",
num_operations: int = 1,
log_prefix: str = "[Image Generation]"
) -> None:
"""
Reusable pre-flight validation helper for all image operations.
Extracted from generate_image() to be reused across all image operation functions.
Args:
user_id: User ID for subscription checking
operation_type: Type of operation (for logging)
num_operations: Number of operations to validate (default: 1)
log_prefix: Logging prefix for operation-specific logs
Raises:
HTTPException: If validation fails (subscription limits exceeded, etc.)
"""
if not user_id:
logger.warning(f"{log_prefix} ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
return
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_generation_operations
from fastapi import HTTPException
logger.info(f"{log_prefix} 🔍 Starting pre-flight validation for user_id={user_id}")
db = next(get_db())
try:
pricing_service = PricingService(db)
# Raises HTTPException immediately if validation fails - frontend gets immediate response
validate_image_generation_operations(
pricing_service=pricing_service,
user_id=user_id,
num_images=num_operations
)
logger.info(f"{log_prefix} ✅ Pre-flight validation passed for user_id={user_id} - proceeding with operation")
except HTTPException as http_ex:
# Re-raise immediately - don't proceed with API call
logger.error(f"{log_prefix} ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
raise
finally:
db.close()
def _track_image_operation_usage(
user_id: str,
provider: str,
model: str,
operation_type: str,
result_bytes: bytes,
cost: float,
prompt: Optional[str] = None,
endpoint: str = "/image-generation",
metadata: Optional[Dict[str, Any]] = None,
log_prefix: str = "[Image Generation]"
) -> Dict[str, Any]:
"""
Reusable usage tracking helper for all image operations.
Extracted from generate_image() to be reused across all image operation functions.
Args:
user_id: User ID for tracking
provider: Provider name (e.g., "wavespeed", "stability")
model: Model name used
operation_type: Type of operation (for logging)
result_bytes: Generated/processed image bytes
cost: Cost of the operation
prompt: Optional prompt text (for request size calculation)
endpoint: API endpoint path (for logging)
metadata: Optional additional metadata
log_prefix: Logging prefix for operation-specific logs
Returns:
Dictionary with tracking information (current_calls, cost, etc.)
"""
try:
from services.database import get_db as get_db_track
db_track = next(get_db_track())
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Get or create usage summary
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush()
# Get current values before update
current_calls_before = getattr(summary, "stability_calls", 0) or 0
current_cost_before = getattr(summary, "stability_cost", 0.0) or 0.0
# Update image calls and cost
new_calls = current_calls_before + 1
new_cost = current_cost_before + cost
# Use direct SQL UPDATE for dynamic attributes
from sqlalchemy import text as sql_text
update_query = sql_text("""
UPDATE usage_summaries
SET stability_calls = :new_calls,
stability_cost = :new_cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_query, {
'new_calls': new_calls,
'new_cost': new_cost,
'user_id': user_id,
'period': current_period
})
# Update total cost
summary.total_cost = (summary.total_cost or 0.0) + cost
summary.total_calls = (summary.total_calls or 0) + 1
summary.updated_at = datetime.utcnow()
# Determine API provider based on actual provider
api_provider = APIProvider.STABILITY # Default for image generation
# Create usage log
request_size = len(prompt.encode("utf-8")) if prompt else 0
usage_log = APIUsageLog(
user_id=user_id,
provider=api_provider,
endpoint=endpoint,
method="POST",
model_used=model or "unknown",
tokens_input=0,
tokens_output=0,
tokens_total=0,
cost_input=0.0,
cost_output=0.0,
cost_total=cost,
response_time=0.0,
status_code=200,
request_size=request_size,
response_size=len(result_bytes),
billing_period=current_period,
)
db_track.add(usage_log)
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
image_limit_display = image_limit if (image_limit > 0 or tier != 'enterprise') else ''
# Get related stats for unified log
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
db_track.commit()
logger.info(f"{log_prefix} ✅ Successfully tracked usage: user {user_id} -> {operation_type} -> {new_calls} calls, ${cost:.4f}")
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
operation_name = operation_type.replace("-", " ").title()
print(f"""
[SUBSCRIPTION] {operation_name}
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: {provider}
├─ Actual Provider: {provider}
├─ Model: {model or 'unknown'}
├─ Calls: {current_calls_before}{new_calls} / {image_limit_display}
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
├─ Audio: {current_audio_calls} / {audio_limit if audio_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
└─ Status: ✅ Allowed & Tracked
""", flush=True)
sys.stdout.flush()
return {
"current_calls": new_calls,
"cost": cost,
"total_cost": new_cost,
}
except Exception as track_error:
logger.error(f"{log_prefix} ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
import traceback
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
db_track.rollback()
return {}
finally:
db_track.close()
except Exception as usage_error:
logger.error(f"{log_prefix} ❌ Failed to track usage: {usage_error}", exc_info=True)
import traceback
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
return {}
def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> ImageGenerationResult:
"""Generate image with pre-flight validation.
@@ -55,32 +303,13 @@ def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_i
options: Image generation options (provider, model, width, height, etc.)
user_id: User ID for subscription checking (optional, but required for validation)
"""
# PRE-FLIGHT VALIDATION: Validate image generation before API call
# MUST happen BEFORE any API calls - return immediately if validation fails
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_generation_operations
from fastapi import HTTPException
logger.info(f"[Image Generation] 🔍 Starting pre-flight validation for user_id={user_id}")
db = next(get_db())
try:
pricing_service = PricingService(db)
# Raises HTTPException immediately if validation fails - frontend gets immediate response
validate_image_generation_operations(
pricing_service=pricing_service,
user_id=user_id
)
logger.info(f"[Image Generation] ✅ Pre-flight validation passed for user_id={user_id} - proceeding with image generation")
except HTTPException as http_ex:
# Re-raise immediately - don't proceed with API call
logger.error(f"[Image Generation] ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
raise
finally:
db.close()
else:
logger.warning(f"[Image Generation] ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
# PRE-FLIGHT VALIDATION: Reuse extracted helper
_validate_image_operation(
user_id=user_id,
operation_type="image-generation",
num_operations=1,
log_prefix="[Image Generation]"
)
opts = options or {}
provider_name = _select_provider(opts.get("provider"))
@@ -114,151 +343,39 @@ def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_i
provider = _get_provider(provider_name)
result = provider.generate(image_options)
# TRACK USAGE after successful API call
has_image_bytes = bool(result.image_bytes) if result else False
image_bytes_len = len(result.image_bytes) if (result and result.image_bytes) else 0
logger.info(f"[Image Generation] Checking tracking conditions: user_id={user_id}, has_result={bool(result)}, has_image_bytes={has_image_bytes}, image_bytes_len={image_bytes_len}")
# TRACK USAGE after successful API call - Reuse extracted helper
if user_id and result and result.image_bytes:
logger.info(f"[Image Generation] ✅ API call successful, tracking usage for user {user_id}")
try:
from services.database import get_db as get_db_track
db_track = next(get_db_track())
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Get or create usage summary
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush()
# Get cost from result metadata or calculate
estimated_cost = 0.0
if result.metadata and "estimated_cost" in result.metadata:
estimated_cost = float(result.metadata["estimated_cost"])
# Calculate cost from result metadata or estimate
estimated_cost = 0.0
if result.metadata and "estimated_cost" in result.metadata:
estimated_cost = float(result.metadata["estimated_cost"])
else:
# Fallback: estimate based on provider/model
if provider_name == "wavespeed":
if result.model and "qwen" in result.model.lower():
estimated_cost = 0.05
else:
# Fallback: estimate based on provider/model
if provider_name == "wavespeed":
if result.model and "qwen" in result.model.lower():
estimated_cost = 0.05
else:
estimated_cost = 0.10 # ideogram-v3-turbo default
elif provider_name == "stability":
estimated_cost = 0.04
else:
estimated_cost = 0.05 # Default estimate
# Get current values before update
current_calls_before = getattr(summary, "stability_calls", 0) or 0
current_cost_before = getattr(summary, "stability_cost", 0.0) or 0.0
# Update image calls and cost
new_calls = current_calls_before + 1
new_cost = current_cost_before + estimated_cost
# Use direct SQL UPDATE for dynamic attributes
from sqlalchemy import text as sql_text
update_query = sql_text("""
UPDATE usage_summaries
SET stability_calls = :new_calls,
stability_cost = :new_cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_query, {
'new_calls': new_calls,
'new_cost': new_cost,
'user_id': user_id,
'period': current_period
})
# Update total cost
summary.total_cost = (summary.total_cost or 0.0) + estimated_cost
summary.total_calls = (summary.total_calls or 0) + 1
summary.updated_at = datetime.utcnow()
# Determine API provider based on actual provider
api_provider = APIProvider.STABILITY # Default for image generation
# Create usage log
usage_log = APIUsageLog(
user_id=user_id,
provider=api_provider,
endpoint="/image-generation",
method="POST",
model_used=result.model or "unknown",
tokens_input=0,
tokens_output=0,
tokens_total=0,
cost_input=0.0,
cost_output=0.0,
cost_total=estimated_cost,
response_time=0.0,
status_code=200,
request_size=len(prompt.encode("utf-8")),
response_size=len(result.image_bytes),
billing_period=current_period,
)
db_track.add(usage_log)
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
image_limit_display = image_limit if (image_limit > 0 or tier != 'enterprise') else ''
# Get related stats for unified log
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
db_track.commit()
logger.info(f"[Image Generation] ✅ Successfully tracked usage: user {user_id} -> image -> {new_calls} calls, ${estimated_cost:.4f}")
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
print(f"""
[SUBSCRIPTION] Image Generation
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: {provider_name}
├─ Actual Provider: {provider_name}
├─ Model: {result.model or 'unknown'}
├─ Calls: {current_calls_before}{new_calls} / {image_limit_display}
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
├─ Audio: {current_audio_calls} / {audio_limit if audio_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
└─ Status: ✅ Allowed & Tracked
""", flush=True)
sys.stdout.flush()
except Exception as track_error:
logger.error(f"[Image Generation] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
import traceback
logger.error(f"[Image Generation] Full traceback: {traceback.format_exc()}")
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
logger.error(f"[Image Generation] ❌ Failed to track usage: {usage_error}", exc_info=True)
import traceback
logger.error(f"[Image Generation] Full traceback: {traceback.format_exc()}")
estimated_cost = 0.10 # ideogram-v3-turbo default
elif provider_name == "stability":
estimated_cost = 0.04
else:
estimated_cost = 0.05 # Default estimate
# Reuse tracking helper
_track_image_operation_usage(
user_id=user_id,
provider=provider_name,
model=result.model or "unknown",
operation_type="image-generation",
result_bytes=result.image_bytes,
cost=estimated_cost,
prompt=prompt,
endpoint="/image-generation",
metadata=result.metadata,
log_prefix="[Image Generation]"
)
else:
logger.warning(f"[Image Generation] ⚠️ Skipping usage tracking: user_id={user_id}, image_bytes={len(result.image_bytes) if result.image_bytes else 0} bytes")
@@ -290,32 +407,13 @@ def generate_character_image(
Returns:
bytes: Generated image bytes with consistent character
"""
# PRE-FLIGHT VALIDATION: Validate image generation before API call
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_generation_operations
from fastapi import HTTPException
logger.info(f"[Character Image Generation] 🔍 Starting pre-flight validation for user_id={user_id}")
db = next(get_db())
try:
pricing_service = PricingService(db)
# Raises HTTPException immediately if validation fails
validate_image_generation_operations(
pricing_service=pricing_service,
user_id=user_id,
num_images=1,
)
logger.info(f"[Character Image Generation] ✅ Pre-flight validation passed for user_id={user_id} - proceeding with character image generation")
except HTTPException as http_ex:
# Re-raise immediately - don't proceed with API call
logger.error(f"[Character Image Generation] ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
raise
finally:
db.close()
else:
logger.warning(f"[Character Image Generation] ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
# PRE-FLIGHT VALIDATION: Reuse extracted helper
_validate_image_operation(
user_id=user_id,
operation_type="character-image-generation",
num_operations=1,
log_prefix="[Character Image Generation]"
)
# Generate character image via WaveSpeed
from services.wavespeed.client import WaveSpeedClient
@@ -332,132 +430,26 @@ def generate_character_image(
timeout=timeout,
)
# TRACK USAGE after successful API call
has_image_bytes = bool(image_bytes) if image_bytes else False
image_bytes_len = len(image_bytes) if image_bytes else 0
logger.info(f"[Character Image Generation] Checking tracking conditions: user_id={user_id}, has_image_bytes={has_image_bytes}, image_bytes_len={image_bytes_len}")
# TRACK USAGE after successful API call - Reuse extracted helper
if user_id and image_bytes:
logger.info(f"[Character Image Generation] ✅ API call successful, tracking usage for user {user_id}")
try:
from services.database import get_db as get_db_track
db_track = next(get_db_track())
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Get or create usage summary
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush()
# Character image cost (same as ideogram-v3-turbo)
estimated_cost = 0.10
current_calls_before = getattr(summary, "stability_calls", 0) or 0
current_cost_before = getattr(summary, "stability_cost", 0.0) or 0.0
new_calls = current_calls_before + 1
new_cost = current_cost_before + estimated_cost
# Use direct SQL UPDATE for dynamic attributes
from sqlalchemy import text as sql_text
update_query = sql_text("""
UPDATE usage_summaries
SET stability_calls = :new_calls,
stability_cost = :new_cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_query, {
'new_calls': new_calls,
'new_cost': new_cost,
'user_id': user_id,
'period': current_period
})
# Update total cost
summary.total_cost = (summary.total_cost or 0.0) + estimated_cost
summary.total_calls = (summary.total_calls or 0) + 1
summary.updated_at = datetime.utcnow()
# Create usage log
usage_log = APIUsageLog(
user_id=user_id,
provider=APIProvider.STABILITY, # Image generation uses STABILITY provider
endpoint="/image-generation/character",
method="POST",
model_used="ideogram-character",
tokens_input=0,
tokens_output=0,
tokens_total=0,
cost_input=0.0,
cost_output=0.0,
cost_total=estimated_cost,
response_time=0.0,
status_code=200,
request_size=len(prompt.encode("utf-8")),
response_size=len(image_bytes),
billing_period=current_period,
)
db_track.add(usage_log)
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
image_limit_display = image_limit if (image_limit > 0 or tier != 'enterprise') else ''
# Get related stats
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
db_track.commit()
# UNIFIED SUBSCRIPTION LOG
print(f"""
[SUBSCRIPTION] Image Generation (Character)
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: wavespeed
├─ Actual Provider: wavespeed
├─ Model: ideogram-character
├─ Calls: {current_calls_before}{new_calls} / {image_limit_display}
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
├─ Audio: {current_audio_calls} / {audio_limit if audio_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
└─ Status: ✅ Allowed & Tracked
""", flush=True)
sys.stdout.flush()
logger.info(f"[Character Image Generation] ✅ Successfully tracked usage: user {user_id} -> {new_calls} calls, ${estimated_cost:.4f}")
except Exception as track_error:
logger.error(f"[Character Image Generation] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
import traceback
logger.error(f"[Character Image Generation] Full traceback: {traceback.format_exc()}")
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
logger.error(f"[Character Image Generation] ❌ Failed to track usage: {usage_error}", exc_info=True)
import traceback
logger.error(f"[Character Image Generation] Full traceback: {traceback.format_exc()}")
# Character image cost (same as ideogram-v3-turbo)
estimated_cost = 0.10
# Reuse tracking helper
_track_image_operation_usage(
user_id=user_id,
provider="wavespeed",
model="ideogram-character",
operation_type="character-image-generation",
result_bytes=image_bytes,
cost=estimated_cost,
prompt=prompt,
endpoint="/image-generation/character",
metadata=None,
log_prefix="[Character Image Generation]"
)
else:
logger.warning(f"[Character Image Generation] ⚠️ Skipping usage tracking: user_id={user_id}, image_bytes={len(image_bytes) if image_bytes else 0} bytes")
@@ -476,3 +468,210 @@ def generate_character_image(
)
def generate_image_edit(
image_base64: str,
prompt: str,
operation: str = "general_edit",
model: Optional[str] = None,
options: Optional[Dict[str, Any]] = None,
user_id: Optional[str] = None
) -> ImageGenerationResult:
"""
Generate edited image - REUSES validation and tracking helpers.
Args:
image_base64: Base64-encoded input image (or data URI)
prompt: Edit instruction prompt
operation: Type of edit operation (e.g., "general_edit", "inpaint", "outpaint")
model: Model ID to use (default: auto-select based on provider)
options: Additional options (mask_base64, negative_prompt, width, height, etc.)
user_id: User ID for validation and tracking
Returns:
ImageGenerationResult with edited image
Raises:
HTTPException: If validation fails or editing fails
ValueError: If options are invalid
"""
# 1. REUSE: Validation helper
_validate_image_operation(
user_id=user_id,
operation_type="image-edit",
num_operations=1,
log_prefix="[Image Edit]"
)
# 2. Determine provider from model or default to wavespeed
opts = options or {}
provider_name = opts.get("provider", "wavespeed")
# If model is specified and starts with "wavespeed", use wavespeed provider
if model and (model.startswith("wavespeed") or model.startswith("qwen") or model.startswith("flux") or model.startswith("nano-banana")):
provider_name = "wavespeed"
# 3. Get provider (REUSES provider pattern)
try:
provider = _get_edit_provider(provider_name)
except ValueError as e:
logger.error(f"[Image Edit] ❌ Provider error: {str(e)}")
raise ValueError(f"Unsupported edit provider: {provider_name}")
# 4. Prepare edit options
edit_options = ImageEditOptions(
image_base64=image_base64,
prompt=prompt,
operation=operation,
mask_base64=opts.get("mask_base64"),
negative_prompt=opts.get("negative_prompt"),
model=model,
width=opts.get("width"),
height=opts.get("height"),
guidance_scale=opts.get("guidance_scale"),
steps=opts.get("steps"),
seed=opts.get("seed"),
extra=opts.get("extra"),
)
# 5. Edit image
logger.info(f"[Image Edit] Starting edit: operation={operation}, model={model}, provider={provider_name}")
try:
result = provider.edit(edit_options)
except Exception as e:
logger.error(f"[Image Edit] ❌ Edit failed: {str(e)}", exc_info=True)
raise HTTPException(
status_code=502,
detail={
"error": "Image editing failed",
"message": str(e)
}
)
def generate_face_swap(
base_image_base64: str,
face_image_base64: str,
model: Optional[str] = None,
options: Optional[Dict[str, Any]] = None,
user_id: Optional[str] = None
) -> ImageGenerationResult:
"""
Generate face swap - REUSES validation and tracking helpers.
Args:
base_image_base64: Base64-encoded base image (or data URI)
face_image_base64: Base64-encoded face image to swap (or data URI)
model: Model ID to use (default: auto-select)
options: Additional options (target_face_index, target_gender, etc.)
user_id: User ID for validation and tracking
Returns:
ImageGenerationResult with swapped face image
Raises:
HTTPException: If validation fails or face swap fails
ValueError: If options are invalid
"""
# 1. REUSE: Validation helper
_validate_image_operation(
user_id=user_id,
operation_type="face-swap",
image_base64=base_image_base64, # Use base image for validation
log_prefix="[Face Swap]"
)
# 2. Get provider (default to wavespeed)
provider_name = "wavespeed"
provider = _get_face_swap_provider(provider_name)
# 3. Prepare options
face_swap_options = FaceSwapOptions(
base_image_base64=base_image_base64,
face_image_base64=face_image_base64,
model=model,
target_face_index=options.get("target_face_index") if options else None,
target_gender=options.get("target_gender") if options else None,
extra=options,
)
# 4. Swap face
try:
result = provider.swap_face(face_swap_options)
# 5. REUSE: Tracking helper
if user_id and result and result.image_bytes:
logger.info(f"[Face Swap] ✅ API call successful, tracking usage for user {user_id}")
# Get model cost
model_id = model or (list(WaveSpeedFaceSwapProvider.SUPPORTED_MODELS.keys())[0] if WaveSpeedFaceSwapProvider.SUPPORTED_MODELS else "unknown")
model_info = WaveSpeedFaceSwapProvider.SUPPORTED_MODELS.get(model_id, {})
estimated_cost = model_info.get("cost", 0.025) # Default to Pro cost
# Reuse tracking helper
_track_image_operation_usage(
user_id=user_id,
provider=provider_name,
model=model_id,
operation_type="face-swap",
result_bytes=result.image_bytes,
cost=estimated_cost,
prompt=None, # Face swap doesn't use prompts
endpoint="/image-studio/face-swap/process",
metadata={
"base_image_size": len(base_image_base64),
"face_image_size": len(face_image_base64),
},
log_prefix="[Face Swap]"
)
else:
logger.warning(f"[Face Swap] ⚠️ Skipping usage tracking: user_id={user_id}, image_bytes={len(result.image_bytes) if result and result.image_bytes else 0} bytes")
return result
except HTTPException:
raise
except Exception as api_error:
logger.error(f"[Face Swap] Face swap API failed: {api_error}")
raise HTTPException(
status_code=502,
detail={
"error": "Face swap failed",
"message": str(api_error)
}
)
# 6. REUSE: Tracking helper
if user_id and result and result.image_bytes:
logger.info(f"[Image Edit] ✅ API call successful, tracking usage for user {user_id}")
# Get cost from result metadata or estimate
estimated_cost = 0.0
if result.metadata and "estimated_cost" in result.metadata:
estimated_cost = float(result.metadata["estimated_cost"])
else:
# Fallback: estimate based on provider/model
if provider_name == "wavespeed":
# Default WaveSpeed edit cost
estimated_cost = 0.02 # Default for most editing models
else:
estimated_cost = 0.05 # Default estimate
# Reuse tracking helper
_track_image_operation_usage(
user_id=user_id,
provider=provider_name,
model=result.model or model or "unknown",
operation_type="image-edit",
result_bytes=result.image_bytes,
cost=estimated_cost,
prompt=prompt,
endpoint="/image-generation/edit",
metadata=result.metadata,
log_prefix="[Image Edit]"
)
else:
logger.warning(f"[Image Edit] ⚠️ Skipping usage tracking: user_id={user_id}, image_bytes={len(result.image_bytes) if result.image_bytes else 0} bytes")
return result

View File

@@ -7,6 +7,9 @@ from .asset_audit import AssetAuditService
from .channel_pack import ChannelPackService
from .campaign_storage import CampaignStorageService
from .product_image_service import ProductImageService
from .product_animation_service import ProductAnimationService, ProductAnimationRequest
from .product_video_service import ProductVideoService, ProductVideoRequest
from .product_avatar_service import ProductAvatarService, ProductAvatarRequest
__all__ = [
"ProductMarketingOrchestrator",
@@ -16,5 +19,11 @@ __all__ = [
"ChannelPackService",
"CampaignStorageService",
"ProductImageService",
"ProductAnimationService",
"ProductAnimationRequest",
"ProductVideoService",
"ProductVideoRequest",
"ProductAvatarService",
"ProductAvatarRequest",
]

View File

@@ -163,6 +163,7 @@ class ProductMarketingOrchestrator:
"asset_id": asset_node.asset_id,
"asset_type": asset_node.asset_type,
"channel": asset_node.channel,
"campaign_id": blueprint.campaign_id, # Include campaign_id for tracking
"proposed_prompt": enhanced_prompt,
"recommended_template": recommended_template.get('id') if recommended_template else None,
"recommended_provider": recommended_template.get('recommended_provider', 'wavespeed') if recommended_template else 'wavespeed',
@@ -170,6 +171,67 @@ class ProductMarketingOrchestrator:
"concept_summary": self._generate_concept_summary(enhanced_prompt),
}
elif asset_node.asset_type == "video":
# Video asset proposals - determine if animation (image-to-video) or demo (text-to-video)
# Default to animation if we have product image, otherwise demo
video_subtype = asset_proposal.get('video_subtype', 'animation') if 'asset_proposal' in locals() else 'demo'
# For demo videos (text-to-video), we need product description
if video_subtype == "demo" or not product_context or not product_context.get('product_image_base64'):
# Text-to-video demo video
video_type = "demo" # Default, can be customized
if asset_node.channel in ["tiktok", "instagram"]:
video_type = "storytelling" # Storytelling for social media
elif asset_node.channel in ["linkedin", "youtube"]:
video_type = "feature_highlight" # Feature highlights for professional
# Estimate cost for text-to-video (WAN 2.5: $0.05-$0.15/second)
duration = 10 # Default 10s for demo videos
resolution = "720p" # Default
cost_per_second = 0.10 if resolution == "720p" else (0.15 if resolution == "1080p" else 0.05)
cost_estimate = duration * cost_per_second
proposals[asset_node.asset_id] = {
"asset_id": asset_node.asset_id,
"asset_type": asset_node.asset_type,
"video_subtype": "demo", # Text-to-video
"channel": asset_node.channel,
"campaign_id": blueprint.campaign_id,
"video_type": video_type,
"duration": duration,
"resolution": resolution,
"cost_estimate": cost_estimate,
"concept_summary": f"Product {video_type} video optimized for {asset_node.channel}",
"note": "Text-to-video demo - requires product description",
}
else:
# Image-to-video animation
animation_type = "reveal" # Default
if asset_node.channel in ["tiktok", "instagram", "youtube"]:
animation_type = "demo" # Demo animations for social media
elif asset_node.channel in ["linkedin", "facebook"]:
animation_type = "reveal" # Professional reveal for B2B
# Estimate cost for image-to-video (WAN 2.5: $0.05-$0.15/second)
duration = 5 # Default 5s for animations
resolution = "720p" # Default
cost_per_second = 0.10 if resolution == "720p" else (0.15 if resolution == "1080p" else 0.05)
cost_estimate = duration * cost_per_second
proposals[asset_node.asset_id] = {
"asset_id": asset_node.asset_id,
"asset_type": asset_node.asset_type,
"video_subtype": "animation", # Image-to-video
"channel": asset_node.channel,
"campaign_id": blueprint.campaign_id,
"animation_type": animation_type,
"duration": duration,
"resolution": resolution,
"cost_estimate": cost_estimate,
"concept_summary": f"Product {animation_type} animation optimized for {asset_node.channel}",
"note": "Requires product image - will be provided during generation",
}
elif asset_node.asset_type == "text":
base_request = f"Write {asset_node.channel} {asset_node.asset_type} for product launch"
enhanced_prompt = self.prompt_builder.build_marketing_copy_prompt(
@@ -184,6 +246,7 @@ class ProductMarketingOrchestrator:
"asset_id": asset_node.asset_id,
"asset_type": asset_node.asset_type,
"channel": asset_node.channel,
"campaign_id": blueprint.campaign_id, # Include campaign_id for tracking
"proposed_prompt": enhanced_prompt,
"cost_estimate": 0.0, # Text generation cost is minimal
"concept_summary": "Marketing copy optimized for channel and persona",
@@ -242,6 +305,124 @@ class ProductMarketingOrchestrator:
],
}
elif asset_type == "video":
# Check video subtype: "animation" (image-to-video) or "demo" (text-to-video)
video_subtype = asset_proposal.get('video_subtype', 'animation')
if video_subtype == "demo":
# Text-to-video: Product demo video from description
from .product_video_service import ProductVideoService, ProductVideoRequest
# Get product info from context
product_name = product_context.get('product_name', 'Product') if product_context else 'Product'
product_description = product_context.get('product_description', '') if product_context else ''
if not product_description:
raise ValueError("Product description required for text-to-video demo generation")
# Get brand context
brand_dna = self.brand_dna_sync.get_brand_dna_tokens(user_id)
brand_context = {
"visual_identity": brand_dna.get("visual_identity", {}),
"persona": brand_dna.get("persona", {}),
}
# Get video type from proposal or default
video_type = asset_proposal.get('video_type', 'demo')
# Create video service
video_service = ProductVideoService()
# Create video request
video_request = ProductVideoRequest(
product_name=product_name,
product_description=product_description,
video_type=video_type,
resolution=asset_proposal.get('resolution', '720p'),
duration=asset_proposal.get('duration', 10),
audio_base64=asset_proposal.get('audio_base64'),
brand_context=brand_context,
additional_context=asset_proposal.get('additional_context'),
)
# Generate video using unified ai_video_generate()
result = await video_service.generate_product_video(video_request, user_id)
# Extract campaign_id for metadata
campaign_id = asset_proposal.get('campaign_id')
asset_id = asset_proposal.get('asset_id', '')
return {
"success": True,
"asset_type": "video",
"video_subtype": "demo",
"video_url": result.get('file_url'),
"video_filename": result.get('filename'),
"cost": result.get('cost', 0.0),
"video_type": video_type,
"campaign_id": campaign_id,
"asset_id": asset_id,
}
else:
# Image-to-video: Product animation
from .product_animation_service import ProductAnimationService, ProductAnimationRequest
# Get product image from proposal or product context
product_image_base64 = asset_proposal.get('product_image_base64')
if not product_image_base64 and product_context:
product_image_base64 = product_context.get('product_image_base64')
if not product_image_base64:
raise ValueError("Product image required for image-to-video animation generation")
# Get animation type from proposal or default to "reveal"
animation_type = asset_proposal.get('animation_type', 'reveal')
product_name = product_context.get('product_name', 'Product') if product_context else 'Product'
product_description = product_context.get('product_description') if product_context else None
# Get brand context
brand_dna = self.brand_dna_sync.get_brand_dna_tokens(user_id)
brand_context = {
"visual_identity": brand_dna.get("visual_identity", {}),
"persona": brand_dna.get("persona", {}),
}
# Create animation service
animation_service = ProductAnimationService()
# Create animation request
animation_request = ProductAnimationRequest(
product_image_base64=product_image_base64,
animation_type=animation_type,
product_name=product_name,
product_description=product_description,
resolution=asset_proposal.get('resolution', '720p'),
duration=asset_proposal.get('duration', 5),
audio_base64=asset_proposal.get('audio_base64'),
brand_context=brand_context,
additional_context=asset_proposal.get('additional_context'),
)
# Generate video
result = await animation_service.animate_product(animation_request, user_id)
# Extract campaign_id for metadata
campaign_id = asset_proposal.get('campaign_id')
asset_id = asset_proposal.get('asset_id', '')
return {
"success": True,
"asset_type": "video",
"video_subtype": "animation",
"video_url": result.get('video_url'),
"video_filename": result.get('filename'),
"cost": result.get('cost', 0.0),
"animation_type": animation_type,
"campaign_id": campaign_id,
"asset_id": asset_id,
}
elif asset_type == "text":
# Import text generation service and tracker
import asyncio
@@ -457,6 +638,10 @@ Return only the final copy text without explanations or markdown formatting."""
if asset_type == "image":
# Premium quality image: ~5-6 credits
return 5.0
elif asset_type == "video":
# WAN 2.5 Image-to-Video: $0.05-$0.15/second
# Default: 5 seconds at 720p = $0.50
return 0.50
elif asset_type == "text":
return 0.0 # Text generation is typically included
else:

View File

@@ -0,0 +1,221 @@
"""
Product Animation Service
Handles product animation workflows using Transform Studio (WAN 2.5 Image-to-Video).
"""
from typing import Dict, Any, Optional
from loguru import logger
from dataclasses import dataclass
from services.image_studio.transform_service import TransformStudioService, TransformImageToVideoRequest
from services.image_studio.studio_manager import ImageStudioManager
from utils.logger_utils import get_service_logger
logger = get_service_logger("product_marketing.animation")
@dataclass
class ProductAnimationRequest:
"""Request for product animation."""
product_image_base64: str
animation_type: str # "reveal", "rotation", "demo", "lifestyle"
product_name: str
product_description: Optional[str] = None
resolution: str = "720p" # 480p, 720p, 1080p
duration: int = 5 # 5 or 10 seconds
audio_base64: Optional[str] = None
brand_context: Optional[Dict[str, Any]] = None
additional_context: Optional[str] = None
class ProductAnimationService:
"""Service for product animation workflows."""
def __init__(self):
"""Initialize Product Animation Service."""
self.transform_service = TransformStudioService()
self.image_studio = ImageStudioManager()
logger.info("[Product Animation Service] Initialized")
def _build_animation_prompt(
self,
animation_type: str,
product_name: str,
product_description: Optional[str],
brand_context: Optional[Dict[str, Any]],
additional_context: Optional[str]
) -> str:
"""
Build animation prompt based on animation type and product context.
Args:
animation_type: Type of animation (reveal, rotation, demo, lifestyle)
product_name: Product name
product_description: Product description
brand_context: Brand DNA context
additional_context: Additional context
Returns:
Animation prompt
"""
base_prompt = f"{product_name}"
if product_description:
base_prompt += f": {product_description}"
# Animation-specific prompts
animation_prompts = {
"reveal": f"{base_prompt} elegantly revealing, smooth camera movement, professional product showcase, cinematic lighting",
"rotation": f"{base_prompt} slowly rotating 360 degrees, smooth rotation, professional product photography, studio lighting, clean background",
"demo": f"{base_prompt} in use, demonstrating features, dynamic movement, engaging presentation, professional product demo",
"lifestyle": f"{base_prompt} in realistic lifestyle setting, natural environment, authentic use case, relatable scenario",
}
prompt = animation_prompts.get(animation_type, base_prompt)
# Add brand context if available
if brand_context:
visual_identity = brand_context.get("visual_identity", {})
if visual_identity.get("color_palette"):
colors = ", ".join(visual_identity["color_palette"][:3]) # First 3 colors
prompt += f", {colors} color scheme"
if visual_identity.get("style_guidelines"):
style = visual_identity["style_guidelines"].get("aesthetic", "")
if style:
prompt += f", {style} style"
# Add additional context
if additional_context:
prompt += f", {additional_context}"
return prompt
async def animate_product(
self,
request: ProductAnimationRequest,
user_id: str
) -> Dict[str, Any]:
"""
Animate a product image into a video.
Args:
request: Product animation request
user_id: User ID for tracking
Returns:
Animation result with video URL and metadata
"""
try:
logger.info(
f"[Product Animation] Animating product '{request.product_name}' "
f"with type '{request.animation_type}' for user {user_id}"
)
# Build animation prompt
animation_prompt = self._build_animation_prompt(
animation_type=request.animation_type,
product_name=request.product_name,
product_description=request.product_description,
brand_context=request.brand_context,
additional_context=request.additional_context
)
# Create transform request
transform_request = TransformImageToVideoRequest(
image_base64=request.product_image_base64,
prompt=animation_prompt,
audio_base64=request.audio_base64,
resolution=request.resolution,
duration=request.duration,
enable_prompt_expansion=True, # Expand prompt for better results
)
# Generate video using Transform Studio
result = await self.transform_service.transform_image_to_video(
request=transform_request,
user_id=user_id
)
# Add product-specific metadata
result["product_name"] = request.product_name
result["animation_type"] = request.animation_type
result["source_module"] = "product_marketing"
logger.info(
f"[Product Animation] ✅ Product animation completed: "
f"cost=${result.get('cost', 0):.2f}, video_url={result.get('video_url', 'N/A')}"
)
return result
except Exception as e:
logger.error(f"[Product Animation] ❌ Error animating product: {str(e)}", exc_info=True)
raise
async def create_product_reveal(
self,
product_image_base64: str,
product_name: str,
product_description: Optional[str],
user_id: str,
resolution: str = "720p",
duration: int = 5,
brand_context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Create product reveal animation."""
request = ProductAnimationRequest(
product_image_base64=product_image_base64,
animation_type="reveal",
product_name=product_name,
product_description=product_description,
resolution=resolution,
duration=duration,
brand_context=brand_context
)
return await self.animate_product(request, user_id)
async def create_product_rotation(
self,
product_image_base64: str,
product_name: str,
product_description: Optional[str],
user_id: str,
resolution: str = "720p",
duration: int = 10, # Longer for full rotation
brand_context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Create 360° product rotation animation."""
request = ProductAnimationRequest(
product_image_base64=product_image_base64,
animation_type="rotation",
product_name=product_name,
product_description=product_description,
resolution=resolution,
duration=duration,
brand_context=brand_context
)
return await self.animate_product(request, user_id)
async def create_product_demo(
self,
product_image_base64: str,
product_name: str,
product_description: Optional[str],
user_id: str,
resolution: str = "720p",
duration: int = 10,
audio_base64: Optional[str] = None,
brand_context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Create product demo video."""
request = ProductAnimationRequest(
product_image_base64=product_image_base64,
animation_type="demo",
product_name=product_name,
product_description=product_description,
resolution=resolution,
duration=duration,
audio_base64=audio_base64,
brand_context=brand_context
)
return await self.animate_product(request, user_id)

View File

@@ -0,0 +1,380 @@
"""
Product Avatar Service
Handles product explainer video generation using InfiniteTalk (talking avatars).
"""
from typing import Dict, Any, Optional
from loguru import logger
from dataclasses import dataclass
from pathlib import Path
import uuid
import os
import base64
from services.image_studio.infinitetalk_adapter import InfiniteTalkService
from services.story_writer.audio_generation_service import StoryAudioGenerationService
from utils.logger_utils import get_service_logger
logger = get_service_logger("product_marketing.avatar")
@dataclass
class ProductAvatarRequest:
"""Request for product explainer video with talking avatar."""
avatar_image_base64: str # Product image, brand spokesperson, or brand mascot
script_text: Optional[str] = None # Text script to convert to audio
audio_base64: Optional[str] = None # Pre-generated audio (alternative to script_text)
product_name: str = "Product"
product_description: Optional[str] = None
explainer_type: str = "product_overview" # product_overview, feature_explainer, tutorial, brand_message
resolution: str = "720p" # 480p or 720p
prompt: Optional[str] = None # Optional prompt for expression/style
mask_image_base64: Optional[str] = None # Optional mask for animatable regions
seed: Optional[int] = None
brand_context: Optional[Dict[str, Any]] = None
additional_context: Optional[str] = None
class ProductAvatarService:
"""Service for product explainer video generation using InfiniteTalk."""
def __init__(self):
"""Initialize Product Avatar Service."""
self.infinitetalk_service = InfiniteTalkService()
self.audio_service = StoryAudioGenerationService()
logger.info("[Product Avatar Service] Initialized")
def _build_avatar_prompt(
self,
explainer_type: str,
product_name: str,
product_description: Optional[str],
brand_context: Optional[Dict[str, Any]],
additional_context: Optional[str]
) -> str:
"""
Build avatar prompt based on explainer type and product context.
Args:
explainer_type: Type of explainer (product_overview, feature_explainer, tutorial, brand_message)
product_name: Product name
product_description: Product description
brand_context: Brand DNA context
additional_context: Additional context
Returns:
Avatar animation prompt
"""
base_description = f"{product_name}"
if product_description:
base_description += f": {product_description}"
# Explainer type-specific prompts
explainer_prompts = {
"product_overview": (
f"Professional product presentation of {base_description}, "
f"engaging and informative, clear communication, confident expression, "
f"professional setting, modern and clean aesthetic"
),
"feature_explainer": (
f"Demonstrating features of {base_description}, "
f"detailed explanation, pointing gestures, clear visual communication, "
f"educational and informative, professional presentation"
),
"tutorial": (
f"Tutorial presentation for {base_description}, "
f"step-by-step explanation, instructional and clear, "
f"friendly and approachable, educational setting"
),
"brand_message": (
f"Brand message delivery for {base_description}, "
f"authentic and compelling, brand storytelling, "
f"emotional connection, professional brand representation"
),
}
prompt = explainer_prompts.get(explainer_type, base_description)
# Add brand context if available
if brand_context:
visual_identity = brand_context.get("visual_identity", {})
if visual_identity.get("style_guidelines"):
style = visual_identity["style_guidelines"].get("aesthetic", "")
if style:
prompt += f", {style} style"
# Add brand values if available
if visual_identity.get("brand_values"):
values = ", ".join(visual_identity["brand_values"][:2]) # First 2 values
prompt += f", embodying {values}"
# Add additional context
if additional_context:
prompt += f", {additional_context}"
return prompt
def _generate_audio_from_script(
self,
script_text: str,
user_id: str,
output_dir: Path
) -> bytes:
"""
Generate audio from script text using TTS.
Args:
script_text: Text to convert to speech
user_id: User ID for tracking
output_dir: Directory to save temporary audio file
Returns:
Audio bytes
"""
try:
# Create temporary audio file
audio_filename = f"avatar_audio_{uuid.uuid4().hex[:8]}.mp3"
audio_path = output_dir / audio_filename
# Generate audio using gTTS (free, always available)
# Note: For premium voices, we could integrate Minimax voice clone here
success = self.audio_service._generate_audio_gtts(
text=script_text,
output_path=audio_path,
lang="en",
slow=False
)
if not success:
raise RuntimeError("Failed to generate audio from script")
# Read audio bytes
with open(audio_path, 'rb') as f:
audio_bytes = f.read()
# Clean up temporary file
try:
os.remove(audio_path)
except Exception:
pass
logger.info(f"[Product Avatar] Generated audio from script: {len(audio_bytes)} bytes")
return audio_bytes
except Exception as e:
logger.error(f"[Product Avatar] Error generating audio: {str(e)}", exc_info=True)
raise
async def generate_product_explainer(
self,
request: ProductAvatarRequest,
user_id: str
) -> Dict[str, Any]:
"""
Generate product explainer video using InfiniteTalk.
Args:
request: Product avatar request
user_id: User ID for tracking
Returns:
Explainer video result with video URL and metadata
"""
try:
logger.info(
f"[Product Avatar] Generating {request.explainer_type} explainer for product '{request.product_name}' "
f"for user {user_id}"
)
# Prepare audio
audio_base64 = request.audio_base64
if not audio_base64 and request.script_text:
# Generate audio from script
base_dir = Path(__file__).parent.parent.parent.parent
temp_dir = base_dir / "temp_audio"
temp_dir.mkdir(parents=True, exist_ok=True)
audio_bytes = self._generate_audio_from_script(
script_text=request.script_text,
user_id=user_id,
output_dir=temp_dir
)
# Convert to base64
audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
audio_base64 = f"data:audio/mpeg;base64,{audio_base64}"
if not audio_base64:
raise ValueError("Either audio_base64 or script_text must be provided")
# Build avatar prompt
avatar_prompt = request.prompt
if not avatar_prompt:
avatar_prompt = self._build_avatar_prompt(
explainer_type=request.explainer_type,
product_name=request.product_name,
product_description=request.product_description,
brand_context=request.brand_context,
additional_context=request.additional_context
)
# Generate video using InfiniteTalk
result = await self.infinitetalk_service.create_talking_avatar(
image_base64=request.avatar_image_base64,
audio_base64=audio_base64,
resolution=request.resolution,
prompt=avatar_prompt,
mask_image_base64=request.mask_image_base64,
seed=request.seed,
user_id=user_id,
)
# Extract video bytes and save to user directory
video_bytes = result.get("video_bytes")
if not video_bytes:
raise ValueError("Avatar generation returned no video bytes")
# Save video file
base_dir = Path(__file__).parent.parent.parent.parent
output_dir = base_dir / "product_avatars"
output_dir.mkdir(parents=True, exist_ok=True)
# Create user-specific directory
user_dir = output_dir / user_id
user_dir.mkdir(parents=True, exist_ok=True)
# Generate filename
safe_product_name = "".join(c for c in request.product_name if c.isalnum() or c in (' ', '-', '_')).strip()[:30]
filename = f"explainer_{safe_product_name}_{request.explainer_type}_{uuid.uuid4().hex[:8]}.mp4"
filename = filename.replace(" ", "_").replace("/", "_").replace("\\", "_")
# Save file
file_path = user_dir / filename
with open(file_path, 'wb') as f:
f.write(video_bytes)
# Check file size (500MB max)
file_size = os.path.getsize(file_path)
if file_size > 500 * 1024 * 1024:
os.remove(file_path)
raise RuntimeError(f"Video file too large: {file_size / (1024*1024):.2f}MB (max 500MB)")
file_url = f"/api/product-marketing/avatars/{user_id}/{filename}"
# Add product-specific metadata
result["product_name"] = request.product_name
result["explainer_type"] = request.explainer_type
result["source_module"] = "product_marketing"
result["filename"] = filename
result["file_path"] = str(file_path)
result["file_url"] = file_url
result["file_size"] = file_size
result["duration"] = result.get("duration", 0.0)
logger.info(
f"[Product Avatar] ✅ Product explainer video generated successfully: "
f"cost=${result.get('cost', 0):.2f}, duration={result.get('duration', 0):.1f}s, "
f"video_url={file_url}"
)
return result
except Exception as e:
logger.error(f"[Product Avatar] ❌ Error generating product explainer: {str(e)}", exc_info=True)
raise
async def create_product_overview(
self,
avatar_image_base64: str,
script_text: str,
product_name: str,
product_description: Optional[str],
user_id: str,
resolution: str = "720p",
audio_base64: Optional[str] = None,
brand_context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Create product overview explainer video."""
request = ProductAvatarRequest(
avatar_image_base64=avatar_image_base64,
script_text=script_text,
audio_base64=audio_base64,
product_name=product_name,
product_description=product_description,
explainer_type="product_overview",
resolution=resolution,
brand_context=brand_context
)
return await self.generate_product_explainer(request, user_id)
async def create_feature_explainer(
self,
avatar_image_base64: str,
script_text: str,
product_name: str,
product_description: Optional[str],
user_id: str,
resolution: str = "720p",
audio_base64: Optional[str] = None,
brand_context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Create product feature explainer video."""
request = ProductAvatarRequest(
avatar_image_base64=avatar_image_base64,
script_text=script_text,
audio_base64=audio_base64,
product_name=product_name,
product_description=product_description,
explainer_type="feature_explainer",
resolution=resolution,
brand_context=brand_context
)
return await self.generate_product_explainer(request, user_id)
async def create_tutorial(
self,
avatar_image_base64: str,
script_text: str,
product_name: str,
product_description: Optional[str],
user_id: str,
resolution: str = "720p",
audio_base64: Optional[str] = None,
brand_context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Create product tutorial video."""
request = ProductAvatarRequest(
avatar_image_base64=avatar_image_base64,
script_text=script_text,
audio_base64=audio_base64,
product_name=product_name,
product_description=product_description,
explainer_type="tutorial",
resolution=resolution,
brand_context=brand_context
)
return await self.generate_product_explainer(request, user_id)
async def create_brand_message(
self,
avatar_image_base64: str,
script_text: str,
product_name: str,
product_description: Optional[str],
user_id: str,
resolution: str = "720p",
audio_base64: Optional[str] = None,
brand_context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Create brand message video."""
request = ProductAvatarRequest(
avatar_image_base64=avatar_image_base64,
script_text=script_text,
audio_base64=audio_base64,
product_name=product_name,
product_description=product_description,
explainer_type="brand_message",
resolution=resolution,
brand_context=brand_context
)
return await self.generate_product_explainer(request, user_id)

View File

@@ -0,0 +1,312 @@
"""
Product Video Service
Handles product demo video generation using WAN 2.5 Text-to-Video via main_video_generation.
"""
from typing import Dict, Any, Optional
from loguru import logger
from dataclasses import dataclass
from services.llm_providers.main_video_generation import ai_video_generate
from utils.logger_utils import get_service_logger
logger = get_service_logger("product_marketing.video")
@dataclass
class ProductVideoRequest:
"""Request for product demo video generation."""
product_name: str
product_description: str
video_type: str # "demo", "storytelling", "feature_highlight", "launch"
resolution: str = "720p" # 480p, 720p, 1080p
duration: int = 10 # 5 or 10 seconds
audio_base64: Optional[str] = None
brand_context: Optional[Dict[str, Any]] = None
additional_context: Optional[str] = None
negative_prompt: Optional[str] = None
seed: Optional[int] = None
class ProductVideoService:
"""Service for product demo video generation using WAN 2.5 Text-to-Video."""
def __init__(self):
"""Initialize Product Video Service."""
logger.info("[Product Video Service] Initialized")
def _build_video_prompt(
self,
video_type: str,
product_name: str,
product_description: str,
brand_context: Optional[Dict[str, Any]],
additional_context: Optional[str]
) -> str:
"""
Build video prompt based on video type and product context.
Args:
video_type: Type of video (demo, storytelling, feature_highlight, launch)
product_name: Product name
product_description: Product description
brand_context: Brand DNA context
additional_context: Additional context
Returns:
Video generation prompt
"""
base_description = f"{product_name}"
if product_description:
base_description += f": {product_description}"
# Video type-specific prompts
video_prompts = {
"demo": (
f"{base_description} being demonstrated in use, showcasing key features and benefits, "
f"professional product demonstration, dynamic camera movement, engaging presentation, "
f"clear product visibility, modern and clean aesthetic"
),
"storytelling": (
f"Story of {base_description}, narrative-driven product showcase, emotional connection, "
f"cinematic storytelling, compelling visual narrative, professional cinematography, "
f"engaging product story"
),
"feature_highlight": (
f"{base_description} highlighting key features, close-up shots of important details, "
f"feature-focused presentation, professional product photography, clear feature visibility, "
f"modern and sleek aesthetic"
),
"launch": (
f"{base_description} product launch reveal, exciting unveiling, dynamic presentation, "
f"professional product showcase, launch event aesthetic, engaging and energetic, "
f"modern and premium feel"
),
}
prompt = video_prompts.get(video_type, base_description)
# Add brand context if available
if brand_context:
visual_identity = brand_context.get("visual_identity", {})
if visual_identity.get("color_palette"):
colors = ", ".join(visual_identity["color_palette"][:3]) # First 3 colors
prompt += f", {colors} color scheme"
if visual_identity.get("style_guidelines"):
style = visual_identity["style_guidelines"].get("aesthetic", "")
if style:
prompt += f", {style} style"
# Add brand values if available
if visual_identity.get("brand_values"):
values = ", ".join(visual_identity["brand_values"][:2]) # First 2 values
prompt += f", embodying {values}"
# Add additional context
if additional_context:
prompt += f", {additional_context}"
return prompt
async def generate_product_video(
self,
request: ProductVideoRequest,
user_id: str
) -> Dict[str, Any]:
"""
Generate product demo video using WAN 2.5 Text-to-Video.
This method uses the unified ai_video_generate() entry point which handles:
- Pre-flight validation
- Usage tracking
- Cost tracking
- Error handling
Args:
request: Product video request
user_id: User ID for tracking
Returns:
Video generation result with video URL and metadata
"""
try:
logger.info(
f"[Product Video] Generating {request.video_type} video for product '{request.product_name}' "
f"for user {user_id}"
)
# Build video prompt
video_prompt = self._build_video_prompt(
video_type=request.video_type,
product_name=request.product_name,
product_description=request.product_description,
brand_context=request.brand_context,
additional_context=request.additional_context
)
# Build negative prompt (default to avoid common issues)
negative_prompt = request.negative_prompt or (
"blurry, low quality, distorted, deformed, ugly, bad anatomy, "
"watermark, text overlay, logo, signature"
)
# Generate video using unified entry point
# This handles pre-flight validation, usage tracking, and cost tracking automatically
result = await ai_video_generate(
prompt=video_prompt,
operation_type="text-to-video",
provider="wavespeed",
user_id=user_id,
model="alibaba/wan-2.5/text-to-video", # WAN 2.5 Text-to-Video
duration=request.duration,
resolution=request.resolution,
audio_base64=request.audio_base64,
negative_prompt=negative_prompt,
seed=request.seed,
enable_prompt_expansion=True, # Enable prompt optimization
)
# Extract video bytes and save to user directory
video_bytes = result.get("video_bytes")
if not video_bytes:
raise ValueError("Video generation returned no video bytes")
# Save video file (similar to Transform Studio)
from pathlib import Path
import uuid
import os
base_dir = Path(__file__).parent.parent.parent.parent
output_dir = base_dir / "product_videos"
output_dir.mkdir(parents=True, exist_ok=True)
# Create user-specific directory
user_dir = output_dir / user_id
user_dir.mkdir(parents=True, exist_ok=True)
# Generate filename (sanitize to avoid issues)
safe_product_name = "".join(c for c in request.product_name if c.isalnum() or c in (' ', '-', '_')).strip()[:30]
filename = f"product_{safe_product_name}_{request.video_type}_{uuid.uuid4().hex[:8]}.mp4"
filename = filename.replace(" ", "_").replace("/", "_").replace("\\", "_")
# Save file
file_path = user_dir / filename
with open(file_path, 'wb') as f:
f.write(video_bytes)
# Check file size (500MB max)
file_size = os.path.getsize(file_path)
if file_size > 500 * 1024 * 1024:
os.remove(file_path)
raise RuntimeError(f"Video file too large: {file_size / (1024*1024):.2f}MB (max 500MB)")
file_url = f"/api/product-marketing/videos/{user_id}/{filename}"
# Add product-specific metadata
result["product_name"] = request.product_name
result["video_type"] = request.video_type
result["source_module"] = "product_marketing"
result["filename"] = filename
result["file_path"] = str(file_path)
result["file_url"] = file_url
result["file_size"] = len(video_bytes)
logger.info(
f"[Product Video] ✅ Product video generated successfully: "
f"cost=${result.get('cost', 0):.2f}, video_url={file_url}"
)
return result
except Exception as e:
logger.error(f"[Product Video] ❌ Error generating product video: {str(e)}", exc_info=True)
raise
async def create_product_demo(
self,
product_name: str,
product_description: str,
user_id: str,
resolution: str = "720p",
duration: int = 10,
audio_base64: Optional[str] = None,
brand_context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Create product demo video (product in use, demonstrating features)."""
request = ProductVideoRequest(
product_name=product_name,
product_description=product_description,
video_type="demo",
resolution=resolution,
duration=duration,
audio_base64=audio_base64,
brand_context=brand_context
)
return await self.generate_product_video(request, user_id)
async def create_product_storytelling(
self,
product_name: str,
product_description: str,
user_id: str,
resolution: str = "720p",
duration: int = 10,
audio_base64: Optional[str] = None,
brand_context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Create product storytelling video (narrative-driven product showcase)."""
request = ProductVideoRequest(
product_name=product_name,
product_description=product_description,
video_type="storytelling",
resolution=resolution,
duration=duration,
audio_base64=audio_base64,
brand_context=brand_context
)
return await self.generate_product_video(request, user_id)
async def create_product_feature_highlight(
self,
product_name: str,
product_description: str,
user_id: str,
resolution: str = "720p",
duration: int = 10,
audio_base64: Optional[str] = None,
brand_context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Create product feature highlight video (close-up shots of key features)."""
request = ProductVideoRequest(
product_name=product_name,
product_description=product_description,
video_type="feature_highlight",
resolution=resolution,
duration=duration,
audio_base64=audio_base64,
brand_context=brand_context
)
return await self.generate_product_video(request, user_id)
async def create_product_launch(
self,
product_name: str,
product_description: str,
user_id: str,
resolution: str = "1080p", # Higher quality for launch
duration: int = 10,
audio_base64: Optional[str] = None,
brand_context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Create product launch video (exciting unveiling, launch event aesthetic)."""
request = ProductVideoRequest(
product_name=product_name,
product_description=product_description,
video_type="launch",
resolution=resolution,
duration=duration,
audio_base64=audio_base64,
brand_context=brand_context
)
return await self.generate_product_video(request, user_id)

View File

@@ -50,6 +50,7 @@ class IntentAwareAnalyzer:
raw_results: Dict[str, Any],
intent: ResearchIntent,
research_persona: Optional[ResearchPersona] = None,
user_id: Optional[str] = None,
) -> IntentDrivenResearchResult:
"""
Analyze raw research results based on user intent.
@@ -84,7 +85,7 @@ class IntentAwareAnalyzer:
result = llm_text_gen(
prompt=prompt,
json_struct=analysis_schema,
user_id=None
user_id=user_id # Required for subscription checking
)
if isinstance(result, dict) and "error" in result:

View File

@@ -151,6 +151,8 @@ Analyze the user's input and infer their research intent. Determine:
11. **CONFIDENCE**: How confident are you in this inference? (0.0-1.0)
- If < 0.7, set needs_clarification to true and provide clarifying_questions
- Provide a brief reason for your confidence level
- If confidence is low, provide an example of what a great input would look like
## OUTPUT FORMAT
@@ -168,6 +170,8 @@ Return a JSON object:
"perspective": "target perspective or null",
"time_sensitivity": "real_time|recent|historical|evergreen",
"confidence": 0.85,
"confidence_reason": "Brief explanation of why this confidence level (e.g., 'User provided clear keywords and context' or 'Input is vague, missing specific goals')",
"great_example": "Example of what a great input would look like for this research (only if confidence < 0.8)",
"needs_clarification": false,
"clarifying_questions": [],
"analysis_summary": "Brief summary of what the user wants"

View File

@@ -39,6 +39,7 @@ class IntentQueryGenerator:
self,
intent: ResearchIntent,
research_persona: Optional[ResearchPersona] = None,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""
Generate targeted research queries based on intent.
@@ -89,7 +90,7 @@ class IntentQueryGenerator:
result = llm_text_gen(
prompt=prompt,
json_struct=query_schema,
user_id=None
user_id=user_id
)
if isinstance(result, dict) and "error" in result:

View File

@@ -51,6 +51,7 @@ class ResearchIntentInference:
competitor_data: Optional[List[Dict]] = None,
industry: Optional[str] = None,
target_audience: Optional[str] = None,
user_id: Optional[str] = None,
) -> IntentInferenceResponse:
"""
Analyze user input and infer their research intent.
@@ -96,13 +97,15 @@ class ResearchIntentInference:
"perspective": {"type": "string"},
"time_sensitivity": {"type": "string"},
"confidence": {"type": "number"},
"confidence_reason": {"type": "string"},
"great_example": {"type": "string"},
"needs_clarification": {"type": "boolean"},
"clarifying_questions": {"type": "array", "items": {"type": "string"}},
"analysis_summary": {"type": "string"}
},
"required": [
"input_type", "primary_question", "purpose", "content_output",
"expected_deliverables", "depth", "confidence", "analysis_summary"
"expected_deliverables", "depth", "confidence", "confidence_reason", "analysis_summary"
]
}
@@ -112,7 +115,7 @@ class ResearchIntentInference:
result = llm_text_gen(
prompt=prompt,
json_struct=intent_schema,
user_id=None
user_id=user_id
)
if isinstance(result, dict) and "error" in result:
@@ -134,6 +137,8 @@ class ResearchIntentInference:
suggested_keywords=self._extract_keywords_from_input(user_input, keywords),
suggested_angles=result.get("focus_areas", []),
quick_options=quick_options,
confidence_reason=result.get("confidence_reason", ""),
great_example=result.get("great_example", ""),
)
logger.info(f"Intent inferred: purpose={intent.purpose}, confidence={intent.confidence}")
@@ -166,7 +171,7 @@ class ResearchIntentInference:
if not expected_deliverables:
expected_deliverables = self._infer_deliverables_from_purpose(purpose)
return ResearchIntent(
intent = ResearchIntent(
primary_question=result.get("primary_question", user_input),
secondary_questions=result.get("secondary_questions", []),
purpose=purpose.value,
@@ -179,9 +184,13 @@ class ResearchIntentInference:
input_type=input_type.value,
original_input=user_input,
confidence=float(result.get("confidence", 0.7)),
confidence_reason=result.get("confidence_reason"),
great_example=result.get("great_example"),
needs_clarification=result.get("needs_clarification", False),
clarifying_questions=result.get("clarifying_questions", []),
)
return intent
def _safe_enum(self, enum_class, value: str, default):
"""Safely convert string to enum, returning default if invalid."""

View File

@@ -0,0 +1,559 @@
"""
Unified Research Analyzer
Combines intent inference, query generation, and parameter optimization
into a single AI call with justifications for each decision.
This reduces 2 LLM calls to 1, improves coherence, and provides
user-friendly justifications for all settings.
Author: ALwrity Team
Version: 1.0
"""
import json
from typing import Dict, Any, List, Optional, Tuple
from loguru import logger
from models.research_intent_models import (
ResearchIntent,
ResearchQuery,
IntentInferenceResponse,
ResearchPurpose,
ContentOutput,
ExpectedDeliverable,
ResearchDepthLevel,
InputType,
)
from models.research_persona_models import ResearchPersona
class UnifiedResearchAnalyzer:
"""
Unified AI-driven analyzer that performs:
1. Intent inference (what user wants)
2. Query generation (how to search)
3. Parameter optimization (Exa/Tavily settings)
All in a single LLM call with justifications.
"""
def __init__(self):
"""Initialize the unified analyzer."""
logger.info("UnifiedResearchAnalyzer initialized")
async def analyze(
self,
user_input: str,
keywords: Optional[List[str]] = None,
research_persona: Optional[ResearchPersona] = None,
competitor_data: Optional[List[Dict]] = None,
industry: Optional[str] = None,
target_audience: Optional[str] = None,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""
Perform unified analysis of user research request.
Returns:
Dict containing:
- intent: ResearchIntent
- queries: List[ResearchQuery]
- exa_config: Dict with settings and justifications
- tavily_config: Dict with settings and justifications
- recommended_provider: str
- provider_justification: str
"""
try:
logger.info(f"Unified analysis for: {user_input[:100]}...")
keywords = keywords or []
# Build the unified prompt
prompt = self._build_unified_prompt(
user_input=user_input,
keywords=keywords,
research_persona=research_persona,
competitor_data=competitor_data,
industry=industry,
target_audience=target_audience,
)
# Define the comprehensive JSON schema
unified_schema = self._build_unified_schema()
# Call LLM (single call for everything)
from services.llm_providers.main_text_generation import llm_text_gen
result = llm_text_gen(
prompt=prompt,
json_struct=unified_schema,
user_id=user_id
)
if isinstance(result, dict) and "error" in result:
logger.error(f"Unified analysis failed: {result.get('error')}")
return self._create_fallback_response(user_input, keywords)
# Parse the unified result
return self._parse_unified_result(result, user_input)
except Exception as e:
logger.error(f"Error in unified analysis: {e}")
return self._create_fallback_response(user_input, keywords or [])
def _build_unified_prompt(
self,
user_input: str,
keywords: List[str],
research_persona: Optional[ResearchPersona] = None,
competitor_data: Optional[List[Dict]] = None,
industry: Optional[str] = None,
target_audience: Optional[str] = None,
) -> str:
"""Build the unified prompt for intent + queries + parameters."""
# Build persona context
persona_context = self._build_persona_context(research_persona, industry, target_audience)
# Build competitor context
competitor_context = self._build_competitor_context(competitor_data)
prompt = f'''You are an expert AI research strategist. Analyze the user's research request and provide a complete research plan including intent understanding, search queries, and optimal API settings.
## USER INPUT
"{user_input}"
{f"KEYWORDS: {', '.join(keywords)}" if keywords else ""}
## USER CONTEXT
{persona_context}
{competitor_context}
## YOUR TASK: Provide a Complete Research Plan
### PART 1: INTENT ANALYSIS
Understand what the user really wants from their research.
### PART 2: SEARCH QUERIES
Generate 4-8 targeted search queries optimized for semantic search.
### PART 3: PROVIDER SETTINGS
Configure Exa and Tavily API parameters with justifications.
### PART 4: GOOGLE TRENDS KEYWORDS (if trends in deliverables)
If "trends" is in expected_deliverables OR purpose is "explore_trends":
- Suggest 1-3 optimized keywords for Google Trends analysis
- These may differ from research queries (trends need broader, searchable terms)
- Consider: What keywords will show meaningful trends over time?
- Consider: What timeframe will show relevant trends? (1 year, 12 months, etc.)
- Consider: What geographic region is most relevant for the user?
- Explain what insights trends will uncover for content generation:
* Search interest trends over time (optimal publication timing)
* Regional interest distribution (audience targeting)
* Related topics for content expansion
* Related queries for FAQ sections
* Rising topics for timely content opportunities
---
## AVAILABLE PROVIDER OPTIONS
### EXA API OPTIONS (Semantic Search Engine)
| Parameter | Options | Description |
|-----------|---------|-------------|
| type | "auto", "neural", "fast", "deep" | "neural" = semantic understanding, "deep" = comprehensive with query expansion |
| category | "company", "research paper", "news", "github", "tweet", "personal site", "pdf", "financial report", "people" | Focus on specific content types |
| numResults | 5-25 | Number of results (10 recommended) |
| includeDomains | string[] | Domains to include (e.g., ["arxiv.org", "nature.com"]) |
| excludeDomains | string[] | Domains to exclude |
| startPublishedDate | ISO date | Filter by publish date (e.g., "2024-01-01T00:00:00.000Z") |
| text | boolean | Include full text content |
| highlights | boolean | Extract key highlights |
| context | boolean | Return as single context string for RAG |
**WHEN TO USE EXA:**
- Semantic understanding needed (finding similar content)
- Academic/research papers
- Company/competitor research
- Deep, comprehensive results
- Historical content
### TAVILY API OPTIONS (AI-Powered Search)
| Parameter | Options | Description |
|-----------|---------|-------------|
| topic | "general", "news", "finance" | Search topic category |
| search_depth | "basic", "advanced" | "advanced" = multiple semantic snippets per URL |
| include_answer | false, true, "basic", "advanced" | AI-generated answer from results |
| include_raw_content | false, true, "markdown", "text" | Raw page content format |
| time_range | "day", "week", "month", "year" | Filter by recency |
| max_results | 5-20 | Number of results |
| include_domains | string[] | Domains to include |
| exclude_domains | string[] | Domains to exclude |
**WHEN TO USE TAVILY:**
- Real-time/current events
- News and trending topics
- Quick facts with AI answers
- Financial data
- Recent time-sensitive content
---
## OUTPUT FORMAT
Return a JSON object with this exact structure:
```json
{{
"intent": {{
"input_type": "keywords|question|goal|mixed",
"primary_question": "The main question to answer",
"secondary_questions": ["question 1", "question 2"],
"purpose": "learn|create_content|make_decision|compare|solve_problem|find_data|explore_trends|validate|generate_ideas",
"content_output": "blog|podcast|video|social_post|newsletter|presentation|report|whitepaper|email|general",
"expected_deliverables": ["key_statistics", "expert_quotes", "case_studies", "trends", "best_practices"],
"depth": "overview|detailed|expert",
"focus_areas": ["area1", "area2"],
"perspective": "target perspective or null",
"time_sensitivity": "real_time|recent|historical|evergreen",
"confidence": 0.85,
"confidence_reason": "Why this confidence level",
"great_example": "Example of better input if confidence < 0.8",
"needs_clarification": false,
"clarifying_questions": [],
"analysis_summary": "Brief summary of research plan"
}},
"queries": [
{{
"query": "Optimized search query string",
"purpose": "key_statistics|expert_quotes|case_studies|trends|etc",
"provider": "exa|tavily",
"priority": 5,
"expected_results": "What we expect to find",
"justification": "Why this query and provider"
}}
],
"enhanced_keywords": ["expanded", "related", "keywords"],
"research_angles": ["Angle 1: ...", "Angle 2: ..."],
"recommended_provider": "exa|tavily",
"provider_justification": "Why this provider is best for this research",
"exa_config": {{
"enabled": true,
"type": "auto|neural|fast|deep",
"type_justification": "Why this search type",
"category": "news|research paper|company|etc or null",
"category_justification": "Why this category or null",
"numResults": 10,
"numResults_justification": "Why this number",
"includeDomains": [],
"includeDomains_justification": "Why these domains or empty",
"startPublishedDate": "2024-01-01T00:00:00.000Z or null",
"date_justification": "Why this date filter or null",
"highlights": true,
"highlights_justification": "Why enable/disable highlights",
"context": true,
"context_justification": "Why enable/disable context string"
}},
"tavily_config": {{
"enabled": true,
"topic": "general|news|finance",
"topic_justification": "Why this topic",
"search_depth": "basic|advanced",
"search_depth_justification": "Why this depth",
"include_answer": "true|false|basic|advanced",
"include_answer_justification": "Why this answer mode",
"time_range": "day|week|month|year|null",
"time_range_justification": "Why this time range or null",
"max_results": 10,
"max_results_justification": "Why this number",
"include_raw_content": "false|true|markdown|text",
"include_raw_content_justification": "Why this content mode"
}},
"trends_config": {{
"enabled": true|false,
"keywords": ["keyword1", "keyword2"],
"keywords_justification": "Why these keywords for trends analysis",
"timeframe": "today 1-y|today 12-m|all",
"timeframe_justification": "Why this timeframe",
"geo": "US|GB|IN|etc",
"geo_justification": "Why this geographic region",
"expected_insights": [
"Search interest trends over the past year",
"Regional interest distribution",
"Related topics for content expansion",
"Related queries for FAQ sections",
"Optimal publication timing based on interest peaks"
]
}}
}}
```
## DECISION RULES
1. **Provider Selection:**
- Use EXA for: academic research, competitor analysis, deep understanding, finding similar content
- Use TAVILY for: news, current events, quick facts, financial data, real-time info
2. **Query Optimization:**
- Include relevant keywords for semantic matching
- Add context words based on deliverables (e.g., "statistics 2024" for key_statistics)
- Match query style to provider (natural language for Exa, keyword-rich for Tavily)
3. **Parameter Selection:**
- ALWAYS provide justification for each parameter choice
- Consider time sensitivity when setting date filters
- Match category/topic to content type
- Use "advanced" depth when quality matters more than speed
4. **Google Trends Keywords (if trends enabled):**
- Suggest 1-3 keywords optimized for trends analysis
- Keywords should be broader than research queries (e.g., "AI marketing" vs "AI marketing tools for small businesses")
- Consider what will show meaningful search interest trends
- Choose timeframe based on content type (12 months for blogs, 1 year for comprehensive)
- Select geo based on user's target audience or industry
- List specific insights trends will uncover
5. **Justifications:**
- Keep justifications concise (1 sentence)
- Explain the "why" not the "what"
- Reference user's intent when relevant
'''
return prompt
def _build_unified_schema(self) -> Dict[str, Any]:
"""Build the JSON schema for unified response."""
return {
"type": "object",
"properties": {
"intent": {
"type": "object",
"properties": {
"input_type": {"type": "string", "enum": ["keywords", "question", "goal", "mixed"]},
"primary_question": {"type": "string"},
"secondary_questions": {"type": "array", "items": {"type": "string"}},
"purpose": {"type": "string"},
"content_output": {"type": "string"},
"expected_deliverables": {"type": "array", "items": {"type": "string"}},
"depth": {"type": "string", "enum": ["overview", "detailed", "expert"]},
"focus_areas": {"type": "array", "items": {"type": "string"}},
"perspective": {"type": "string"},
"time_sensitivity": {"type": "string"},
"confidence": {"type": "number"},
"confidence_reason": {"type": "string"},
"great_example": {"type": "string"},
"needs_clarification": {"type": "boolean"},
"clarifying_questions": {"type": "array", "items": {"type": "string"}},
"analysis_summary": {"type": "string"}
},
"required": ["primary_question", "purpose", "expected_deliverables", "confidence"]
},
"queries": {
"type": "array",
"items": {
"type": "object",
"properties": {
"query": {"type": "string"},
"purpose": {"type": "string"},
"provider": {"type": "string"},
"priority": {"type": "integer"},
"expected_results": {"type": "string"},
"justification": {"type": "string"}
},
"required": ["query", "purpose", "provider", "priority"]
}
},
"enhanced_keywords": {"type": "array", "items": {"type": "string"}},
"research_angles": {"type": "array", "items": {"type": "string"}},
"recommended_provider": {"type": "string"},
"provider_justification": {"type": "string"},
"exa_config": {
"type": "object",
"properties": {
"enabled": {"type": "boolean"},
"type": {"type": "string"},
"type_justification": {"type": "string"},
"category": {"type": "string"},
"category_justification": {"type": "string"},
"numResults": {"type": "integer"},
"numResults_justification": {"type": "string"},
"includeDomains": {"type": "array", "items": {"type": "string"}},
"includeDomains_justification": {"type": "string"},
"startPublishedDate": {"type": "string"},
"date_justification": {"type": "string"},
"highlights": {"type": "boolean"},
"highlights_justification": {"type": "string"},
"context": {"type": "boolean"},
"context_justification": {"type": "string"}
}
},
"tavily_config": {
"type": "object",
"properties": {
"enabled": {"type": "boolean"},
"topic": {"type": "string"},
"topic_justification": {"type": "string"},
"search_depth": {"type": "string"},
"search_depth_justification": {"type": "string"},
"include_answer": {"type": "string"},
"include_answer_justification": {"type": "string"},
"time_range": {"type": "string"},
"time_range_justification": {"type": "string"},
"max_results": {"type": "integer"},
"max_results_justification": {"type": "string"},
"include_raw_content": {"type": "string"},
"include_raw_content_justification": {"type": "string"}
}
},
"trends_config": {
"type": "object",
"properties": {
"enabled": {"type": "boolean"},
"keywords": {"type": "array", "items": {"type": "string"}},
"keywords_justification": {"type": "string"},
"timeframe": {"type": "string"},
"timeframe_justification": {"type": "string"},
"geo": {"type": "string"},
"geo_justification": {"type": "string"},
"expected_insights": {"type": "array", "items": {"type": "string"}}
}
}
},
"required": ["intent", "queries", "recommended_provider", "exa_config", "tavily_config"]
}
def _build_persona_context(
self,
research_persona: Optional[ResearchPersona],
industry: Optional[str],
target_audience: Optional[str],
) -> str:
"""Build persona context section."""
parts = []
if research_persona:
if research_persona.default_industry:
parts.append(f"Industry: {research_persona.default_industry}")
if research_persona.default_target_audience:
parts.append(f"Target Audience: {research_persona.default_target_audience}")
if research_persona.research_angles:
parts.append(f"Preferred Research Angles: {', '.join(research_persona.research_angles[:3])}")
if research_persona.suggested_keywords:
parts.append(f"Relevant Keywords: {', '.join(research_persona.suggested_keywords[:5])}")
else:
if industry:
parts.append(f"Industry: {industry}")
if target_audience:
parts.append(f"Target Audience: {target_audience}")
if not parts:
return "No specific user context available. Use general best practices."
return "\n".join(parts)
def _build_competitor_context(self, competitor_data: Optional[List[Dict]]) -> str:
"""Build competitor context section."""
if not competitor_data:
return ""
competitor_names = [c.get("name", c.get("url", "")) for c in competitor_data[:5]]
if competitor_names:
return f"\nKnown Competitors: {', '.join(competitor_names)}"
return ""
def _parse_unified_result(self, result: Dict[str, Any], user_input: str) -> Dict[str, Any]:
"""Parse the unified LLM result into structured response."""
intent_data = result.get("intent", {})
# Build ResearchIntent
intent = ResearchIntent(
primary_question=intent_data.get("primary_question", user_input),
secondary_questions=intent_data.get("secondary_questions", []),
purpose=intent_data.get("purpose", "learn"),
content_output=intent_data.get("content_output", "general"),
expected_deliverables=intent_data.get("expected_deliverables", ["key_statistics"]),
depth=intent_data.get("depth", "detailed"),
focus_areas=intent_data.get("focus_areas", []),
perspective=intent_data.get("perspective"),
time_sensitivity=intent_data.get("time_sensitivity"),
input_type=intent_data.get("input_type", "keywords"),
original_input=user_input,
confidence=float(intent_data.get("confidence", 0.7)),
confidence_reason=intent_data.get("confidence_reason"),
great_example=intent_data.get("great_example"),
needs_clarification=intent_data.get("needs_clarification", False),
clarifying_questions=intent_data.get("clarifying_questions", []),
)
# Build queries
queries = []
for q in result.get("queries", []):
try:
queries.append(ResearchQuery(
query=q.get("query", ""),
purpose=q.get("purpose", "key_statistics"),
provider=q.get("provider", "exa"),
priority=int(q.get("priority", 3)),
expected_results=q.get("expected_results", ""),
))
except Exception as e:
logger.warning(f"Failed to parse query: {e}")
return {
"success": True,
"intent": intent,
"queries": queries,
"enhanced_keywords": result.get("enhanced_keywords", []),
"research_angles": result.get("research_angles", []),
"recommended_provider": result.get("recommended_provider", "exa"),
"provider_justification": result.get("provider_justification", ""),
"exa_config": result.get("exa_config", {}),
"tavily_config": result.get("tavily_config", {}),
"trends_config": result.get("trends_config", {}), # NEW: Google Trends configuration
"analysis_summary": intent_data.get("analysis_summary", ""),
}
def _create_fallback_response(self, user_input: str, keywords: List[str]) -> Dict[str, Any]:
"""Create fallback response when analysis fails."""
return {
"success": False,
"intent": ResearchIntent(
primary_question=f"What are the key insights about: {user_input}?",
purpose="learn",
content_output="general",
expected_deliverables=["key_statistics", "best_practices"],
depth="detailed",
original_input=user_input,
confidence=0.5,
),
"queries": [
ResearchQuery(
query=user_input,
purpose="key_statistics",
provider="exa",
priority=5,
expected_results="General research results",
)
],
"enhanced_keywords": keywords,
"research_angles": [],
"recommended_provider": "exa",
"provider_justification": "Default fallback to Exa for semantic search",
"exa_config": {
"enabled": True,
"type": "auto",
"type_justification": "Auto mode for balanced results",
"numResults": 10,
"highlights": True,
},
"tavily_config": {
"enabled": True,
"topic": "general",
"search_depth": "advanced",
"include_answer": True,
},
"trends_config": {
"enabled": False, # Disabled in fallback
},
}

View File

@@ -34,39 +34,81 @@ class ResearchPersonaService:
user_id: str
) -> Optional[ResearchPersona]:
"""
Get research persona for user ONLY if it exists in cache.
This method NEVER generates - it only returns cached personas.
Get research persona for user if it exists in database (regardless of cache validity).
This method NEVER generates - it only returns existing personas.
Use this for config endpoints to avoid triggering rate limit checks.
Note: Returns persona even if cache is expired - cache validity only matters for regeneration.
Args:
user_id: User ID (Clerk string)
Returns:
ResearchPersona if cached and valid, None otherwise
ResearchPersona if exists in database, None otherwise
"""
try:
# Get persona data record
persona_data = self._get_persona_data_record(user_id)
if not persona_data:
logger.debug(f"No persona data found for user {user_id}")
logger.debug(f"[get_cached_only] No persona data record found for user {user_id}")
return None
# Only return if cache is valid and persona exists
if self.is_cache_valid(persona_data) and persona_data.research_persona:
# Check if research_persona field exists and is not None/empty
# Handle cases where it might be None, empty dict {}, or empty string ""
research_persona_raw = persona_data.research_persona
has_persona = (
research_persona_raw is not None
and research_persona_raw != {}
and research_persona_raw != ""
and (isinstance(research_persona_raw, dict) and len(research_persona_raw) > 0)
)
logger.info(
f"[get_cached_only] Checking research persona for user {user_id}: "
f"persona_data exists=True, research_persona_raw={research_persona_raw is not None}, "
f"research_persona type={type(research_persona_raw)}, "
f"has_persona={has_persona}, "
f"generated_at={persona_data.research_persona_generated_at}"
)
# Return persona if it exists, regardless of cache validity
# Cache validity only matters when deciding whether to regenerate
if has_persona:
try:
logger.debug(f"Returning cached research persona for user {user_id}")
return ResearchPersona(**persona_data.research_persona)
cache_valid = self.is_cache_valid(persona_data)
cache_status = "valid" if cache_valid else "expired"
logger.info(
f"[get_cached_only] ✅ Returning research persona for user {user_id} "
f"(cache: {cache_status}, generated_at: {persona_data.research_persona_generated_at})"
)
# Ensure we're passing a dict to ResearchPersona
if not isinstance(research_persona_raw, dict):
logger.error(f"[get_cached_only] research_persona_raw is not a dict: {type(research_persona_raw)}")
return None
parsed_persona = ResearchPersona(**research_persona_raw)
logger.info(
f"[get_cached_only] ✅ Successfully parsed persona for user {user_id}: "
f"industry={parsed_persona.default_industry}, "
f"target_audience={parsed_persona.default_target_audience}"
)
return parsed_persona
except Exception as e:
logger.warning(f"Failed to parse cached research persona: {e}")
logger.error(f"[get_cached_only] ❌ Failed to parse research persona for user {user_id}: {e}", exc_info=True)
logger.debug(
f"[get_cached_only] Persona data details: "
f"type={type(research_persona_raw)}, "
f"is_dict={isinstance(research_persona_raw, dict)}, "
f"value sample: {str(research_persona_raw)[:500] if research_persona_raw else 'None'}"
)
return None
# Cache invalid or persona missing - return None (don't generate)
logger.debug(f"No valid cached research persona for user {user_id}")
# Persona doesn't exist in database
logger.info(f"[get_cached_only] ⚠️ No research persona found in database for user {user_id}")
return None
except Exception as e:
logger.error(f"Error getting cached research persona for user {user_id}: {e}")
logger.error(f"[get_cached_only] ❌ Error getting research persona for user {user_id}: {e}", exc_info=True)
return None
def get_or_generate(
@@ -92,25 +134,40 @@ class ResearchPersonaService:
logger.warning(f"No persona data found for user {user_id}, cannot generate research persona")
return None
# Check cache if not forcing refresh
if not force_refresh and self.is_cache_valid(persona_data):
if persona_data.research_persona:
# Check if persona exists in database
if persona_data.research_persona:
# Persona exists - check if we should return it or regenerate
cache_valid = self.is_cache_valid(persona_data)
if not force_refresh and cache_valid:
# Cache is valid - return existing persona
logger.info(f"Using cached research persona for user {user_id}")
try:
return ResearchPersona(**persona_data.research_persona)
except Exception as e:
logger.warning(f"Failed to parse cached research persona: {e}, regenerating...")
# Fall through to regeneration
# Fall through to regeneration if parsing fails
elif not force_refresh:
# Persona exists but cache expired - return it anyway (don't regenerate unless forced)
logger.info(f"Research persona exists for user {user_id} but cache expired - returning existing persona (use force_refresh=true to regenerate)")
try:
return ResearchPersona(**persona_data.research_persona)
except Exception as e:
logger.warning(f"Failed to parse existing research persona: {e}, regenerating...")
# Fall through to regeneration if parsing fails
else:
logger.info(f"Research persona missing for user {user_id}, generating...")
else:
if force_refresh:
# force_refresh=True - regenerate even though persona exists
logger.info(f"Forcing refresh of research persona for user {user_id}")
else:
logger.info(f"Cache expired for user {user_id}, regenerating...")
else:
# Persona doesn't exist - generate new one
logger.info(f"Research persona missing for user {user_id}, generating...")
# Generate new research persona
# Generate new research persona (only reaches here if:
# 1. Persona doesn't exist, OR
# 2. force_refresh=True, OR
# 3. Parsing of existing persona failed
try:
logger.info(f"Generating research persona for user {user_id}")
research_persona = self.generate_research_persona(user_id)
except HTTPException:
# Re-raise HTTPExceptions (e.g., 429 subscription limit) so they propagate to API

View File

@@ -0,0 +1,9 @@
"""
Google Trends Research Service
Provides Google Trends data integration for the Research Engine.
"""
from .google_trends_service import GoogleTrendsService
__all__ = ['GoogleTrendsService']

View File

@@ -0,0 +1,380 @@
"""
Google Trends Service
Provides Google Trends data integration for the Research Engine.
Handles rate limiting, caching, error handling, and data serialization.
Author: ALwrity Team
Version: 1.0
"""
import asyncio
from typing import List, Dict, Any, Optional
from datetime import datetime, timedelta
from loguru import logger
import pandas as pd
try:
from pytrends.request import TrendReq
PYTrends_AVAILABLE = True
except ImportError:
PYTrends_AVAILABLE = False
logger.warning("pytrends not installed. Google Trends features will be unavailable.")
from .rate_limiter import RateLimiter
class GoogleTrendsService:
"""
Service for fetching and analyzing Google Trends data.
Features:
- Interest over time
- Interest by region
- Related topics
- Related queries
- Rate limiting (1 req/sec)
- Caching (24-hour TTL)
- Async support
- Error handling with retry logic
"""
def __init__(self):
"""Initialize the Google Trends service."""
if not PYTrends_AVAILABLE:
raise RuntimeError("pytrends library is required. Install with: pip install pytrends")
self.rate_limiter = RateLimiter(max_calls=1, period=1.0) # 1 request per second
self.cache: Dict[str, Dict[str, Any]] = {} # Simple in-memory cache
self.cache_ttl = timedelta(hours=24) # 24-hour cache
logger.info("GoogleTrendsService initialized")
async def analyze_trends(
self,
keywords: List[str],
timeframe: str = "today 12-m",
geo: str = "US",
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""
Comprehensive trends analysis.
Fetches all trends data in a single optimized call:
- Interest over time
- Interest by region
- Related topics (top & rising)
- Related queries (top & rising)
Args:
keywords: List of keywords to analyze (1-5 keywords recommended)
timeframe: Timeframe string (e.g., "today 12-m", "today 1-y", "all")
geo: Country code (e.g., "US", "GB", "IN")
user_id: User ID for subscription checks (optional for now)
Returns:
Dict containing all trends data in serializable format
Raises:
ValueError: If keywords list is empty or too long
RuntimeError: If pytrends is not available or API fails
"""
if not keywords:
raise ValueError("Keywords list cannot be empty")
if len(keywords) > 5:
logger.warning(f"Too many keywords ({len(keywords)}), using first 5")
keywords = keywords[:5]
# Check cache first
cache_key = self._build_cache_key(keywords, timeframe, geo)
cached_data = self._get_from_cache(cache_key)
if cached_data:
logger.info(f"Returning cached trends data for: {keywords}")
return {**cached_data, "cached": True}
# Rate limit
await self.rate_limiter.acquire()
try:
logger.info(f"Fetching Google Trends data for: {keywords} (timeframe: {timeframe}, geo: {geo})")
# Initialize pytrends (sync operation, run in thread)
pytrends = await asyncio.to_thread(
self._initialize_pytrends,
keywords,
timeframe,
geo
)
# Fetch all data in parallel (pytrends methods are sync, so use to_thread)
interest_over_time_task = asyncio.to_thread(
lambda: self._safe_interest_over_time(pytrends)
)
interest_by_region_task = asyncio.to_thread(
lambda: self._safe_interest_by_region(pytrends)
)
related_topics_task = asyncio.to_thread(
lambda: self._safe_related_topics(pytrends, keywords)
)
related_queries_task = asyncio.to_thread(
lambda: self._safe_related_queries(pytrends, keywords)
)
# Wait for all tasks
interest_over_time, interest_by_region, related_topics, related_queries = await asyncio.gather(
interest_over_time_task,
interest_by_region_task,
related_topics_task,
related_queries_task,
return_exceptions=True
)
# Handle exceptions
if isinstance(interest_over_time, Exception):
logger.error(f"Interest over time failed: {interest_over_time}")
interest_over_time = []
if isinstance(interest_by_region, Exception):
logger.error(f"Interest by region failed: {interest_by_region}")
interest_by_region = []
if isinstance(related_topics, Exception):
logger.error(f"Related topics failed: {related_topics}")
related_topics = {"top": [], "rising": []}
if isinstance(related_queries, Exception):
logger.error(f"Related queries failed: {related_queries}")
related_queries = {"top": [], "rising": []}
# Build result
result = {
"interest_over_time": interest_over_time,
"interest_by_region": interest_by_region,
"related_topics": related_topics,
"related_queries": related_queries,
"timeframe": timeframe,
"geo": geo,
"keywords": keywords,
"timestamp": datetime.utcnow().isoformat(),
"cached": False
}
# Cache result
self._save_to_cache(cache_key, result)
logger.info(f"Google Trends data fetched successfully: {len(interest_over_time)} time points, {len(interest_by_region)} regions")
return result
except Exception as e:
logger.error(f"Google Trends analysis failed: {e}")
# Return fallback response
return self._create_fallback_response(keywords, timeframe, geo, str(e))
def _initialize_pytrends(
self,
keywords: List[str],
timeframe: str,
geo: str
) -> TrendReq:
"""Initialize pytrends and build payload (sync operation)."""
pytrends = TrendReq(hl='en-US', tz=360)
pytrends.build_payload(kw_list=keywords, timeframe=timeframe, geo=geo)
return pytrends
def _safe_interest_over_time(self, pytrends: TrendReq) -> List[Dict[str, Any]]:
"""Safely fetch interest over time data."""
try:
df = pytrends.interest_over_time()
if df.empty:
return []
return self._format_dataframe(df.reset_index())
except Exception as e:
logger.error(f"Error fetching interest over time: {e}")
return []
def _safe_interest_by_region(self, pytrends: TrendReq) -> List[Dict[str, Any]]:
"""Safely fetch interest by region data."""
try:
df = pytrends.interest_by_region(resolution='COUNTRY', inc_low_vol=True, inc_geo_code=False)
if df.empty:
return []
return self._format_dataframe(df.reset_index())
except Exception as e:
logger.error(f"Error fetching interest by region: {e}")
return []
def _safe_related_topics(
self,
pytrends: TrendReq,
keywords: List[str]
) -> Dict[str, List[Dict[str, Any]]]:
"""Safely fetch related topics."""
try:
topics_data = pytrends.related_topics()
result = {"top": [], "rising": []}
for keyword in keywords:
if keyword in topics_data and isinstance(topics_data[keyword], dict):
keyword_topics = topics_data[keyword]
if "top" in keyword_topics and not keyword_topics["top"].empty:
top_df = keyword_topics["top"]
# Select relevant columns
if "topic_title" in top_df.columns and "value" in top_df.columns:
top_data = top_df[["topic_title", "value"]].to_dict('records')
result["top"].extend(top_data)
if "rising" in keyword_topics and not keyword_topics["rising"].empty:
rising_df = keyword_topics["rising"]
if "topic_title" in rising_df.columns and "value" in rising_df.columns:
rising_data = rising_df[["topic_title", "value"]].to_dict('records')
result["rising"].extend(rising_data)
return result
except Exception as e:
logger.error(f"Error fetching related topics: {e}")
return {"top": [], "rising": []}
def _safe_related_queries(
self,
pytrends: TrendReq,
keywords: List[str]
) -> Dict[str, List[Dict[str, Any]]]:
"""Safely fetch related queries."""
try:
queries_data = pytrends.related_queries()
result = {"top": [], "rising": []}
for keyword in keywords:
if keyword in queries_data and isinstance(queries_data[keyword], dict):
keyword_queries = queries_data[keyword]
if "top" in keyword_queries and not keyword_queries["top"].empty:
top_df = keyword_queries["top"]
result["top"].extend(top_df.to_dict('records'))
if "rising" in keyword_queries and not keyword_queries["rising"].empty:
rising_df = keyword_queries["rising"]
result["rising"].extend(rising_df.to_dict('records'))
return result
except Exception as e:
logger.error(f"Error fetching related queries: {e}")
return {"top": [], "rising": []}
def _format_dataframe(self, df: pd.DataFrame) -> List[Dict[str, Any]]:
"""Convert DataFrame to list of dicts (serializable format)."""
if df.empty:
return []
# Convert datetime columns to strings
for col in df.columns:
if pd.api.types.is_datetime64_any_dtype(df[col]):
df[col] = df[col].astype(str)
# Convert to dict records
return df.to_dict('records')
def _build_cache_key(self, keywords: List[str], timeframe: str, geo: str) -> str:
"""Build cache key from parameters."""
keywords_str = ":".join(sorted(keywords))
return f"google_trends:{keywords_str}:{timeframe}:{geo}"
def _get_from_cache(self, cache_key: str) -> Optional[Dict[str, Any]]:
"""Get data from cache if not expired."""
if cache_key not in self.cache:
return None
cached_entry = self.cache[cache_key]
cached_time = datetime.fromisoformat(cached_entry.get("timestamp", ""))
if datetime.utcnow() - cached_time > self.cache_ttl:
# Expired, remove from cache
del self.cache[cache_key]
return None
# Return cached data (without cached flag)
result = {**cached_entry}
result.pop("cached", None)
return result
def _save_to_cache(self, cache_key: str, data: Dict[str, Any]):
"""Save data to cache."""
# Store with timestamp
cache_entry = {
**data,
"cached_at": datetime.utcnow().isoformat()
}
self.cache[cache_key] = cache_entry
# Clean up old cache entries periodically
if len(self.cache) > 100: # Limit cache size
self._cleanup_cache()
def _cleanup_cache(self):
"""Remove expired cache entries."""
now = datetime.utcnow()
expired_keys = []
for key, entry in self.cache.items():
cached_time = datetime.fromisoformat(entry.get("cached_at", entry.get("timestamp", "")))
if now - cached_time > self.cache_ttl:
expired_keys.append(key)
for key in expired_keys:
del self.cache[key]
logger.debug(f"Cleaned up {len(expired_keys)} expired cache entries")
def _create_fallback_response(
self,
keywords: List[str],
timeframe: str,
geo: str,
error_message: str
) -> Dict[str, Any]:
"""Create fallback response when trends analysis fails."""
return {
"interest_over_time": [],
"interest_by_region": [],
"related_topics": {"top": [], "rising": []},
"related_queries": {"top": [], "rising": []},
"timeframe": timeframe,
"geo": geo,
"keywords": keywords,
"timestamp": datetime.utcnow().isoformat(),
"cached": False,
"error": error_message
}
async def get_trending_searches(
self,
country: str = "united_states",
user_id: Optional[str] = None
) -> List[str]:
"""
Get current trending searches for a country.
Args:
country: Country name (e.g., "united_states", "united_kingdom")
user_id: User ID for subscription checks
Returns:
List of trending search terms
"""
await self.rate_limiter.acquire()
try:
pytrends = TrendReq(hl='en-US', tz=360)
trending_df = await asyncio.to_thread(
lambda: pytrends.trending_searches(pn=country)
)
if trending_df.empty:
return []
# Return as list of strings
return trending_df[0].tolist() if len(trending_df.columns) > 0 else []
except Exception as e:
logger.error(f"Error fetching trending searches: {e}")
return []

View File

@@ -0,0 +1,57 @@
"""
Rate Limiter for Google Trends API
Ensures we don't exceed Google Trends rate limits (1 request per second).
"""
import asyncio
from time import time
from collections import deque
from loguru import logger
class RateLimiter:
"""
Simple rate limiter for Google Trends API.
Limits requests to max_calls per period (in seconds).
"""
def __init__(self, max_calls: int = 1, period: float = 1.0):
"""
Initialize rate limiter.
Args:
max_calls: Maximum number of calls allowed
period: Time period in seconds
"""
self.max_calls = max_calls
self.period = period
self.calls = deque()
self._lock = asyncio.Lock()
async def acquire(self):
"""
Acquire permission to make a request.
Will wait if rate limit would be exceeded.
"""
async with self._lock:
now = time()
# Remove old calls outside the period
while self.calls and self.calls[0] < now - self.period:
self.calls.popleft()
# If at limit, wait until oldest call expires
if len(self.calls) >= self.max_calls:
sleep_time = self.period - (now - self.calls[0])
if sleep_time > 0:
logger.debug(f"Rate limit reached, waiting {sleep_time:.2f}s")
await asyncio.sleep(sleep_time)
# Recursively try again after waiting
return await self.acquire()
# Record this call
self.calls.append(time())
logger.debug(f"Rate limit check passed, {len(self.calls)}/{self.max_calls} calls in period")

View File

@@ -0,0 +1,557 @@
"""
Edit Studio Service - Video editing operations.
Phase 1: Basic FFmpeg operations (Trim/Cut, Speed Control, Stabilization)
Phase 2: Text Overlay & Captions, Audio Enhancement, Noise Reduction
Phase 3: AI Features (Background Replacement, Object Removal, Color Grading)
"""
import asyncio
import logging
import subprocess
import tempfile
import uuid
from pathlib import Path
from typing import Any, Dict, Optional
from fastapi import HTTPException
from backend.services.video_studio.video_processors import (
trim_video,
adjust_speed,
)
logger = logging.getLogger(__name__)
class EditService:
"""Service for video editing operations."""
def __init__(self):
logger.info("[EditService] Service initialized")
def calculate_cost(self, edit_type: str, duration: float = 10.0) -> float:
"""Calculate cost for video editing operation. FFmpeg operations are free."""
return 0.0
async def trim_video(
self,
video_data: bytes,
start_time: float = 0.0,
end_time: Optional[float] = None,
max_duration: Optional[float] = None,
trim_mode: str = "beginning",
user_id: str = None,
) -> Dict[str, Any]:
"""Trim video to specified duration or time range."""
try:
logger.info(f"[EditService] Video trim: user={user_id}, start={start_time}, end={end_time}")
processed_video_bytes = await asyncio.to_thread(
trim_video,
video_bytes=video_data,
start_time=start_time,
end_time=end_time,
max_duration=max_duration,
trim_mode=trim_mode,
)
from backend.services.content_assets.content_asset_service import ContentAssetService
from backend.database.database import get_db
db_gen = get_db()
db = next(db_gen)
try:
asset_service = ContentAssetService(db)
filename = f"edited_trim_{uuid.uuid4().hex[:8]}.mp4"
asset_result = asset_service.save_video_asset(
user_id=user_id,
video_data=processed_video_bytes,
filename=filename,
asset_type="video_edit",
metadata={"edit_type": "trim", "start_time": start_time, "end_time": end_time},
)
return {
"success": True,
"video_url": asset_result.get("url"),
"asset_id": asset_result.get("asset_id"),
"cost": 0.0,
"edit_type": "trim",
"metadata": {"start_time": start_time, "end_time": end_time},
}
finally:
db.close()
except Exception as e:
logger.error(f"[EditService] Video trim failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Video trimming failed: {str(e)}")
async def adjust_speed(
self,
video_data: bytes,
speed_factor: float,
user_id: str = None,
) -> Dict[str, Any]:
"""Adjust video playback speed."""
try:
logger.info(f"[EditService] Speed adjustment: user={user_id}, factor={speed_factor}")
if speed_factor <= 0:
raise HTTPException(status_code=400, detail="Speed factor must be greater than 0")
if speed_factor > 4.0:
raise HTTPException(status_code=400, detail="Speed factor cannot exceed 4.0")
processed_video_bytes = await asyncio.to_thread(
adjust_speed,
video_bytes=video_data,
speed_factor=speed_factor,
)
from backend.services.content_assets.content_asset_service import ContentAssetService
from backend.database.database import get_db
db_gen = get_db()
db = next(db_gen)
try:
asset_service = ContentAssetService(db)
filename = f"edited_speed_{uuid.uuid4().hex[:8]}.mp4"
asset_result = asset_service.save_video_asset(
user_id=user_id,
video_data=processed_video_bytes,
filename=filename,
asset_type="video_edit",
metadata={"edit_type": "speed", "speed_factor": speed_factor},
)
return {
"success": True,
"video_url": asset_result.get("url"),
"asset_id": asset_result.get("asset_id"),
"cost": 0.0,
"edit_type": "speed",
"metadata": {"speed_factor": speed_factor},
}
finally:
db.close()
except Exception as e:
logger.error(f"[EditService] Speed adjustment failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Speed adjustment failed: {str(e)}")
async def stabilize_video(
self,
video_data: bytes,
smoothing: int = 10,
user_id: str = None,
) -> Dict[str, Any]:
"""Stabilize video using FFmpeg vidstab."""
try:
logger.info(f"[EditService] Stabilization: user={user_id}, smoothing={smoothing}")
smoothing = max(1, min(100, smoothing))
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as input_file:
input_file.write(video_data)
input_path = input_file.name
transforms_file = tempfile.NamedTemporaryFile(suffix=".trf", delete=False, delete_on_close=False)
transforms_path = transforms_file.name
transforms_file.close()
output_path = None
try:
detect_cmd = [
"ffmpeg", "-i", input_path,
"-vf", f"vidstabdetect=stepsize=6:shakiness=10:accuracy=15:result={transforms_path}",
"-f", "null", "-"
]
subprocess.run(detect_cmd, capture_output=True, text=True, timeout=300)
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as output_file:
output_path = output_file.name
transform_cmd = [
"ffmpeg", "-i", input_path,
"-vf", f"vidstabtransform=input={transforms_path}:smoothing={smoothing}:zoom=1:optzoom=1",
"-c:v", "libx264", "-preset", "medium", "-crf", "23",
"-c:a", "copy", "-y", output_path
]
result = subprocess.run(transform_cmd, capture_output=True, text=True, timeout=600)
if result.returncode != 0:
raise HTTPException(status_code=500, detail=f"Stabilization failed: {result.stderr}")
with open(output_path, "rb") as f:
processed_video_bytes = f.read()
from backend.services.content_assets.content_asset_service import ContentAssetService
from backend.database.database import get_db
db_gen = get_db()
db = next(db_gen)
try:
asset_service = ContentAssetService(db)
filename = f"edited_stabilized_{uuid.uuid4().hex[:8]}.mp4"
asset_result = asset_service.save_video_asset(
user_id=user_id,
video_data=processed_video_bytes,
filename=filename,
asset_type="video_edit",
metadata={"edit_type": "stabilize", "smoothing": smoothing},
)
return {
"success": True,
"video_url": asset_result.get("url"),
"asset_id": asset_result.get("asset_id"),
"cost": 0.0,
"edit_type": "stabilize",
"metadata": {"smoothing": smoothing},
}
finally:
db.close()
finally:
Path(input_path).unlink(missing_ok=True)
Path(transforms_path).unlink(missing_ok=True)
if output_path:
Path(output_path).unlink(missing_ok=True)
except subprocess.TimeoutExpired:
raise HTTPException(status_code=504, detail="Stabilization timed out")
except Exception as e:
logger.error(f"[EditService] Stabilization failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Stabilization failed: {str(e)}")
# Phase 2: Text and Audio operations
async def add_text_overlay(
self,
video_data: bytes,
text: str,
position: str = "center",
font_size: int = 48,
font_color: str = "white",
background_color: str = "black@0.5",
start_time: float = 0.0,
end_time: Optional[float] = None,
user_id: str = None,
) -> Dict[str, Any]:
"""Add text overlay to video using FFmpeg drawtext filter."""
try:
logger.info(f"[EditService] Text overlay: user={user_id}, text='{text[:30]}...'")
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as input_file:
input_file.write(video_data)
input_path = input_file.name
output_path = None
try:
position_map = {
"top": "(w-text_w)/2:50",
"center": "(w-text_w)/2:(h-text_h)/2",
"bottom": "(w-text_w)/2:h-text_h-50",
"top-left": "50:50",
"top-right": "w-text_w-50:50",
"bottom-left": "50:h-text_h-50",
"bottom-right": "w-text_w-50:h-text_h-50",
}
pos_expr = position_map.get(position, position_map["center"])
escaped_text = text.replace("'", "'\\''").replace(":", "\\:")
drawtext_filter = (
f"drawtext=text='{escaped_text}':"
f"fontsize={font_size}:fontcolor={font_color}:"
f"x={pos_expr.split(':')[0]}:y={pos_expr.split(':')[1]}:"
f"box=1:boxcolor={background_color}:boxborderw=10"
)
if start_time > 0 or end_time is not None:
enable_expr = f"between(t,{start_time},{end_time if end_time else 9999})"
drawtext_filter += f":enable='{enable_expr}'"
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as output_file:
output_path = output_file.name
cmd = [
"ffmpeg", "-i", input_path, "-vf", drawtext_filter,
"-c:v", "libx264", "-preset", "medium", "-crf", "23",
"-c:a", "copy", "-y", output_path
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
if result.returncode != 0:
raise HTTPException(status_code=500, detail=f"Text overlay failed: {result.stderr}")
with open(output_path, "rb") as f:
processed_video_bytes = f.read()
from backend.services.content_assets.content_asset_service import ContentAssetService
from backend.database.database import get_db
db_gen = get_db()
db = next(db_gen)
try:
asset_service = ContentAssetService(db)
filename = f"edited_text_{uuid.uuid4().hex[:8]}.mp4"
asset_result = asset_service.save_video_asset(
user_id=user_id,
video_data=processed_video_bytes,
filename=filename,
asset_type="video_edit",
metadata={"edit_type": "text_overlay", "text": text[:100], "position": position},
)
return {
"success": True,
"video_url": asset_result.get("url"),
"asset_id": asset_result.get("asset_id"),
"cost": 0.0,
"edit_type": "text_overlay",
"metadata": {"text": text[:100], "position": position, "font_size": font_size},
}
finally:
db.close()
finally:
Path(input_path).unlink(missing_ok=True)
if output_path:
Path(output_path).unlink(missing_ok=True)
except Exception as e:
logger.error(f"[EditService] Text overlay failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Text overlay failed: {str(e)}")
async def adjust_volume(
self,
video_data: bytes,
volume_factor: float,
user_id: str = None,
) -> Dict[str, Any]:
"""Adjust video audio volume using FFmpeg."""
try:
logger.info(f"[EditService] Volume adjustment: user={user_id}, factor={volume_factor}")
if volume_factor < 0:
raise HTTPException(status_code=400, detail="Volume factor must be non-negative")
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as input_file:
input_file.write(video_data)
input_path = input_file.name
output_path = None
try:
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as output_file:
output_path = output_file.name
cmd = [
"ffmpeg", "-i", input_path,
"-af", f"volume={volume_factor}",
"-c:v", "copy", "-c:a", "aac", "-y", output_path
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
if result.returncode != 0:
raise HTTPException(status_code=500, detail=f"Volume adjustment failed: {result.stderr}")
with open(output_path, "rb") as f:
processed_video_bytes = f.read()
from backend.services.content_assets.content_asset_service import ContentAssetService
from backend.database.database import get_db
db_gen = get_db()
db = next(db_gen)
try:
asset_service = ContentAssetService(db)
filename = f"edited_volume_{uuid.uuid4().hex[:8]}.mp4"
asset_result = asset_service.save_video_asset(
user_id=user_id,
video_data=processed_video_bytes,
filename=filename,
asset_type="video_edit",
metadata={"edit_type": "volume", "volume_factor": volume_factor},
)
return {
"success": True,
"video_url": asset_result.get("url"),
"asset_id": asset_result.get("asset_id"),
"cost": 0.0,
"edit_type": "volume",
"metadata": {"volume_factor": volume_factor},
}
finally:
db.close()
finally:
Path(input_path).unlink(missing_ok=True)
if output_path:
Path(output_path).unlink(missing_ok=True)
except Exception as e:
logger.error(f"[EditService] Volume adjustment failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Volume adjustment failed: {str(e)}")
async def normalize_audio(
self,
video_data: bytes,
target_level: float = -14.0,
user_id: str = None,
) -> Dict[str, Any]:
"""Normalize audio levels using FFmpeg loudnorm filter (EBU R128)."""
try:
logger.info(f"[EditService] Audio normalization: user={user_id}, level={target_level} LUFS")
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as input_file:
input_file.write(video_data)
input_path = input_file.name
output_path = None
try:
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as output_file:
output_path = output_file.name
cmd = [
"ffmpeg", "-i", input_path,
"-af", f"loudnorm=I={target_level}:TP=-1.5:LRA=11",
"-c:v", "copy", "-c:a", "aac", "-y", output_path
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
if result.returncode != 0:
raise HTTPException(status_code=500, detail=f"Audio normalization failed: {result.stderr}")
with open(output_path, "rb") as f:
processed_video_bytes = f.read()
from backend.services.content_assets.content_asset_service import ContentAssetService
from backend.database.database import get_db
db_gen = get_db()
db = next(db_gen)
try:
asset_service = ContentAssetService(db)
filename = f"edited_normalized_{uuid.uuid4().hex[:8]}.mp4"
asset_result = asset_service.save_video_asset(
user_id=user_id,
video_data=processed_video_bytes,
filename=filename,
asset_type="video_edit",
metadata={"edit_type": "normalize", "target_level": target_level},
)
return {
"success": True,
"video_url": asset_result.get("url"),
"asset_id": asset_result.get("asset_id"),
"cost": 0.0,
"edit_type": "normalize",
"metadata": {"target_level": target_level},
}
finally:
db.close()
finally:
Path(input_path).unlink(missing_ok=True)
if output_path:
Path(output_path).unlink(missing_ok=True)
except Exception as e:
logger.error(f"[EditService] Audio normalization failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Audio normalization failed: {str(e)}")
async def reduce_noise(
self,
video_data: bytes,
noise_reduction_strength: float = 0.5,
user_id: str = None,
) -> Dict[str, Any]:
"""Reduce audio noise using FFmpeg's anlmdn filter."""
try:
logger.info(f"[EditService] Noise reduction: user={user_id}, strength={noise_reduction_strength}")
strength = max(0.0, min(1.0, noise_reduction_strength))
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as input_file:
input_file.write(video_data)
input_path = input_file.name
output_path = None
try:
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as output_file:
output_path = output_file.name
sigma = 0.0001 + (strength * 0.005)
cmd = [
"ffmpeg", "-i", input_path,
"-af", f"anlmdn=s={sigma}:p=0.002:r=0.002",
"-c:v", "copy", "-c:a", "aac", "-y", output_path
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
if result.returncode != 0:
# Fallback to highpass/lowpass
cmd = [
"ffmpeg", "-i", input_path,
"-af", "highpass=f=80,lowpass=f=12000",
"-c:v", "copy", "-c:a", "aac", "-y", output_path
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
if result.returncode != 0:
raise HTTPException(status_code=500, detail=f"Noise reduction failed: {result.stderr}")
with open(output_path, "rb") as f:
processed_video_bytes = f.read()
from backend.services.content_assets.content_asset_service import ContentAssetService
from backend.database.database import get_db
db_gen = get_db()
db = next(db_gen)
try:
asset_service = ContentAssetService(db)
filename = f"edited_denoised_{uuid.uuid4().hex[:8]}.mp4"
asset_result = asset_service.save_video_asset(
user_id=user_id,
video_data=processed_video_bytes,
filename=filename,
asset_type="video_edit",
metadata={"edit_type": "noise_reduction", "strength": strength},
)
return {
"success": True,
"video_url": asset_result.get("url"),
"asset_id": asset_result.get("asset_id"),
"cost": 0.0,
"edit_type": "noise_reduction",
"metadata": {"strength": strength},
}
finally:
db.close()
finally:
Path(input_path).unlink(missing_ok=True)
if output_path:
Path(output_path).unlink(missing_ok=True)
except Exception as e:
logger.error(f"[EditService] Noise reduction failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Noise reduction failed: {str(e)}")

View File

@@ -0,0 +1,9 @@
"""
Video generation generator for WaveSpeed API.
Modular implementation with separate modules for different video operations.
"""
from .generator import VideoGenerator
__all__ = ["VideoGenerator"]

View File

@@ -0,0 +1,244 @@
"""
Video audio generation operations.
"""
import requests
from typing import Optional, Callable
from fastapi import HTTPException
from utils.logger_utils import get_service_logger
from .base import VideoBase
logger = get_service_logger("wavespeed.generators.video.audio")
class VideoAudio(VideoBase):
"""Video audio generation operations."""
def hunyuan_video_foley(
self,
video: str, # Base64-encoded video or URL
prompt: Optional[str] = None, # Optional text prompt describing desired sounds
seed: int = -1, # Random seed (-1 for random)
enable_sync_mode: bool = False,
timeout: int = 300,
progress_callback: Optional[Callable[[float, str], None]] = None,
) -> bytes:
"""
Generate realistic Foley and ambient audio from video using Hunyuan Video Foley.
Args:
video: Base64-encoded video data URI or public URL (source video)
prompt: Optional text prompt describing desired sounds (e.g., "ocean waves, seagulls")
seed: Random seed for reproducibility (-1 for random)
enable_sync_mode: If True, wait for result and return it directly
timeout: Request timeout in seconds (default: 300)
progress_callback: Optional callback function(progress: float, message: str) for progress updates
Returns:
bytes: Video with generated audio
Raises:
HTTPException: If the audio generation fails
"""
model_path = "wavespeed-ai/hunyuan-video-foley"
url = f"{self.base_url}/{model_path}"
# Build payload
payload = {
"video": video,
"seed": seed,
}
if prompt:
payload["prompt"] = prompt
logger.info(
f"[WaveSpeed] Hunyuan Video Foley request via {url} "
f"(has_prompt={prompt is not None}, seed={seed})"
)
# Submit the task
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
if response.status_code != 200:
logger.error(f"[WaveSpeed] Hunyuan Video Foley submission failed: {response.status_code} {response.text}")
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed Hunyuan Video Foley submission failed",
"status_code": response.status_code,
"response": response.text,
},
)
response_json = response.json()
data = response_json.get("data") or response_json
prediction_id = data.get("id")
if not prediction_id:
logger.error(f"[WaveSpeed] No prediction ID in Hunyuan Video Foley response: {response.text}")
raise HTTPException(
status_code=502,
detail="WaveSpeed Hunyuan Video Foley response missing prediction id",
)
logger.info(f"[WaveSpeed] Hunyuan Video Foley task submitted: {prediction_id}")
if enable_sync_mode:
result = self.polling.poll_until_complete(
prediction_id,
timeout_seconds=timeout,
interval_seconds=2.0,
progress_callback=progress_callback,
)
outputs = result.get("outputs") or []
if not outputs:
raise HTTPException(status_code=502, detail="WaveSpeed Hunyuan Video Foley returned no outputs")
video_url = None
if isinstance(outputs[0], str):
video_url = outputs[0]
elif isinstance(outputs[0], dict):
video_url = outputs[0].get("url") or outputs[0].get("video_url")
if not video_url:
raise HTTPException(status_code=502, detail="WaveSpeed Hunyuan Video Foley output format not recognized")
logger.info(f"[WaveSpeed] Downloading video with audio from: {video_url}")
video_response = requests.get(video_url, timeout=timeout)
if video_response.status_code != 200:
logger.error(f"[WaveSpeed] Failed to download video with audio: {video_response.status_code}")
raise HTTPException(
status_code=502,
detail="Failed to download video with audio from WaveSpeed",
)
video_bytes = video_response.content
logger.info(f"[WaveSpeed] Hunyuan Video Foley completed successfully (size: {len(video_bytes)} bytes)")
return video_bytes
else:
raise HTTPException(
status_code=501,
detail={
"error": "Async mode not yet implemented for Hunyuan Video Foley",
"prediction_id": prediction_id,
},
)
def think_sound(
self,
video: str, # Base64-encoded video or URL
prompt: Optional[str] = None, # Optional text prompt describing desired sounds
seed: int = -1, # Random seed (-1 for random)
enable_sync_mode: bool = False,
timeout: int = 300,
progress_callback: Optional[Callable[[float, str], None]] = None,
) -> bytes:
"""
Generate realistic sound effects and audio tracks from video using Think Sound.
Args:
video: Base64-encoded video data URI or public URL (source video)
prompt: Optional text prompt describing desired sounds (e.g., "engine roaring, footsteps on gravel")
seed: Random seed for reproducibility (-1 for random)
enable_sync_mode: If True, wait for result and return it directly
timeout: Request timeout in seconds (default: 300)
progress_callback: Optional callback function(progress: float, message: str) for progress updates
Returns:
bytes: Video with generated audio
Raises:
HTTPException: If the audio generation fails
"""
model_path = "wavespeed-ai/think-sound"
url = f"{self.base_url}/{model_path}"
# Build payload
payload = {
"video": video,
"seed": seed,
}
if prompt:
payload["prompt"] = prompt
logger.info(
f"[WaveSpeed] Think Sound request via {url} "
f"(has_prompt={prompt is not None}, seed={seed})"
)
# Submit the task
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
if response.status_code != 200:
logger.error(f"[WaveSpeed] Think Sound submission failed: {response.status_code} {response.text}")
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed Think Sound submission failed",
"status_code": response.status_code,
"response": response.text,
},
)
response_json = response.json()
data = response_json.get("data") or response_json
prediction_id = data.get("id")
if not prediction_id:
logger.error(f"[WaveSpeed] No prediction ID in Think Sound response: {response.text}")
raise HTTPException(
status_code=502,
detail="WaveSpeed Think Sound response missing prediction id",
)
logger.info(f"[WaveSpeed] Think Sound task submitted: {prediction_id}")
if enable_sync_mode:
result = self.polling.poll_until_complete(
prediction_id,
timeout_seconds=timeout,
interval_seconds=2.0,
progress_callback=progress_callback,
)
outputs = result.get("outputs") or []
if not outputs:
raise HTTPException(status_code=502, detail="WaveSpeed Think Sound returned no outputs")
video_url = None
if isinstance(outputs[0], str):
video_url = outputs[0]
elif isinstance(outputs[0], dict):
video_url = outputs[0].get("url") or outputs[0].get("video_url")
if not video_url:
raise HTTPException(status_code=502, detail="WaveSpeed Think Sound output format not recognized")
logger.info(f"[WaveSpeed] Downloading video with audio from: {video_url}")
video_response = requests.get(video_url, timeout=timeout)
if video_response.status_code != 200:
logger.error(f"[WaveSpeed] Failed to download video with audio: {video_response.status_code}")
raise HTTPException(
status_code=502,
detail="Failed to download video with audio from WaveSpeed",
)
video_bytes = video_response.content
logger.info(f"[WaveSpeed] Think Sound completed successfully (size: {len(video_bytes)} bytes)")
return video_bytes
else:
raise HTTPException(
status_code=501,
detail={
"error": "Async mode not yet implemented for Think Sound",
"prediction_id": prediction_id,
},
)

View File

@@ -0,0 +1,127 @@
"""
Video background removal operations.
"""
import requests
from typing import Optional, Callable
from fastapi import HTTPException
from utils.logger_utils import get_service_logger
from .base import VideoBase
logger = get_service_logger("wavespeed.generators.video.background")
class VideoBackground(VideoBase):
"""Video background removal operations."""
def remove_background(
self,
video: str, # Base64-encoded video or URL
background_image: Optional[str] = None, # Base64-encoded image or URL (optional)
enable_sync_mode: bool = False,
timeout: int = 300,
progress_callback: Optional[Callable[[float, str], None]] = None,
) -> bytes:
"""
Remove or replace video background using Video Background Remover.
Args:
video: Base64-encoded video data URI or public URL (source video)
background_image: Optional base64-encoded image data URI or public URL (replacement background)
enable_sync_mode: If True, wait for result and return it directly
timeout: Request timeout in seconds (default: 300)
progress_callback: Optional callback function(progress: float, message: str) for progress updates
Returns:
bytes: Video with background removed/replaced
Raises:
HTTPException: If the background removal fails
"""
model_path = "wavespeed-ai/video-background-remover"
url = f"{self.base_url}/{model_path}"
# Build payload
payload = {
"video": video,
}
if background_image:
payload["background_image"] = background_image
logger.info(
f"[WaveSpeed] Video background removal request via {url} "
f"(has_background={background_image is not None})"
)
# Submit the task
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
if response.status_code != 200:
logger.error(f"[WaveSpeed] Video background removal submission failed: {response.status_code} {response.text}")
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed video background removal submission failed",
"status_code": response.status_code,
"response": response.text,
},
)
response_json = response.json()
data = response_json.get("data") or response_json
prediction_id = data.get("id")
if not prediction_id:
logger.error(f"[WaveSpeed] No prediction ID in video background removal response: {response.text}")
raise HTTPException(
status_code=502,
detail="WaveSpeed video background removal response missing prediction id",
)
logger.info(f"[WaveSpeed] Video background removal task submitted: {prediction_id}")
if enable_sync_mode:
result = self.polling.poll_until_complete(
prediction_id,
timeout_seconds=timeout,
interval_seconds=2.0,
progress_callback=progress_callback,
)
outputs = result.get("outputs") or []
if not outputs:
raise HTTPException(status_code=502, detail="WaveSpeed video background removal returned no outputs")
video_url = None
if isinstance(outputs[0], str):
video_url = outputs[0]
elif isinstance(outputs[0], dict):
video_url = outputs[0].get("url") or outputs[0].get("video_url")
if not video_url:
raise HTTPException(status_code=502, detail="WaveSpeed video background removal output format not recognized")
logger.info(f"[WaveSpeed] Downloading processed video from: {video_url}")
video_response = requests.get(video_url, timeout=timeout)
if video_response.status_code != 200:
logger.error(f"[WaveSpeed] Failed to download processed video: {video_response.status_code}")
raise HTTPException(
status_code=502,
detail="Failed to download processed video from WaveSpeed",
)
video_bytes = video_response.content
logger.info(f"[WaveSpeed] Video background removal completed successfully (size: {len(video_bytes)} bytes)")
return video_bytes
else:
raise HTTPException(
status_code=501,
detail={
"error": "Async mode not yet implemented for video background removal",
"prediction_id": prediction_id,
},
)

View File

@@ -0,0 +1,84 @@
"""
Base functionality for video operations.
Shared utilities for HTTP requests, video download, and common operations.
"""
import requests
from typing import Optional
from fastapi import HTTPException
from utils.logger_utils import get_service_logger
logger = get_service_logger("wavespeed.generators.video.base")
class VideoBase:
"""Base class for video operations with shared functionality."""
def __init__(self, api_key: str, base_url: str, polling):
"""Initialize video base.
Args:
api_key: WaveSpeed API key
base_url: WaveSpeed API base URL
polling: WaveSpeedPolling instance for async operations
"""
self.api_key = api_key
self.base_url = base_url
self.polling = polling
def _get_headers(self) -> dict:
"""Get HTTP headers for API requests."""
return {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
def _download_video(self, video_url: str, timeout: int = 180) -> bytes:
"""Download video from URL.
Args:
video_url: URL to download video from
timeout: Request timeout in seconds
Returns:
bytes: Video bytes
Raises:
HTTPException: If download fails
"""
logger.info(f"[WaveSpeed] Downloading video from: {video_url}")
video_response = requests.get(video_url, timeout=timeout)
if video_response.status_code != 200:
raise HTTPException(
status_code=502,
detail={
"error": "Failed to download video",
"status_code": video_response.status_code,
"response": video_response.text[:200],
}
)
return video_response.content
def _extract_video_url(self, outputs: list) -> Optional[str]:
"""Extract video URL from outputs array.
Args:
outputs: Array of outputs (can be strings or dicts)
Returns:
Optional[str]: Video URL if found, None otherwise
"""
if not outputs:
return None
output = outputs[0]
if isinstance(output, str):
return output if output.startswith("http") else None
elif isinstance(output, dict):
return output.get("url") or output.get("video_url")
return None

View File

@@ -0,0 +1,109 @@
"""
Video enhancement operations (upscaling).
"""
import requests
from typing import Optional, Callable
from fastapi import HTTPException
from utils.logger_utils import get_service_logger
from .base import VideoBase
logger = get_service_logger("wavespeed.generators.video.enhancement")
class VideoEnhancement(VideoBase):
"""Video enhancement operations."""
def upscale_video(
self,
video: str, # Base64-encoded video or URL
target_resolution: str = "1080p",
enable_sync_mode: bool = False,
timeout: int = 300,
progress_callback: Optional[Callable[[float, str], None]] = None,
) -> bytes:
"""
Upscale video using FlashVSR.
Args:
video: Base64-encoded video data URI or public URL
target_resolution: Target resolution ("720p", "1080p", "2k", "4k")
enable_sync_mode: If True, wait for result and return it directly
timeout: Request timeout in seconds (default: 300 for long videos)
progress_callback: Optional callback function(progress: float, message: str) for progress updates
Returns:
bytes: Upscaled video bytes
Raises:
HTTPException: If the upscaling fails
"""
model_path = "wavespeed-ai/flashvsr"
url = f"{self.base_url}/{model_path}"
payload = {
"video": video,
"target_resolution": target_resolution,
}
logger.info(f"[WaveSpeed] Upscaling video via {url} (target={target_resolution})")
# Submit the task
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
if response.status_code != 200:
logger.error(f"[WaveSpeed] FlashVSR submission failed: {response.status_code} {response.text}")
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed FlashVSR submission failed",
"status_code": response.status_code,
"response": response.text,
},
)
response_json = response.json()
data = response_json.get("data") or response_json
prediction_id = data.get("id")
if not prediction_id:
logger.error(f"[WaveSpeed] No prediction ID in FlashVSR response: {response.text}")
raise HTTPException(
status_code=502,
detail="WaveSpeed FlashVSR response missing prediction id",
)
logger.info(f"[WaveSpeed] FlashVSR task submitted: {prediction_id}")
# Poll for result
result = self.polling.poll_until_complete(
prediction_id,
timeout_seconds=timeout,
interval_seconds=2.0, # Longer interval for upscaling (slower process)
progress_callback=progress_callback,
)
outputs = result.get("outputs") or []
if not outputs:
raise HTTPException(status_code=502, detail="WaveSpeed FlashVSR returned no outputs")
video_url = outputs[0] if isinstance(outputs[0], str) else outputs[0].get("url")
if not video_url:
raise HTTPException(status_code=502, detail="WaveSpeed FlashVSR output format not recognized")
# Download the upscaled video
logger.info(f"[WaveSpeed] Downloading upscaled video from: {video_url}")
video_response = requests.get(video_url, timeout=timeout)
if video_response.status_code != 200:
logger.error(f"[WaveSpeed] Failed to download upscaled video: {video_response.status_code}")
raise HTTPException(
status_code=502,
detail="Failed to download upscaled video from WaveSpeed",
)
video_bytes = video_response.content
logger.info(f"[WaveSpeed] Video upscaling completed successfully (size: {len(video_bytes)} bytes)")
return video_bytes

View File

@@ -0,0 +1,161 @@
"""
Video extension operations.
"""
import requests
from typing import Optional, Callable
from fastapi import HTTPException
from utils.logger_utils import get_service_logger
from .base import VideoBase
logger = get_service_logger("wavespeed.generators.video.extension")
class VideoExtension(VideoBase):
"""Video extension operations."""
def extend_video(
self,
video: str, # Base64-encoded video or URL
prompt: str,
model: str = "wan-2.5", # "wan-2.5", "wan-2.2-spicy", or "seedance-1.5-pro"
audio: Optional[str] = None, # Optional audio URL (WAN 2.5 only)
negative_prompt: Optional[str] = None, # WAN 2.5 only
resolution: str = "720p",
duration: int = 5,
enable_prompt_expansion: bool = False, # WAN 2.5 only
generate_audio: bool = True, # Seedance 1.5 Pro only
camera_fixed: bool = False, # Seedance 1.5 Pro only
seed: Optional[int] = None,
enable_sync_mode: bool = False,
timeout: int = 300,
progress_callback: Optional[Callable[[float, str], None]] = None,
) -> bytes:
"""
Extend video duration using WAN 2.5, WAN 2.2 Spicy, or Seedance 1.5 Pro video-extend.
Args:
video: Base64-encoded video data URI or public URL
prompt: Text prompt describing how to extend the video
model: Model to use ("wan-2.5", "wan-2.2-spicy", or "seedance-1.5-pro")
audio: Optional audio URL to guide generation (WAN 2.5 only)
negative_prompt: Optional negative prompt (WAN 2.5 only)
resolution: Output resolution (varies by model)
duration: Duration of extended video in seconds (varies by model)
enable_prompt_expansion: Enable prompt optimizer (WAN 2.5 only)
generate_audio: Generate audio for extended video (Seedance 1.5 Pro only)
camera_fixed: Fix camera position (Seedance 1.5 Pro only)
seed: Random seed for reproducibility (-1 for random)
enable_sync_mode: If True, wait for result and return it directly
timeout: Request timeout in seconds (default: 300)
progress_callback: Optional callback function(progress: float, message: str) for progress updates
Returns:
bytes: Extended video bytes
Raises:
HTTPException: If the extension fails
"""
# Determine model path
if model in ("wan-2.2-spicy", "wavespeed-ai/wan-2.2-spicy/video-extend"):
model_path = "wavespeed-ai/wan-2.2-spicy/video-extend"
elif model in ("seedance-1.5-pro", "bytedance/seedance-v1.5-pro/video-extend"):
model_path = "bytedance/seedance-v1.5-pro/video-extend"
else:
# Default to WAN 2.5
model_path = "alibaba/wan-2.5/video-extend"
url = f"{self.base_url}/{model_path}"
# Base payload (common to all models)
payload = {
"video": video,
"prompt": prompt,
"resolution": resolution,
"duration": duration,
}
# Model-specific parameters
if model_path == "alibaba/wan-2.5/video-extend":
# WAN 2.5 specific
payload["enable_prompt_expansion"] = enable_prompt_expansion
if audio:
payload["audio"] = audio
if negative_prompt:
payload["negative_prompt"] = negative_prompt
elif model_path == "bytedance/seedance-v1.5-pro/video-extend":
# Seedance 1.5 Pro specific
payload["generate_audio"] = generate_audio
payload["camera_fixed"] = camera_fixed
# Seed (all models support it)
if seed is not None:
payload["seed"] = seed
logger.info(f"[WaveSpeed] Extending video via {url} (duration={duration}s, resolution={resolution})")
# Submit the task
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
if response.status_code != 200:
logger.error(f"[WaveSpeed] Video extend submission failed: {response.status_code} {response.text}")
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed video extend submission failed",
"status_code": response.status_code,
"response": response.text,
},
)
response_json = response.json()
data = response_json.get("data") or response_json
prediction_id = data.get("id")
if not prediction_id:
logger.error(f"[WaveSpeed] No prediction ID in video extend response: {response.text}")
raise HTTPException(
status_code=502,
detail="WaveSpeed video extend response missing prediction id",
)
logger.info(f"[WaveSpeed] Video extend task submitted: {prediction_id}")
# Poll for result
result = self.polling.poll_until_complete(
prediction_id,
timeout_seconds=timeout,
interval_seconds=2.0,
progress_callback=progress_callback,
)
outputs = result.get("outputs") or []
if not outputs:
raise HTTPException(status_code=502, detail="WaveSpeed video extend returned no outputs")
# Handle outputs - can be array of strings or array of objects
video_url = None
if isinstance(outputs[0], str):
video_url = outputs[0]
elif isinstance(outputs[0], dict):
video_url = outputs[0].get("url") or outputs[0].get("video_url")
if not video_url:
raise HTTPException(status_code=502, detail="WaveSpeed video extend output format not recognized")
# Download the extended video
logger.info(f"[WaveSpeed] Downloading extended video from: {video_url}")
video_response = requests.get(video_url, timeout=timeout)
if video_response.status_code != 200:
logger.error(f"[WaveSpeed] Failed to download extended video: {video_response.status_code}")
raise HTTPException(
status_code=502,
detail="Failed to download extended video from WaveSpeed",
)
video_bytes = video_response.content
logger.info(f"[WaveSpeed] Video extension completed successfully (size: {len(video_bytes)} bytes)")
return video_bytes

View File

@@ -0,0 +1,283 @@
"""
Face swap operations.
"""
import requests
from typing import Optional, Callable
from fastapi import HTTPException
from utils.logger_utils import get_service_logger
from .base import VideoBase
logger = get_service_logger("wavespeed.generators.video.face_swap")
class VideoFaceSwap(VideoBase):
"""Face swap operations."""
def face_swap(
self,
image: str, # Base64-encoded image or URL
video: str, # Base64-encoded video or URL
prompt: Optional[str] = None,
resolution: str = "480p",
seed: Optional[int] = None,
enable_sync_mode: bool = False,
timeout: int = 300,
progress_callback: Optional[Callable[[float, str], None]] = None,
) -> bytes:
"""
Perform face/character swap using MoCha (wavespeed-ai/wan-2.1/mocha).
Args:
image: Base64-encoded image data URI or public URL (reference character)
video: Base64-encoded video data URI or public URL (source video)
prompt: Optional prompt to guide the swap
resolution: Output resolution ("480p" or "720p")
seed: Random seed for reproducibility (-1 for random)
enable_sync_mode: If True, wait for result and return it directly
timeout: Request timeout in seconds (default: 300)
progress_callback: Optional callback function(progress: float, message: str) for progress updates
Returns:
bytes: Face-swapped video bytes
Raises:
HTTPException: If the face swap fails
"""
model_path = "wavespeed-ai/wan-2.1/mocha"
url = f"{self.base_url}/{model_path}"
# Build payload
payload = {
"image": image,
"video": video,
}
if prompt:
payload["prompt"] = prompt
if resolution in ("480p", "720p"):
payload["resolution"] = resolution
else:
payload["resolution"] = "480p" # Default
if seed is not None:
payload["seed"] = seed
else:
payload["seed"] = -1 # Random seed
logger.info(
f"[WaveSpeed] Face swap request via {url} "
f"(resolution={payload['resolution']}, seed={payload['seed']})"
)
# Submit the task
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
if response.status_code != 200:
logger.error(f"[WaveSpeed] Face swap submission failed: {response.status_code} {response.text}")
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed face swap submission failed",
"status_code": response.status_code,
"response": response.text,
},
)
response_json = response.json()
data = response_json.get("data") or response_json
if not data or "id" not in data:
logger.error(f"[WaveSpeed] Unexpected face swap response: {response.text}")
raise HTTPException(
status_code=502,
detail={"error": "WaveSpeed response missing prediction id"},
)
prediction_id = data["id"]
logger.info(f"[WaveSpeed] Face swap submitted: {prediction_id}")
if enable_sync_mode:
# Poll until complete
result = self.polling.poll_until_complete(
prediction_id,
timeout_seconds=timeout,
interval_seconds=2.0,
progress_callback=progress_callback,
)
# Extract video URL from result
outputs = result.get("outputs", [])
if not outputs:
raise HTTPException(
status_code=502,
detail={"error": "Face swap completed but no output video found"},
)
# Handle outputs - can be array of strings or array of objects
video_url = None
if isinstance(outputs[0], str):
video_url = outputs[0]
elif isinstance(outputs[0], dict):
video_url = outputs[0].get("url") or outputs[0].get("video_url")
if not video_url:
raise HTTPException(
status_code=502,
detail={"error": "Face swap output format not recognized"},
)
# Download video
logger.info(f"[WaveSpeed] Downloading face-swapped video from: {video_url}")
video_response = requests.get(video_url, timeout=timeout)
if video_response.status_code != 200:
raise HTTPException(
status_code=502,
detail={"error": f"Failed to download face-swapped video: {video_response.status_code}"},
)
video_bytes = video_response.content
logger.info(f"[WaveSpeed] Face swap completed: {len(video_bytes)} bytes")
return video_bytes
else:
# Return prediction ID for async polling
raise HTTPException(
status_code=501,
detail={
"error": "Async mode not yet implemented for face swap",
"prediction_id": prediction_id,
},
)
def video_face_swap(
self,
video: str, # Base64-encoded video or URL
face_image: str, # Base64-encoded image or URL
target_gender: str = "all",
target_index: int = 0,
enable_sync_mode: bool = False,
timeout: int = 300,
progress_callback: Optional[Callable[[float, str], None]] = None,
) -> bytes:
"""
Perform face swap using Video Face Swap (wavespeed-ai/video-face-swap).
Args:
video: Base64-encoded video data URI or public URL (source video)
face_image: Base64-encoded image data URI or public URL (reference face)
target_gender: Filter which faces to swap ("all", "female", "male")
target_index: Select which face to swap (0 = largest, 1 = second largest, etc.)
enable_sync_mode: If True, wait for result and return it directly
timeout: Request timeout in seconds (default: 300)
progress_callback: Optional callback function(progress: float, message: str) for progress updates
Returns:
bytes: Face-swapped video bytes
Raises:
HTTPException: If the face swap fails
"""
model_path = "wavespeed-ai/video-face-swap"
url = f"{self.base_url}/{model_path}"
# Build payload
payload = {
"video": video,
"face_image": face_image,
}
if target_gender in ("all", "female", "male"):
payload["target_gender"] = target_gender
else:
payload["target_gender"] = "all" # Default
if 0 <= target_index <= 10:
payload["target_index"] = target_index
else:
payload["target_index"] = 0 # Default
logger.info(
f"[WaveSpeed] Video face swap request via {url} "
f"(target_gender={payload['target_gender']}, target_index={payload['target_index']})"
)
# Submit the task
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
if response.status_code != 200:
logger.error(f"[WaveSpeed] Video face swap submission failed: {response.status_code} {response.text}")
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed video face swap submission failed",
"status_code": response.status_code,
"response": response.text,
},
)
response_json = response.json()
data = response_json.get("data") or response_json
if not data or "id" not in data:
logger.error(f"[WaveSpeed] Unexpected video face swap response: {response.text}")
raise HTTPException(
status_code=502,
detail={"error": "WaveSpeed response missing prediction id"},
)
prediction_id = data["id"]
logger.info(f"[WaveSpeed] Video face swap submitted: {prediction_id}")
if enable_sync_mode:
# Poll until complete
result = self.polling.poll_until_complete(
prediction_id,
timeout_seconds=timeout,
interval_seconds=2.0,
progress_callback=progress_callback,
)
# Extract video URL from result
outputs = result.get("outputs", [])
if not outputs:
raise HTTPException(
status_code=502,
detail={"error": "Video face swap completed but no output video found"},
)
# Handle outputs - can be array of strings or array of objects
video_url = None
if isinstance(outputs[0], str):
video_url = outputs[0]
elif isinstance(outputs[0], dict):
video_url = outputs[0].get("url") or outputs[0].get("video_url")
if not video_url:
raise HTTPException(
status_code=502,
detail={"error": "Video face swap output format not recognized"},
)
# Download video
logger.info(f"[WaveSpeed] Downloading face-swapped video from: {video_url}")
video_response = requests.get(video_url, timeout=timeout)
if video_response.status_code != 200:
raise HTTPException(
status_code=502,
detail={"error": f"Failed to download face-swapped video: {video_response.status_code}"},
)
video_bytes = video_response.content
logger.info(f"[WaveSpeed] Video face swap completed: {len(video_bytes)} bytes")
return video_bytes
else:
# Return prediction ID for async polling
raise HTTPException(
status_code=501,
detail={
"error": "Async mode not yet implemented for video face swap",
"prediction_id": prediction_id,
},
)

View File

@@ -0,0 +1,333 @@
"""
Video generation operations (text-to-video and image-to-video).
"""
import requests
from typing import Any, Dict, Optional
from fastapi import HTTPException
from utils.logger_utils import get_service_logger
from .base import VideoBase
logger = get_service_logger("wavespeed.generators.video.generation")
class VideoGeneration(VideoBase):
"""Video generation operations."""
def submit_image_to_video(
self,
model_path: str,
payload: Dict[str, Any],
timeout: int = 30,
) -> str:
"""
Submit an image-to-video generation request.
Returns the prediction ID for polling.
"""
url = f"{self.base_url}/{model_path}"
logger.info(f"[WaveSpeed] Submitting request to {url}")
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
if response.status_code != 200:
logger.error(f"[WaveSpeed] Submission failed: {response.status_code} {response.text}")
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed image-to-video submission failed",
"status_code": response.status_code,
"response": response.text,
},
)
data = response.json().get("data")
if not data or "id" not in data:
logger.error(f"[WaveSpeed] Unexpected submission response: {response.text}")
raise HTTPException(
status_code=502,
detail={"error": "WaveSpeed response missing prediction id"},
)
prediction_id = data["id"]
logger.info(f"[WaveSpeed] Submitted request: {prediction_id}")
return prediction_id
def submit_text_to_video(
self,
model_path: str,
payload: Dict[str, Any],
timeout: int = 60,
) -> str:
"""
Submit a text-to-video generation request to WaveSpeed.
Args:
model_path: Model path (e.g., "alibaba/wan-2.5/text-to-video")
payload: Request payload with prompt, resolution, duration, optional audio
timeout: Request timeout in seconds
Returns:
Prediction ID for polling
"""
url = f"{self.base_url}/{model_path}"
logger.info(f"[WaveSpeed] Submitting text-to-video request to {url}")
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
if response.status_code != 200:
logger.error(f"[WaveSpeed] Text-to-video submission failed: {response.status_code} {response.text}")
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed text-to-video submission failed",
"status_code": response.status_code,
"response": response.text,
},
)
data = response.json().get("data")
if not data or "id" not in data:
logger.error(f"[WaveSpeed] Unexpected text-to-video response: {response.text}")
raise HTTPException(
status_code=502,
detail={"error": "WaveSpeed response missing prediction id"},
)
prediction_id = data["id"]
logger.info(f"[WaveSpeed] Submitted text-to-video request: {prediction_id}")
return prediction_id
def generate_text_video(
self,
prompt: str,
resolution: str = "720p", # 480p, 720p, 1080p
duration: int = 5, # 5 or 10 seconds
audio_base64: Optional[str] = None, # Optional audio for lip-sync
negative_prompt: Optional[str] = None,
seed: Optional[int] = None,
enable_prompt_expansion: bool = True,
enable_sync_mode: bool = False,
timeout: int = 180,
) -> Dict[str, Any]:
"""
Generate video from text prompt using WAN 2.5 text-to-video.
Args:
prompt: Text prompt describing the video
resolution: Output resolution (480p, 720p, 1080p)
duration: Video duration in seconds (5 or 10)
audio_base64: Optional audio file (wav/mp3, 3-30s, ≤15MB) for lip-sync
negative_prompt: Optional negative prompt
seed: Optional random seed for reproducibility
enable_prompt_expansion: Enable prompt optimizer
enable_sync_mode: If True, wait for result and return it directly
timeout: Request timeout in seconds
Returns:
Dictionary with video bytes, metadata, and cost
"""
model_path = "alibaba/wan-2.5/text-to-video"
# Validate resolution
valid_resolutions = ["480p", "720p", "1080p"]
if resolution not in valid_resolutions:
raise HTTPException(
status_code=400,
detail=f"Invalid resolution: {resolution}. Must be one of: {valid_resolutions}"
)
# Validate duration
if duration not in [5, 10]:
raise HTTPException(
status_code=400,
detail="Duration must be 5 or 10 seconds"
)
# Build payload
payload = {
"prompt": prompt,
"resolution": resolution,
"duration": duration,
"enable_prompt_expansion": enable_prompt_expansion,
"enable_sync_mode": enable_sync_mode,
}
# Add optional audio
if audio_base64:
payload["audio"] = audio_base64
# Add optional parameters
if negative_prompt:
payload["negative_prompt"] = negative_prompt
if seed is not None:
payload["seed"] = seed
# Submit request
logger.info(
f"[WaveSpeed] Generating text-to-video: resolution={resolution}, "
f"duration={duration}s, prompt_length={len(prompt)}, sync_mode={enable_sync_mode}"
)
# For sync mode, submit and get result directly
if enable_sync_mode:
url = f"{self.base_url}/{model_path}"
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
if response.status_code != 200:
logger.error(f"[WaveSpeed] Text-to-video submission failed: {response.status_code} {response.text}")
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed text-to-video submission failed",
"status_code": response.status_code,
"response": response.text[:500],
},
)
response_json = response.json()
data = response_json.get("data") or response_json
# Check status - if "created" or "processing", we need to poll even in sync mode
status = data.get("status", "").lower()
outputs = data.get("outputs") or []
prediction_id = data.get("id")
logger.debug(
f"[WaveSpeed] Sync mode response: status='{status}', outputs_count={len(outputs)}, "
f"prediction_id={prediction_id}"
)
# Handle sync mode - result should be directly in outputs
if status == "completed" and outputs:
# Sync mode returned completed result - use it directly
logger.info(f"[WaveSpeed] Got immediate video results from sync mode (status: {status})")
video_url = outputs[0]
if not isinstance(video_url, str) or not video_url.startswith("http"):
logger.error(f"[WaveSpeed] Invalid video URL format in sync mode: {video_url}")
raise HTTPException(
status_code=502,
detail=f"Invalid video URL format: {video_url}",
)
video_bytes = self._download_video(video_url)
metadata = data.get("metadata") or {}
else:
# Sync mode returned "created", "processing", or incomplete status - need to poll
if not prediction_id:
logger.error(
f"[WaveSpeed] Sync mode returned status '{status}' but no prediction ID. "
f"Response: {response.text[:500]}"
)
raise HTTPException(
status_code=502,
detail="WaveSpeed text-to-video sync mode returned async response without prediction ID",
)
logger.info(
f"[WaveSpeed] Sync mode returned status '{status}' with {len(outputs)} output(s). "
f"Falling back to polling (prediction_id: {prediction_id})"
)
# Poll for completion
try:
result = self.polling.poll_until_complete(
prediction_id,
timeout_seconds=timeout,
interval_seconds=2.0,
)
except HTTPException as e:
detail = e.detail or {}
if isinstance(detail, dict):
detail.setdefault("prediction_id", prediction_id)
detail.setdefault("resume_available", True)
raise HTTPException(status_code=e.status_code, detail=detail)
outputs = result.get("outputs") or []
if not outputs:
logger.error(f"[WaveSpeed] Polling completed but no outputs: {result}")
raise HTTPException(
status_code=502,
detail="WaveSpeed text-to-video completed but returned no outputs",
)
video_url = outputs[0]
if not isinstance(video_url, str) or not video_url.startswith("http"):
logger.error(f"[WaveSpeed] Invalid video URL format after polling: {video_url}")
raise HTTPException(
status_code=502,
detail=f"Invalid video URL format: {video_url}",
)
video_bytes = self._download_video(video_url)
metadata = result.get("metadata") or {}
else:
# Async mode - submit and poll
prediction_id = self.submit_text_to_video(model_path, payload, timeout=timeout)
# Poll for completion
try:
result = self.polling.poll_until_complete(
prediction_id,
timeout_seconds=timeout,
interval_seconds=2.0
)
except HTTPException as e:
detail = e.detail or {}
if isinstance(detail, dict):
detail.setdefault("prediction_id", prediction_id)
detail.setdefault("resume_available", True)
raise HTTPException(status_code=e.status_code, detail=detail)
# Extract video URL
outputs = result.get("outputs") or []
if not outputs:
raise HTTPException(
status_code=502,
detail="WAN 2.5 text-to-video completed but returned no outputs"
)
video_url = outputs[0]
if not isinstance(video_url, str) or not video_url.startswith("http"):
raise HTTPException(
status_code=502,
detail=f"Invalid video URL format: {video_url}"
)
video_bytes = self._download_video(video_url)
metadata = result.get("metadata") or {}
# prediction_id is already set from earlier in the function
# Calculate cost (same pricing as image-to-video)
pricing = {
"480p": 0.05,
"720p": 0.10,
"1080p": 0.15,
}
cost = pricing.get(resolution, 0.10) * duration
# Get video dimensions
resolution_dims = {
"480p": (854, 480),
"720p": (1280, 720),
"1080p": (1920, 1080),
}
width, height = resolution_dims.get(resolution, (1280, 720))
logger.info(
f"[WaveSpeed] ✅ Generated text-to-video: {len(video_bytes)} bytes, "
f"resolution={resolution}, duration={duration}s, cost=${cost:.2f}"
)
return {
"video_bytes": video_bytes,
"prompt": prompt,
"duration": float(duration),
"model_name": "alibaba/wan-2.5/text-to-video",
"cost": cost,
"provider": "wavespeed",
"source_video_url": video_url,
"prediction_id": prediction_id,
"resolution": resolution,
"width": width,
"height": height,
"metadata": metadata,
}

View File

@@ -0,0 +1,263 @@
"""
Main VideoGenerator class that composes all video operation modules.
This class maintains backward compatibility with the original monolithic VideoGenerator
by delegating to specialized modules for different video operations.
"""
from typing import Any, Dict, Optional, Callable
from .base import VideoBase
from .generation import VideoGeneration
from .enhancement import VideoEnhancement
from .extension import VideoExtension
from .face_swap import VideoFaceSwap
from .translation import VideoTranslation
from .background import VideoBackground
from .audio import VideoAudio
class VideoGenerator(VideoBase):
"""
Video generation generator for WaveSpeed API.
This class composes multiple specialized modules to provide all video operations
while maintaining a single unified interface for backward compatibility.
"""
def __init__(self, api_key: str, base_url: str, polling):
"""Initialize video generator.
Args:
api_key: WaveSpeed API key
base_url: WaveSpeed API base URL
polling: WaveSpeedPolling instance for async operations
"""
super().__init__(api_key, base_url, polling)
# Initialize specialized modules
self._generation = VideoGeneration(api_key, base_url, polling)
self._enhancement = VideoEnhancement(api_key, base_url, polling)
self._extension = VideoExtension(api_key, base_url, polling)
self._face_swap = VideoFaceSwap(api_key, base_url, polling)
self._translation = VideoTranslation(api_key, base_url, polling)
self._background = VideoBackground(api_key, base_url, polling)
self._audio = VideoAudio(api_key, base_url, polling)
# Generation methods (delegated to VideoGeneration)
def submit_image_to_video(
self,
model_path: str,
payload: Dict[str, Any],
timeout: int = 30,
) -> str:
"""Submit an image-to-video generation request."""
return self._generation.submit_image_to_video(model_path, payload, timeout)
def submit_text_to_video(
self,
model_path: str,
payload: Dict[str, Any],
timeout: int = 60,
) -> str:
"""Submit a text-to-video generation request to WaveSpeed."""
return self._generation.submit_text_to_video(model_path, payload, timeout)
def generate_text_video(
self,
prompt: str,
resolution: str = "720p",
duration: int = 5,
audio_base64: Optional[str] = None,
negative_prompt: Optional[str] = None,
seed: Optional[int] = None,
enable_prompt_expansion: bool = True,
enable_sync_mode: bool = False,
timeout: int = 180,
) -> Dict[str, Any]:
"""Generate video from text prompt using WAN 2.5 text-to-video."""
return self._generation.generate_text_video(
prompt=prompt,
resolution=resolution,
duration=duration,
audio_base64=audio_base64,
negative_prompt=negative_prompt,
seed=seed,
enable_prompt_expansion=enable_prompt_expansion,
enable_sync_mode=enable_sync_mode,
timeout=timeout,
)
# Enhancement methods (delegated to VideoEnhancement)
def upscale_video(
self,
video: str,
target_resolution: str = "1080p",
enable_sync_mode: bool = False,
timeout: int = 300,
progress_callback: Optional[Callable[[float, str], None]] = None,
) -> bytes:
"""Upscale video using FlashVSR."""
return self._enhancement.upscale_video(
video=video,
target_resolution=target_resolution,
enable_sync_mode=enable_sync_mode,
timeout=timeout,
progress_callback=progress_callback,
)
# Extension methods (delegated to VideoExtension)
def extend_video(
self,
video: str,
prompt: str,
model: str = "wan-2.5",
audio: Optional[str] = None,
negative_prompt: Optional[str] = None,
resolution: str = "720p",
duration: int = 5,
enable_prompt_expansion: bool = False,
generate_audio: bool = True,
camera_fixed: bool = False,
seed: Optional[int] = None,
enable_sync_mode: bool = False,
timeout: int = 300,
progress_callback: Optional[Callable[[float, str], None]] = None,
) -> bytes:
"""Extend video duration using WAN 2.5, WAN 2.2 Spicy, or Seedance 1.5 Pro video-extend."""
return self._extension.extend_video(
video=video,
prompt=prompt,
model=model,
audio=audio,
negative_prompt=negative_prompt,
resolution=resolution,
duration=duration,
enable_prompt_expansion=enable_prompt_expansion,
generate_audio=generate_audio,
camera_fixed=camera_fixed,
seed=seed,
enable_sync_mode=enable_sync_mode,
timeout=timeout,
progress_callback=progress_callback,
)
# Face swap methods (delegated to VideoFaceSwap)
def face_swap(
self,
image: str,
video: str,
prompt: Optional[str] = None,
resolution: str = "480p",
seed: Optional[int] = None,
enable_sync_mode: bool = False,
timeout: int = 300,
progress_callback: Optional[Callable[[float, str], None]] = None,
) -> bytes:
"""Perform face/character swap using MoCha (wavespeed-ai/wan-2.1/mocha)."""
return self._face_swap.face_swap(
image=image,
video=video,
prompt=prompt,
resolution=resolution,
seed=seed,
enable_sync_mode=enable_sync_mode,
timeout=timeout,
progress_callback=progress_callback,
)
def video_face_swap(
self,
video: str,
face_image: str,
target_gender: str = "all",
target_index: int = 0,
enable_sync_mode: bool = False,
timeout: int = 300,
progress_callback: Optional[Callable[[float, str], None]] = None,
) -> bytes:
"""Perform face swap using Video Face Swap (wavespeed-ai/video-face-swap)."""
return self._face_swap.video_face_swap(
video=video,
face_image=face_image,
target_gender=target_gender,
target_index=target_index,
enable_sync_mode=enable_sync_mode,
timeout=timeout,
progress_callback=progress_callback,
)
# Translation methods (delegated to VideoTranslation)
def video_translate(
self,
video: str,
output_language: str = "English",
enable_sync_mode: bool = False,
timeout: int = 600,
progress_callback: Optional[Callable[[float, str], None]] = None,
) -> bytes:
"""Translate video to target language using HeyGen Video Translate."""
return self._translation.video_translate(
video=video,
output_language=output_language,
enable_sync_mode=enable_sync_mode,
timeout=timeout,
progress_callback=progress_callback,
)
# Background methods (delegated to VideoBackground)
def remove_background(
self,
video: str,
background_image: Optional[str] = None,
enable_sync_mode: bool = False,
timeout: int = 300,
progress_callback: Optional[Callable[[float, str], None]] = None,
) -> bytes:
"""Remove or replace video background using Video Background Remover."""
return self._background.remove_background(
video=video,
background_image=background_image,
enable_sync_mode=enable_sync_mode,
timeout=timeout,
progress_callback=progress_callback,
)
# Audio methods (delegated to VideoAudio)
def hunyuan_video_foley(
self,
video: str,
prompt: Optional[str] = None,
seed: int = -1,
enable_sync_mode: bool = False,
timeout: int = 300,
progress_callback: Optional[Callable[[float, str], None]] = None,
) -> bytes:
"""Generate realistic Foley and ambient audio from video using Hunyuan Video Foley."""
return self._audio.hunyuan_video_foley(
video=video,
prompt=prompt,
seed=seed,
enable_sync_mode=enable_sync_mode,
timeout=timeout,
progress_callback=progress_callback,
)
def think_sound(
self,
video: str,
prompt: Optional[str] = None,
seed: int = -1,
enable_sync_mode: bool = False,
timeout: int = 300,
progress_callback: Optional[Callable[[float, str], None]] = None,
) -> bytes:
"""Generate realistic sound effects and audio tracks from video using Think Sound."""
return self._audio.think_sound(
video=video,
prompt=prompt,
seed=seed,
enable_sync_mode=enable_sync_mode,
timeout=timeout,
progress_callback=progress_callback,
)

View File

@@ -0,0 +1,133 @@
"""
Video translation operations.
"""
import requests
from typing import Optional, Callable
from fastapi import HTTPException
from utils.logger_utils import get_service_logger
from .base import VideoBase
logger = get_service_logger("wavespeed.generators.video.translation")
class VideoTranslation(VideoBase):
"""Video translation operations."""
def video_translate(
self,
video: str, # Base64-encoded video or URL
output_language: str = "English",
enable_sync_mode: bool = False,
timeout: int = 600,
progress_callback: Optional[Callable[[float, str], None]] = None,
) -> bytes:
"""
Translate video to target language using HeyGen Video Translate.
Args:
video: Base64-encoded video data URI or public URL (source video)
output_language: Target language for translation (default: "English")
enable_sync_mode: If True, wait for result and return it directly
timeout: Request timeout in seconds (default: 600)
progress_callback: Optional callback function(progress: float, message: str) for progress updates
Returns:
bytes: Translated video bytes
Raises:
HTTPException: If the video translation fails
"""
model_path = "heygen/video-translate"
url = f"{self.base_url}/{model_path}"
# Build payload
payload = {
"video": video,
"output_language": output_language,
}
logger.info(
f"[WaveSpeed] Video translate request via {url} "
f"(output_language={output_language})"
)
# Submit the task
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
if response.status_code != 200:
logger.error(f"[WaveSpeed] Video translate submission failed: {response.status_code} {response.text}")
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed video translate submission failed",
"status_code": response.status_code,
"response": response.text,
},
)
response_json = response.json()
data = response_json.get("data") or response_json
if not data or "id" not in data:
logger.error(f"[WaveSpeed] Unexpected video translate response: {response.text}")
raise HTTPException(
status_code=502,
detail={"error": "WaveSpeed response missing prediction id"},
)
prediction_id = data["id"]
logger.info(f"[WaveSpeed] Video translate submitted: {prediction_id}")
if enable_sync_mode:
# Poll until complete
result = self.polling.poll_until_complete(
prediction_id,
timeout_seconds=timeout,
interval_seconds=2.0,
progress_callback=progress_callback,
)
# Extract video URL from result
outputs = result.get("outputs", [])
if not outputs:
raise HTTPException(
status_code=502,
detail={"error": "Video translate completed but no output video found"},
)
# Handle outputs - can be array of strings or array of objects
video_url = None
if isinstance(outputs[0], str):
video_url = outputs[0]
elif isinstance(outputs[0], dict):
video_url = outputs[0].get("url") or outputs[0].get("video_url")
if not video_url:
raise HTTPException(
status_code=502,
detail={"error": "Video translate output format not recognized"},
)
# Download video
logger.info(f"[WaveSpeed] Downloading translated video from: {video_url}")
video_response = requests.get(video_url, timeout=timeout)
if video_response.status_code != 200:
raise HTTPException(
status_code=502,
detail={"error": f"Failed to download translated video: {video_response.status_code}"},
)
video_bytes = video_response.content
logger.info(f"[WaveSpeed] Video translate completed: {len(video_bytes)} bytes")
return video_bytes
else:
# Return prediction ID for async polling
raise HTTPException(
status_code=501,
detail={
"error": "Async mode not yet implemented for video translate",
"prediction_id": prediction_id,
},
)