Files
ALwrity/backend/services/video_studio/avatar_service.py

168 lines
6.6 KiB
Python

"""
Avatar Studio Service
Service for creating talking avatars using InfiniteTalk and Hunyuan Avatar.
Supports both models with automatic selection or explicit model choice.
"""
from typing import Dict, Any, Optional
from fastapi import HTTPException
from loguru import logger
from services.image_studio.infinitetalk_adapter import InfiniteTalkService
from services.video_studio.hunyuan_avatar_adapter import HunyuanAvatarService
from utils.logger_utils import get_service_logger
from services.llm_providers.main_video_generation import _track_video_operation_usage
logger = get_service_logger("video_studio.avatar")
class AvatarStudioService:
"""Service for Avatar Studio operations using InfiniteTalk and Hunyuan Avatar."""
def __init__(self):
"""Initialize Avatar Studio service."""
self.infinitetalk_service = InfiniteTalkService()
self.hunyuan_avatar_service = HunyuanAvatarService()
logger.info("[AvatarStudio] Service initialized with InfiniteTalk and Hunyuan Avatar")
async def create_talking_avatar(
self,
image_base64: str,
audio_base64: str,
resolution: str = "720p",
prompt: Optional[str] = None,
mask_image_base64: Optional[str] = None,
seed: Optional[int] = None,
user_id: str = "video_studio",
model: str = "infinitetalk",
progress_callback: Optional[callable] = None,
) -> Dict[str, Any]:
"""
Create talking avatar video using InfiniteTalk or Hunyuan Avatar.
Args:
image_base64: Person image in base64 or data URI
audio_base64: Audio file in base64 or data URI
resolution: Output resolution (480p or 720p)
prompt: Optional prompt for expression/style
mask_image_base64: Optional mask for animatable regions (InfiniteTalk only)
seed: Optional random seed
user_id: User ID for tracking
model: Model to use - "infinitetalk" (default) or "hunyuan-avatar"
progress_callback: Optional progress callback function
Returns:
Dictionary with video_bytes, metadata, cost, and file info
"""
logger.info(
f"[AvatarStudio] Creating talking avatar: user={user_id}, resolution={resolution}, model={model}"
)
# PRE-FLIGHT VALIDATION: Validate video generation before API call
# MUST happen BEFORE any API calls - return immediately if validation fails
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_video_generation_operations
db = next(get_db())
try:
pricing_service = PricingService(db)
# Raises HTTPException immediately if validation fails - frontend gets immediate response
validate_video_generation_operations(
pricing_service=pricing_service,
user_id=user_id
)
except HTTPException:
# Re-raise immediately - don't proceed with API call
logger.error(f"[AvatarStudio] ❌ Pre-flight validation failed - blocking API call")
raise
finally:
db.close()
import time
start_time = time.time()
try:
if model == "hunyuan-avatar":
# Use Hunyuan Avatar (doesn't support mask_image)
result = await self.hunyuan_avatar_service.create_talking_avatar(
image_base64=image_base64,
audio_base64=audio_base64,
resolution=resolution,
prompt=prompt,
seed=seed,
user_id=user_id,
progress_callback=progress_callback,
)
else:
# Default to InfiniteTalk
result = await self.infinitetalk_service.create_talking_avatar(
image_base64=image_base64,
audio_base64=audio_base64,
resolution=resolution,
prompt=prompt,
mask_image_base64=mask_image_base64,
seed=seed,
user_id=user_id,
)
response_time = time.time() - start_time
logger.info(
f"[AvatarStudio] ✅ Talking avatar created: "
f"model={model}, resolution={resolution}, duration={result.get('duration', 0)}s, "
f"cost=${result.get('cost', 0):.2f}"
)
# TRACK USAGE after successful API call
# Use video_bytes if available, otherwise check if result itself is bytes (unlikely, dict expected)
video_bytes = result.get("video_bytes")
if user_id and video_bytes:
_track_video_operation_usage(
user_id=user_id,
provider=model, # Use model name as provider/actual_provider for now
model=model,
operation_type="talking-avatar",
result_bytes=video_bytes,
cost=result.get("cost", 0.0),
prompt=prompt,
endpoint="/avatar-generation",
metadata=result,
log_prefix="[Avatar Generation]",
response_time=response_time
)
return result
except HTTPException:
raise
except Exception as e:
logger.error(f"[AvatarStudio] ❌ Error creating talking avatar: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"Failed to create talking avatar: {str(e)}"
)
def calculate_cost_estimate(
self,
resolution: str,
estimated_duration: float,
model: str = "infinitetalk",
) -> float:
"""
Calculate estimated cost for talking avatar generation.
Args:
resolution: Output resolution (480p or 720p)
estimated_duration: Estimated video duration in seconds
model: Model to use - "infinitetalk" (default) or "hunyuan-avatar"
Returns:
Estimated cost in USD
"""
if model == "hunyuan-avatar":
return self.hunyuan_avatar_service.calculate_cost(resolution, estimated_duration)
else:
return self.infinitetalk_service.calculate_cost(resolution, estimated_duration)