AI Researcher and Video Studio implementation complete
This commit is contained in:
@@ -110,6 +110,11 @@ class ContentAssetService:
|
||||
search_query: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
favorites_only: bool = False,
|
||||
collection_id: Optional[int] = None,
|
||||
date_from: Optional[datetime] = None,
|
||||
date_to: Optional[datetime] = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> Tuple[List[ContentAsset], int]:
|
||||
@@ -157,11 +162,37 @@ class ContentAssetService:
|
||||
tag_filters = [ContentAsset.tags.contains([tag]) for tag in tags]
|
||||
query = query.filter(or_(*tag_filters))
|
||||
|
||||
if collection_id:
|
||||
query = query.filter(ContentAsset.collection_id == collection_id)
|
||||
|
||||
if date_from:
|
||||
query = query.filter(ContentAsset.created_at >= date_from)
|
||||
|
||||
if date_to:
|
||||
query = query.filter(ContentAsset.created_at <= date_to)
|
||||
|
||||
# Get total count before pagination
|
||||
total_count = query.count()
|
||||
|
||||
# Apply ordering and pagination
|
||||
query = query.order_by(desc(ContentAsset.created_at))
|
||||
# Apply ordering
|
||||
order_column = ContentAsset.created_at
|
||||
if sort_by == "created_at":
|
||||
order_column = ContentAsset.created_at
|
||||
elif sort_by == "updated_at":
|
||||
order_column = ContentAsset.updated_at
|
||||
elif sort_by == "cost":
|
||||
order_column = ContentAsset.cost
|
||||
elif sort_by == "file_size":
|
||||
order_column = ContentAsset.file_size
|
||||
elif sort_by == "title":
|
||||
order_column = ContentAsset.title
|
||||
|
||||
if sort_order.lower() == "asc":
|
||||
query = query.order_by(order_column)
|
||||
else:
|
||||
query = query.order_by(desc(order_column))
|
||||
|
||||
# Apply pagination
|
||||
query = query.limit(limit).offset(offset)
|
||||
|
||||
return query.all(), total_count
|
||||
@@ -319,4 +350,231 @@ class ContentAssetService:
|
||||
"total_cost": 0.0,
|
||||
"favorites_count": 0,
|
||||
}
|
||||
|
||||
# ==================== Collection Management ====================
|
||||
|
||||
def create_collection(
|
||||
self,
|
||||
user_id: str,
|
||||
name: str,
|
||||
description: Optional[str] = None,
|
||||
is_public: bool = False,
|
||||
) -> AssetCollection:
|
||||
"""Create a new asset collection."""
|
||||
try:
|
||||
collection = AssetCollection(
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
description=description,
|
||||
is_public=is_public,
|
||||
)
|
||||
|
||||
self.db.add(collection)
|
||||
self.db.commit()
|
||||
self.db.refresh(collection)
|
||||
|
||||
logger.info(f"Created collection {collection.id} '{name}' for user {user_id}")
|
||||
return collection
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"Error creating collection: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_user_collections(
|
||||
self,
|
||||
user_id: str,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> Tuple[List[AssetCollection], int]:
|
||||
"""Get all collections for a user."""
|
||||
try:
|
||||
query = self.db.query(AssetCollection).filter(
|
||||
AssetCollection.user_id == user_id
|
||||
)
|
||||
|
||||
total_count = query.count()
|
||||
query = query.order_by(desc(AssetCollection.created_at))
|
||||
query = query.limit(limit).offset(offset)
|
||||
|
||||
return query.all(), total_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching collections: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_collection_by_id(self, collection_id: int, user_id: str) -> Optional[AssetCollection]:
|
||||
"""Get a specific collection by ID."""
|
||||
try:
|
||||
return self.db.query(AssetCollection).filter(
|
||||
and_(
|
||||
AssetCollection.id == collection_id,
|
||||
AssetCollection.user_id == user_id
|
||||
)
|
||||
).first()
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching collection {collection_id}: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
def update_collection(
|
||||
self,
|
||||
collection_id: int,
|
||||
user_id: str,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
is_public: Optional[bool] = None,
|
||||
cover_asset_id: Optional[int] = None,
|
||||
) -> Optional[AssetCollection]:
|
||||
"""Update collection metadata."""
|
||||
try:
|
||||
collection = self.get_collection_by_id(collection_id, user_id)
|
||||
if not collection:
|
||||
return None
|
||||
|
||||
if name is not None:
|
||||
collection.name = name
|
||||
if description is not None:
|
||||
collection.description = description
|
||||
if is_public is not None:
|
||||
collection.is_public = is_public
|
||||
if cover_asset_id is not None:
|
||||
# Verify asset belongs to user
|
||||
asset = self.get_asset_by_id(cover_asset_id, user_id)
|
||||
if asset:
|
||||
collection.cover_asset_id = cover_asset_id
|
||||
else:
|
||||
collection.cover_asset_id = None
|
||||
|
||||
collection.updated_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
self.db.refresh(collection)
|
||||
|
||||
logger.info(f"Updated collection {collection_id} for user {user_id}")
|
||||
return collection
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"Error updating collection: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
def delete_collection(self, collection_id: int, user_id: str) -> bool:
|
||||
"""Delete a collection (assets are not deleted, just removed from collection)."""
|
||||
try:
|
||||
collection = self.get_collection_by_id(collection_id, user_id)
|
||||
if not collection:
|
||||
return False
|
||||
|
||||
# Remove assets from collection before deleting
|
||||
self.db.query(ContentAsset).filter(
|
||||
ContentAsset.collection_id == collection_id
|
||||
).update({ContentAsset.collection_id: None})
|
||||
|
||||
self.db.delete(collection)
|
||||
self.db.commit()
|
||||
|
||||
logger.info(f"Deleted collection {collection_id} for user {user_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"Error deleting collection: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
def add_assets_to_collection(
|
||||
self,
|
||||
collection_id: int,
|
||||
user_id: str,
|
||||
asset_ids: List[int],
|
||||
) -> int:
|
||||
"""Add assets to a collection. Returns number of assets added."""
|
||||
try:
|
||||
collection = self.get_collection_by_id(collection_id, user_id)
|
||||
if not collection:
|
||||
return 0
|
||||
|
||||
# Verify all assets belong to user
|
||||
assets = self.db.query(ContentAsset).filter(
|
||||
and_(
|
||||
ContentAsset.id.in_(asset_ids),
|
||||
ContentAsset.user_id == user_id
|
||||
)
|
||||
).all()
|
||||
|
||||
count = 0
|
||||
for asset in assets:
|
||||
asset.collection_id = collection_id
|
||||
count += 1
|
||||
|
||||
collection.updated_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
|
||||
logger.info(f"Added {count} assets to collection {collection_id}")
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"Error adding assets to collection: {str(e)}", exc_info=True)
|
||||
return 0
|
||||
|
||||
def remove_assets_from_collection(
|
||||
self,
|
||||
collection_id: int,
|
||||
user_id: str,
|
||||
asset_ids: List[int],
|
||||
) -> int:
|
||||
"""Remove assets from a collection. Returns number of assets removed."""
|
||||
try:
|
||||
collection = self.get_collection_by_id(collection_id, user_id)
|
||||
if not collection:
|
||||
return 0
|
||||
|
||||
# Remove assets from collection
|
||||
count = self.db.query(ContentAsset).filter(
|
||||
and_(
|
||||
ContentAsset.id.in_(asset_ids),
|
||||
ContentAsset.collection_id == collection_id,
|
||||
ContentAsset.user_id == user_id
|
||||
)
|
||||
).update({ContentAsset.collection_id: None})
|
||||
|
||||
collection.updated_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
|
||||
logger.info(f"Removed {count} assets from collection {collection_id}")
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"Error removing assets from collection: {str(e)}", exc_info=True)
|
||||
return 0
|
||||
|
||||
def get_collection_assets(
|
||||
self,
|
||||
collection_id: int,
|
||||
user_id: str,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> Tuple[List[ContentAsset], int]:
|
||||
"""Get all assets in a collection."""
|
||||
try:
|
||||
collection = self.get_collection_by_id(collection_id, user_id)
|
||||
if not collection:
|
||||
return [], 0
|
||||
|
||||
query = self.db.query(ContentAsset).filter(
|
||||
and_(
|
||||
ContentAsset.collection_id == collection_id,
|
||||
ContentAsset.user_id == user_id
|
||||
)
|
||||
)
|
||||
|
||||
total_count = query.count()
|
||||
query = query.order_by(desc(ContentAsset.created_at))
|
||||
query = query.limit(limit).offset(offset)
|
||||
|
||||
return query.all(), total_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching collection assets: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ from .edit_service import EditStudioService, EditStudioRequest
|
||||
from .upscale_service import UpscaleStudioService, UpscaleStudioRequest
|
||||
from .control_service import ControlStudioService, ControlStudioRequest
|
||||
from .social_optimizer_service import SocialOptimizerService, SocialOptimizerRequest
|
||||
from .compression_service import ImageCompressionService, CompressionRequest, CompressionResult
|
||||
from .format_converter_service import ImageFormatConverterService, FormatConversionRequest, FormatConversionResult
|
||||
from .transform_service import (
|
||||
TransformStudioService,
|
||||
TransformImageToVideoRequest,
|
||||
@@ -25,6 +27,12 @@ __all__ = [
|
||||
"ControlStudioRequest",
|
||||
"SocialOptimizerService",
|
||||
"SocialOptimizerRequest",
|
||||
"ImageCompressionService",
|
||||
"CompressionRequest",
|
||||
"CompressionResult",
|
||||
"ImageFormatConverterService",
|
||||
"FormatConversionRequest",
|
||||
"FormatConversionResult",
|
||||
"TransformStudioService",
|
||||
"TransformImageToVideoRequest",
|
||||
"TalkingAvatarRequest",
|
||||
|
||||
367
backend/services/image_studio/compression_service.py
Normal file
367
backend/services/image_studio/compression_service.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""Image Compression Service for optimizing image file sizes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Literal
|
||||
|
||||
from PIL import Image, ExifTags
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
logger = get_service_logger("image_studio.compression")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompressionRequest:
|
||||
"""Request model for image compression."""
|
||||
image_base64: str
|
||||
quality: int = 85 # 1-100, where 100 is best quality
|
||||
format: str = "jpeg" # jpeg, png, webp, avif
|
||||
target_size_kb: Optional[int] = None # Target file size in KB
|
||||
strip_metadata: bool = True
|
||||
progressive: bool = True # Progressive JPEG
|
||||
optimize: bool = True # Optimize encoding
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompressionResult:
|
||||
"""Result of compression operation."""
|
||||
success: bool
|
||||
image_base64: str
|
||||
original_size_kb: float
|
||||
compressed_size_kb: float
|
||||
compression_ratio: float
|
||||
format: str
|
||||
width: int
|
||||
height: int
|
||||
quality_used: int
|
||||
metadata_stripped: bool
|
||||
|
||||
|
||||
class ImageCompressionService:
|
||||
"""Service for image compression and optimization."""
|
||||
|
||||
SUPPORTED_FORMATS = ["jpeg", "jpg", "png", "webp"]
|
||||
|
||||
# Format-specific options
|
||||
FORMAT_OPTIONS = {
|
||||
"jpeg": {"quality": (1, 100), "progressive": True, "optimize": True},
|
||||
"jpg": {"quality": (1, 100), "progressive": True, "optimize": True},
|
||||
"png": {"compress_level": (0, 9), "optimize": True},
|
||||
"webp": {"quality": (1, 100), "lossless": False},
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
logger.info("[Compression] ImageCompressionService initialized")
|
||||
|
||||
def _decode_image(self, image_base64: str) -> tuple[Image.Image, int]:
|
||||
"""Decode base64 image and return PIL Image and original size."""
|
||||
# Handle data URL format
|
||||
if "," in image_base64:
|
||||
image_base64 = image_base64.split(",", 1)[1]
|
||||
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
original_size = len(image_bytes)
|
||||
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
return image, original_size
|
||||
|
||||
def _strip_exif(self, image: Image.Image) -> Image.Image:
|
||||
"""Remove EXIF metadata from image."""
|
||||
# Create a new image without EXIF data
|
||||
data = list(image.getdata())
|
||||
image_without_exif = Image.new(image.mode, image.size)
|
||||
image_without_exif.putdata(data)
|
||||
return image_without_exif
|
||||
|
||||
def _compress_to_target_size(
|
||||
self,
|
||||
image: Image.Image,
|
||||
target_size_kb: int,
|
||||
format: str,
|
||||
min_quality: int = 10,
|
||||
max_quality: int = 95,
|
||||
) -> tuple[bytes, int]:
|
||||
"""Compress image to target file size using binary search."""
|
||||
target_bytes = target_size_kb * 1024
|
||||
|
||||
low, high = min_quality, max_quality
|
||||
best_result = None
|
||||
best_quality = max_quality
|
||||
|
||||
while low <= high:
|
||||
mid = (low + high) // 2
|
||||
compressed = self._compress_image(image, format, mid, True, True)
|
||||
|
||||
if len(compressed) <= target_bytes:
|
||||
best_result = compressed
|
||||
best_quality = mid
|
||||
low = mid + 1 # Try higher quality
|
||||
else:
|
||||
high = mid - 1 # Try lower quality
|
||||
|
||||
if best_result is None:
|
||||
# Even minimum quality exceeds target, return min quality result
|
||||
best_result = self._compress_image(image, format, min_quality, True, True)
|
||||
best_quality = min_quality
|
||||
|
||||
return best_result, best_quality
|
||||
|
||||
def _compress_image(
|
||||
self,
|
||||
image: Image.Image,
|
||||
format: str,
|
||||
quality: int,
|
||||
progressive: bool,
|
||||
optimize: bool,
|
||||
) -> bytes:
|
||||
"""Compress image with given settings."""
|
||||
buffer = io.BytesIO()
|
||||
|
||||
# Handle format-specific options
|
||||
save_kwargs: Dict[str, Any] = {}
|
||||
|
||||
format_lower = format.lower()
|
||||
if format_lower in ["jpeg", "jpg"]:
|
||||
# Convert to RGB if necessary (JPEG doesn't support alpha)
|
||||
if image.mode in ("RGBA", "P"):
|
||||
image = image.convert("RGB")
|
||||
save_kwargs["format"] = "JPEG"
|
||||
save_kwargs["quality"] = quality
|
||||
save_kwargs["optimize"] = optimize
|
||||
if progressive:
|
||||
save_kwargs["progressive"] = True
|
||||
elif format_lower == "png":
|
||||
save_kwargs["format"] = "PNG"
|
||||
save_kwargs["optimize"] = optimize
|
||||
# PNG uses compress_level (0-9) instead of quality
|
||||
compress_level = max(0, min(9, (100 - quality) // 11))
|
||||
save_kwargs["compress_level"] = compress_level
|
||||
elif format_lower == "webp":
|
||||
save_kwargs["format"] = "WEBP"
|
||||
save_kwargs["quality"] = quality
|
||||
save_kwargs["method"] = 6 # Best compression
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {format}")
|
||||
|
||||
image.save(buffer, **save_kwargs)
|
||||
return buffer.getvalue()
|
||||
|
||||
async def compress(
|
||||
self,
|
||||
request: CompressionRequest,
|
||||
user_id: Optional[str] = None,
|
||||
) -> CompressionResult:
|
||||
"""Compress an image with specified settings."""
|
||||
logger.info(f"[Compression] Processing compression request for user: {user_id}")
|
||||
|
||||
try:
|
||||
# Decode image
|
||||
image, original_size = self._decode_image(request.image_base64)
|
||||
original_size_kb = original_size / 1024
|
||||
|
||||
logger.info(f"[Compression] Original size: {original_size_kb:.2f} KB, dimensions: {image.size}")
|
||||
|
||||
# Strip metadata if requested
|
||||
if request.strip_metadata:
|
||||
image = self._strip_exif(image)
|
||||
|
||||
# Validate format
|
||||
format_lower = request.format.lower()
|
||||
if format_lower not in self.SUPPORTED_FORMATS:
|
||||
raise ValueError(f"Unsupported format: {request.format}. Supported: {self.SUPPORTED_FORMATS}")
|
||||
|
||||
# Compress to target size or with quality setting
|
||||
if request.target_size_kb:
|
||||
compressed_bytes, quality_used = self._compress_to_target_size(
|
||||
image,
|
||||
request.target_size_kb,
|
||||
format_lower,
|
||||
)
|
||||
else:
|
||||
compressed_bytes = self._compress_image(
|
||||
image,
|
||||
format_lower,
|
||||
request.quality,
|
||||
request.progressive,
|
||||
request.optimize,
|
||||
)
|
||||
quality_used = request.quality
|
||||
|
||||
compressed_size_kb = len(compressed_bytes) / 1024
|
||||
compression_ratio = (1 - compressed_size_kb / original_size_kb) * 100 if original_size_kb > 0 else 0
|
||||
|
||||
# Encode result
|
||||
mime_type = "image/jpeg" if format_lower in ["jpeg", "jpg"] else f"image/{format_lower}"
|
||||
result_base64 = f"data:{mime_type};base64,{base64.b64encode(compressed_bytes).decode()}"
|
||||
|
||||
logger.info(f"[Compression] Compressed: {original_size_kb:.2f}KB → {compressed_size_kb:.2f}KB ({compression_ratio:.1f}% reduction)")
|
||||
|
||||
return CompressionResult(
|
||||
success=True,
|
||||
image_base64=result_base64,
|
||||
original_size_kb=round(original_size_kb, 2),
|
||||
compressed_size_kb=round(compressed_size_kb, 2),
|
||||
compression_ratio=round(compression_ratio, 2),
|
||||
format=format_lower,
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
quality_used=quality_used,
|
||||
metadata_stripped=request.strip_metadata,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Compression] Failed to compress image: {e}")
|
||||
raise
|
||||
|
||||
async def compress_batch(
|
||||
self,
|
||||
requests: List[CompressionRequest],
|
||||
user_id: Optional[str] = None,
|
||||
) -> List[CompressionResult]:
|
||||
"""Compress multiple images with same or individual settings."""
|
||||
logger.info(f"[Compression] Processing batch of {len(requests)} images for user: {user_id}")
|
||||
|
||||
results = []
|
||||
for i, request in enumerate(requests):
|
||||
try:
|
||||
result = await self.compress(request, user_id)
|
||||
results.append(result)
|
||||
logger.info(f"[Compression] Batch item {i+1}/{len(requests)} complete")
|
||||
except Exception as e:
|
||||
logger.error(f"[Compression] Batch item {i+1} failed: {e}")
|
||||
# Return partial success
|
||||
results.append(CompressionResult(
|
||||
success=False,
|
||||
image_base64="",
|
||||
original_size_kb=0,
|
||||
compressed_size_kb=0,
|
||||
compression_ratio=0,
|
||||
format="",
|
||||
width=0,
|
||||
height=0,
|
||||
quality_used=0,
|
||||
metadata_stripped=False,
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
async def estimate_compression(
|
||||
self,
|
||||
image_base64: str,
|
||||
format: str = "jpeg",
|
||||
quality: int = 85,
|
||||
) -> Dict[str, Any]:
|
||||
"""Estimate compression results without actually compressing."""
|
||||
try:
|
||||
image, original_size = self._decode_image(image_base64)
|
||||
original_size_kb = original_size / 1024
|
||||
|
||||
# Quick estimation based on format and quality
|
||||
if format.lower() in ["jpeg", "jpg"]:
|
||||
# JPEG compression ratio estimate
|
||||
estimated_ratio = 0.1 + (quality / 100) * 0.4 # 10-50% of original
|
||||
elif format.lower() == "webp":
|
||||
# WebP is typically 25-34% smaller than JPEG
|
||||
estimated_ratio = 0.08 + (quality / 100) * 0.35
|
||||
else: # PNG
|
||||
estimated_ratio = 0.7 + (quality / 100) * 0.2 # PNG is less compressible
|
||||
|
||||
estimated_size_kb = original_size_kb * estimated_ratio
|
||||
|
||||
return {
|
||||
"original_size_kb": round(original_size_kb, 2),
|
||||
"estimated_size_kb": round(estimated_size_kb, 2),
|
||||
"estimated_reduction_percent": round((1 - estimated_ratio) * 100, 1),
|
||||
"width": image.width,
|
||||
"height": image.height,
|
||||
"format": format.lower(),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[Compression] Estimation failed: {e}")
|
||||
raise
|
||||
|
||||
def get_supported_formats(self) -> List[Dict[str, Any]]:
|
||||
"""Get list of supported compression formats with details."""
|
||||
return [
|
||||
{
|
||||
"id": "jpeg",
|
||||
"name": "JPEG",
|
||||
"extension": ".jpg",
|
||||
"description": "Best for photos. Lossy compression with excellent size reduction.",
|
||||
"supports_transparency": False,
|
||||
"quality_range": [1, 100],
|
||||
"recommended_quality": 85,
|
||||
"use_cases": ["Photos", "Blog images", "Email", "Social media"],
|
||||
},
|
||||
{
|
||||
"id": "png",
|
||||
"name": "PNG",
|
||||
"extension": ".png",
|
||||
"description": "Best for graphics with transparency. Lossless compression.",
|
||||
"supports_transparency": True,
|
||||
"quality_range": [1, 100],
|
||||
"recommended_quality": 90,
|
||||
"use_cases": ["Logos", "Icons", "Graphics", "Screenshots"],
|
||||
},
|
||||
{
|
||||
"id": "webp",
|
||||
"name": "WebP",
|
||||
"extension": ".webp",
|
||||
"description": "Modern format with excellent compression. 25-34% smaller than JPEG.",
|
||||
"supports_transparency": True,
|
||||
"quality_range": [1, 100],
|
||||
"recommended_quality": 80,
|
||||
"use_cases": ["Web images", "Fast loading", "Modern browsers"],
|
||||
},
|
||||
]
|
||||
|
||||
def get_presets(self) -> List[Dict[str, Any]]:
|
||||
"""Get compression presets for common use cases."""
|
||||
return [
|
||||
{
|
||||
"id": "web",
|
||||
"name": "Web Optimized",
|
||||
"description": "Balanced quality and size for web pages",
|
||||
"format": "webp",
|
||||
"quality": 80,
|
||||
"strip_metadata": True,
|
||||
},
|
||||
{
|
||||
"id": "email",
|
||||
"name": "Email Friendly",
|
||||
"description": "Small file size for email attachments (<200KB target)",
|
||||
"format": "jpeg",
|
||||
"quality": 70,
|
||||
"target_size_kb": 200,
|
||||
"strip_metadata": True,
|
||||
},
|
||||
{
|
||||
"id": "social",
|
||||
"name": "Social Media",
|
||||
"description": "Optimized for social platforms",
|
||||
"format": "jpeg",
|
||||
"quality": 85,
|
||||
"strip_metadata": True,
|
||||
},
|
||||
{
|
||||
"id": "high_quality",
|
||||
"name": "High Quality",
|
||||
"description": "Minimal compression for quality-critical images",
|
||||
"format": "png",
|
||||
"quality": 95,
|
||||
"strip_metadata": False,
|
||||
},
|
||||
{
|
||||
"id": "maximum",
|
||||
"name": "Maximum Compression",
|
||||
"description": "Smallest possible file size",
|
||||
"format": "webp",
|
||||
"quality": 60,
|
||||
"strip_metadata": True,
|
||||
},
|
||||
]
|
||||
@@ -1,17 +1,10 @@
|
||||
"""Create Studio service for AI-powered image generation."""
|
||||
|
||||
import os
|
||||
from typing import Optional, Dict, Any, List, Literal
|
||||
from dataclasses import dataclass
|
||||
|
||||
from services.llm_providers.image_generation import (
|
||||
ImageGenerationOptions,
|
||||
ImageGenerationResult,
|
||||
HuggingFaceImageProvider,
|
||||
GeminiImageProvider,
|
||||
StabilityImageProvider,
|
||||
WaveSpeedImageProvider,
|
||||
)
|
||||
from services.llm_providers.main_image_generation import generate_image
|
||||
from services.llm_providers.image_generation import ImageGenerationResult
|
||||
from .templates import TemplateManager, ImageTemplate, Platform, TemplateCategory
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
@@ -75,29 +68,8 @@ class CreateStudioService:
|
||||
self.template_manager = TemplateManager()
|
||||
logger.info("[Create Studio] Initialized with template manager")
|
||||
|
||||
def _get_provider_instance(self, provider_name: str, api_key: Optional[str] = None):
|
||||
"""Get provider instance by name.
|
||||
|
||||
Args:
|
||||
provider_name: Name of the provider
|
||||
api_key: Optional API key (uses env vars if not provided)
|
||||
|
||||
Returns:
|
||||
Provider instance
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is not supported
|
||||
"""
|
||||
if provider_name == "stability":
|
||||
return StabilityImageProvider(api_key=api_key or os.getenv("STABILITY_API_KEY"))
|
||||
elif provider_name == "wavespeed":
|
||||
return WaveSpeedImageProvider(api_key=api_key or os.getenv("WAVESPEED_API_KEY"))
|
||||
elif provider_name == "huggingface":
|
||||
return HuggingFaceImageProvider(api_token=api_key or os.getenv("HF_API_KEY"))
|
||||
elif provider_name == "gemini":
|
||||
return GeminiImageProvider(api_key=api_key or os.getenv("GEMINI_API_KEY"))
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider_name}")
|
||||
# Removed _get_provider_instance() - now using unified entry point
|
||||
# Provider selection is handled by main_image_generation.generate_image()
|
||||
|
||||
def _select_provider_and_model(
|
||||
self,
|
||||
@@ -289,30 +261,17 @@ class CreateStudioService:
|
||||
logger.info("[Create Studio] Starting generation: prompt=%s, template=%s",
|
||||
request.prompt[:100], request.template_id)
|
||||
|
||||
# Pre-flight validation: Check subscription and usage limits
|
||||
if user_id:
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_generation_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
logger.info(f"[Create Studio] 🛂 Running pre-flight validation for user {user_id}")
|
||||
validate_image_generation_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id,
|
||||
num_images=request.num_variations
|
||||
)
|
||||
logger.info(f"[Create Studio] ✅ Pre-flight validation passed - proceeding with generation")
|
||||
except HTTPException as http_ex:
|
||||
logger.error(f"[Create Studio] ❌ Pre-flight validation failed - blocking generation")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
logger.warning("[Create Studio] ⚠️ No user_id provided - skipping pre-flight validation")
|
||||
# Pre-flight validation: Reuse unified helper
|
||||
# Note: Validation for num_variations will be done per-image in generate_image()
|
||||
# We validate once upfront to fail fast if user has no credits
|
||||
if user_id and request.num_variations > 0:
|
||||
from services.llm_providers.main_image_generation import _validate_image_operation
|
||||
_validate_image_operation(
|
||||
user_id=user_id,
|
||||
operation_type="create-studio-generation",
|
||||
num_operations=request.num_variations,
|
||||
log_prefix="[Create Studio]"
|
||||
)
|
||||
|
||||
# Load template if specified
|
||||
template = None
|
||||
@@ -337,36 +296,37 @@ class CreateStudioService:
|
||||
# Select provider and model
|
||||
provider_name, model = self._select_provider_and_model(request, template)
|
||||
|
||||
# Get provider instance
|
||||
try:
|
||||
provider = self._get_provider_instance(provider_name)
|
||||
except Exception as e:
|
||||
logger.error("[Create Studio] ❌ Failed to initialize provider %s: %s",
|
||||
provider_name, str(e))
|
||||
raise RuntimeError(f"Provider initialization failed: {str(e)}")
|
||||
|
||||
# Generate images
|
||||
# Generate images using unified entry point
|
||||
# This ensures consistent validation, tracking, and error handling
|
||||
results = []
|
||||
for i in range(request.num_variations):
|
||||
logger.info("[Create Studio] Generating variation %d/%d",
|
||||
i + 1, request.num_variations)
|
||||
|
||||
try:
|
||||
# Prepare options
|
||||
options = ImageGenerationOptions(
|
||||
prompt=prompt,
|
||||
negative_prompt=request.negative_prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
guidance_scale=request.guidance_scale,
|
||||
steps=request.steps,
|
||||
seed=request.seed + i if request.seed else None,
|
||||
model=model,
|
||||
extra={"style_preset": request.style_preset} if request.style_preset else {}
|
||||
)
|
||||
# Prepare options for unified entry point
|
||||
options = {
|
||||
"provider": provider_name,
|
||||
"model": model,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"negative_prompt": request.negative_prompt,
|
||||
"guidance_scale": request.guidance_scale,
|
||||
"steps": request.steps,
|
||||
"seed": request.seed + i if request.seed else None,
|
||||
}
|
||||
|
||||
# Generate image
|
||||
result: ImageGenerationResult = provider.generate(options)
|
||||
# Add style preset to extra if specified
|
||||
if request.style_preset:
|
||||
options["extra"] = {"style_preset": request.style_preset}
|
||||
|
||||
# Generate image using unified entry point
|
||||
# This handles validation, provider selection, generation, and tracking automatically
|
||||
result: ImageGenerationResult = generate_image(
|
||||
prompt=prompt,
|
||||
options=options,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
results.append({
|
||||
"image_bytes": result.image_bytes,
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import Any, Dict, Literal, Optional
|
||||
from PIL import Image
|
||||
|
||||
from services.llm_providers.main_image_editing import edit_image as huggingface_edit_image
|
||||
from services.llm_providers.main_image_generation import generate_image_edit
|
||||
from services.stability_service import StabilityAIService
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
@@ -213,6 +214,249 @@ class EditStudioService:
|
||||
def list_operations(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Expose supported operations for UI rendering."""
|
||||
return self.SUPPORTED_OPERATIONS
|
||||
|
||||
def get_available_models(
|
||||
self,
|
||||
operation: Optional[str] = None,
|
||||
tier: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Get available WaveSpeed editing models.
|
||||
|
||||
Args:
|
||||
operation: Filter by operation type (e.g., "general_edit")
|
||||
tier: Filter by tier ("budget", "mid", "premium")
|
||||
|
||||
Returns:
|
||||
Dictionary with models and metadata
|
||||
"""
|
||||
from services.llm_providers.image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
|
||||
|
||||
provider = WaveSpeedEditProvider()
|
||||
all_models = provider.get_available_models()
|
||||
|
||||
# Filter by operation if specified
|
||||
if operation:
|
||||
filtered = provider.get_models_by_operation(operation)
|
||||
all_models = {k: v for k, v in all_models.items() if k in filtered}
|
||||
|
||||
# Filter by tier if specified
|
||||
if tier:
|
||||
filtered = provider.get_models_by_tier(tier)
|
||||
all_models = {k: v for k, v in all_models.items() if k in filtered}
|
||||
|
||||
# Format for API response
|
||||
models_list = []
|
||||
for model_id, model_info in all_models.items():
|
||||
models_list.append({
|
||||
"id": model_id,
|
||||
"name": model_info.get("name", model_id),
|
||||
"description": model_info.get("description", ""),
|
||||
"cost": model_info.get("cost", 0.02),
|
||||
"cost_8k": model_info.get("cost_8k"), # Optional
|
||||
"tier": model_info.get("tier", "mid"),
|
||||
"max_resolution": model_info.get("max_resolution", [2048, 2048]),
|
||||
"capabilities": model_info.get("capabilities", []),
|
||||
"use_cases": self._get_use_cases_for_model(model_id, model_info),
|
||||
"features": self._get_features_for_model(model_info),
|
||||
"supports_multi_image": model_info.get("supports_multi_image", False),
|
||||
"supports_controlnet": model_info.get("supports_controlnet", False),
|
||||
"languages": model_info.get("languages", ["en"]),
|
||||
})
|
||||
|
||||
return {
|
||||
"models": models_list,
|
||||
"total": len(models_list),
|
||||
}
|
||||
|
||||
def recommend_model(
|
||||
self,
|
||||
operation: str,
|
||||
image_resolution: Optional[Dict[str, int]] = None,
|
||||
user_tier: Optional[str] = None,
|
||||
preferences: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Recommend best model for given operation and context.
|
||||
|
||||
Args:
|
||||
operation: Operation type (e.g., "general_edit")
|
||||
image_resolution: Dict with "width" and "height"
|
||||
user_tier: User subscription tier ("free", "pro", "enterprise")
|
||||
preferences: Dict with "prioritize_cost" or "prioritize_quality"
|
||||
|
||||
Returns:
|
||||
Dictionary with recommended model and alternatives
|
||||
"""
|
||||
from services.llm_providers.image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
|
||||
|
||||
provider = WaveSpeedEditProvider()
|
||||
available_models = provider.get_models_by_operation(operation)
|
||||
|
||||
if not available_models:
|
||||
# Fallback to all models if operation doesn't match
|
||||
available_models = provider.get_available_models()
|
||||
|
||||
# Filter by resolution if provided
|
||||
if image_resolution:
|
||||
width = image_resolution.get("width", 0)
|
||||
height = image_resolution.get("height", 0)
|
||||
max_dimension = max(width, height)
|
||||
|
||||
# Filter models that support this resolution
|
||||
filtered = {}
|
||||
for model_id, model_info in available_models.items():
|
||||
max_res = model_info.get("max_resolution", (2048, 2048))
|
||||
max_supported = max(max_res[0], max_res[1])
|
||||
if max_dimension <= max_supported:
|
||||
filtered[model_id] = model_info
|
||||
available_models = filtered
|
||||
|
||||
if not available_models:
|
||||
# No models match, return first available
|
||||
all_models = provider.get_available_models()
|
||||
if all_models:
|
||||
first_model_id = list(all_models.keys())[0]
|
||||
return {
|
||||
"recommended_model": first_model_id,
|
||||
"reason": "No specific match found, using default model",
|
||||
"alternatives": [],
|
||||
}
|
||||
else:
|
||||
raise ValueError("No models available")
|
||||
|
||||
# Apply preferences
|
||||
prioritize_cost = preferences and preferences.get("prioritize_cost", False)
|
||||
prioritize_quality = preferences and preferences.get("prioritize_quality", False)
|
||||
|
||||
# Score models
|
||||
scored_models = []
|
||||
for model_id, model_info in available_models.items():
|
||||
score = 0
|
||||
cost = model_info.get("cost", 0.02)
|
||||
tier = model_info.get("tier", "mid")
|
||||
max_res = model_info.get("max_resolution", (2048, 2048))
|
||||
max_resolution = max(max_res[0], max_res[1])
|
||||
|
||||
# Cost scoring (lower is better)
|
||||
if prioritize_cost:
|
||||
score += (1.0 / cost) * 100 # Invert cost for scoring
|
||||
else:
|
||||
score += (1.0 / cost) * 50 # Less weight if not prioritizing
|
||||
|
||||
# Quality scoring (higher resolution = better)
|
||||
if prioritize_quality:
|
||||
score += max_resolution / 10 # Higher weight for quality
|
||||
else:
|
||||
score += max_resolution / 20 # Lower weight
|
||||
|
||||
# Tier preference based on user tier
|
||||
if user_tier == "free":
|
||||
if tier == "budget":
|
||||
score += 50
|
||||
elif tier == "mid":
|
||||
score += 20
|
||||
elif user_tier in ["pro", "enterprise"]:
|
||||
if tier == "premium":
|
||||
score += 50
|
||||
elif tier == "mid":
|
||||
score += 30
|
||||
|
||||
scored_models.append((model_id, model_info, score))
|
||||
|
||||
# Sort by score (highest first)
|
||||
scored_models.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
# Get recommended model
|
||||
recommended_id, recommended_info, recommended_score = scored_models[0]
|
||||
|
||||
# Build reason
|
||||
reasons = []
|
||||
if prioritize_cost:
|
||||
reasons.append("Lowest cost option")
|
||||
if prioritize_quality:
|
||||
reasons.append("Best quality")
|
||||
if image_resolution:
|
||||
reasons.append(f"Supports {image_resolution.get('width')}×{image_resolution.get('height')} resolution")
|
||||
if user_tier == "free" and recommended_info.get("tier") == "budget":
|
||||
reasons.append("Budget-friendly for free tier")
|
||||
|
||||
reason = ", ".join(reasons) if reasons else "Best match for your requirements"
|
||||
|
||||
# Get alternatives (top 2-3)
|
||||
alternatives = []
|
||||
for model_id, model_info, score in scored_models[1:4]:
|
||||
alt_reason = f"Alternative: {model_info.get('tier', 'mid').title()} tier"
|
||||
if model_info.get("cost", 0) < recommended_info.get("cost", 0):
|
||||
alt_reason += ", lower cost"
|
||||
elif model_info.get("cost", 0) > recommended_info.get("cost", 0):
|
||||
alt_reason += ", higher quality"
|
||||
alternatives.append({
|
||||
"model_id": model_id,
|
||||
"name": model_info.get("name", model_id),
|
||||
"cost": model_info.get("cost", 0.02),
|
||||
"reason": alt_reason,
|
||||
})
|
||||
|
||||
return {
|
||||
"recommended_model": recommended_id,
|
||||
"reason": reason,
|
||||
"alternatives": alternatives,
|
||||
}
|
||||
|
||||
def _get_use_cases_for_model(self, model_id: str, model_info: Dict[str, Any]) -> list:
|
||||
"""Get use cases for a model based on its capabilities."""
|
||||
use_cases_map = {
|
||||
"general_edit": ["Quick edits", "Style changes", "Background replacement"],
|
||||
"style_transfer": ["Apply artistic styles", "Style transformations"],
|
||||
"text_edit": ["Add text to images", "Edit text in images"],
|
||||
"multi_image": ["Batch editing", "Consistent character work"],
|
||||
"high_res": ["Professional work", "Print materials", "4K/8K editing"],
|
||||
"professional": ["Marketing campaigns", "Brand assets"],
|
||||
"typography": ["Text-heavy edits", "Typography generation"],
|
||||
"portrait_retouching": ["Portrait edits", "Beauty retouching"],
|
||||
"fashion_edit": ["Fashion photography", "Outfit changes"],
|
||||
"product_edit": ["E-commerce", "Product photography"],
|
||||
}
|
||||
|
||||
capabilities = model_info.get("capabilities", [])
|
||||
use_cases = []
|
||||
for cap in capabilities:
|
||||
if cap in use_cases_map:
|
||||
use_cases.extend(use_cases_map[cap])
|
||||
|
||||
# Remove duplicates
|
||||
return list(set(use_cases)) if use_cases else ["General image editing"]
|
||||
|
||||
def _get_features_for_model(self, model_info: Dict[str, Any]) -> list:
|
||||
"""Get feature list for a model."""
|
||||
features = []
|
||||
|
||||
if model_info.get("supports_multi_image"):
|
||||
max_images = model_info.get("api_params", {}).get("max_images", 0)
|
||||
if max_images:
|
||||
features.append(f"Multi-image ({max_images} images)")
|
||||
else:
|
||||
features.append("Multi-image support")
|
||||
|
||||
if model_info.get("supports_controlnet"):
|
||||
features.append("ControlNet support")
|
||||
|
||||
languages = model_info.get("languages", [])
|
||||
if len(languages) > 1:
|
||||
features.append(f"Multilingual ({', '.join(languages)})")
|
||||
elif "multilingual" in languages:
|
||||
features.append("Multilingual support")
|
||||
|
||||
max_res = model_info.get("max_resolution", (2048, 2048))
|
||||
if max(max_res) >= 4096:
|
||||
features.append("4K/8K support")
|
||||
elif max(max_res) >= 2048:
|
||||
features.append("2K support")
|
||||
|
||||
api_params = model_info.get("api_params", {})
|
||||
if api_params.get("supports_guidance_scale"):
|
||||
features.append("Guidance scale control")
|
||||
|
||||
return features if features else ["Standard editing"]
|
||||
|
||||
async def process_edit(
|
||||
self,
|
||||
@@ -221,6 +465,9 @@ class EditStudioService:
|
||||
) -> Dict[str, Any]:
|
||||
"""Process edit request and return normalized response."""
|
||||
|
||||
# Pre-flight validation: Use specific validator for editing operations
|
||||
# Note: Editing uses validate_image_editing_operations (different from generation)
|
||||
# This is intentional as editing may have different subscription limits
|
||||
if user_id:
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
@@ -386,29 +633,109 @@ class EditStudioService:
|
||||
mask_bytes: Optional[bytes],
|
||||
user_id: Optional[str],
|
||||
) -> bytes:
|
||||
"""Execute Hugging Face powered general editing (synchronous API)."""
|
||||
"""Execute general editing - routes to WaveSpeed (unified entry) or HuggingFace (legacy).
|
||||
|
||||
If model is a WaveSpeed model (qwen-edit-plus, nano-banana-pro-edit-ultra, seedream-v4.5-edit),
|
||||
uses unified entry point. Otherwise falls back to HuggingFace for backward compatibility.
|
||||
"""
|
||||
if not request.prompt:
|
||||
raise ValueError("Prompt is required for general edits")
|
||||
|
||||
options = {
|
||||
"provider": request.provider or "huggingface",
|
||||
"model": request.model,
|
||||
"guidance_scale": request.guidance_scale,
|
||||
"steps": request.steps,
|
||||
"seed": request.seed,
|
||||
}
|
||||
|
||||
# huggingface edit is synchronous - run in thread
|
||||
result = await asyncio.to_thread(
|
||||
huggingface_edit_image,
|
||||
image_bytes,
|
||||
request.prompt,
|
||||
options,
|
||||
user_id,
|
||||
mask_bytes, # Optional mask for selective editing
|
||||
# Check if model is a WaveSpeed editing model
|
||||
from services.llm_providers.image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
|
||||
provider = WaveSpeedEditProvider()
|
||||
wavespeed_models = set(provider.get_available_models().keys())
|
||||
|
||||
# Also check if provider is explicitly set to "wavespeed"
|
||||
is_wavespeed = (
|
||||
request.provider == "wavespeed" or
|
||||
(request.model and request.model in wavespeed_models)
|
||||
)
|
||||
|
||||
# Auto-detect: If no model specified and operation is general_edit, recommend one
|
||||
if not request.model and not is_wavespeed and request.operation == "general_edit":
|
||||
# Auto-select recommended model
|
||||
try:
|
||||
# Get image dimensions for recommendation
|
||||
with Image.open(io.BytesIO(image_bytes)) as img:
|
||||
image_resolution = {"width": img.width, "height": img.height}
|
||||
|
||||
recommendation = self.recommend_model(
|
||||
operation=request.operation,
|
||||
image_resolution=image_resolution,
|
||||
preferences={"prioritize_cost": True}, # Default to cost-optimized
|
||||
)
|
||||
recommended_model = recommendation.get("recommended_model")
|
||||
if recommended_model and recommended_model in wavespeed_models:
|
||||
logger.info(f"[Edit Studio] Auto-selected model: {recommended_model} (reason: {recommendation.get('reason')})")
|
||||
request.model = recommended_model
|
||||
is_wavespeed = True
|
||||
except Exception as e:
|
||||
logger.warning(f"[Edit Studio] Auto-detection failed: {e}, falling back to HuggingFace")
|
||||
|
||||
if is_wavespeed:
|
||||
# Use unified entry point for WaveSpeed models
|
||||
logger.info(f"[Edit Studio] Using WaveSpeed unified entry for model={request.model}")
|
||||
|
||||
# Convert image bytes to base64
|
||||
import base64
|
||||
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
|
||||
# Prepare options for unified entry point
|
||||
edit_options = {
|
||||
"mask_base64": None,
|
||||
"negative_prompt": request.negative_prompt,
|
||||
"width": None, # Will be determined from image if needed
|
||||
"height": None,
|
||||
"guidance_scale": request.guidance_scale,
|
||||
"steps": request.steps,
|
||||
"seed": request.seed,
|
||||
}
|
||||
|
||||
# Add mask if provided
|
||||
if mask_bytes:
|
||||
edit_options["mask_base64"] = base64.b64encode(mask_bytes).decode("utf-8")
|
||||
|
||||
# Extract dimensions from image if needed
|
||||
with Image.open(io.BytesIO(image_bytes)) as img:
|
||||
edit_options["width"] = img.width
|
||||
edit_options["height"] = img.height
|
||||
|
||||
# Call unified entry point (synchronous, so run in thread)
|
||||
result = await asyncio.to_thread(
|
||||
generate_image_edit,
|
||||
image_base64=image_base64,
|
||||
prompt=request.prompt,
|
||||
operation=request.operation or "general_edit",
|
||||
model=request.model, # Will auto-select if None
|
||||
options=edit_options,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
return result.image_bytes
|
||||
else:
|
||||
# Fall back to HuggingFace for backward compatibility
|
||||
logger.info("[Edit Studio] Using HuggingFace (legacy) for general edit")
|
||||
|
||||
options = {
|
||||
"provider": request.provider or "huggingface",
|
||||
"model": request.model,
|
||||
"guidance_scale": request.guidance_scale,
|
||||
"steps": request.steps,
|
||||
"seed": request.seed,
|
||||
}
|
||||
|
||||
return result.image_bytes
|
||||
# huggingface edit is synchronous - run in thread
|
||||
result = await asyncio.to_thread(
|
||||
huggingface_edit_image,
|
||||
image_bytes,
|
||||
request.prompt,
|
||||
options,
|
||||
user_id,
|
||||
mask_bytes, # Optional mask for selective editing
|
||||
)
|
||||
|
||||
return result.image_bytes
|
||||
|
||||
@staticmethod
|
||||
def _extract_image_bytes(result: Any) -> bytes:
|
||||
|
||||
266
backend/services/image_studio/face_swap_service.py
Normal file
266
backend/services/image_studio/face_swap_service.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""Face Swap Studio service for AI-powered face swapping."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
from PIL import Image
|
||||
|
||||
from services.llm_providers.main_image_generation import generate_face_swap
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("image_studio.face_swap")
|
||||
|
||||
|
||||
@dataclass
|
||||
class FaceSwapStudioRequest:
|
||||
"""Request model for face swap operations."""
|
||||
base_image_base64: str
|
||||
face_image_base64: str
|
||||
model: Optional[str] = None
|
||||
target_face_index: Optional[int] = None
|
||||
target_gender: Optional[str] = None
|
||||
options: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class FaceSwapService:
|
||||
"""Service for face swap operations."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_available_models(
|
||||
self,
|
||||
tier: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Get available WaveSpeed face swap models.
|
||||
|
||||
Args:
|
||||
tier: Filter by tier ("budget", "mid", "premium")
|
||||
|
||||
Returns:
|
||||
Dictionary with models and metadata
|
||||
"""
|
||||
from services.llm_providers.image_generation.wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
|
||||
|
||||
provider = WaveSpeedFaceSwapProvider()
|
||||
all_models = provider.get_available_models()
|
||||
|
||||
# Filter by tier if specified
|
||||
if tier:
|
||||
filtered = provider.get_models_by_tier(tier)
|
||||
all_models = {k: v for k, v in all_models.items() if k in filtered}
|
||||
|
||||
# Format for API response
|
||||
models_list = []
|
||||
for model_id, model_info in all_models.items():
|
||||
models_list.append({
|
||||
"id": model_id,
|
||||
"name": model_info.get("name", model_id),
|
||||
"description": model_info.get("description", ""),
|
||||
"cost": model_info.get("cost", 0.025),
|
||||
"tier": model_info.get("tier", "mid"),
|
||||
"capabilities": model_info.get("capabilities", []),
|
||||
"use_cases": self._get_use_cases_for_model(model_id, model_info),
|
||||
"features": model_info.get("features", []),
|
||||
"max_faces": model_info.get("max_faces", 1),
|
||||
})
|
||||
|
||||
return {
|
||||
"models": models_list,
|
||||
"total": len(models_list),
|
||||
}
|
||||
|
||||
def recommend_model(
|
||||
self,
|
||||
base_image_resolution: Optional[Dict[str, int]] = None,
|
||||
face_image_resolution: Optional[Dict[str, int]] = None,
|
||||
user_tier: Optional[str] = None,
|
||||
preferences: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Recommend best model for face swap.
|
||||
|
||||
Args:
|
||||
base_image_resolution: Dict with "width" and "height" of base image
|
||||
face_image_resolution: Dict with "width" and "height" of face image
|
||||
user_tier: User subscription tier ("free", "pro", "enterprise")
|
||||
preferences: Dict with "prioritize_cost" or "prioritize_quality"
|
||||
|
||||
Returns:
|
||||
Dictionary with recommended model and alternatives
|
||||
"""
|
||||
from services.llm_providers.image_generation.wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
|
||||
|
||||
provider = WaveSpeedFaceSwapProvider()
|
||||
available_models = provider.get_available_models()
|
||||
|
||||
if not available_models:
|
||||
raise ValueError("No models available")
|
||||
|
||||
# Apply preferences
|
||||
prioritize_cost = preferences and preferences.get("prioritize_cost", False)
|
||||
prioritize_quality = preferences and preferences.get("prioritize_quality", False)
|
||||
|
||||
# Score models
|
||||
scored_models = []
|
||||
for model_id, model_info in available_models.items():
|
||||
score = 0
|
||||
cost = model_info.get("cost", 0.025)
|
||||
tier = model_info.get("tier", "mid")
|
||||
|
||||
# Cost scoring (lower is better)
|
||||
if prioritize_cost:
|
||||
score += (1.0 / cost) * 100
|
||||
else:
|
||||
score += (1.0 / cost) * 50
|
||||
|
||||
# Quality scoring (higher cost = better quality for face swap)
|
||||
if prioritize_quality:
|
||||
score += cost * 20
|
||||
else:
|
||||
score += cost * 10
|
||||
|
||||
# Tier preference based on user tier
|
||||
if user_tier == "free":
|
||||
if tier == "budget":
|
||||
score += 50
|
||||
elif tier == "mid":
|
||||
score += 20
|
||||
elif user_tier in ["pro", "enterprise"]:
|
||||
if tier == "premium":
|
||||
score += 50
|
||||
elif tier == "mid":
|
||||
score += 30
|
||||
|
||||
scored_models.append((model_id, model_info, score))
|
||||
|
||||
# Sort by score (highest first)
|
||||
scored_models.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
# Get recommended model
|
||||
recommended_id, recommended_info, recommended_score = scored_models[0]
|
||||
|
||||
# Build reason
|
||||
reasons = []
|
||||
if prioritize_cost:
|
||||
reasons.append("Lowest cost option")
|
||||
if prioritize_quality:
|
||||
reasons.append("Best quality")
|
||||
if user_tier == "free" and recommended_info.get("tier") == "budget":
|
||||
reasons.append("Budget-friendly for free tier")
|
||||
|
||||
reason = ", ".join(reasons) if reasons else "Best match for your requirements"
|
||||
|
||||
# Get alternatives (top 2-3)
|
||||
alternatives = []
|
||||
for model_id, model_info, score in scored_models[1:4]:
|
||||
alt_reason = f"Alternative: {model_info.get('tier', 'mid').title()} tier"
|
||||
if model_info.get("cost", 0) < recommended_info.get("cost", 0):
|
||||
alt_reason += ", lower cost"
|
||||
elif model_info.get("cost", 0) > recommended_info.get("cost", 0):
|
||||
alt_reason += ", higher quality"
|
||||
alternatives.append({
|
||||
"model_id": model_id,
|
||||
"name": model_info.get("name", model_id),
|
||||
"cost": model_info.get("cost", 0.025),
|
||||
"reason": alt_reason,
|
||||
})
|
||||
|
||||
return {
|
||||
"recommended_model": recommended_id,
|
||||
"reason": reason,
|
||||
"alternatives": alternatives,
|
||||
}
|
||||
|
||||
def _get_use_cases_for_model(self, model_id: str, model_info: Dict[str, Any]) -> list:
|
||||
"""Get use cases for a model based on its capabilities."""
|
||||
use_cases_map = {
|
||||
"face_swap": ["Portrait editing", "Fun swaps", "Social media"],
|
||||
"head_swap": ["Casting and concept design", "Privacy and anonymization", "Photo exploration"],
|
||||
"full_head_replacement": ["Full head replacement", "Hair included", "Casting mockups"],
|
||||
"realistic_blending": ["Professional work", "Marketing", "Entertainment"],
|
||||
"multi_face": ["Group photos", "Family photos", "Team photos", "Creative projects", "Content creation"],
|
||||
"face_enhancement": ["High-quality results", "Professional work", "Marketing campaigns"],
|
||||
"identity_preservation": ["Character consistency", "Brand identity"],
|
||||
}
|
||||
|
||||
capabilities = model_info.get("capabilities", [])
|
||||
use_cases = []
|
||||
for cap in capabilities:
|
||||
if cap in use_cases_map:
|
||||
use_cases.extend(use_cases_map[cap])
|
||||
|
||||
return list(set(use_cases)) if use_cases else ["General face swapping"]
|
||||
|
||||
async def process_face_swap(
|
||||
self,
|
||||
request: FaceSwapStudioRequest,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Process face swap request.
|
||||
|
||||
Args:
|
||||
request: Face swap request
|
||||
user_id: User ID for tracking
|
||||
|
||||
Returns:
|
||||
Dictionary with result image and metadata
|
||||
"""
|
||||
# Auto-detect model if not specified
|
||||
selected_model = request.model
|
||||
if not selected_model:
|
||||
try:
|
||||
# Get image dimensions for recommendation
|
||||
base_img = Image.open(io.BytesIO(base64.b64decode(request.base_image_base64.split(",", 1)[1] if "," in request.base_image_base64 else request.base_image_base64)))
|
||||
face_img = Image.open(io.BytesIO(base64.b64decode(request.face_image_base64.split(",", 1)[1] if "," in request.face_image_base64 else request.face_image_base64)))
|
||||
|
||||
base_resolution = {"width": base_img.width, "height": base_img.height}
|
||||
face_resolution = {"width": face_img.width, "height": face_img.height}
|
||||
|
||||
recommendation = self.recommend_model(
|
||||
base_image_resolution=base_resolution,
|
||||
face_image_resolution=face_resolution,
|
||||
preferences={"prioritize_cost": True},
|
||||
)
|
||||
selected_model = recommendation.get("recommended_model")
|
||||
logger.info(f"[Face Swap] Auto-selected model: {selected_model} (reason: {recommendation.get('reason')})")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Face Swap] Auto-detection failed: {e}, using default model")
|
||||
# Use first available model as fallback
|
||||
from services.llm_providers.image_generation.wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
|
||||
provider = WaveSpeedFaceSwapProvider()
|
||||
all_models = provider.get_available_models()
|
||||
if all_models:
|
||||
selected_model = list(all_models.keys())[0]
|
||||
|
||||
# Prepare options
|
||||
options = request.options or {}
|
||||
if request.target_face_index is not None:
|
||||
options["target_face_index"] = request.target_face_index
|
||||
if request.target_gender:
|
||||
options["target_gender"] = request.target_gender
|
||||
|
||||
# Call unified entry point
|
||||
result = generate_face_swap(
|
||||
base_image_base64=request.base_image_base64,
|
||||
face_image_base64=request.face_image_base64,
|
||||
model=selected_model,
|
||||
options=options,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Convert result to base64
|
||||
result_base64 = base64.b64encode(result.image_bytes).decode("utf-8")
|
||||
result_data_url = f"data:image/png;base64,{result_base64}"
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"image_base64": result_data_url,
|
||||
"width": result.width,
|
||||
"height": result.height,
|
||||
"provider": result.provider,
|
||||
"model": result.model,
|
||||
"metadata": result.metadata or {},
|
||||
}
|
||||
403
backend/services/image_studio/format_converter_service.py
Normal file
403
backend/services/image_studio/format_converter_service.py
Normal file
@@ -0,0 +1,403 @@
|
||||
"""Image Format Converter Service for converting between image formats."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from PIL import Image, ImageCms
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
logger = get_service_logger("image_studio.format_converter")
|
||||
|
||||
|
||||
@dataclass
|
||||
class FormatConversionRequest:
|
||||
"""Request model for format conversion."""
|
||||
image_base64: str
|
||||
target_format: str # png, jpeg, jpg, webp, gif, bmp, tiff
|
||||
preserve_transparency: bool = True
|
||||
quality: Optional[int] = None # For lossy formats (1-100)
|
||||
color_space: Optional[str] = None # sRGB, Adobe RGB, etc.
|
||||
strip_metadata: bool = False # Keep metadata by default for conversion
|
||||
optimize: bool = True
|
||||
progressive: bool = True # For JPEG
|
||||
|
||||
|
||||
@dataclass
|
||||
class FormatConversionResult:
|
||||
"""Result of format conversion."""
|
||||
success: bool
|
||||
image_base64: str
|
||||
original_format: str
|
||||
target_format: str
|
||||
original_size_kb: float
|
||||
converted_size_kb: float
|
||||
width: int
|
||||
height: int
|
||||
transparency_preserved: bool
|
||||
metadata_preserved: bool
|
||||
color_space: Optional[str] = None
|
||||
|
||||
|
||||
class ImageFormatConverterService:
|
||||
"""Service for converting images between formats."""
|
||||
|
||||
SUPPORTED_FORMATS = {
|
||||
"png": {
|
||||
"name": "PNG",
|
||||
"description": "Lossless format with transparency support",
|
||||
"supports_transparency": True,
|
||||
"supports_lossy": False,
|
||||
"mime_type": "image/png",
|
||||
},
|
||||
"jpeg": {
|
||||
"name": "JPEG",
|
||||
"description": "Lossy format, best for photos",
|
||||
"supports_transparency": False,
|
||||
"supports_lossy": True,
|
||||
"mime_type": "image/jpeg",
|
||||
},
|
||||
"jpg": {
|
||||
"name": "JPEG",
|
||||
"description": "Lossy format, best for photos",
|
||||
"supports_transparency": False,
|
||||
"supports_lossy": True,
|
||||
"mime_type": "image/jpeg",
|
||||
},
|
||||
"webp": {
|
||||
"name": "WebP",
|
||||
"description": "Modern format with excellent compression",
|
||||
"supports_transparency": True,
|
||||
"supports_lossy": True,
|
||||
"mime_type": "image/webp",
|
||||
},
|
||||
"gif": {
|
||||
"name": "GIF",
|
||||
"description": "Supports animation and transparency",
|
||||
"supports_transparency": True,
|
||||
"supports_lossy": False,
|
||||
"mime_type": "image/gif",
|
||||
},
|
||||
"bmp": {
|
||||
"name": "BMP",
|
||||
"description": "Uncompressed bitmap format",
|
||||
"supports_transparency": False,
|
||||
"supports_lossy": False,
|
||||
"mime_type": "image/bmp",
|
||||
},
|
||||
"tiff": {
|
||||
"name": "TIFF",
|
||||
"description": "High-quality format for print",
|
||||
"supports_transparency": True,
|
||||
"supports_lossy": False,
|
||||
"mime_type": "image/tiff",
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
logger.info("[Format Converter] ImageFormatConverterService initialized")
|
||||
|
||||
def _decode_image(self, image_base64: str) -> tuple[Image.Image, int, str]:
|
||||
"""Decode base64 image and return PIL Image, size, and format."""
|
||||
# Handle data URL format
|
||||
if "," in image_base64:
|
||||
image_base64 = image_base64.split(",", 1)[1]
|
||||
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
original_size = len(image_bytes)
|
||||
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
original_format = image.format.lower() if image.format else "unknown"
|
||||
|
||||
return image, original_size, original_format
|
||||
|
||||
def _strip_exif(self, image: Image.Image) -> Image.Image:
|
||||
"""Remove EXIF metadata from image."""
|
||||
data = list(image.getdata())
|
||||
image_without_exif = Image.new(image.mode, image.size)
|
||||
image_without_exif.putdata(data)
|
||||
return image_without_exif
|
||||
|
||||
def _convert_color_space(
|
||||
self,
|
||||
image: Image.Image,
|
||||
target_color_space: str,
|
||||
) -> Image.Image:
|
||||
"""Convert image color space."""
|
||||
try:
|
||||
# Get current color space
|
||||
if hasattr(image, 'info') and 'icc_profile' in image.info:
|
||||
# Image has ICC profile
|
||||
try:
|
||||
src_profile = ImageCms.ImageCmsProfile(io.BytesIO(image.info['icc_profile']))
|
||||
if target_color_space.lower() == "srgb":
|
||||
dst_profile = ImageCms.createProfile("sRGB")
|
||||
elif target_color_space.lower() == "adobe rgb":
|
||||
dst_profile = ImageCms.createProfile("Adobe RGB")
|
||||
else:
|
||||
return image # Unknown color space
|
||||
|
||||
transform = ImageCms.ImageCmsTransform(src_profile, dst_profile, image.mode, image.mode)
|
||||
image = ImageCms.applyTransform(image, transform)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Format Converter] Color space conversion failed: {e}")
|
||||
else:
|
||||
# No ICC profile, assume sRGB
|
||||
logger.info("[Format Converter] No ICC profile found, assuming sRGB")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Format Converter] Color space conversion error: {e}")
|
||||
|
||||
return image
|
||||
|
||||
def _convert_image(
|
||||
self,
|
||||
image: Image.Image,
|
||||
target_format: str,
|
||||
quality: Optional[int],
|
||||
preserve_transparency: bool,
|
||||
optimize: bool,
|
||||
progressive: bool,
|
||||
) -> bytes:
|
||||
"""Convert image to target format."""
|
||||
buffer = io.BytesIO()
|
||||
format_lower = target_format.lower()
|
||||
|
||||
# Handle format-specific conversions
|
||||
save_kwargs: Dict[str, Any] = {}
|
||||
|
||||
# Check if source has transparency and target doesn't support it
|
||||
has_transparency = image.mode in ("RGBA", "LA", "P") and (
|
||||
"transparency" in image.info or image.mode == "RGBA"
|
||||
)
|
||||
|
||||
if format_lower in ["jpeg", "jpg"]:
|
||||
# JPEG doesn't support transparency
|
||||
if has_transparency and preserve_transparency:
|
||||
# Convert to RGB, losing transparency
|
||||
if image.mode in ("RGBA", "LA"):
|
||||
# Create white background
|
||||
rgb_image = Image.new("RGB", image.size, (255, 255, 255))
|
||||
if image.mode == "RGBA":
|
||||
rgb_image.paste(image, mask=image.split()[3]) # Use alpha channel as mask
|
||||
else:
|
||||
rgb_image.paste(image)
|
||||
image = rgb_image
|
||||
elif image.mode == "P":
|
||||
image = image.convert("RGB")
|
||||
else:
|
||||
image = image.convert("RGB")
|
||||
|
||||
save_kwargs["format"] = "JPEG"
|
||||
if quality:
|
||||
save_kwargs["quality"] = quality
|
||||
else:
|
||||
save_kwargs["quality"] = 95 # Default high quality
|
||||
save_kwargs["optimize"] = optimize
|
||||
if progressive:
|
||||
save_kwargs["progressive"] = True
|
||||
|
||||
elif format_lower == "png":
|
||||
save_kwargs["format"] = "PNG"
|
||||
save_kwargs["optimize"] = optimize
|
||||
# PNG compression level (0-9)
|
||||
if quality:
|
||||
compress_level = max(0, min(9, (100 - quality) // 11))
|
||||
save_kwargs["compress_level"] = compress_level
|
||||
else:
|
||||
save_kwargs["compress_level"] = 6 # Default
|
||||
|
||||
elif format_lower == "webp":
|
||||
save_kwargs["format"] = "WEBP"
|
||||
if quality:
|
||||
save_kwargs["quality"] = quality
|
||||
else:
|
||||
save_kwargs["quality"] = 80 # Default
|
||||
save_kwargs["method"] = 6 # Best compression
|
||||
if preserve_transparency and has_transparency:
|
||||
# WebP supports transparency
|
||||
if image.mode not in ("RGBA", "LA"):
|
||||
image = image.convert("RGBA")
|
||||
|
||||
elif format_lower == "gif":
|
||||
save_kwargs["format"] = "GIF"
|
||||
# GIF conversion
|
||||
if image.mode != "P":
|
||||
# Convert to palette mode for GIF
|
||||
image = image.convert("P", palette=Image.ADAPTIVE)
|
||||
save_kwargs["optimize"] = optimize
|
||||
if preserve_transparency and has_transparency:
|
||||
save_kwargs["transparency"] = 255 # Preserve transparency
|
||||
|
||||
elif format_lower == "bmp":
|
||||
save_kwargs["format"] = "BMP"
|
||||
if image.mode in ("RGBA", "LA", "P") and has_transparency:
|
||||
# BMP doesn't support transparency, convert to RGB
|
||||
if image.mode == "RGBA":
|
||||
rgb_image = Image.new("RGB", image.size, (255, 255, 255))
|
||||
rgb_image.paste(image, mask=image.split()[3])
|
||||
image = rgb_image
|
||||
else:
|
||||
image = image.convert("RGB")
|
||||
|
||||
elif format_lower == "tiff":
|
||||
save_kwargs["format"] = "TIFF"
|
||||
save_kwargs["compression"] = "tiff_lzw" # Lossless compression
|
||||
if preserve_transparency and has_transparency:
|
||||
# TIFF supports transparency
|
||||
if image.mode not in ("RGBA", "LA"):
|
||||
image = image.convert("RGBA")
|
||||
else:
|
||||
raise ValueError(f"Unsupported target format: {target_format}")
|
||||
|
||||
image.save(buffer, **save_kwargs)
|
||||
return buffer.getvalue()
|
||||
|
||||
async def convert(
|
||||
self,
|
||||
request: FormatConversionRequest,
|
||||
user_id: Optional[str] = None,
|
||||
) -> FormatConversionResult:
|
||||
"""Convert an image to target format."""
|
||||
logger.info(f"[Format Converter] Processing conversion request for user: {user_id}")
|
||||
|
||||
try:
|
||||
# Decode image
|
||||
image, original_size, original_format = self._decode_image(request.image_base64)
|
||||
original_size_kb = original_size / 1024
|
||||
|
||||
logger.info(f"[Format Converter] Original: {original_format}, Target: {request.target_format}, Size: {original_size_kb:.2f} KB")
|
||||
|
||||
# Validate target format
|
||||
format_lower = request.target_format.lower()
|
||||
if format_lower not in self.SUPPORTED_FORMATS:
|
||||
raise ValueError(f"Unsupported format: {request.target_format}. Supported: {list(self.SUPPORTED_FORMATS.keys())}")
|
||||
|
||||
# Check transparency preservation
|
||||
has_transparency = image.mode in ("RGBA", "LA", "P") and (
|
||||
"transparency" in image.info or image.mode == "RGBA"
|
||||
)
|
||||
target_supports_transparency = self.SUPPORTED_FORMATS[format_lower]["supports_transparency"]
|
||||
transparency_preserved = (
|
||||
has_transparency and
|
||||
target_supports_transparency and
|
||||
request.preserve_transparency
|
||||
)
|
||||
|
||||
# Color space conversion
|
||||
if request.color_space:
|
||||
image = self._convert_color_space(image, request.color_space)
|
||||
|
||||
# Strip metadata if requested
|
||||
metadata_preserved = not request.strip_metadata
|
||||
if request.strip_metadata:
|
||||
image = self._strip_exif(image)
|
||||
|
||||
# Convert format
|
||||
converted_bytes = self._convert_image(
|
||||
image,
|
||||
format_lower,
|
||||
request.quality,
|
||||
request.preserve_transparency,
|
||||
request.optimize,
|
||||
request.progressive,
|
||||
)
|
||||
|
||||
converted_size_kb = len(converted_bytes) / 1024
|
||||
|
||||
# Encode result
|
||||
mime_type = self.SUPPORTED_FORMATS[format_lower]["mime_type"]
|
||||
result_base64 = f"data:{mime_type};base64,{base64.b64encode(converted_bytes).decode()}"
|
||||
|
||||
logger.info(f"[Format Converter] Converted: {original_size_kb:.2f}KB → {converted_size_kb:.2f}KB")
|
||||
|
||||
return FormatConversionResult(
|
||||
success=True,
|
||||
image_base64=result_base64,
|
||||
original_format=original_format,
|
||||
target_format=format_lower,
|
||||
original_size_kb=round(original_size_kb, 2),
|
||||
converted_size_kb=round(converted_size_kb, 2),
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
transparency_preserved=transparency_preserved,
|
||||
metadata_preserved=metadata_preserved,
|
||||
color_space=request.color_space,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Format Converter] Failed to convert image: {e}")
|
||||
raise
|
||||
|
||||
async def convert_batch(
|
||||
self,
|
||||
requests: List[FormatConversionRequest],
|
||||
user_id: Optional[str] = None,
|
||||
) -> List[FormatConversionResult]:
|
||||
"""Convert multiple images."""
|
||||
logger.info(f"[Format Converter] Processing batch of {len(requests)} images for user: {user_id}")
|
||||
|
||||
results = []
|
||||
for i, request in enumerate(requests):
|
||||
try:
|
||||
result = await self.convert(request, user_id)
|
||||
results.append(result)
|
||||
logger.info(f"[Format Converter] Batch item {i+1}/{len(requests)} complete")
|
||||
except Exception as e:
|
||||
logger.error(f"[Format Converter] Batch item {i+1} failed: {e}")
|
||||
results.append(FormatConversionResult(
|
||||
success=False,
|
||||
image_base64="",
|
||||
original_format="",
|
||||
target_format="",
|
||||
original_size_kb=0,
|
||||
converted_size_kb=0,
|
||||
width=0,
|
||||
height=0,
|
||||
transparency_preserved=False,
|
||||
metadata_preserved=False,
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
def get_supported_formats(self) -> List[Dict[str, Any]]:
|
||||
"""Get list of supported formats with details."""
|
||||
return [
|
||||
{
|
||||
"id": fmt_id,
|
||||
"name": fmt_info["name"],
|
||||
"description": fmt_info["description"],
|
||||
"supports_transparency": fmt_info["supports_transparency"],
|
||||
"supports_lossy": fmt_info["supports_lossy"],
|
||||
"mime_type": fmt_info["mime_type"],
|
||||
}
|
||||
for fmt_id, fmt_info in self.SUPPORTED_FORMATS.items()
|
||||
]
|
||||
|
||||
def get_format_recommendations(self, source_format: str) -> List[Dict[str, Any]]:
|
||||
"""Get format recommendations based on source format."""
|
||||
recommendations = {
|
||||
"png": [
|
||||
{"format": "webp", "reason": "60% smaller file size, maintains transparency"},
|
||||
{"format": "jpeg", "reason": "Best for photos, smaller file size"},
|
||||
],
|
||||
"jpeg": [
|
||||
{"format": "webp", "reason": "25-34% smaller with similar quality"},
|
||||
{"format": "png", "reason": "Lossless, supports transparency"},
|
||||
],
|
||||
"jpg": [
|
||||
{"format": "webp", "reason": "25-34% smaller with similar quality"},
|
||||
{"format": "png", "reason": "Lossless, supports transparency"},
|
||||
],
|
||||
"webp": [
|
||||
{"format": "png", "reason": "Better compatibility, lossless"},
|
||||
{"format": "jpeg", "reason": "Universal compatibility"},
|
||||
],
|
||||
}
|
||||
|
||||
source_lower = source_format.lower()
|
||||
return recommendations.get(source_lower, [])
|
||||
@@ -7,6 +7,9 @@ from .edit_service import EditStudioService, EditStudioRequest
|
||||
from .upscale_service import UpscaleStudioService, UpscaleStudioRequest
|
||||
from .control_service import ControlStudioService, ControlStudioRequest
|
||||
from .social_optimizer_service import SocialOptimizerService, SocialOptimizerRequest
|
||||
from .face_swap_service import FaceSwapService, FaceSwapStudioRequest
|
||||
from .compression_service import ImageCompressionService, CompressionRequest, CompressionResult
|
||||
from .format_converter_service import ImageFormatConverterService, FormatConversionRequest, FormatConversionResult
|
||||
from .transform_service import (
|
||||
TransformStudioService,
|
||||
TransformImageToVideoRequest,
|
||||
@@ -29,6 +32,9 @@ class ImageStudioManager:
|
||||
self.upscale_service = UpscaleStudioService()
|
||||
self.control_service = ControlStudioService()
|
||||
self.social_optimizer_service = SocialOptimizerService()
|
||||
self.face_swap_service = FaceSwapService()
|
||||
self.compression_service = ImageCompressionService()
|
||||
self.format_converter_service = ImageFormatConverterService()
|
||||
self.transform_service = TransformStudioService()
|
||||
logger.info("[Image Studio Manager] Initialized successfully")
|
||||
|
||||
@@ -69,6 +75,99 @@ class ImageStudioManager:
|
||||
def get_edit_operations(self) -> Dict[str, Any]:
|
||||
"""Expose edit operations for UI."""
|
||||
return self.edit_service.list_operations()
|
||||
|
||||
def get_edit_models(
|
||||
self,
|
||||
operation: Optional[str] = None,
|
||||
tier: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Get available editing models.
|
||||
|
||||
Args:
|
||||
operation: Filter by operation type
|
||||
tier: Filter by tier (budget, mid, premium)
|
||||
|
||||
Returns:
|
||||
Dictionary with models and metadata
|
||||
"""
|
||||
return self.edit_service.get_available_models(operation=operation, tier=tier)
|
||||
|
||||
def recommend_edit_model(
|
||||
self,
|
||||
operation: str,
|
||||
image_resolution: Optional[Dict[str, int]] = None,
|
||||
user_tier: Optional[str] = None,
|
||||
preferences: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Recommend best editing model for given context.
|
||||
|
||||
Args:
|
||||
operation: Operation type
|
||||
image_resolution: Image dimensions
|
||||
user_tier: User subscription tier
|
||||
preferences: User preferences (prioritize_cost, prioritize_quality)
|
||||
|
||||
Returns:
|
||||
Dictionary with recommended model and alternatives
|
||||
"""
|
||||
return self.edit_service.recommend_model(
|
||||
operation=operation,
|
||||
image_resolution=image_resolution,
|
||||
user_tier=user_tier,
|
||||
preferences=preferences,
|
||||
)
|
||||
|
||||
# ====================
|
||||
# FACE SWAP STUDIO
|
||||
# ====================
|
||||
|
||||
async def face_swap(
|
||||
self,
|
||||
request: FaceSwapStudioRequest,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run Face Swap Studio operations."""
|
||||
logger.info("[Image Studio] Face swap request from user: %s", user_id)
|
||||
return await self.face_swap_service.process_face_swap(request, user_id=user_id)
|
||||
|
||||
def get_face_swap_models(
|
||||
self,
|
||||
tier: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Get available face swap models.
|
||||
|
||||
Args:
|
||||
tier: Filter by tier (budget, mid, premium)
|
||||
|
||||
Returns:
|
||||
Dictionary with models and metadata
|
||||
"""
|
||||
return self.face_swap_service.get_available_models(tier=tier)
|
||||
|
||||
def recommend_face_swap_model(
|
||||
self,
|
||||
base_image_resolution: Optional[Dict[str, int]] = None,
|
||||
face_image_resolution: Optional[Dict[str, int]] = None,
|
||||
user_tier: Optional[str] = None,
|
||||
preferences: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Recommend best face swap model for given context.
|
||||
|
||||
Args:
|
||||
base_image_resolution: Base image dimensions
|
||||
face_image_resolution: Face image dimensions
|
||||
user_tier: User subscription tier
|
||||
preferences: User preferences (prioritize_cost, prioritize_quality)
|
||||
|
||||
Returns:
|
||||
Dictionary with recommended model and alternatives
|
||||
"""
|
||||
return self.face_swap_service.recommend_model(
|
||||
base_image_resolution=base_image_resolution,
|
||||
face_image_resolution=face_image_resolution,
|
||||
user_tier=user_tier,
|
||||
preferences=preferences,
|
||||
)
|
||||
|
||||
# ====================
|
||||
# UPSCALE STUDIO
|
||||
@@ -377,3 +476,72 @@ class ImageStudioManager:
|
||||
"""Estimate cost for transform operation."""
|
||||
return self.transform_service.estimate_cost(operation, resolution, duration)
|
||||
|
||||
# ====================
|
||||
# COMPRESSION STUDIO
|
||||
# ====================
|
||||
|
||||
async def compress_image(
|
||||
self,
|
||||
request: CompressionRequest,
|
||||
user_id: Optional[str] = None,
|
||||
) -> CompressionResult:
|
||||
"""Compress an image with specified settings."""
|
||||
logger.info("[Image Studio] Compress image request from user: %s", user_id)
|
||||
return await self.compression_service.compress(request, user_id=user_id)
|
||||
|
||||
async def compress_batch(
|
||||
self,
|
||||
requests: List[CompressionRequest],
|
||||
user_id: Optional[str] = None,
|
||||
) -> List[CompressionResult]:
|
||||
"""Compress multiple images."""
|
||||
logger.info("[Image Studio] Batch compress request (%d images) from user: %s", len(requests), user_id)
|
||||
return await self.compression_service.compress_batch(requests, user_id=user_id)
|
||||
|
||||
async def estimate_compression(
|
||||
self,
|
||||
image_base64: str,
|
||||
format: str = "jpeg",
|
||||
quality: int = 85,
|
||||
) -> Dict[str, Any]:
|
||||
"""Estimate compression results without compressing."""
|
||||
return await self.compression_service.estimate_compression(image_base64, format, quality)
|
||||
|
||||
def get_compression_formats(self) -> List[Dict[str, Any]]:
|
||||
"""Get supported compression formats."""
|
||||
return self.compression_service.get_supported_formats()
|
||||
|
||||
def get_compression_presets(self) -> List[Dict[str, Any]]:
|
||||
"""Get compression presets for common use cases."""
|
||||
return self.compression_service.get_presets()
|
||||
|
||||
# ====================
|
||||
# FORMAT CONVERTER
|
||||
# ====================
|
||||
|
||||
async def convert_format(
|
||||
self,
|
||||
request: FormatConversionRequest,
|
||||
user_id: Optional[str] = None,
|
||||
) -> FormatConversionResult:
|
||||
"""Convert an image to target format."""
|
||||
logger.info("[Image Studio] Convert format request from user: %s", user_id)
|
||||
return await self.format_converter_service.convert(request, user_id=user_id)
|
||||
|
||||
async def convert_format_batch(
|
||||
self,
|
||||
requests: List[FormatConversionRequest],
|
||||
user_id: Optional[str] = None,
|
||||
) -> List[FormatConversionResult]:
|
||||
"""Convert multiple images."""
|
||||
logger.info("[Image Studio] Batch convert format request (%d images) from user: %s", len(requests), user_id)
|
||||
return await self.format_converter_service.convert_batch(requests, user_id=user_id)
|
||||
|
||||
def get_supported_formats(self) -> List[Dict[str, Any]]:
|
||||
"""Get supported conversion formats."""
|
||||
return self.format_converter_service.get_supported_formats()
|
||||
|
||||
def get_format_recommendations(self, source_format: str) -> List[Dict[str, Any]]:
|
||||
"""Get format recommendations based on source format."""
|
||||
return self.format_converter_service.get_format_recommendations(source_format)
|
||||
|
||||
|
||||
@@ -36,18 +36,16 @@ class UpscaleStudioService:
|
||||
request: UpscaleStudioRequest,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
# Pre-flight validation: Reuse unified helper
|
||||
# Note: Using image-generation validation since upscaling uses same subscription limits
|
||||
if user_id:
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_upscale_operations
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
logger.info("[Upscale Studio] 🛂 Running pre-flight validation for user %s", user_id)
|
||||
validate_image_upscale_operations(pricing_service=pricing_service, user_id=user_id)
|
||||
finally:
|
||||
db.close()
|
||||
from services.llm_providers.main_image_generation import _validate_image_operation
|
||||
_validate_image_operation(
|
||||
user_id=user_id,
|
||||
operation_type="image-upscale",
|
||||
num_operations=1,
|
||||
log_prefix="[Upscale Studio]"
|
||||
)
|
||||
|
||||
image_bytes = self._decode_base64(request.image_base64)
|
||||
if not image_bytes:
|
||||
|
||||
@@ -1,4 +1,12 @@
|
||||
from .base import ImageGenerationOptions, ImageGenerationResult, ImageGenerationProvider
|
||||
from .base import (
|
||||
ImageGenerationOptions,
|
||||
ImageGenerationResult,
|
||||
ImageGenerationProvider,
|
||||
ImageEditOptions,
|
||||
ImageEditProvider,
|
||||
FaceSwapOptions,
|
||||
FaceSwapProvider,
|
||||
)
|
||||
from .hf_provider import HuggingFaceImageProvider
|
||||
from .gemini_provider import GeminiImageProvider
|
||||
from .stability_provider import StabilityImageProvider
|
||||
@@ -8,6 +16,10 @@ __all__ = [
|
||||
"ImageGenerationOptions",
|
||||
"ImageGenerationResult",
|
||||
"ImageGenerationProvider",
|
||||
"ImageEditOptions",
|
||||
"ImageEditProvider",
|
||||
"FaceSwapOptions",
|
||||
"FaceSwapProvider",
|
||||
"HuggingFaceImageProvider",
|
||||
"GeminiImageProvider",
|
||||
"StabilityImageProvider",
|
||||
|
||||
@@ -28,6 +28,50 @@ class ImageGenerationResult:
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageEditOptions:
|
||||
"""Options for image editing operations."""
|
||||
image_base64: str
|
||||
prompt: str
|
||||
operation: str # "general_edit", "inpaint", "outpaint", "remove_background", etc.
|
||||
mask_base64: Optional[str] = None
|
||||
negative_prompt: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
guidance_scale: Optional[float] = None
|
||||
steps: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
extra: Optional[Dict[str, Any]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for API calls."""
|
||||
result = {
|
||||
"image_base64": self.image_base64,
|
||||
"prompt": self.prompt,
|
||||
"operation": self.operation,
|
||||
}
|
||||
if self.mask_base64:
|
||||
result["mask_base64"] = self.mask_base64
|
||||
if self.negative_prompt:
|
||||
result["negative_prompt"] = self.negative_prompt
|
||||
if self.model:
|
||||
result["model"] = self.model
|
||||
if self.width:
|
||||
result["width"] = self.width
|
||||
if self.height:
|
||||
result["height"] = self.height
|
||||
if self.guidance_scale is not None:
|
||||
result["guidance_scale"] = self.guidance_scale
|
||||
if self.steps:
|
||||
result["steps"] = self.steps
|
||||
if self.seed is not None:
|
||||
result["seed"] = self.seed
|
||||
if self.extra:
|
||||
result.update(self.extra)
|
||||
return result
|
||||
|
||||
|
||||
class ImageGenerationProvider(Protocol):
|
||||
"""Protocol for image generation providers."""
|
||||
|
||||
@@ -35,3 +79,44 @@ class ImageGenerationProvider(Protocol):
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class FaceSwapOptions:
|
||||
"""Options for face swap operations."""
|
||||
base_image_base64: str # Image to swap face into
|
||||
face_image_base64: str # Face to swap
|
||||
model: Optional[str] = None
|
||||
target_face_index: Optional[int] = None # For multi-face images (0 = largest)
|
||||
target_gender: Optional[str] = None # "all", "female", "male" (for some models)
|
||||
extra: Optional[Dict[str, Any]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for API calls."""
|
||||
result = {
|
||||
"base_image_base64": self.base_image_base64,
|
||||
"face_image_base64": self.face_image_base64,
|
||||
}
|
||||
if self.model:
|
||||
result["model"] = self.model
|
||||
if self.target_face_index is not None:
|
||||
result["target_face_index"] = self.target_face_index
|
||||
if self.target_gender:
|
||||
result["target_gender"] = self.target_gender
|
||||
if self.extra:
|
||||
result.update(self.extra)
|
||||
return result
|
||||
|
||||
|
||||
class ImageEditProvider(Protocol):
|
||||
"""Protocol for image editing providers."""
|
||||
|
||||
def edit(self, options: ImageEditOptions) -> ImageGenerationResult:
|
||||
...
|
||||
|
||||
|
||||
class FaceSwapProvider(Protocol):
|
||||
"""Protocol for face swap providers."""
|
||||
|
||||
def swap_face(self, options: FaceSwapOptions) -> ImageGenerationResult:
|
||||
...
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,691 @@
|
||||
"""WaveSpeed AI image editing provider (14 editing models)."""
|
||||
|
||||
import io
|
||||
import os
|
||||
import requests
|
||||
from typing import Optional
|
||||
from PIL import Image
|
||||
from fastapi import HTTPException
|
||||
|
||||
from .base import ImageEditProvider, ImageEditOptions, ImageGenerationResult
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
logger = get_service_logger("wavespeed.edit_provider")
|
||||
|
||||
|
||||
class WaveSpeedEditProvider(ImageEditProvider):
|
||||
"""WaveSpeed AI image editing provider supporting 14 editing models.
|
||||
|
||||
REUSES: WaveSpeedClient, model registry pattern, result format
|
||||
"""
|
||||
|
||||
# Model registry - populated with WaveSpeed editing models
|
||||
SUPPORTED_MODELS = {
|
||||
"qwen-edit": {
|
||||
"model_path": "wavespeed-ai/qwen-image/edit",
|
||||
"name": "Qwen Image Edit",
|
||||
"description": "20B MMDiT image-to-image model offering precise bilingual (Chinese & English) text edits while preserving style. Single-image editing with style preservation.",
|
||||
"cost": 0.02, # Same as Plus version
|
||||
"max_resolution": (1536, 1536), # Based on docs: similar to Plus
|
||||
"capabilities": ["general_edit", "style_transfer", "text_edit"],
|
||||
"tier": "budget",
|
||||
"supports_multi_image": False, # Single image only (uses "image" not "images")
|
||||
"supports_controlnet": False, # Not mentioned in docs
|
||||
"languages": ["en", "zh"],
|
||||
"api_params": {
|
||||
"uses_size": True, # Uses "size" parameter (width*height)
|
||||
"uses_aspect_ratio": False,
|
||||
"uses_resolution": False,
|
||||
"uses_image_singular": True, # Uses "image" (singular) not "images" (array)
|
||||
"default_output_format": "jpeg", # Per API docs: default is "jpeg"
|
||||
"supports_seed": True, # Per API docs: seed parameter supported
|
||||
}
|
||||
},
|
||||
"qwen-edit-plus": {
|
||||
"model_path": "wavespeed-ai/qwen-image/edit-plus",
|
||||
"name": "Qwen Image Edit Plus",
|
||||
"description": "20B MMDiT image editor with multi-image editing, single-image consistency and native ControlNet support. Bilingual (CN/EN) text editing, appearance-level and semantic-level edits.",
|
||||
"cost": 0.02,
|
||||
"max_resolution": (1536, 1536), # Based on docs: 256-1536 per dimension
|
||||
"capabilities": ["general_edit", "style_transfer", "text_edit", "multi_image"],
|
||||
"tier": "budget",
|
||||
"supports_multi_image": True, # Up to 3 reference images
|
||||
"supports_controlnet": True,
|
||||
"languages": ["en", "zh"],
|
||||
"api_params": {
|
||||
"uses_size": True, # Uses "size" parameter (width*height)
|
||||
"uses_aspect_ratio": False,
|
||||
"uses_resolution": False,
|
||||
"uses_image_singular": False, # Uses "images" (array)
|
||||
"supports_seed": True, # Seed parameter supported (default for Qwen models)
|
||||
}
|
||||
},
|
||||
"nano-banana-pro-edit-ultra": {
|
||||
"model_path": "google/nano-banana-pro/edit-ultra",
|
||||
"name": "Google Nano Banana Pro Edit Ultra",
|
||||
"description": "High-resolution image editing with 4K/8K native output. Natural language instructions, multilingual text support. Premium quality editing for professional marketing and high-res work.",
|
||||
"cost": 0.15, # 4K - from enhancement proposal
|
||||
"cost_8k": 0.18, # 8K - from enhancement proposal
|
||||
"max_resolution": (8192, 8192), # 8K support
|
||||
"capabilities": ["general_edit", "high_res", "professional", "typography"],
|
||||
"tier": "premium",
|
||||
"supports_multi_image": True, # Up to 14 reference images
|
||||
"supports_controlnet": False,
|
||||
"languages": ["en", "multilingual"],
|
||||
"api_params": {
|
||||
"uses_size": False, # Uses aspect_ratio and resolution instead
|
||||
"uses_aspect_ratio": True, # "1:1", "16:9", etc.
|
||||
"uses_resolution": True, # "4k" or "8k"
|
||||
"max_images": 14,
|
||||
"default_output_format": "png", # Per API docs: default is "png"
|
||||
"supports_seed": False, # Per API docs: no seed parameter
|
||||
}
|
||||
},
|
||||
"seedream-v4.5-edit": {
|
||||
"model_path": "bytedance/seedream-v4.5/edit",
|
||||
"name": "Bytedance Seedream V4.5 Edit",
|
||||
"description": "Preserves facial features, lighting, and color tone from reference images, delivering professional, high-fidelity edits up to 4K with strong prompt adherence. Reference-faithful editing with multi-image support.",
|
||||
"cost": 0.04, # Per generated image
|
||||
"max_resolution": (4096, 4096), # 4K support (1024-4096 per dimension)
|
||||
"capabilities": ["general_edit", "portrait_retouching", "fashion_edit", "product_edit", "multi_image"],
|
||||
"tier": "mid",
|
||||
"supports_multi_image": True, # Up to 10 reference images
|
||||
"supports_controlnet": False,
|
||||
"languages": ["en"],
|
||||
"api_params": {
|
||||
"uses_size": True, # Uses "size" parameter (width*height format, 1024-4096 per dimension)
|
||||
"uses_aspect_ratio": False,
|
||||
"uses_resolution": False,
|
||||
"max_images": 10,
|
||||
"default_output_format": "png",
|
||||
"supports_seed": False, # No seed parameter in API docs (Seedream V4.5)
|
||||
}
|
||||
},
|
||||
"flux-kontext-pro": {
|
||||
"model_path": "wavespeed-ai/flux-kontext-pro",
|
||||
"name": "FLUX Kontext Pro",
|
||||
"description": "FLUX.1 Kontext [pro] offers improved prompt adherence and accurate typography generation for consistent, high-quality edits at speed. Typography-focused editing with improved prompt adherence.",
|
||||
"cost": 0.04, # From enhancement proposal
|
||||
"max_resolution": (2048, 2048), # Estimated, not specified in docs
|
||||
"capabilities": ["general_edit", "typography", "text_edit", "style_transfer"],
|
||||
"tier": "mid",
|
||||
"supports_multi_image": False, # Single image only (uses "image" not "images")
|
||||
"supports_controlnet": False,
|
||||
"languages": ["en"],
|
||||
"api_params": {
|
||||
"uses_size": False, # Uses aspect_ratio instead
|
||||
"uses_aspect_ratio": True, # Aspect ratio as string (e.g., "16:9", "1:1")
|
||||
"uses_resolution": False,
|
||||
"uses_image_singular": True, # Uses "image" (singular) not "images" (array)
|
||||
"supports_guidance_scale": True, # Has guidance_scale parameter (default 3.5, range 1-20)
|
||||
"default_guidance_scale": 3.5, # Per API docs
|
||||
"supports_seed": False, # No seed parameter in API docs
|
||||
}
|
||||
},
|
||||
# TODO: Add remaining 9 models once docs are provided
|
||||
}
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
"""Initialize WaveSpeed edit provider.
|
||||
|
||||
Args:
|
||||
api_key: WaveSpeed API key (falls back to env var if not provided)
|
||||
"""
|
||||
self.api_key = api_key or os.getenv("WAVESPEED_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("WaveSpeed API key not found. Set WAVESPEED_API_KEY environment variable.")
|
||||
|
||||
# REUSE: Same client as generation provider
|
||||
self.client = WaveSpeedClient(api_key=self.api_key)
|
||||
logger.info("[WaveSpeed Edit Provider] Initialized with %d models",
|
||||
len(self.SUPPORTED_MODELS))
|
||||
|
||||
def _validate_options(self, options: ImageEditOptions) -> None:
|
||||
"""Validate editing options.
|
||||
|
||||
Args:
|
||||
options: Image editing options
|
||||
|
||||
Raises:
|
||||
ValueError: If options are invalid
|
||||
"""
|
||||
model = options.model or list(self.SUPPORTED_MODELS.keys())[0] if self.SUPPORTED_MODELS else None
|
||||
|
||||
if not model:
|
||||
raise ValueError("No model specified and no default model available")
|
||||
|
||||
if model not in self.SUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
f"Unsupported model: {model}. "
|
||||
f"Supported models: {list(self.SUPPORTED_MODELS.keys())}"
|
||||
)
|
||||
|
||||
model_info = self.SUPPORTED_MODELS[model]
|
||||
max_width, max_height = model_info.get("max_resolution", (4096, 4096))
|
||||
|
||||
if options.width and options.width > max_width:
|
||||
raise ValueError(
|
||||
f"Width {options.width} exceeds maximum {max_width} for model {model}"
|
||||
)
|
||||
|
||||
if options.height and options.height > max_height:
|
||||
raise ValueError(
|
||||
f"Height {options.height} exceeds maximum {max_height} for model {model}"
|
||||
)
|
||||
|
||||
if not options.prompt or len(options.prompt.strip()) == 0:
|
||||
raise ValueError("Prompt cannot be empty")
|
||||
|
||||
if not options.image_base64:
|
||||
raise ValueError("Image base64 cannot be empty")
|
||||
|
||||
def edit(self, options: ImageEditOptions) -> ImageGenerationResult:
|
||||
"""Edit image using WaveSpeed AI models.
|
||||
|
||||
Args:
|
||||
options: Image editing options
|
||||
|
||||
Returns:
|
||||
ImageGenerationResult with edited image
|
||||
|
||||
Raises:
|
||||
ValueError: If options are invalid
|
||||
RuntimeError: If editing fails
|
||||
"""
|
||||
# Validate options
|
||||
self._validate_options(options)
|
||||
|
||||
# Determine model
|
||||
model = options.model or (list(self.SUPPORTED_MODELS.keys())[0] if self.SUPPORTED_MODELS else None)
|
||||
if not model:
|
||||
raise ValueError("No model available for editing")
|
||||
|
||||
model_info = self.SUPPORTED_MODELS[model]
|
||||
model_path = model_info["model_path"]
|
||||
|
||||
logger.info("[WaveSpeed Edit] Starting edit: model=%s, operation=%s, prompt=%s",
|
||||
model, options.operation, options.prompt[:100])
|
||||
|
||||
try:
|
||||
# Prepare extra parameters based on model capabilities
|
||||
extra_params = options.extra or {}
|
||||
|
||||
# Add model-specific parameters if needed
|
||||
api_params = model_info.get("api_params", {})
|
||||
if api_params.get("uses_resolution", False):
|
||||
# For Nano Banana: determine resolution from dimensions or use default
|
||||
if options.width and options.height:
|
||||
if options.width >= 4096 or options.height >= 4096:
|
||||
extra_params["resolution"] = "8k"
|
||||
else:
|
||||
extra_params["resolution"] = "4k"
|
||||
elif "resolution" not in extra_params:
|
||||
extra_params["resolution"] = "4k" # Default to 4K
|
||||
|
||||
if api_params.get("uses_aspect_ratio", False) and not extra_params.get("aspect_ratio"):
|
||||
# Calculate aspect ratio if dimensions provided
|
||||
if options.width and options.height:
|
||||
aspect_ratio = self._calculate_aspect_ratio(options.width, options.height)
|
||||
if aspect_ratio:
|
||||
extra_params["aspect_ratio"] = aspect_ratio
|
||||
|
||||
# Call WaveSpeed API for editing
|
||||
result = self._call_wavespeed_edit_api(
|
||||
model_path=model_path,
|
||||
image_base64=options.image_base64,
|
||||
prompt=options.prompt,
|
||||
operation=options.operation,
|
||||
mask_base64=options.mask_base64,
|
||||
negative_prompt=options.negative_prompt,
|
||||
width=options.width,
|
||||
height=options.height,
|
||||
guidance_scale=options.guidance_scale,
|
||||
steps=options.steps,
|
||||
seed=options.seed,
|
||||
extra=extra_params
|
||||
)
|
||||
|
||||
# Extract image bytes from result
|
||||
if isinstance(result, bytes):
|
||||
image_bytes = result
|
||||
elif isinstance(result, dict) and "image" in result:
|
||||
image_bytes = result["image"]
|
||||
elif isinstance(result, dict) and "image_bytes" in result:
|
||||
image_bytes = result["image_bytes"]
|
||||
else:
|
||||
raise ValueError(f"Unexpected response format from WaveSpeed API: {type(result)}")
|
||||
|
||||
# Load image to get dimensions
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
width, height = image.size
|
||||
|
||||
# Calculate estimated cost - handle resolution-based pricing
|
||||
estimated_cost = model_info.get("cost", 0.02)
|
||||
if api_params.get("uses_resolution", False):
|
||||
# Check if 8K was requested
|
||||
resolution = extra_params.get("resolution", "4k")
|
||||
if resolution == "8k" and "cost_8k" in model_info:
|
||||
estimated_cost = model_info["cost_8k"]
|
||||
|
||||
logger.info("[WaveSpeed Edit] ✅ Successfully edited image: %d bytes, %dx%d",
|
||||
len(image_bytes), width, height)
|
||||
|
||||
# REUSE: Same result format as generation
|
||||
return ImageGenerationResult(
|
||||
image_bytes=image_bytes,
|
||||
width=width,
|
||||
height=height,
|
||||
provider="wavespeed",
|
||||
model=model,
|
||||
seed=options.seed,
|
||||
metadata={
|
||||
"provider": "wavespeed",
|
||||
"model": model,
|
||||
"model_name": model_info.get("name", model),
|
||||
"operation": options.operation,
|
||||
"prompt": options.prompt,
|
||||
"negative_prompt": options.negative_prompt,
|
||||
"estimated_cost": estimated_cost,
|
||||
"tier": model_info.get("tier", "mid"),
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[WaveSpeed Edit] ❌ Error editing image: %s", str(e), exc_info=True)
|
||||
raise RuntimeError(f"WaveSpeed edit failed: {str(e)}")
|
||||
|
||||
def _call_wavespeed_edit_api(
|
||||
self,
|
||||
model_path: str,
|
||||
image_base64: str,
|
||||
prompt: str,
|
||||
operation: str,
|
||||
mask_base64: Optional[str] = None,
|
||||
negative_prompt: Optional[str] = None,
|
||||
width: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
guidance_scale: Optional[float] = None,
|
||||
steps: Optional[int] = None,
|
||||
seed: Optional[int] = None,
|
||||
extra: Optional[dict] = None
|
||||
) -> bytes:
|
||||
"""Call WaveSpeed API for image editing.
|
||||
|
||||
REUSES: Same pattern as ImageGenerator.generate_image()
|
||||
|
||||
Args:
|
||||
model_path: Full model path (e.g., "wavespeed-ai/qwen-image/edit-plus")
|
||||
image_base64: Base64-encoded input image
|
||||
prompt: Edit instruction prompt
|
||||
operation: Type of operation
|
||||
mask_base64: Optional mask for inpainting
|
||||
negative_prompt: Optional negative prompt
|
||||
width: Optional target width
|
||||
height: Optional target height
|
||||
guidance_scale: Optional guidance scale (not used by all models)
|
||||
steps: Optional number of steps (not used by all models)
|
||||
seed: Optional seed
|
||||
extra: Optional extra parameters
|
||||
|
||||
Returns:
|
||||
Edited image bytes
|
||||
|
||||
Raises:
|
||||
RuntimeError: If API call fails
|
||||
"""
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
# Build URL - REUSES same pattern as ImageGenerator
|
||||
url = f"{self.client.BASE_URL}/{model_path}"
|
||||
|
||||
# Prepare images array - WaveSpeed expects array of image strings
|
||||
# Format: base64 strings or data URIs (data:image/png;base64,...)
|
||||
# For Qwen Image Edit Plus: supports up to 3 reference images
|
||||
images = []
|
||||
|
||||
# Add main image - check if it's already a data URI or just base64
|
||||
if image_base64.startswith("data:image"):
|
||||
# Already a data URI
|
||||
images.append(image_base64)
|
||||
else:
|
||||
# Assume it's base64, convert to data URI
|
||||
# Try to detect format from base64 or default to PNG
|
||||
images.append(f"data:image/png;base64,{image_base64}")
|
||||
|
||||
# If mask is provided, add it as second image
|
||||
# Note: Some models may need mask in different format - will adjust per model
|
||||
if mask_base64:
|
||||
if mask_base64.startswith("data:image"):
|
||||
images.append(mask_base64)
|
||||
else:
|
||||
images.append(f"data:image/png;base64,{mask_base64}")
|
||||
|
||||
# Get model info to determine API parameter structure
|
||||
model_info = self.SUPPORTED_MODELS.get(model_path.split("/")[-1] if "/" in model_path else model_path)
|
||||
if not model_info:
|
||||
# Fallback: try to find model by matching path
|
||||
for model_id, info in self.SUPPORTED_MODELS.items():
|
||||
if info["model_path"] == model_path:
|
||||
model_info = info
|
||||
break
|
||||
|
||||
if not model_info:
|
||||
raise ValueError(f"Model info not found for: {model_path}")
|
||||
|
||||
api_params = model_info.get("api_params", {})
|
||||
|
||||
# Build payload - following WaveSpeed API structure
|
||||
# Note: output_format default varies by model (PNG for most, but can be JPEG)
|
||||
default_output_format = api_params.get("default_output_format", "png")
|
||||
|
||||
# Some models use "image" (singular) instead of "images" (array)
|
||||
uses_image_singular = api_params.get("uses_image_singular", False)
|
||||
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"enable_sync_mode": True, # Use sync mode for immediate results
|
||||
"enable_base64_output": False, # Get URL, then download
|
||||
"output_format": default_output_format,
|
||||
}
|
||||
|
||||
# Add image(s) based on model API format
|
||||
if uses_image_singular:
|
||||
# Models like Qwen Edit (basic) use "image" (singular)
|
||||
# Use first image only (single image editing)
|
||||
if images:
|
||||
payload["image"] = images[0]
|
||||
else:
|
||||
raise ValueError("At least one image is required")
|
||||
else:
|
||||
# Models like Qwen Edit Plus, Nano Banana use "images" (array)
|
||||
payload["images"] = images
|
||||
|
||||
# Allow override of output_format from extra params
|
||||
if extra and "output_format" in extra:
|
||||
payload["output_format"] = extra["output_format"]
|
||||
|
||||
# Model-specific parameter handling
|
||||
if api_params.get("uses_size", True):
|
||||
# Models like Qwen Edit Plus use "size" parameter (width*height format)
|
||||
if width and height:
|
||||
payload["size"] = f"{width}*{height}"
|
||||
elif width:
|
||||
payload["size"] = f"{width}*{width}" # Square if only width provided
|
||||
elif height:
|
||||
payload["size"] = f"{height}*{height}" # Square if only height provided
|
||||
|
||||
if api_params.get("uses_aspect_ratio", False):
|
||||
# Models like Nano Banana and FLUX Kontext Pro use "aspect_ratio" parameter
|
||||
if width and height:
|
||||
# Calculate aspect ratio from dimensions
|
||||
aspect_ratio = self._calculate_aspect_ratio(width, height)
|
||||
if aspect_ratio:
|
||||
payload["aspect_ratio"] = aspect_ratio
|
||||
elif extra and "aspect_ratio" in extra:
|
||||
payload["aspect_ratio"] = extra["aspect_ratio"]
|
||||
|
||||
if api_params.get("uses_resolution", False):
|
||||
# Models like Nano Banana use "resolution" parameter ("4k" or "8k")
|
||||
if extra and "resolution" in extra:
|
||||
payload["resolution"] = extra["resolution"]
|
||||
else:
|
||||
# Default to 4K, or 8K if dimensions suggest high-res
|
||||
if width and height and (width >= 4096 or height >= 4096):
|
||||
payload["resolution"] = "8k"
|
||||
else:
|
||||
payload["resolution"] = "4k" # Default to 4K per API docs
|
||||
|
||||
# Add optional parameters (model-agnostic)
|
||||
# Guidance scale: Only add if model supports it (e.g., FLUX Kontext Pro)
|
||||
if api_params.get("supports_guidance_scale", False):
|
||||
default_guidance = api_params.get("default_guidance_scale", 3.5)
|
||||
if guidance_scale is not None:
|
||||
# Clamp to valid range (1-20 per FLUX Kontext Pro docs)
|
||||
payload["guidance_scale"] = max(1, min(20, guidance_scale))
|
||||
elif extra and "guidance_scale" in extra:
|
||||
payload["guidance_scale"] = max(1, min(20, extra["guidance_scale"]))
|
||||
else:
|
||||
payload["guidance_scale"] = default_guidance
|
||||
|
||||
# Seed parameter: Only add if model supports it
|
||||
if api_params.get("supports_seed", True): # Default to True for backward compatibility
|
||||
if seed is not None:
|
||||
payload["seed"] = seed
|
||||
else:
|
||||
payload["seed"] = -1 # Random seed (per API docs default)
|
||||
|
||||
# Add any extra parameters
|
||||
if extra:
|
||||
# Filter out parameters we've already handled
|
||||
handled_params = {"aspect_ratio", "resolution", "size", "seed", "guidance_scale"}
|
||||
for key, value in extra.items():
|
||||
if key not in handled_params:
|
||||
payload[key] = value
|
||||
|
||||
logger.info(f"[WaveSpeed Edit] Submitting edit request to {url} (model={model_path}, prompt_length={len(prompt)})")
|
||||
|
||||
# Make API call - REUSES same pattern as ImageGenerator
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=self.client._headers(),
|
||||
json=payload,
|
||||
timeout=120
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed Edit] API call failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed image editing failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text[:500],
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
# Check status
|
||||
status = data.get("status", "").lower()
|
||||
outputs = data.get("outputs") or []
|
||||
prediction_id = data.get("id")
|
||||
|
||||
logger.debug(
|
||||
f"[WaveSpeed Edit] Response: status='{status}', outputs_count={len(outputs)}, "
|
||||
f"prediction_id={prediction_id}"
|
||||
)
|
||||
|
||||
# Handle sync mode - result should be directly in outputs
|
||||
if outputs and status == "completed":
|
||||
logger.info(f"[WaveSpeed Edit] Got immediate results from sync mode")
|
||||
image_url = self._extract_image_url(outputs)
|
||||
return self._download_image(image_url, timeout=120)
|
||||
|
||||
# Sync mode returned "created" or "processing" - need to poll
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed Edit] Sync mode returned status '{status}' but no prediction ID")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed sync mode returned async response without prediction ID",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[WaveSpeed Edit] Sync mode returned status '{status}' with no outputs. "
|
||||
f"Polling for result (prediction_id: {prediction_id})"
|
||||
)
|
||||
|
||||
# Poll for result - REUSES polling utility
|
||||
result = self.client.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=180,
|
||||
interval_seconds=2.0,
|
||||
)
|
||||
|
||||
outputs = result.get("outputs") or []
|
||||
if not outputs:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed edit returned no outputs after polling"
|
||||
)
|
||||
|
||||
# Extract image URL from outputs - REUSE helper method
|
||||
image_url = self._extract_image_url(outputs)
|
||||
return self._download_image(image_url, timeout=120)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[WaveSpeed Edit] Unexpected error: {str(e)}", exc_info=True)
|
||||
raise RuntimeError(f"WaveSpeed edit API call failed: {str(e)}")
|
||||
|
||||
def _extract_image_url(self, outputs: list) -> str:
|
||||
"""Extract image URL from outputs - REUSES same pattern as ImageGenerator.
|
||||
|
||||
Args:
|
||||
outputs: Array of output URLs or objects
|
||||
|
||||
Returns:
|
||||
Image URL string
|
||||
|
||||
Raises:
|
||||
HTTPException: If output format is invalid
|
||||
"""
|
||||
if not isinstance(outputs, list) or len(outputs) == 0:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed edit returned no outputs",
|
||||
)
|
||||
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, str):
|
||||
image_url = first_output
|
||||
elif isinstance(first_output, dict):
|
||||
image_url = first_output.get("url") or first_output.get("image_url") or first_output.get("output")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed edit output format not recognized",
|
||||
)
|
||||
|
||||
if not image_url or not (image_url.startswith("http://") or image_url.startswith("https://")):
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed edit returned invalid image URL",
|
||||
)
|
||||
|
||||
return image_url
|
||||
|
||||
def _download_image(self, image_url: str, timeout: int = 120) -> bytes:
|
||||
"""Download image from URL - REUSES same pattern as ImageGenerator.
|
||||
|
||||
Args:
|
||||
image_url: URL to download from
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Image bytes
|
||||
|
||||
Raises:
|
||||
HTTPException: If download fails
|
||||
"""
|
||||
logger.info(f"[WaveSpeed Edit] Downloading edited image from: {image_url}")
|
||||
image_response = requests.get(image_url, timeout=timeout)
|
||||
|
||||
if image_response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed Edit] Failed to download image: {image_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"Failed to download edited image: {image_response.status_code}"
|
||||
)
|
||||
|
||||
logger.info(f"[WaveSpeed Edit] Successfully downloaded image ({len(image_response.content)} bytes)")
|
||||
return image_response.content
|
||||
|
||||
def _calculate_aspect_ratio(self, width: int, height: int) -> Optional[str]:
|
||||
"""Calculate aspect ratio string from dimensions.
|
||||
|
||||
Args:
|
||||
width: Image width
|
||||
height: Image height
|
||||
|
||||
Returns:
|
||||
Aspect ratio string (e.g., "16:9") or None if not standard
|
||||
"""
|
||||
# Common aspect ratios (includes FLUX Kontext Pro supported ratios)
|
||||
ratios = {
|
||||
(1, 1): "1:1",
|
||||
(3, 2): "3:2",
|
||||
(2, 3): "2:3",
|
||||
(3, 4): "3:4",
|
||||
(4, 3): "4:3",
|
||||
(4, 5): "4:5",
|
||||
(5, 4): "5:4",
|
||||
(9, 16): "9:16",
|
||||
(16, 9): "16:9",
|
||||
(21, 9): "21:9",
|
||||
(9, 21): "9:21", # FLUX Kontext Pro also supports 9:21
|
||||
}
|
||||
|
||||
# Calculate GCD to simplify ratio
|
||||
def gcd(a, b):
|
||||
while b:
|
||||
a, b = b, a % b
|
||||
return a
|
||||
|
||||
divisor = gcd(width, height)
|
||||
simplified = (width // divisor, height // divisor)
|
||||
|
||||
# Check if it matches a standard ratio (with some tolerance)
|
||||
for (w, h), ratio_str in ratios.items():
|
||||
# Allow small tolerance for rounding
|
||||
if abs(simplified[0] / simplified[1] - w / h) < 0.01:
|
||||
return ratio_str
|
||||
|
||||
# If no match, return None (model may not support custom aspect ratios)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_available_models(cls) -> dict:
|
||||
"""Get available editing models and their information.
|
||||
|
||||
Returns:
|
||||
Dictionary of available models
|
||||
"""
|
||||
return cls.SUPPORTED_MODELS
|
||||
|
||||
@classmethod
|
||||
def get_models_by_tier(cls, tier: str) -> dict:
|
||||
"""Get models filtered by tier (budget, mid, premium).
|
||||
|
||||
Args:
|
||||
tier: Tier name ("budget", "mid", "premium")
|
||||
|
||||
Returns:
|
||||
Dictionary of models in the specified tier
|
||||
"""
|
||||
return {
|
||||
model_id: model_info
|
||||
for model_id, model_info in cls.SUPPORTED_MODELS.items()
|
||||
if model_info.get("tier") == tier
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_models_by_operation(cls, operation: str) -> dict:
|
||||
"""Get models that support a specific operation.
|
||||
|
||||
Args:
|
||||
operation: Operation type (e.g., "inpaint", "outpaint", "general_edit")
|
||||
|
||||
Returns:
|
||||
Dictionary of models supporting the operation
|
||||
"""
|
||||
return {
|
||||
model_id: model_info
|
||||
for model_id, model_info in cls.SUPPORTED_MODELS.items()
|
||||
if operation in model_info.get("capabilities", [])
|
||||
}
|
||||
@@ -0,0 +1,367 @@
|
||||
"""WaveSpeed Face Swap Provider for Image Studio."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
from typing import Optional, Dict, Any
|
||||
from PIL import Image
|
||||
|
||||
from services.llm_providers.image_generation.base import (
|
||||
FaceSwapOptions,
|
||||
FaceSwapProvider,
|
||||
ImageGenerationResult,
|
||||
)
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("llm_providers.wavespeed_face_swap")
|
||||
|
||||
|
||||
class WaveSpeedFaceSwapProvider:
|
||||
"""WaveSpeed provider for face swap operations."""
|
||||
|
||||
SUPPORTED_MODELS = {
|
||||
"image-face-swap-pro": {
|
||||
"model_path": "wavespeed-ai/image-face-swap-pro",
|
||||
"name": "Image Face Swap Pro",
|
||||
"description": "Instant online AI face swap for photos with no watermark, delivering realistic, shareable results in seconds.",
|
||||
"cost": 0.025,
|
||||
"tier": "mid",
|
||||
"capabilities": ["face_swap", "realistic_blending"],
|
||||
"features": ["Enhanced blending", "Realistic results", "Watermark-free"],
|
||||
"max_faces": 1,
|
||||
"api_params": {
|
||||
"output_format": "jpeg",
|
||||
"supports_base64": True,
|
||||
"supports_sync": True,
|
||||
},
|
||||
},
|
||||
"image-head-swap": {
|
||||
"model_path": "wavespeed-ai/image-head-swap",
|
||||
"name": "Image Head Swap",
|
||||
"description": "Instant online AI head & face swap for photos with no watermark. Replaces entire head (face + hair + outline) while preserving body, pose and background.",
|
||||
"cost": 0.025,
|
||||
"tier": "mid",
|
||||
"capabilities": ["head_swap", "full_head_replacement", "realistic_blending"],
|
||||
"features": ["Full head replacement", "Hair included", "Pose preservation", "Watermark-free"],
|
||||
"max_faces": 1,
|
||||
"api_params": {
|
||||
"output_format": "jpeg",
|
||||
"supports_base64": True,
|
||||
"supports_sync": True,
|
||||
},
|
||||
},
|
||||
"akool-face-swap": {
|
||||
"model_path": "akool/image-face-swap",
|
||||
"name": "Akool Image Face Swap",
|
||||
"description": "Powerful AI-powered face swapping with multi-face replacement for group photos. Seamlessly replaces faces with natural lighting and skin tone matching.",
|
||||
"cost": 0.16,
|
||||
"tier": "premium",
|
||||
"capabilities": ["face_swap", "multi_face", "realistic_blending", "face_enhancement"],
|
||||
"features": ["Multi-face swapping (up to 5)", "Face enhancement", "Group photos", "High-quality blending"],
|
||||
"max_faces": 5, # Supports 1-5 faces
|
||||
"api_params": {
|
||||
"uses_source_target_arrays": True, # Uses source_image and target_image arrays
|
||||
"supports_face_enhance": True,
|
||||
"supports_base64": True,
|
||||
"supports_sync": False, # May need polling
|
||||
},
|
||||
},
|
||||
"infinite-you": {
|
||||
"model_path": "wavespeed-ai/infinite-you",
|
||||
"name": "InfiniteYou",
|
||||
"description": "High-quality face swapping powered by ByteDance's zero-shot identity preservation technology. Maintains facial identity characteristics with exceptional realism.",
|
||||
"cost": 0.03,
|
||||
"tier": "mid",
|
||||
"capabilities": ["face_swap", "identity_preservation", "realistic_blending"],
|
||||
"features": ["Zero-shot learning", "Identity preservation", "High-quality results", "Fast processing"],
|
||||
"max_faces": 1,
|
||||
"api_params": {
|
||||
"uses_source_target_names": True, # Uses source_image and target_image (not image/face_image)
|
||||
"target_is_base": True, # target_image is the base image (where face will be swapped)
|
||||
"source_is_face": True, # source_image is the face to swap in
|
||||
"supports_seed": True, # Supports seed parameter
|
||||
"supports_base64": True,
|
||||
"supports_sync": True,
|
||||
},
|
||||
},
|
||||
# Placeholder for additional models (will be added as docs are provided)
|
||||
# "image-face-swap": {...}, # Basic version ($0.01)
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.client = WaveSpeedClient()
|
||||
|
||||
def _validate_options(self, options: FaceSwapOptions) -> None:
|
||||
"""Validate face swap options."""
|
||||
if not options.base_image_base64:
|
||||
raise ValueError("base_image_base64 is required")
|
||||
if not options.face_image_base64:
|
||||
raise ValueError("face_image_base64 is required")
|
||||
|
||||
# Validate model
|
||||
if options.model and options.model not in self.SUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
f"Unsupported model: {options.model}. "
|
||||
f"Supported models: {list(self.SUPPORTED_MODELS.keys())}"
|
||||
)
|
||||
|
||||
def _extract_image_url(self, data_url: str) -> str:
|
||||
"""Extract image URL from data URL or return as-is if already a URL."""
|
||||
if data_url.startswith("data:image"):
|
||||
# It's a data URL, we'll need to upload it
|
||||
return data_url
|
||||
return data_url
|
||||
|
||||
def _upload_image_if_needed(self, image_data: str) -> str:
|
||||
"""Upload image if it's a base64 data URL, otherwise return URL."""
|
||||
if image_data.startswith("data:image"):
|
||||
# Extract base64 data
|
||||
header, encoded = image_data.split(",", 1)
|
||||
image_bytes = base64.b64decode(encoded)
|
||||
|
||||
# Upload to temporary storage (or use WaveSpeed upload endpoint if available)
|
||||
# For now, we'll return the data URL and let the API handle it
|
||||
# In production, you might want to upload to S3/CloudFlare first
|
||||
return image_data
|
||||
return image_data
|
||||
|
||||
def _call_wavespeed_face_swap_api(
|
||||
self, options: FaceSwapOptions, model_info: Dict[str, Any]
|
||||
) -> ImageGenerationResult:
|
||||
"""Call WaveSpeed face swap API."""
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
model_path = model_info["model_path"]
|
||||
api_params = model_info.get("api_params", {})
|
||||
uses_source_target_arrays = api_params.get("uses_source_target_arrays", False)
|
||||
|
||||
# Prepare images - extract base64 if data URI
|
||||
base_image = options.base_image_base64
|
||||
if base_image.startswith("data:image"):
|
||||
# Keep as data URI - API should accept it
|
||||
pass
|
||||
elif not base_image.startswith("http"):
|
||||
# Assume it's base64, convert to data URI
|
||||
base_image = f"data:image/png;base64,{base_image}"
|
||||
|
||||
face_image = options.face_image_base64
|
||||
if face_image.startswith("data:image"):
|
||||
# Keep as data URI
|
||||
pass
|
||||
elif not face_image.startswith("http"):
|
||||
# Assume it's base64, convert to data URI
|
||||
face_image = f"data:image/png;base64,{face_image}"
|
||||
|
||||
# Build API payload - handle different API formats
|
||||
uses_source_target_names = api_params.get("uses_source_target_names", False)
|
||||
|
||||
if uses_source_target_arrays:
|
||||
# Akool format: uses source_image and target_image as arrays
|
||||
# For single face swap: source_image is the new face, target_image is reference from main image
|
||||
# Since we only have one face_image, we'll use it as source and the base_image as target reference
|
||||
payload = {
|
||||
"image": base_image,
|
||||
"source_image": [face_image], # Array of source faces (1-5) - the new face to swap in
|
||||
"target_image": [base_image], # Array of target faces (1-5) - reference from main image
|
||||
"face_enhance": api_params.get("supports_face_enhance", True), # Default to True for Akool
|
||||
"enable_base64_output": True,
|
||||
}
|
||||
|
||||
# Allow override from extra params
|
||||
if options.extra:
|
||||
if "source_image" in options.extra:
|
||||
payload["source_image"] = options.extra["source_image"]
|
||||
if "target_image" in options.extra:
|
||||
payload["target_image"] = options.extra["target_image"]
|
||||
if "face_enhance" in options.extra:
|
||||
payload["face_enhance"] = options.extra["face_enhance"]
|
||||
elif uses_source_target_names:
|
||||
# InfiniteYou format: uses source_image and target_image (single values, different names)
|
||||
# target_image = base image (where face will be swapped)
|
||||
# source_image = face image (face to swap in)
|
||||
payload = {
|
||||
"target_image": base_image, # Base image where face will be swapped
|
||||
"source_image": face_image, # Face to swap in
|
||||
"enable_base64_output": True,
|
||||
}
|
||||
|
||||
# Add seed if supported
|
||||
if api_params.get("supports_seed", False):
|
||||
seed = options.extra.get("seed") if options.extra else None
|
||||
payload["seed"] = seed if seed is not None else -1 # Default to -1 (random)
|
||||
|
||||
# Allow override from extra params
|
||||
if options.extra:
|
||||
if "source_image" in options.extra:
|
||||
payload["source_image"] = options.extra["source_image"]
|
||||
if "target_image" in options.extra:
|
||||
payload["target_image"] = options.extra["target_image"]
|
||||
if "seed" in options.extra and api_params.get("supports_seed", False):
|
||||
payload["seed"] = options.extra["seed"]
|
||||
else:
|
||||
# Standard format: uses image and face_image (single values)
|
||||
payload = {
|
||||
"image": base_image,
|
||||
"face_image": face_image,
|
||||
"output_format": api_params.get("output_format", "jpeg"),
|
||||
"enable_base64_output": True, # Always get base64 for our use case
|
||||
"enable_sync_mode": True, # Use sync mode for immediate results
|
||||
}
|
||||
|
||||
# Add any extra parameters (filter out already handled ones)
|
||||
if options.extra:
|
||||
handled_keys = {"source_image", "target_image", "face_enhance", "output_format", "enable_sync_mode", "seed"}
|
||||
for key, value in options.extra.items():
|
||||
if key not in handled_keys:
|
||||
payload[key] = value
|
||||
|
||||
url = f"{self.client.BASE_URL}/{model_path}"
|
||||
headers = self.client._headers()
|
||||
|
||||
logger.info(f"[Face Swap] Calling WaveSpeed API: {url}")
|
||||
logger.debug(f"[Face Swap] Payload keys: {list(payload.keys())}")
|
||||
|
||||
try:
|
||||
# Call API
|
||||
response = requests.post(url, headers=headers, json=payload, timeout=120)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[Face Swap] API call failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed face swap failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
# Check status - Akool uses different status values
|
||||
status = data.get("status", "").lower()
|
||||
# Akool uses "output" (singular), others use "outputs" (plural)
|
||||
outputs = data.get("outputs") or data.get("output") or []
|
||||
# Normalize to list if it's a single value
|
||||
if not isinstance(outputs, list):
|
||||
outputs = [outputs] if outputs else []
|
||||
|
||||
prediction_id = data.get("id")
|
||||
|
||||
# Handle completed status - Akool uses "succeeded", others use "completed"
|
||||
is_completed = status in ["completed", "succeeded"]
|
||||
|
||||
# Handle sync mode - result should be directly in outputs
|
||||
if outputs and is_completed:
|
||||
logger.info(f"[Face Swap] Got immediate results (status: {status})")
|
||||
# Extract image URL or base64
|
||||
output = outputs[0]
|
||||
if output.startswith("data:image") or output.startswith("http"):
|
||||
if output.startswith("http"):
|
||||
# Download from URL
|
||||
import requests
|
||||
img_response = requests.get(output, timeout=60)
|
||||
img_response.raise_for_status()
|
||||
image_bytes = img_response.content
|
||||
else:
|
||||
# Extract base64 from data URI
|
||||
image_bytes = base64.b64decode(output.split(",", 1)[1])
|
||||
else:
|
||||
# Assume it's base64 string
|
||||
image_bytes = base64.b64decode(output)
|
||||
elif prediction_id:
|
||||
# Need to poll
|
||||
logger.info(f"[Face Swap] Polling for result (prediction_id: {prediction_id}, status: {status})")
|
||||
result = self.client.poll_until_complete(prediction_id, timeout_seconds=120, interval_seconds=1.0)
|
||||
# Check both outputs and output fields
|
||||
outputs = result.get("outputs") or result.get("output") or []
|
||||
if not isinstance(outputs, list):
|
||||
outputs = [outputs] if outputs else []
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed face swap returned no outputs")
|
||||
output = outputs[0]
|
||||
if output.startswith("http"):
|
||||
import requests
|
||||
img_response = requests.get(output, timeout=60)
|
||||
img_response.raise_for_status()
|
||||
image_bytes = img_response.content
|
||||
elif output.startswith("data:image"):
|
||||
image_bytes = base64.b64decode(output.split(",", 1)[1])
|
||||
else:
|
||||
image_bytes = base64.b64decode(output)
|
||||
else:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed face swap response missing outputs and prediction ID")
|
||||
|
||||
# Get image dimensions
|
||||
img = Image.open(io.BytesIO(image_bytes))
|
||||
width, height = img.size
|
||||
|
||||
logger.info(f"[Face Swap] ✅ Successfully swapped face: {len(image_bytes)} bytes, {width}x{height}")
|
||||
|
||||
return ImageGenerationResult(
|
||||
image_bytes=image_bytes,
|
||||
width=width,
|
||||
height=height,
|
||||
provider="wavespeed",
|
||||
model=options.model or model_path,
|
||||
metadata={
|
||||
"model_path": model_path,
|
||||
"status": status,
|
||||
"created_at": data.get("created_at"),
|
||||
},
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Face Swap] API call failed: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "Face swap failed",
|
||||
"message": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
def swap_face(self, options: FaceSwapOptions) -> ImageGenerationResult:
|
||||
"""Swap face in image using WaveSpeed models."""
|
||||
self._validate_options(options)
|
||||
|
||||
# Determine model
|
||||
model_id = options.model
|
||||
if not model_id:
|
||||
# Default to first available model
|
||||
model_id = list(self.SUPPORTED_MODELS.keys())[0]
|
||||
logger.info(f"[Face Swap] No model specified, using default: {model_id}")
|
||||
|
||||
model_info = self.SUPPORTED_MODELS[model_id]
|
||||
|
||||
# Call API
|
||||
return self._call_wavespeed_face_swap_api(options, model_info)
|
||||
|
||||
@classmethod
|
||||
def get_available_models(cls) -> dict:
|
||||
"""Get available face swap models and their information."""
|
||||
return cls.SUPPORTED_MODELS
|
||||
|
||||
@classmethod
|
||||
def get_models_by_tier(cls, tier: str) -> dict:
|
||||
"""Get models filtered by tier (budget, mid, premium)."""
|
||||
return {
|
||||
model_id: model_info
|
||||
for model_id, model_info in cls.SUPPORTED_MODELS.items()
|
||||
if model_info.get("tier") == tier
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_models_by_capability(cls, capability: str) -> dict:
|
||||
"""Get models that support a specific capability."""
|
||||
return {
|
||||
model_id: model_info
|
||||
for model_id, model_info in cls.SUPPORTED_MODELS.items()
|
||||
if capability in model_info.get("capabilities", [])
|
||||
}
|
||||
@@ -8,11 +8,16 @@ from typing import Optional, Dict, Any
|
||||
from .image_generation import (
|
||||
ImageGenerationOptions,
|
||||
ImageGenerationResult,
|
||||
ImageEditOptions,
|
||||
ImageEditProvider,
|
||||
HuggingFaceImageProvider,
|
||||
GeminiImageProvider,
|
||||
StabilityImageProvider,
|
||||
WaveSpeedImageProvider,
|
||||
)
|
||||
from .image_generation.base import FaceSwapOptions, FaceSwapProvider
|
||||
from .image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
|
||||
from .image_generation.wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
@@ -47,6 +52,249 @@ def _get_provider(provider_name: str):
|
||||
raise ValueError(f"Unknown image provider: {provider_name}")
|
||||
|
||||
|
||||
def _get_face_swap_provider(provider_name: str) -> FaceSwapProvider:
|
||||
"""Get face swap provider by name."""
|
||||
if provider_name == "wavespeed":
|
||||
return WaveSpeedFaceSwapProvider()
|
||||
raise ValueError(f"Unknown face swap provider: {provider_name}")
|
||||
|
||||
|
||||
def _get_edit_provider(provider_name: str) -> ImageEditProvider:
|
||||
"""Get editing provider instance.
|
||||
|
||||
Args:
|
||||
provider_name: Provider name ("wavespeed", "stability", etc.)
|
||||
|
||||
Returns:
|
||||
ImageEditProvider instance
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is not supported
|
||||
"""
|
||||
if provider_name == "wavespeed":
|
||||
return WaveSpeedEditProvider()
|
||||
# TODO: Add Stability edit provider if needed
|
||||
# elif provider_name == "stability":
|
||||
# return StabilityEditProvider()
|
||||
else:
|
||||
raise ValueError(f"Unknown edit provider: {provider_name}")
|
||||
|
||||
|
||||
def _validate_image_operation(
|
||||
user_id: Optional[str],
|
||||
operation_type: str = "image-generation",
|
||||
num_operations: int = 1,
|
||||
log_prefix: str = "[Image Generation]"
|
||||
) -> None:
|
||||
"""
|
||||
Reusable pre-flight validation helper for all image operations.
|
||||
|
||||
Extracted from generate_image() to be reused across all image operation functions.
|
||||
|
||||
Args:
|
||||
user_id: User ID for subscription checking
|
||||
operation_type: Type of operation (for logging)
|
||||
num_operations: Number of operations to validate (default: 1)
|
||||
log_prefix: Logging prefix for operation-specific logs
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails (subscription limits exceeded, etc.)
|
||||
"""
|
||||
if not user_id:
|
||||
logger.warning(f"{log_prefix} ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
|
||||
return
|
||||
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_generation_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
logger.info(f"{log_prefix} 🔍 Starting pre-flight validation for user_id={user_id}")
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
# Raises HTTPException immediately if validation fails - frontend gets immediate response
|
||||
validate_image_generation_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id,
|
||||
num_images=num_operations
|
||||
)
|
||||
logger.info(f"{log_prefix} ✅ Pre-flight validation passed for user_id={user_id} - proceeding with operation")
|
||||
except HTTPException as http_ex:
|
||||
# Re-raise immediately - don't proceed with API call
|
||||
logger.error(f"{log_prefix} ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _track_image_operation_usage(
|
||||
user_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
operation_type: str,
|
||||
result_bytes: bytes,
|
||||
cost: float,
|
||||
prompt: Optional[str] = None,
|
||||
endpoint: str = "/image-generation",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
log_prefix: str = "[Image Generation]"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Reusable usage tracking helper for all image operations.
|
||||
|
||||
Extracted from generate_image() to be reused across all image operation functions.
|
||||
|
||||
Args:
|
||||
user_id: User ID for tracking
|
||||
provider: Provider name (e.g., "wavespeed", "stability")
|
||||
model: Model name used
|
||||
operation_type: Type of operation (for logging)
|
||||
result_bytes: Generated/processed image bytes
|
||||
cost: Cost of the operation
|
||||
prompt: Optional prompt text (for request size calculation)
|
||||
endpoint: API endpoint path (for logging)
|
||||
metadata: Optional additional metadata
|
||||
log_prefix: Logging prefix for operation-specific logs
|
||||
|
||||
Returns:
|
||||
Dictionary with tracking information (current_calls, cost, etc.)
|
||||
"""
|
||||
try:
|
||||
from services.database import get_db as get_db_track
|
||||
db_track = next(get_db_track())
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get or create usage summary
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush()
|
||||
|
||||
# Get current values before update
|
||||
current_calls_before = getattr(summary, "stability_calls", 0) or 0
|
||||
current_cost_before = getattr(summary, "stability_cost", 0.0) or 0.0
|
||||
|
||||
# Update image calls and cost
|
||||
new_calls = current_calls_before + 1
|
||||
new_cost = current_cost_before + cost
|
||||
|
||||
# Use direct SQL UPDATE for dynamic attributes
|
||||
from sqlalchemy import text as sql_text
|
||||
update_query = sql_text("""
|
||||
UPDATE usage_summaries
|
||||
SET stability_calls = :new_calls,
|
||||
stability_cost = :new_cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_query, {
|
||||
'new_calls': new_calls,
|
||||
'new_cost': new_cost,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
|
||||
# Update total cost
|
||||
summary.total_cost = (summary.total_cost or 0.0) + cost
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
# Determine API provider based on actual provider
|
||||
api_provider = APIProvider.STABILITY # Default for image generation
|
||||
|
||||
# Create usage log
|
||||
request_size = len(prompt.encode("utf-8")) if prompt else 0
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=api_provider,
|
||||
endpoint=endpoint,
|
||||
method="POST",
|
||||
model_used=model or "unknown",
|
||||
tokens_input=0,
|
||||
tokens_output=0,
|
||||
tokens_total=0,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=cost,
|
||||
response_time=0.0,
|
||||
status_code=200,
|
||||
request_size=request_size,
|
||||
response_size=len(result_bytes),
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
|
||||
image_limit_display = image_limit if (image_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
# Get related stats for unified log
|
||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"{log_prefix} ✅ Successfully tracked usage: user {user_id} -> {operation_type} -> {new_calls} calls, ${cost:.4f}")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
|
||||
operation_name = operation_type.replace("-", " ").title()
|
||||
print(f"""
|
||||
[SUBSCRIPTION] {operation_name}
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: {provider}
|
||||
├─ Actual Provider: {provider}
|
||||
├─ Model: {model or 'unknown'}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {image_limit_display}
|
||||
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit if audio_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
return {
|
||||
"current_calls": new_calls,
|
||||
"cost": cost,
|
||||
"total_cost": new_cost,
|
||||
}
|
||||
|
||||
except Exception as track_error:
|
||||
logger.error(f"{log_prefix} ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
|
||||
db_track.rollback()
|
||||
return {}
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"{log_prefix} ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
|
||||
return {}
|
||||
|
||||
|
||||
def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> ImageGenerationResult:
|
||||
"""Generate image with pre-flight validation.
|
||||
|
||||
@@ -55,32 +303,13 @@ def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_i
|
||||
options: Image generation options (provider, model, width, height, etc.)
|
||||
user_id: User ID for subscription checking (optional, but required for validation)
|
||||
"""
|
||||
# PRE-FLIGHT VALIDATION: Validate image generation before API call
|
||||
# MUST happen BEFORE any API calls - return immediately if validation fails
|
||||
if user_id:
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_generation_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
logger.info(f"[Image Generation] 🔍 Starting pre-flight validation for user_id={user_id}")
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
# Raises HTTPException immediately if validation fails - frontend gets immediate response
|
||||
validate_image_generation_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id
|
||||
)
|
||||
logger.info(f"[Image Generation] ✅ Pre-flight validation passed for user_id={user_id} - proceeding with image generation")
|
||||
except HTTPException as http_ex:
|
||||
# Re-raise immediately - don't proceed with API call
|
||||
logger.error(f"[Image Generation] ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
logger.warning(f"[Image Generation] ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
|
||||
# PRE-FLIGHT VALIDATION: Reuse extracted helper
|
||||
_validate_image_operation(
|
||||
user_id=user_id,
|
||||
operation_type="image-generation",
|
||||
num_operations=1,
|
||||
log_prefix="[Image Generation]"
|
||||
)
|
||||
opts = options or {}
|
||||
provider_name = _select_provider(opts.get("provider"))
|
||||
|
||||
@@ -114,151 +343,39 @@ def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_i
|
||||
provider = _get_provider(provider_name)
|
||||
result = provider.generate(image_options)
|
||||
|
||||
# TRACK USAGE after successful API call
|
||||
has_image_bytes = bool(result.image_bytes) if result else False
|
||||
image_bytes_len = len(result.image_bytes) if (result and result.image_bytes) else 0
|
||||
logger.info(f"[Image Generation] Checking tracking conditions: user_id={user_id}, has_result={bool(result)}, has_image_bytes={has_image_bytes}, image_bytes_len={image_bytes_len}")
|
||||
# TRACK USAGE after successful API call - Reuse extracted helper
|
||||
if user_id and result and result.image_bytes:
|
||||
logger.info(f"[Image Generation] ✅ API call successful, tracking usage for user {user_id}")
|
||||
try:
|
||||
from services.database import get_db as get_db_track
|
||||
db_track = next(get_db_track())
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get or create usage summary
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush()
|
||||
|
||||
# Get cost from result metadata or calculate
|
||||
estimated_cost = 0.0
|
||||
if result.metadata and "estimated_cost" in result.metadata:
|
||||
estimated_cost = float(result.metadata["estimated_cost"])
|
||||
|
||||
# Calculate cost from result metadata or estimate
|
||||
estimated_cost = 0.0
|
||||
if result.metadata and "estimated_cost" in result.metadata:
|
||||
estimated_cost = float(result.metadata["estimated_cost"])
|
||||
else:
|
||||
# Fallback: estimate based on provider/model
|
||||
if provider_name == "wavespeed":
|
||||
if result.model and "qwen" in result.model.lower():
|
||||
estimated_cost = 0.05
|
||||
else:
|
||||
# Fallback: estimate based on provider/model
|
||||
if provider_name == "wavespeed":
|
||||
if result.model and "qwen" in result.model.lower():
|
||||
estimated_cost = 0.05
|
||||
else:
|
||||
estimated_cost = 0.10 # ideogram-v3-turbo default
|
||||
elif provider_name == "stability":
|
||||
estimated_cost = 0.04
|
||||
else:
|
||||
estimated_cost = 0.05 # Default estimate
|
||||
|
||||
# Get current values before update
|
||||
current_calls_before = getattr(summary, "stability_calls", 0) or 0
|
||||
current_cost_before = getattr(summary, "stability_cost", 0.0) or 0.0
|
||||
|
||||
# Update image calls and cost
|
||||
new_calls = current_calls_before + 1
|
||||
new_cost = current_cost_before + estimated_cost
|
||||
|
||||
# Use direct SQL UPDATE for dynamic attributes
|
||||
from sqlalchemy import text as sql_text
|
||||
update_query = sql_text("""
|
||||
UPDATE usage_summaries
|
||||
SET stability_calls = :new_calls,
|
||||
stability_cost = :new_cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_query, {
|
||||
'new_calls': new_calls,
|
||||
'new_cost': new_cost,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
|
||||
# Update total cost
|
||||
summary.total_cost = (summary.total_cost or 0.0) + estimated_cost
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
# Determine API provider based on actual provider
|
||||
api_provider = APIProvider.STABILITY # Default for image generation
|
||||
|
||||
# Create usage log
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=api_provider,
|
||||
endpoint="/image-generation",
|
||||
method="POST",
|
||||
model_used=result.model or "unknown",
|
||||
tokens_input=0,
|
||||
tokens_output=0,
|
||||
tokens_total=0,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=estimated_cost,
|
||||
response_time=0.0,
|
||||
status_code=200,
|
||||
request_size=len(prompt.encode("utf-8")),
|
||||
response_size=len(result.image_bytes),
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
|
||||
image_limit_display = image_limit if (image_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
# Get related stats for unified log
|
||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"[Image Generation] ✅ Successfully tracked usage: user {user_id} -> image -> {new_calls} calls, ${estimated_cost:.4f}")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
|
||||
print(f"""
|
||||
[SUBSCRIPTION] Image Generation
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: {provider_name}
|
||||
├─ Actual Provider: {provider_name}
|
||||
├─ Model: {result.model or 'unknown'}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {image_limit_display}
|
||||
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit if audio_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
except Exception as track_error:
|
||||
logger.error(f"[Image Generation] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"[Image Generation] Full traceback: {traceback.format_exc()}")
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"[Image Generation] ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"[Image Generation] Full traceback: {traceback.format_exc()}")
|
||||
estimated_cost = 0.10 # ideogram-v3-turbo default
|
||||
elif provider_name == "stability":
|
||||
estimated_cost = 0.04
|
||||
else:
|
||||
estimated_cost = 0.05 # Default estimate
|
||||
|
||||
# Reuse tracking helper
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider=provider_name,
|
||||
model=result.model or "unknown",
|
||||
operation_type="image-generation",
|
||||
result_bytes=result.image_bytes,
|
||||
cost=estimated_cost,
|
||||
prompt=prompt,
|
||||
endpoint="/image-generation",
|
||||
metadata=result.metadata,
|
||||
log_prefix="[Image Generation]"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Image Generation] ⚠️ Skipping usage tracking: user_id={user_id}, image_bytes={len(result.image_bytes) if result.image_bytes else 0} bytes")
|
||||
|
||||
@@ -290,32 +407,13 @@ def generate_character_image(
|
||||
Returns:
|
||||
bytes: Generated image bytes with consistent character
|
||||
"""
|
||||
# PRE-FLIGHT VALIDATION: Validate image generation before API call
|
||||
if user_id:
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_generation_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
logger.info(f"[Character Image Generation] 🔍 Starting pre-flight validation for user_id={user_id}")
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
# Raises HTTPException immediately if validation fails
|
||||
validate_image_generation_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id,
|
||||
num_images=1,
|
||||
)
|
||||
logger.info(f"[Character Image Generation] ✅ Pre-flight validation passed for user_id={user_id} - proceeding with character image generation")
|
||||
except HTTPException as http_ex:
|
||||
# Re-raise immediately - don't proceed with API call
|
||||
logger.error(f"[Character Image Generation] ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
logger.warning(f"[Character Image Generation] ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
|
||||
# PRE-FLIGHT VALIDATION: Reuse extracted helper
|
||||
_validate_image_operation(
|
||||
user_id=user_id,
|
||||
operation_type="character-image-generation",
|
||||
num_operations=1,
|
||||
log_prefix="[Character Image Generation]"
|
||||
)
|
||||
|
||||
# Generate character image via WaveSpeed
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
@@ -332,132 +430,26 @@ def generate_character_image(
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# TRACK USAGE after successful API call
|
||||
has_image_bytes = bool(image_bytes) if image_bytes else False
|
||||
image_bytes_len = len(image_bytes) if image_bytes else 0
|
||||
logger.info(f"[Character Image Generation] Checking tracking conditions: user_id={user_id}, has_image_bytes={has_image_bytes}, image_bytes_len={image_bytes_len}")
|
||||
# TRACK USAGE after successful API call - Reuse extracted helper
|
||||
if user_id and image_bytes:
|
||||
logger.info(f"[Character Image Generation] ✅ API call successful, tracking usage for user {user_id}")
|
||||
try:
|
||||
from services.database import get_db as get_db_track
|
||||
db_track = next(get_db_track())
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get or create usage summary
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush()
|
||||
|
||||
# Character image cost (same as ideogram-v3-turbo)
|
||||
estimated_cost = 0.10
|
||||
current_calls_before = getattr(summary, "stability_calls", 0) or 0
|
||||
current_cost_before = getattr(summary, "stability_cost", 0.0) or 0.0
|
||||
|
||||
new_calls = current_calls_before + 1
|
||||
new_cost = current_cost_before + estimated_cost
|
||||
|
||||
# Use direct SQL UPDATE for dynamic attributes
|
||||
from sqlalchemy import text as sql_text
|
||||
update_query = sql_text("""
|
||||
UPDATE usage_summaries
|
||||
SET stability_calls = :new_calls,
|
||||
stability_cost = :new_cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_query, {
|
||||
'new_calls': new_calls,
|
||||
'new_cost': new_cost,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
|
||||
# Update total cost
|
||||
summary.total_cost = (summary.total_cost or 0.0) + estimated_cost
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
# Create usage log
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.STABILITY, # Image generation uses STABILITY provider
|
||||
endpoint="/image-generation/character",
|
||||
method="POST",
|
||||
model_used="ideogram-character",
|
||||
tokens_input=0,
|
||||
tokens_output=0,
|
||||
tokens_total=0,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=estimated_cost,
|
||||
response_time=0.0,
|
||||
status_code=200,
|
||||
request_size=len(prompt.encode("utf-8")),
|
||||
response_size=len(image_bytes),
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
image_limit_display = image_limit if (image_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
# Get related stats
|
||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG
|
||||
print(f"""
|
||||
[SUBSCRIPTION] Image Generation (Character)
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: wavespeed
|
||||
├─ Actual Provider: wavespeed
|
||||
├─ Model: ideogram-character
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {image_limit_display}
|
||||
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit if audio_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
logger.info(f"[Character Image Generation] ✅ Successfully tracked usage: user {user_id} -> {new_calls} calls, ${estimated_cost:.4f}")
|
||||
|
||||
except Exception as track_error:
|
||||
logger.error(f"[Character Image Generation] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"[Character Image Generation] Full traceback: {traceback.format_exc()}")
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"[Character Image Generation] ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"[Character Image Generation] Full traceback: {traceback.format_exc()}")
|
||||
|
||||
# Character image cost (same as ideogram-v3-turbo)
|
||||
estimated_cost = 0.10
|
||||
|
||||
# Reuse tracking helper
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider="wavespeed",
|
||||
model="ideogram-character",
|
||||
operation_type="character-image-generation",
|
||||
result_bytes=image_bytes,
|
||||
cost=estimated_cost,
|
||||
prompt=prompt,
|
||||
endpoint="/image-generation/character",
|
||||
metadata=None,
|
||||
log_prefix="[Character Image Generation]"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Character Image Generation] ⚠️ Skipping usage tracking: user_id={user_id}, image_bytes={len(image_bytes) if image_bytes else 0} bytes")
|
||||
|
||||
@@ -476,3 +468,210 @@ def generate_character_image(
|
||||
)
|
||||
|
||||
|
||||
def generate_image_edit(
|
||||
image_base64: str,
|
||||
prompt: str,
|
||||
operation: str = "general_edit",
|
||||
model: Optional[str] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
user_id: Optional[str] = None
|
||||
) -> ImageGenerationResult:
|
||||
"""
|
||||
Generate edited image - REUSES validation and tracking helpers.
|
||||
|
||||
Args:
|
||||
image_base64: Base64-encoded input image (or data URI)
|
||||
prompt: Edit instruction prompt
|
||||
operation: Type of edit operation (e.g., "general_edit", "inpaint", "outpaint")
|
||||
model: Model ID to use (default: auto-select based on provider)
|
||||
options: Additional options (mask_base64, negative_prompt, width, height, etc.)
|
||||
user_id: User ID for validation and tracking
|
||||
|
||||
Returns:
|
||||
ImageGenerationResult with edited image
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails or editing fails
|
||||
ValueError: If options are invalid
|
||||
"""
|
||||
# 1. REUSE: Validation helper
|
||||
_validate_image_operation(
|
||||
user_id=user_id,
|
||||
operation_type="image-edit",
|
||||
num_operations=1,
|
||||
log_prefix="[Image Edit]"
|
||||
)
|
||||
|
||||
# 2. Determine provider from model or default to wavespeed
|
||||
opts = options or {}
|
||||
provider_name = opts.get("provider", "wavespeed")
|
||||
|
||||
# If model is specified and starts with "wavespeed", use wavespeed provider
|
||||
if model and (model.startswith("wavespeed") or model.startswith("qwen") or model.startswith("flux") or model.startswith("nano-banana")):
|
||||
provider_name = "wavespeed"
|
||||
|
||||
# 3. Get provider (REUSES provider pattern)
|
||||
try:
|
||||
provider = _get_edit_provider(provider_name)
|
||||
except ValueError as e:
|
||||
logger.error(f"[Image Edit] ❌ Provider error: {str(e)}")
|
||||
raise ValueError(f"Unsupported edit provider: {provider_name}")
|
||||
|
||||
# 4. Prepare edit options
|
||||
edit_options = ImageEditOptions(
|
||||
image_base64=image_base64,
|
||||
prompt=prompt,
|
||||
operation=operation,
|
||||
mask_base64=opts.get("mask_base64"),
|
||||
negative_prompt=opts.get("negative_prompt"),
|
||||
model=model,
|
||||
width=opts.get("width"),
|
||||
height=opts.get("height"),
|
||||
guidance_scale=opts.get("guidance_scale"),
|
||||
steps=opts.get("steps"),
|
||||
seed=opts.get("seed"),
|
||||
extra=opts.get("extra"),
|
||||
)
|
||||
|
||||
# 5. Edit image
|
||||
logger.info(f"[Image Edit] Starting edit: operation={operation}, model={model}, provider={provider_name}")
|
||||
try:
|
||||
result = provider.edit(edit_options)
|
||||
except Exception as e:
|
||||
logger.error(f"[Image Edit] ❌ Edit failed: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "Image editing failed",
|
||||
"message": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def generate_face_swap(
|
||||
base_image_base64: str,
|
||||
face_image_base64: str,
|
||||
model: Optional[str] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
user_id: Optional[str] = None
|
||||
) -> ImageGenerationResult:
|
||||
"""
|
||||
Generate face swap - REUSES validation and tracking helpers.
|
||||
|
||||
Args:
|
||||
base_image_base64: Base64-encoded base image (or data URI)
|
||||
face_image_base64: Base64-encoded face image to swap (or data URI)
|
||||
model: Model ID to use (default: auto-select)
|
||||
options: Additional options (target_face_index, target_gender, etc.)
|
||||
user_id: User ID for validation and tracking
|
||||
|
||||
Returns:
|
||||
ImageGenerationResult with swapped face image
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails or face swap fails
|
||||
ValueError: If options are invalid
|
||||
"""
|
||||
# 1. REUSE: Validation helper
|
||||
_validate_image_operation(
|
||||
user_id=user_id,
|
||||
operation_type="face-swap",
|
||||
image_base64=base_image_base64, # Use base image for validation
|
||||
log_prefix="[Face Swap]"
|
||||
)
|
||||
|
||||
# 2. Get provider (default to wavespeed)
|
||||
provider_name = "wavespeed"
|
||||
provider = _get_face_swap_provider(provider_name)
|
||||
|
||||
# 3. Prepare options
|
||||
face_swap_options = FaceSwapOptions(
|
||||
base_image_base64=base_image_base64,
|
||||
face_image_base64=face_image_base64,
|
||||
model=model,
|
||||
target_face_index=options.get("target_face_index") if options else None,
|
||||
target_gender=options.get("target_gender") if options else None,
|
||||
extra=options,
|
||||
)
|
||||
|
||||
# 4. Swap face
|
||||
try:
|
||||
result = provider.swap_face(face_swap_options)
|
||||
|
||||
# 5. REUSE: Tracking helper
|
||||
if user_id and result and result.image_bytes:
|
||||
logger.info(f"[Face Swap] ✅ API call successful, tracking usage for user {user_id}")
|
||||
|
||||
# Get model cost
|
||||
model_id = model or (list(WaveSpeedFaceSwapProvider.SUPPORTED_MODELS.keys())[0] if WaveSpeedFaceSwapProvider.SUPPORTED_MODELS else "unknown")
|
||||
model_info = WaveSpeedFaceSwapProvider.SUPPORTED_MODELS.get(model_id, {})
|
||||
estimated_cost = model_info.get("cost", 0.025) # Default to Pro cost
|
||||
|
||||
# Reuse tracking helper
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider=provider_name,
|
||||
model=model_id,
|
||||
operation_type="face-swap",
|
||||
result_bytes=result.image_bytes,
|
||||
cost=estimated_cost,
|
||||
prompt=None, # Face swap doesn't use prompts
|
||||
endpoint="/image-studio/face-swap/process",
|
||||
metadata={
|
||||
"base_image_size": len(base_image_base64),
|
||||
"face_image_size": len(face_image_base64),
|
||||
},
|
||||
log_prefix="[Face Swap]"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Face Swap] ⚠️ Skipping usage tracking: user_id={user_id}, image_bytes={len(result.image_bytes) if result and result.image_bytes else 0} bytes")
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as api_error:
|
||||
logger.error(f"[Face Swap] Face swap API failed: {api_error}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "Face swap failed",
|
||||
"message": str(api_error)
|
||||
}
|
||||
)
|
||||
|
||||
# 6. REUSE: Tracking helper
|
||||
if user_id and result and result.image_bytes:
|
||||
logger.info(f"[Image Edit] ✅ API call successful, tracking usage for user {user_id}")
|
||||
|
||||
# Get cost from result metadata or estimate
|
||||
estimated_cost = 0.0
|
||||
if result.metadata and "estimated_cost" in result.metadata:
|
||||
estimated_cost = float(result.metadata["estimated_cost"])
|
||||
else:
|
||||
# Fallback: estimate based on provider/model
|
||||
if provider_name == "wavespeed":
|
||||
# Default WaveSpeed edit cost
|
||||
estimated_cost = 0.02 # Default for most editing models
|
||||
else:
|
||||
estimated_cost = 0.05 # Default estimate
|
||||
|
||||
# Reuse tracking helper
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider=provider_name,
|
||||
model=result.model or model or "unknown",
|
||||
operation_type="image-edit",
|
||||
result_bytes=result.image_bytes,
|
||||
cost=estimated_cost,
|
||||
prompt=prompt,
|
||||
endpoint="/image-generation/edit",
|
||||
metadata=result.metadata,
|
||||
log_prefix="[Image Edit]"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Image Edit] ⚠️ Skipping usage tracking: user_id={user_id}, image_bytes={len(result.image_bytes) if result.image_bytes else 0} bytes")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,9 @@ from .asset_audit import AssetAuditService
|
||||
from .channel_pack import ChannelPackService
|
||||
from .campaign_storage import CampaignStorageService
|
||||
from .product_image_service import ProductImageService
|
||||
from .product_animation_service import ProductAnimationService, ProductAnimationRequest
|
||||
from .product_video_service import ProductVideoService, ProductVideoRequest
|
||||
from .product_avatar_service import ProductAvatarService, ProductAvatarRequest
|
||||
|
||||
__all__ = [
|
||||
"ProductMarketingOrchestrator",
|
||||
@@ -16,5 +19,11 @@ __all__ = [
|
||||
"ChannelPackService",
|
||||
"CampaignStorageService",
|
||||
"ProductImageService",
|
||||
"ProductAnimationService",
|
||||
"ProductAnimationRequest",
|
||||
"ProductVideoService",
|
||||
"ProductVideoRequest",
|
||||
"ProductAvatarService",
|
||||
"ProductAvatarRequest",
|
||||
]
|
||||
|
||||
|
||||
@@ -163,6 +163,7 @@ class ProductMarketingOrchestrator:
|
||||
"asset_id": asset_node.asset_id,
|
||||
"asset_type": asset_node.asset_type,
|
||||
"channel": asset_node.channel,
|
||||
"campaign_id": blueprint.campaign_id, # Include campaign_id for tracking
|
||||
"proposed_prompt": enhanced_prompt,
|
||||
"recommended_template": recommended_template.get('id') if recommended_template else None,
|
||||
"recommended_provider": recommended_template.get('recommended_provider', 'wavespeed') if recommended_template else 'wavespeed',
|
||||
@@ -170,6 +171,67 @@ class ProductMarketingOrchestrator:
|
||||
"concept_summary": self._generate_concept_summary(enhanced_prompt),
|
||||
}
|
||||
|
||||
elif asset_node.asset_type == "video":
|
||||
# Video asset proposals - determine if animation (image-to-video) or demo (text-to-video)
|
||||
# Default to animation if we have product image, otherwise demo
|
||||
video_subtype = asset_proposal.get('video_subtype', 'animation') if 'asset_proposal' in locals() else 'demo'
|
||||
|
||||
# For demo videos (text-to-video), we need product description
|
||||
if video_subtype == "demo" or not product_context or not product_context.get('product_image_base64'):
|
||||
# Text-to-video demo video
|
||||
video_type = "demo" # Default, can be customized
|
||||
if asset_node.channel in ["tiktok", "instagram"]:
|
||||
video_type = "storytelling" # Storytelling for social media
|
||||
elif asset_node.channel in ["linkedin", "youtube"]:
|
||||
video_type = "feature_highlight" # Feature highlights for professional
|
||||
|
||||
# Estimate cost for text-to-video (WAN 2.5: $0.05-$0.15/second)
|
||||
duration = 10 # Default 10s for demo videos
|
||||
resolution = "720p" # Default
|
||||
cost_per_second = 0.10 if resolution == "720p" else (0.15 if resolution == "1080p" else 0.05)
|
||||
cost_estimate = duration * cost_per_second
|
||||
|
||||
proposals[asset_node.asset_id] = {
|
||||
"asset_id": asset_node.asset_id,
|
||||
"asset_type": asset_node.asset_type,
|
||||
"video_subtype": "demo", # Text-to-video
|
||||
"channel": asset_node.channel,
|
||||
"campaign_id": blueprint.campaign_id,
|
||||
"video_type": video_type,
|
||||
"duration": duration,
|
||||
"resolution": resolution,
|
||||
"cost_estimate": cost_estimate,
|
||||
"concept_summary": f"Product {video_type} video optimized for {asset_node.channel}",
|
||||
"note": "Text-to-video demo - requires product description",
|
||||
}
|
||||
else:
|
||||
# Image-to-video animation
|
||||
animation_type = "reveal" # Default
|
||||
if asset_node.channel in ["tiktok", "instagram", "youtube"]:
|
||||
animation_type = "demo" # Demo animations for social media
|
||||
elif asset_node.channel in ["linkedin", "facebook"]:
|
||||
animation_type = "reveal" # Professional reveal for B2B
|
||||
|
||||
# Estimate cost for image-to-video (WAN 2.5: $0.05-$0.15/second)
|
||||
duration = 5 # Default 5s for animations
|
||||
resolution = "720p" # Default
|
||||
cost_per_second = 0.10 if resolution == "720p" else (0.15 if resolution == "1080p" else 0.05)
|
||||
cost_estimate = duration * cost_per_second
|
||||
|
||||
proposals[asset_node.asset_id] = {
|
||||
"asset_id": asset_node.asset_id,
|
||||
"asset_type": asset_node.asset_type,
|
||||
"video_subtype": "animation", # Image-to-video
|
||||
"channel": asset_node.channel,
|
||||
"campaign_id": blueprint.campaign_id,
|
||||
"animation_type": animation_type,
|
||||
"duration": duration,
|
||||
"resolution": resolution,
|
||||
"cost_estimate": cost_estimate,
|
||||
"concept_summary": f"Product {animation_type} animation optimized for {asset_node.channel}",
|
||||
"note": "Requires product image - will be provided during generation",
|
||||
}
|
||||
|
||||
elif asset_node.asset_type == "text":
|
||||
base_request = f"Write {asset_node.channel} {asset_node.asset_type} for product launch"
|
||||
enhanced_prompt = self.prompt_builder.build_marketing_copy_prompt(
|
||||
@@ -184,6 +246,7 @@ class ProductMarketingOrchestrator:
|
||||
"asset_id": asset_node.asset_id,
|
||||
"asset_type": asset_node.asset_type,
|
||||
"channel": asset_node.channel,
|
||||
"campaign_id": blueprint.campaign_id, # Include campaign_id for tracking
|
||||
"proposed_prompt": enhanced_prompt,
|
||||
"cost_estimate": 0.0, # Text generation cost is minimal
|
||||
"concept_summary": "Marketing copy optimized for channel and persona",
|
||||
@@ -242,6 +305,124 @@ class ProductMarketingOrchestrator:
|
||||
],
|
||||
}
|
||||
|
||||
elif asset_type == "video":
|
||||
# Check video subtype: "animation" (image-to-video) or "demo" (text-to-video)
|
||||
video_subtype = asset_proposal.get('video_subtype', 'animation')
|
||||
|
||||
if video_subtype == "demo":
|
||||
# Text-to-video: Product demo video from description
|
||||
from .product_video_service import ProductVideoService, ProductVideoRequest
|
||||
|
||||
# Get product info from context
|
||||
product_name = product_context.get('product_name', 'Product') if product_context else 'Product'
|
||||
product_description = product_context.get('product_description', '') if product_context else ''
|
||||
|
||||
if not product_description:
|
||||
raise ValueError("Product description required for text-to-video demo generation")
|
||||
|
||||
# Get brand context
|
||||
brand_dna = self.brand_dna_sync.get_brand_dna_tokens(user_id)
|
||||
brand_context = {
|
||||
"visual_identity": brand_dna.get("visual_identity", {}),
|
||||
"persona": brand_dna.get("persona", {}),
|
||||
}
|
||||
|
||||
# Get video type from proposal or default
|
||||
video_type = asset_proposal.get('video_type', 'demo')
|
||||
|
||||
# Create video service
|
||||
video_service = ProductVideoService()
|
||||
|
||||
# Create video request
|
||||
video_request = ProductVideoRequest(
|
||||
product_name=product_name,
|
||||
product_description=product_description,
|
||||
video_type=video_type,
|
||||
resolution=asset_proposal.get('resolution', '720p'),
|
||||
duration=asset_proposal.get('duration', 10),
|
||||
audio_base64=asset_proposal.get('audio_base64'),
|
||||
brand_context=brand_context,
|
||||
additional_context=asset_proposal.get('additional_context'),
|
||||
)
|
||||
|
||||
# Generate video using unified ai_video_generate()
|
||||
result = await video_service.generate_product_video(video_request, user_id)
|
||||
|
||||
# Extract campaign_id for metadata
|
||||
campaign_id = asset_proposal.get('campaign_id')
|
||||
asset_id = asset_proposal.get('asset_id', '')
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"asset_type": "video",
|
||||
"video_subtype": "demo",
|
||||
"video_url": result.get('file_url'),
|
||||
"video_filename": result.get('filename'),
|
||||
"cost": result.get('cost', 0.0),
|
||||
"video_type": video_type,
|
||||
"campaign_id": campaign_id,
|
||||
"asset_id": asset_id,
|
||||
}
|
||||
|
||||
else:
|
||||
# Image-to-video: Product animation
|
||||
from .product_animation_service import ProductAnimationService, ProductAnimationRequest
|
||||
|
||||
# Get product image from proposal or product context
|
||||
product_image_base64 = asset_proposal.get('product_image_base64')
|
||||
if not product_image_base64 and product_context:
|
||||
product_image_base64 = product_context.get('product_image_base64')
|
||||
|
||||
if not product_image_base64:
|
||||
raise ValueError("Product image required for image-to-video animation generation")
|
||||
|
||||
# Get animation type from proposal or default to "reveal"
|
||||
animation_type = asset_proposal.get('animation_type', 'reveal')
|
||||
product_name = product_context.get('product_name', 'Product') if product_context else 'Product'
|
||||
product_description = product_context.get('product_description') if product_context else None
|
||||
|
||||
# Get brand context
|
||||
brand_dna = self.brand_dna_sync.get_brand_dna_tokens(user_id)
|
||||
brand_context = {
|
||||
"visual_identity": brand_dna.get("visual_identity", {}),
|
||||
"persona": brand_dna.get("persona", {}),
|
||||
}
|
||||
|
||||
# Create animation service
|
||||
animation_service = ProductAnimationService()
|
||||
|
||||
# Create animation request
|
||||
animation_request = ProductAnimationRequest(
|
||||
product_image_base64=product_image_base64,
|
||||
animation_type=animation_type,
|
||||
product_name=product_name,
|
||||
product_description=product_description,
|
||||
resolution=asset_proposal.get('resolution', '720p'),
|
||||
duration=asset_proposal.get('duration', 5),
|
||||
audio_base64=asset_proposal.get('audio_base64'),
|
||||
brand_context=brand_context,
|
||||
additional_context=asset_proposal.get('additional_context'),
|
||||
)
|
||||
|
||||
# Generate video
|
||||
result = await animation_service.animate_product(animation_request, user_id)
|
||||
|
||||
# Extract campaign_id for metadata
|
||||
campaign_id = asset_proposal.get('campaign_id')
|
||||
asset_id = asset_proposal.get('asset_id', '')
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"asset_type": "video",
|
||||
"video_subtype": "animation",
|
||||
"video_url": result.get('video_url'),
|
||||
"video_filename": result.get('filename'),
|
||||
"cost": result.get('cost', 0.0),
|
||||
"animation_type": animation_type,
|
||||
"campaign_id": campaign_id,
|
||||
"asset_id": asset_id,
|
||||
}
|
||||
|
||||
elif asset_type == "text":
|
||||
# Import text generation service and tracker
|
||||
import asyncio
|
||||
@@ -457,6 +638,10 @@ Return only the final copy text without explanations or markdown formatting."""
|
||||
if asset_type == "image":
|
||||
# Premium quality image: ~5-6 credits
|
||||
return 5.0
|
||||
elif asset_type == "video":
|
||||
# WAN 2.5 Image-to-Video: $0.05-$0.15/second
|
||||
# Default: 5 seconds at 720p = $0.50
|
||||
return 0.50
|
||||
elif asset_type == "text":
|
||||
return 0.0 # Text generation is typically included
|
||||
else:
|
||||
|
||||
221
backend/services/product_marketing/product_animation_service.py
Normal file
221
backend/services/product_marketing/product_animation_service.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""
|
||||
Product Animation Service
|
||||
Handles product animation workflows using Transform Studio (WAN 2.5 Image-to-Video).
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from loguru import logger
|
||||
from dataclasses import dataclass
|
||||
|
||||
from services.image_studio.transform_service import TransformStudioService, TransformImageToVideoRequest
|
||||
from services.image_studio.studio_manager import ImageStudioManager
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("product_marketing.animation")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProductAnimationRequest:
|
||||
"""Request for product animation."""
|
||||
product_image_base64: str
|
||||
animation_type: str # "reveal", "rotation", "demo", "lifestyle"
|
||||
product_name: str
|
||||
product_description: Optional[str] = None
|
||||
resolution: str = "720p" # 480p, 720p, 1080p
|
||||
duration: int = 5 # 5 or 10 seconds
|
||||
audio_base64: Optional[str] = None
|
||||
brand_context: Optional[Dict[str, Any]] = None
|
||||
additional_context: Optional[str] = None
|
||||
|
||||
|
||||
class ProductAnimationService:
|
||||
"""Service for product animation workflows."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Product Animation Service."""
|
||||
self.transform_service = TransformStudioService()
|
||||
self.image_studio = ImageStudioManager()
|
||||
logger.info("[Product Animation Service] Initialized")
|
||||
|
||||
def _build_animation_prompt(
|
||||
self,
|
||||
animation_type: str,
|
||||
product_name: str,
|
||||
product_description: Optional[str],
|
||||
brand_context: Optional[Dict[str, Any]],
|
||||
additional_context: Optional[str]
|
||||
) -> str:
|
||||
"""
|
||||
Build animation prompt based on animation type and product context.
|
||||
|
||||
Args:
|
||||
animation_type: Type of animation (reveal, rotation, demo, lifestyle)
|
||||
product_name: Product name
|
||||
product_description: Product description
|
||||
brand_context: Brand DNA context
|
||||
additional_context: Additional context
|
||||
|
||||
Returns:
|
||||
Animation prompt
|
||||
"""
|
||||
base_prompt = f"{product_name}"
|
||||
if product_description:
|
||||
base_prompt += f": {product_description}"
|
||||
|
||||
# Animation-specific prompts
|
||||
animation_prompts = {
|
||||
"reveal": f"{base_prompt} elegantly revealing, smooth camera movement, professional product showcase, cinematic lighting",
|
||||
"rotation": f"{base_prompt} slowly rotating 360 degrees, smooth rotation, professional product photography, studio lighting, clean background",
|
||||
"demo": f"{base_prompt} in use, demonstrating features, dynamic movement, engaging presentation, professional product demo",
|
||||
"lifestyle": f"{base_prompt} in realistic lifestyle setting, natural environment, authentic use case, relatable scenario",
|
||||
}
|
||||
|
||||
prompt = animation_prompts.get(animation_type, base_prompt)
|
||||
|
||||
# Add brand context if available
|
||||
if brand_context:
|
||||
visual_identity = brand_context.get("visual_identity", {})
|
||||
if visual_identity.get("color_palette"):
|
||||
colors = ", ".join(visual_identity["color_palette"][:3]) # First 3 colors
|
||||
prompt += f", {colors} color scheme"
|
||||
|
||||
if visual_identity.get("style_guidelines"):
|
||||
style = visual_identity["style_guidelines"].get("aesthetic", "")
|
||||
if style:
|
||||
prompt += f", {style} style"
|
||||
|
||||
# Add additional context
|
||||
if additional_context:
|
||||
prompt += f", {additional_context}"
|
||||
|
||||
return prompt
|
||||
|
||||
async def animate_product(
|
||||
self,
|
||||
request: ProductAnimationRequest,
|
||||
user_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Animate a product image into a video.
|
||||
|
||||
Args:
|
||||
request: Product animation request
|
||||
user_id: User ID for tracking
|
||||
|
||||
Returns:
|
||||
Animation result with video URL and metadata
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
f"[Product Animation] Animating product '{request.product_name}' "
|
||||
f"with type '{request.animation_type}' for user {user_id}"
|
||||
)
|
||||
|
||||
# Build animation prompt
|
||||
animation_prompt = self._build_animation_prompt(
|
||||
animation_type=request.animation_type,
|
||||
product_name=request.product_name,
|
||||
product_description=request.product_description,
|
||||
brand_context=request.brand_context,
|
||||
additional_context=request.additional_context
|
||||
)
|
||||
|
||||
# Create transform request
|
||||
transform_request = TransformImageToVideoRequest(
|
||||
image_base64=request.product_image_base64,
|
||||
prompt=animation_prompt,
|
||||
audio_base64=request.audio_base64,
|
||||
resolution=request.resolution,
|
||||
duration=request.duration,
|
||||
enable_prompt_expansion=True, # Expand prompt for better results
|
||||
)
|
||||
|
||||
# Generate video using Transform Studio
|
||||
result = await self.transform_service.transform_image_to_video(
|
||||
request=transform_request,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Add product-specific metadata
|
||||
result["product_name"] = request.product_name
|
||||
result["animation_type"] = request.animation_type
|
||||
result["source_module"] = "product_marketing"
|
||||
|
||||
logger.info(
|
||||
f"[Product Animation] ✅ Product animation completed: "
|
||||
f"cost=${result.get('cost', 0):.2f}, video_url={result.get('video_url', 'N/A')}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Product Animation] ❌ Error animating product: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_product_reveal(
|
||||
self,
|
||||
product_image_base64: str,
|
||||
product_name: str,
|
||||
product_description: Optional[str],
|
||||
user_id: str,
|
||||
resolution: str = "720p",
|
||||
duration: int = 5,
|
||||
brand_context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create product reveal animation."""
|
||||
request = ProductAnimationRequest(
|
||||
product_image_base64=product_image_base64,
|
||||
animation_type="reveal",
|
||||
product_name=product_name,
|
||||
product_description=product_description,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
brand_context=brand_context
|
||||
)
|
||||
return await self.animate_product(request, user_id)
|
||||
|
||||
async def create_product_rotation(
|
||||
self,
|
||||
product_image_base64: str,
|
||||
product_name: str,
|
||||
product_description: Optional[str],
|
||||
user_id: str,
|
||||
resolution: str = "720p",
|
||||
duration: int = 10, # Longer for full rotation
|
||||
brand_context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create 360° product rotation animation."""
|
||||
request = ProductAnimationRequest(
|
||||
product_image_base64=product_image_base64,
|
||||
animation_type="rotation",
|
||||
product_name=product_name,
|
||||
product_description=product_description,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
brand_context=brand_context
|
||||
)
|
||||
return await self.animate_product(request, user_id)
|
||||
|
||||
async def create_product_demo(
|
||||
self,
|
||||
product_image_base64: str,
|
||||
product_name: str,
|
||||
product_description: Optional[str],
|
||||
user_id: str,
|
||||
resolution: str = "720p",
|
||||
duration: int = 10,
|
||||
audio_base64: Optional[str] = None,
|
||||
brand_context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create product demo video."""
|
||||
request = ProductAnimationRequest(
|
||||
product_image_base64=product_image_base64,
|
||||
animation_type="demo",
|
||||
product_name=product_name,
|
||||
product_description=product_description,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
audio_base64=audio_base64,
|
||||
brand_context=brand_context
|
||||
)
|
||||
return await self.animate_product(request, user_id)
|
||||
380
backend/services/product_marketing/product_avatar_service.py
Normal file
380
backend/services/product_marketing/product_avatar_service.py
Normal file
@@ -0,0 +1,380 @@
|
||||
"""
|
||||
Product Avatar Service
|
||||
Handles product explainer video generation using InfiniteTalk (talking avatars).
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from loguru import logger
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import uuid
|
||||
import os
|
||||
import base64
|
||||
|
||||
from services.image_studio.infinitetalk_adapter import InfiniteTalkService
|
||||
from services.story_writer.audio_generation_service import StoryAudioGenerationService
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("product_marketing.avatar")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProductAvatarRequest:
|
||||
"""Request for product explainer video with talking avatar."""
|
||||
avatar_image_base64: str # Product image, brand spokesperson, or brand mascot
|
||||
script_text: Optional[str] = None # Text script to convert to audio
|
||||
audio_base64: Optional[str] = None # Pre-generated audio (alternative to script_text)
|
||||
product_name: str = "Product"
|
||||
product_description: Optional[str] = None
|
||||
explainer_type: str = "product_overview" # product_overview, feature_explainer, tutorial, brand_message
|
||||
resolution: str = "720p" # 480p or 720p
|
||||
prompt: Optional[str] = None # Optional prompt for expression/style
|
||||
mask_image_base64: Optional[str] = None # Optional mask for animatable regions
|
||||
seed: Optional[int] = None
|
||||
brand_context: Optional[Dict[str, Any]] = None
|
||||
additional_context: Optional[str] = None
|
||||
|
||||
|
||||
class ProductAvatarService:
|
||||
"""Service for product explainer video generation using InfiniteTalk."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Product Avatar Service."""
|
||||
self.infinitetalk_service = InfiniteTalkService()
|
||||
self.audio_service = StoryAudioGenerationService()
|
||||
logger.info("[Product Avatar Service] Initialized")
|
||||
|
||||
def _build_avatar_prompt(
|
||||
self,
|
||||
explainer_type: str,
|
||||
product_name: str,
|
||||
product_description: Optional[str],
|
||||
brand_context: Optional[Dict[str, Any]],
|
||||
additional_context: Optional[str]
|
||||
) -> str:
|
||||
"""
|
||||
Build avatar prompt based on explainer type and product context.
|
||||
|
||||
Args:
|
||||
explainer_type: Type of explainer (product_overview, feature_explainer, tutorial, brand_message)
|
||||
product_name: Product name
|
||||
product_description: Product description
|
||||
brand_context: Brand DNA context
|
||||
additional_context: Additional context
|
||||
|
||||
Returns:
|
||||
Avatar animation prompt
|
||||
"""
|
||||
base_description = f"{product_name}"
|
||||
if product_description:
|
||||
base_description += f": {product_description}"
|
||||
|
||||
# Explainer type-specific prompts
|
||||
explainer_prompts = {
|
||||
"product_overview": (
|
||||
f"Professional product presentation of {base_description}, "
|
||||
f"engaging and informative, clear communication, confident expression, "
|
||||
f"professional setting, modern and clean aesthetic"
|
||||
),
|
||||
"feature_explainer": (
|
||||
f"Demonstrating features of {base_description}, "
|
||||
f"detailed explanation, pointing gestures, clear visual communication, "
|
||||
f"educational and informative, professional presentation"
|
||||
),
|
||||
"tutorial": (
|
||||
f"Tutorial presentation for {base_description}, "
|
||||
f"step-by-step explanation, instructional and clear, "
|
||||
f"friendly and approachable, educational setting"
|
||||
),
|
||||
"brand_message": (
|
||||
f"Brand message delivery for {base_description}, "
|
||||
f"authentic and compelling, brand storytelling, "
|
||||
f"emotional connection, professional brand representation"
|
||||
),
|
||||
}
|
||||
|
||||
prompt = explainer_prompts.get(explainer_type, base_description)
|
||||
|
||||
# Add brand context if available
|
||||
if brand_context:
|
||||
visual_identity = brand_context.get("visual_identity", {})
|
||||
if visual_identity.get("style_guidelines"):
|
||||
style = visual_identity["style_guidelines"].get("aesthetic", "")
|
||||
if style:
|
||||
prompt += f", {style} style"
|
||||
|
||||
# Add brand values if available
|
||||
if visual_identity.get("brand_values"):
|
||||
values = ", ".join(visual_identity["brand_values"][:2]) # First 2 values
|
||||
prompt += f", embodying {values}"
|
||||
|
||||
# Add additional context
|
||||
if additional_context:
|
||||
prompt += f", {additional_context}"
|
||||
|
||||
return prompt
|
||||
|
||||
def _generate_audio_from_script(
|
||||
self,
|
||||
script_text: str,
|
||||
user_id: str,
|
||||
output_dir: Path
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate audio from script text using TTS.
|
||||
|
||||
Args:
|
||||
script_text: Text to convert to speech
|
||||
user_id: User ID for tracking
|
||||
output_dir: Directory to save temporary audio file
|
||||
|
||||
Returns:
|
||||
Audio bytes
|
||||
"""
|
||||
try:
|
||||
# Create temporary audio file
|
||||
audio_filename = f"avatar_audio_{uuid.uuid4().hex[:8]}.mp3"
|
||||
audio_path = output_dir / audio_filename
|
||||
|
||||
# Generate audio using gTTS (free, always available)
|
||||
# Note: For premium voices, we could integrate Minimax voice clone here
|
||||
success = self.audio_service._generate_audio_gtts(
|
||||
text=script_text,
|
||||
output_path=audio_path,
|
||||
lang="en",
|
||||
slow=False
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise RuntimeError("Failed to generate audio from script")
|
||||
|
||||
# Read audio bytes
|
||||
with open(audio_path, 'rb') as f:
|
||||
audio_bytes = f.read()
|
||||
|
||||
# Clean up temporary file
|
||||
try:
|
||||
os.remove(audio_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(f"[Product Avatar] Generated audio from script: {len(audio_bytes)} bytes")
|
||||
return audio_bytes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Product Avatar] Error generating audio: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def generate_product_explainer(
|
||||
self,
|
||||
request: ProductAvatarRequest,
|
||||
user_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate product explainer video using InfiniteTalk.
|
||||
|
||||
Args:
|
||||
request: Product avatar request
|
||||
user_id: User ID for tracking
|
||||
|
||||
Returns:
|
||||
Explainer video result with video URL and metadata
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
f"[Product Avatar] Generating {request.explainer_type} explainer for product '{request.product_name}' "
|
||||
f"for user {user_id}"
|
||||
)
|
||||
|
||||
# Prepare audio
|
||||
audio_base64 = request.audio_base64
|
||||
if not audio_base64 and request.script_text:
|
||||
# Generate audio from script
|
||||
base_dir = Path(__file__).parent.parent.parent.parent
|
||||
temp_dir = base_dir / "temp_audio"
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
audio_bytes = self._generate_audio_from_script(
|
||||
script_text=request.script_text,
|
||||
user_id=user_id,
|
||||
output_dir=temp_dir
|
||||
)
|
||||
|
||||
# Convert to base64
|
||||
audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
|
||||
audio_base64 = f"data:audio/mpeg;base64,{audio_base64}"
|
||||
|
||||
if not audio_base64:
|
||||
raise ValueError("Either audio_base64 or script_text must be provided")
|
||||
|
||||
# Build avatar prompt
|
||||
avatar_prompt = request.prompt
|
||||
if not avatar_prompt:
|
||||
avatar_prompt = self._build_avatar_prompt(
|
||||
explainer_type=request.explainer_type,
|
||||
product_name=request.product_name,
|
||||
product_description=request.product_description,
|
||||
brand_context=request.brand_context,
|
||||
additional_context=request.additional_context
|
||||
)
|
||||
|
||||
# Generate video using InfiniteTalk
|
||||
result = await self.infinitetalk_service.create_talking_avatar(
|
||||
image_base64=request.avatar_image_base64,
|
||||
audio_base64=audio_base64,
|
||||
resolution=request.resolution,
|
||||
prompt=avatar_prompt,
|
||||
mask_image_base64=request.mask_image_base64,
|
||||
seed=request.seed,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Extract video bytes and save to user directory
|
||||
video_bytes = result.get("video_bytes")
|
||||
if not video_bytes:
|
||||
raise ValueError("Avatar generation returned no video bytes")
|
||||
|
||||
# Save video file
|
||||
base_dir = Path(__file__).parent.parent.parent.parent
|
||||
output_dir = base_dir / "product_avatars"
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create user-specific directory
|
||||
user_dir = output_dir / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate filename
|
||||
safe_product_name = "".join(c for c in request.product_name if c.isalnum() or c in (' ', '-', '_')).strip()[:30]
|
||||
filename = f"explainer_{safe_product_name}_{request.explainer_type}_{uuid.uuid4().hex[:8]}.mp4"
|
||||
filename = filename.replace(" ", "_").replace("/", "_").replace("\\", "_")
|
||||
|
||||
# Save file
|
||||
file_path = user_dir / filename
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(video_bytes)
|
||||
|
||||
# Check file size (500MB max)
|
||||
file_size = os.path.getsize(file_path)
|
||||
if file_size > 500 * 1024 * 1024:
|
||||
os.remove(file_path)
|
||||
raise RuntimeError(f"Video file too large: {file_size / (1024*1024):.2f}MB (max 500MB)")
|
||||
|
||||
file_url = f"/api/product-marketing/avatars/{user_id}/{filename}"
|
||||
|
||||
# Add product-specific metadata
|
||||
result["product_name"] = request.product_name
|
||||
result["explainer_type"] = request.explainer_type
|
||||
result["source_module"] = "product_marketing"
|
||||
result["filename"] = filename
|
||||
result["file_path"] = str(file_path)
|
||||
result["file_url"] = file_url
|
||||
result["file_size"] = file_size
|
||||
result["duration"] = result.get("duration", 0.0)
|
||||
|
||||
logger.info(
|
||||
f"[Product Avatar] ✅ Product explainer video generated successfully: "
|
||||
f"cost=${result.get('cost', 0):.2f}, duration={result.get('duration', 0):.1f}s, "
|
||||
f"video_url={file_url}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Product Avatar] ❌ Error generating product explainer: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_product_overview(
|
||||
self,
|
||||
avatar_image_base64: str,
|
||||
script_text: str,
|
||||
product_name: str,
|
||||
product_description: Optional[str],
|
||||
user_id: str,
|
||||
resolution: str = "720p",
|
||||
audio_base64: Optional[str] = None,
|
||||
brand_context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create product overview explainer video."""
|
||||
request = ProductAvatarRequest(
|
||||
avatar_image_base64=avatar_image_base64,
|
||||
script_text=script_text,
|
||||
audio_base64=audio_base64,
|
||||
product_name=product_name,
|
||||
product_description=product_description,
|
||||
explainer_type="product_overview",
|
||||
resolution=resolution,
|
||||
brand_context=brand_context
|
||||
)
|
||||
return await self.generate_product_explainer(request, user_id)
|
||||
|
||||
async def create_feature_explainer(
|
||||
self,
|
||||
avatar_image_base64: str,
|
||||
script_text: str,
|
||||
product_name: str,
|
||||
product_description: Optional[str],
|
||||
user_id: str,
|
||||
resolution: str = "720p",
|
||||
audio_base64: Optional[str] = None,
|
||||
brand_context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create product feature explainer video."""
|
||||
request = ProductAvatarRequest(
|
||||
avatar_image_base64=avatar_image_base64,
|
||||
script_text=script_text,
|
||||
audio_base64=audio_base64,
|
||||
product_name=product_name,
|
||||
product_description=product_description,
|
||||
explainer_type="feature_explainer",
|
||||
resolution=resolution,
|
||||
brand_context=brand_context
|
||||
)
|
||||
return await self.generate_product_explainer(request, user_id)
|
||||
|
||||
async def create_tutorial(
|
||||
self,
|
||||
avatar_image_base64: str,
|
||||
script_text: str,
|
||||
product_name: str,
|
||||
product_description: Optional[str],
|
||||
user_id: str,
|
||||
resolution: str = "720p",
|
||||
audio_base64: Optional[str] = None,
|
||||
brand_context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create product tutorial video."""
|
||||
request = ProductAvatarRequest(
|
||||
avatar_image_base64=avatar_image_base64,
|
||||
script_text=script_text,
|
||||
audio_base64=audio_base64,
|
||||
product_name=product_name,
|
||||
product_description=product_description,
|
||||
explainer_type="tutorial",
|
||||
resolution=resolution,
|
||||
brand_context=brand_context
|
||||
)
|
||||
return await self.generate_product_explainer(request, user_id)
|
||||
|
||||
async def create_brand_message(
|
||||
self,
|
||||
avatar_image_base64: str,
|
||||
script_text: str,
|
||||
product_name: str,
|
||||
product_description: Optional[str],
|
||||
user_id: str,
|
||||
resolution: str = "720p",
|
||||
audio_base64: Optional[str] = None,
|
||||
brand_context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create brand message video."""
|
||||
request = ProductAvatarRequest(
|
||||
avatar_image_base64=avatar_image_base64,
|
||||
script_text=script_text,
|
||||
audio_base64=audio_base64,
|
||||
product_name=product_name,
|
||||
product_description=product_description,
|
||||
explainer_type="brand_message",
|
||||
resolution=resolution,
|
||||
brand_context=brand_context
|
||||
)
|
||||
return await self.generate_product_explainer(request, user_id)
|
||||
312
backend/services/product_marketing/product_video_service.py
Normal file
312
backend/services/product_marketing/product_video_service.py
Normal file
@@ -0,0 +1,312 @@
|
||||
"""
|
||||
Product Video Service
|
||||
Handles product demo video generation using WAN 2.5 Text-to-Video via main_video_generation.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from loguru import logger
|
||||
from dataclasses import dataclass
|
||||
|
||||
from services.llm_providers.main_video_generation import ai_video_generate
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("product_marketing.video")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProductVideoRequest:
|
||||
"""Request for product demo video generation."""
|
||||
product_name: str
|
||||
product_description: str
|
||||
video_type: str # "demo", "storytelling", "feature_highlight", "launch"
|
||||
resolution: str = "720p" # 480p, 720p, 1080p
|
||||
duration: int = 10 # 5 or 10 seconds
|
||||
audio_base64: Optional[str] = None
|
||||
brand_context: Optional[Dict[str, Any]] = None
|
||||
additional_context: Optional[str] = None
|
||||
negative_prompt: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
|
||||
class ProductVideoService:
|
||||
"""Service for product demo video generation using WAN 2.5 Text-to-Video."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Product Video Service."""
|
||||
logger.info("[Product Video Service] Initialized")
|
||||
|
||||
def _build_video_prompt(
|
||||
self,
|
||||
video_type: str,
|
||||
product_name: str,
|
||||
product_description: str,
|
||||
brand_context: Optional[Dict[str, Any]],
|
||||
additional_context: Optional[str]
|
||||
) -> str:
|
||||
"""
|
||||
Build video prompt based on video type and product context.
|
||||
|
||||
Args:
|
||||
video_type: Type of video (demo, storytelling, feature_highlight, launch)
|
||||
product_name: Product name
|
||||
product_description: Product description
|
||||
brand_context: Brand DNA context
|
||||
additional_context: Additional context
|
||||
|
||||
Returns:
|
||||
Video generation prompt
|
||||
"""
|
||||
base_description = f"{product_name}"
|
||||
if product_description:
|
||||
base_description += f": {product_description}"
|
||||
|
||||
# Video type-specific prompts
|
||||
video_prompts = {
|
||||
"demo": (
|
||||
f"{base_description} being demonstrated in use, showcasing key features and benefits, "
|
||||
f"professional product demonstration, dynamic camera movement, engaging presentation, "
|
||||
f"clear product visibility, modern and clean aesthetic"
|
||||
),
|
||||
"storytelling": (
|
||||
f"Story of {base_description}, narrative-driven product showcase, emotional connection, "
|
||||
f"cinematic storytelling, compelling visual narrative, professional cinematography, "
|
||||
f"engaging product story"
|
||||
),
|
||||
"feature_highlight": (
|
||||
f"{base_description} highlighting key features, close-up shots of important details, "
|
||||
f"feature-focused presentation, professional product photography, clear feature visibility, "
|
||||
f"modern and sleek aesthetic"
|
||||
),
|
||||
"launch": (
|
||||
f"{base_description} product launch reveal, exciting unveiling, dynamic presentation, "
|
||||
f"professional product showcase, launch event aesthetic, engaging and energetic, "
|
||||
f"modern and premium feel"
|
||||
),
|
||||
}
|
||||
|
||||
prompt = video_prompts.get(video_type, base_description)
|
||||
|
||||
# Add brand context if available
|
||||
if brand_context:
|
||||
visual_identity = brand_context.get("visual_identity", {})
|
||||
if visual_identity.get("color_palette"):
|
||||
colors = ", ".join(visual_identity["color_palette"][:3]) # First 3 colors
|
||||
prompt += f", {colors} color scheme"
|
||||
|
||||
if visual_identity.get("style_guidelines"):
|
||||
style = visual_identity["style_guidelines"].get("aesthetic", "")
|
||||
if style:
|
||||
prompt += f", {style} style"
|
||||
|
||||
# Add brand values if available
|
||||
if visual_identity.get("brand_values"):
|
||||
values = ", ".join(visual_identity["brand_values"][:2]) # First 2 values
|
||||
prompt += f", embodying {values}"
|
||||
|
||||
# Add additional context
|
||||
if additional_context:
|
||||
prompt += f", {additional_context}"
|
||||
|
||||
return prompt
|
||||
|
||||
async def generate_product_video(
|
||||
self,
|
||||
request: ProductVideoRequest,
|
||||
user_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate product demo video using WAN 2.5 Text-to-Video.
|
||||
|
||||
This method uses the unified ai_video_generate() entry point which handles:
|
||||
- Pre-flight validation
|
||||
- Usage tracking
|
||||
- Cost tracking
|
||||
- Error handling
|
||||
|
||||
Args:
|
||||
request: Product video request
|
||||
user_id: User ID for tracking
|
||||
|
||||
Returns:
|
||||
Video generation result with video URL and metadata
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
f"[Product Video] Generating {request.video_type} video for product '{request.product_name}' "
|
||||
f"for user {user_id}"
|
||||
)
|
||||
|
||||
# Build video prompt
|
||||
video_prompt = self._build_video_prompt(
|
||||
video_type=request.video_type,
|
||||
product_name=request.product_name,
|
||||
product_description=request.product_description,
|
||||
brand_context=request.brand_context,
|
||||
additional_context=request.additional_context
|
||||
)
|
||||
|
||||
# Build negative prompt (default to avoid common issues)
|
||||
negative_prompt = request.negative_prompt or (
|
||||
"blurry, low quality, distorted, deformed, ugly, bad anatomy, "
|
||||
"watermark, text overlay, logo, signature"
|
||||
)
|
||||
|
||||
# Generate video using unified entry point
|
||||
# This handles pre-flight validation, usage tracking, and cost tracking automatically
|
||||
result = await ai_video_generate(
|
||||
prompt=video_prompt,
|
||||
operation_type="text-to-video",
|
||||
provider="wavespeed",
|
||||
user_id=user_id,
|
||||
model="alibaba/wan-2.5/text-to-video", # WAN 2.5 Text-to-Video
|
||||
duration=request.duration,
|
||||
resolution=request.resolution,
|
||||
audio_base64=request.audio_base64,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=request.seed,
|
||||
enable_prompt_expansion=True, # Enable prompt optimization
|
||||
)
|
||||
|
||||
# Extract video bytes and save to user directory
|
||||
video_bytes = result.get("video_bytes")
|
||||
if not video_bytes:
|
||||
raise ValueError("Video generation returned no video bytes")
|
||||
|
||||
# Save video file (similar to Transform Studio)
|
||||
from pathlib import Path
|
||||
import uuid
|
||||
import os
|
||||
|
||||
base_dir = Path(__file__).parent.parent.parent.parent
|
||||
output_dir = base_dir / "product_videos"
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create user-specific directory
|
||||
user_dir = output_dir / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate filename (sanitize to avoid issues)
|
||||
safe_product_name = "".join(c for c in request.product_name if c.isalnum() or c in (' ', '-', '_')).strip()[:30]
|
||||
filename = f"product_{safe_product_name}_{request.video_type}_{uuid.uuid4().hex[:8]}.mp4"
|
||||
filename = filename.replace(" ", "_").replace("/", "_").replace("\\", "_")
|
||||
|
||||
# Save file
|
||||
file_path = user_dir / filename
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(video_bytes)
|
||||
|
||||
# Check file size (500MB max)
|
||||
file_size = os.path.getsize(file_path)
|
||||
if file_size > 500 * 1024 * 1024:
|
||||
os.remove(file_path)
|
||||
raise RuntimeError(f"Video file too large: {file_size / (1024*1024):.2f}MB (max 500MB)")
|
||||
|
||||
file_url = f"/api/product-marketing/videos/{user_id}/{filename}"
|
||||
|
||||
# Add product-specific metadata
|
||||
result["product_name"] = request.product_name
|
||||
result["video_type"] = request.video_type
|
||||
result["source_module"] = "product_marketing"
|
||||
result["filename"] = filename
|
||||
result["file_path"] = str(file_path)
|
||||
result["file_url"] = file_url
|
||||
result["file_size"] = len(video_bytes)
|
||||
|
||||
logger.info(
|
||||
f"[Product Video] ✅ Product video generated successfully: "
|
||||
f"cost=${result.get('cost', 0):.2f}, video_url={file_url}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Product Video] ❌ Error generating product video: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_product_demo(
|
||||
self,
|
||||
product_name: str,
|
||||
product_description: str,
|
||||
user_id: str,
|
||||
resolution: str = "720p",
|
||||
duration: int = 10,
|
||||
audio_base64: Optional[str] = None,
|
||||
brand_context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create product demo video (product in use, demonstrating features)."""
|
||||
request = ProductVideoRequest(
|
||||
product_name=product_name,
|
||||
product_description=product_description,
|
||||
video_type="demo",
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
audio_base64=audio_base64,
|
||||
brand_context=brand_context
|
||||
)
|
||||
return await self.generate_product_video(request, user_id)
|
||||
|
||||
async def create_product_storytelling(
|
||||
self,
|
||||
product_name: str,
|
||||
product_description: str,
|
||||
user_id: str,
|
||||
resolution: str = "720p",
|
||||
duration: int = 10,
|
||||
audio_base64: Optional[str] = None,
|
||||
brand_context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create product storytelling video (narrative-driven product showcase)."""
|
||||
request = ProductVideoRequest(
|
||||
product_name=product_name,
|
||||
product_description=product_description,
|
||||
video_type="storytelling",
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
audio_base64=audio_base64,
|
||||
brand_context=brand_context
|
||||
)
|
||||
return await self.generate_product_video(request, user_id)
|
||||
|
||||
async def create_product_feature_highlight(
|
||||
self,
|
||||
product_name: str,
|
||||
product_description: str,
|
||||
user_id: str,
|
||||
resolution: str = "720p",
|
||||
duration: int = 10,
|
||||
audio_base64: Optional[str] = None,
|
||||
brand_context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create product feature highlight video (close-up shots of key features)."""
|
||||
request = ProductVideoRequest(
|
||||
product_name=product_name,
|
||||
product_description=product_description,
|
||||
video_type="feature_highlight",
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
audio_base64=audio_base64,
|
||||
brand_context=brand_context
|
||||
)
|
||||
return await self.generate_product_video(request, user_id)
|
||||
|
||||
async def create_product_launch(
|
||||
self,
|
||||
product_name: str,
|
||||
product_description: str,
|
||||
user_id: str,
|
||||
resolution: str = "1080p", # Higher quality for launch
|
||||
duration: int = 10,
|
||||
audio_base64: Optional[str] = None,
|
||||
brand_context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create product launch video (exciting unveiling, launch event aesthetic)."""
|
||||
request = ProductVideoRequest(
|
||||
product_name=product_name,
|
||||
product_description=product_description,
|
||||
video_type="launch",
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
audio_base64=audio_base64,
|
||||
brand_context=brand_context
|
||||
)
|
||||
return await self.generate_product_video(request, user_id)
|
||||
@@ -50,6 +50,7 @@ class IntentAwareAnalyzer:
|
||||
raw_results: Dict[str, Any],
|
||||
intent: ResearchIntent,
|
||||
research_persona: Optional[ResearchPersona] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> IntentDrivenResearchResult:
|
||||
"""
|
||||
Analyze raw research results based on user intent.
|
||||
@@ -84,7 +85,7 @@ class IntentAwareAnalyzer:
|
||||
result = llm_text_gen(
|
||||
prompt=prompt,
|
||||
json_struct=analysis_schema,
|
||||
user_id=None
|
||||
user_id=user_id # Required for subscription checking
|
||||
)
|
||||
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
|
||||
@@ -151,6 +151,8 @@ Analyze the user's input and infer their research intent. Determine:
|
||||
|
||||
11. **CONFIDENCE**: How confident are you in this inference? (0.0-1.0)
|
||||
- If < 0.7, set needs_clarification to true and provide clarifying_questions
|
||||
- Provide a brief reason for your confidence level
|
||||
- If confidence is low, provide an example of what a great input would look like
|
||||
|
||||
## OUTPUT FORMAT
|
||||
|
||||
@@ -168,6 +170,8 @@ Return a JSON object:
|
||||
"perspective": "target perspective or null",
|
||||
"time_sensitivity": "real_time|recent|historical|evergreen",
|
||||
"confidence": 0.85,
|
||||
"confidence_reason": "Brief explanation of why this confidence level (e.g., 'User provided clear keywords and context' or 'Input is vague, missing specific goals')",
|
||||
"great_example": "Example of what a great input would look like for this research (only if confidence < 0.8)",
|
||||
"needs_clarification": false,
|
||||
"clarifying_questions": [],
|
||||
"analysis_summary": "Brief summary of what the user wants"
|
||||
|
||||
@@ -39,6 +39,7 @@ class IntentQueryGenerator:
|
||||
self,
|
||||
intent: ResearchIntent,
|
||||
research_persona: Optional[ResearchPersona] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate targeted research queries based on intent.
|
||||
@@ -89,7 +90,7 @@ class IntentQueryGenerator:
|
||||
result = llm_text_gen(
|
||||
prompt=prompt,
|
||||
json_struct=query_schema,
|
||||
user_id=None
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
|
||||
@@ -51,6 +51,7 @@ class ResearchIntentInference:
|
||||
competitor_data: Optional[List[Dict]] = None,
|
||||
industry: Optional[str] = None,
|
||||
target_audience: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> IntentInferenceResponse:
|
||||
"""
|
||||
Analyze user input and infer their research intent.
|
||||
@@ -96,13 +97,15 @@ class ResearchIntentInference:
|
||||
"perspective": {"type": "string"},
|
||||
"time_sensitivity": {"type": "string"},
|
||||
"confidence": {"type": "number"},
|
||||
"confidence_reason": {"type": "string"},
|
||||
"great_example": {"type": "string"},
|
||||
"needs_clarification": {"type": "boolean"},
|
||||
"clarifying_questions": {"type": "array", "items": {"type": "string"}},
|
||||
"analysis_summary": {"type": "string"}
|
||||
},
|
||||
"required": [
|
||||
"input_type", "primary_question", "purpose", "content_output",
|
||||
"expected_deliverables", "depth", "confidence", "analysis_summary"
|
||||
"expected_deliverables", "depth", "confidence", "confidence_reason", "analysis_summary"
|
||||
]
|
||||
}
|
||||
|
||||
@@ -112,7 +115,7 @@ class ResearchIntentInference:
|
||||
result = llm_text_gen(
|
||||
prompt=prompt,
|
||||
json_struct=intent_schema,
|
||||
user_id=None
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
@@ -134,6 +137,8 @@ class ResearchIntentInference:
|
||||
suggested_keywords=self._extract_keywords_from_input(user_input, keywords),
|
||||
suggested_angles=result.get("focus_areas", []),
|
||||
quick_options=quick_options,
|
||||
confidence_reason=result.get("confidence_reason", ""),
|
||||
great_example=result.get("great_example", ""),
|
||||
)
|
||||
|
||||
logger.info(f"Intent inferred: purpose={intent.purpose}, confidence={intent.confidence}")
|
||||
@@ -166,7 +171,7 @@ class ResearchIntentInference:
|
||||
if not expected_deliverables:
|
||||
expected_deliverables = self._infer_deliverables_from_purpose(purpose)
|
||||
|
||||
return ResearchIntent(
|
||||
intent = ResearchIntent(
|
||||
primary_question=result.get("primary_question", user_input),
|
||||
secondary_questions=result.get("secondary_questions", []),
|
||||
purpose=purpose.value,
|
||||
@@ -179,9 +184,13 @@ class ResearchIntentInference:
|
||||
input_type=input_type.value,
|
||||
original_input=user_input,
|
||||
confidence=float(result.get("confidence", 0.7)),
|
||||
confidence_reason=result.get("confidence_reason"),
|
||||
great_example=result.get("great_example"),
|
||||
needs_clarification=result.get("needs_clarification", False),
|
||||
clarifying_questions=result.get("clarifying_questions", []),
|
||||
)
|
||||
|
||||
return intent
|
||||
|
||||
def _safe_enum(self, enum_class, value: str, default):
|
||||
"""Safely convert string to enum, returning default if invalid."""
|
||||
|
||||
559
backend/services/research/intent/unified_research_analyzer.py
Normal file
559
backend/services/research/intent/unified_research_analyzer.py
Normal file
@@ -0,0 +1,559 @@
|
||||
"""
|
||||
Unified Research Analyzer
|
||||
|
||||
Combines intent inference, query generation, and parameter optimization
|
||||
into a single AI call with justifications for each decision.
|
||||
|
||||
This reduces 2 LLM calls to 1, improves coherence, and provides
|
||||
user-friendly justifications for all settings.
|
||||
|
||||
Author: ALwrity Team
|
||||
Version: 1.0
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from loguru import logger
|
||||
|
||||
from models.research_intent_models import (
|
||||
ResearchIntent,
|
||||
ResearchQuery,
|
||||
IntentInferenceResponse,
|
||||
ResearchPurpose,
|
||||
ContentOutput,
|
||||
ExpectedDeliverable,
|
||||
ResearchDepthLevel,
|
||||
InputType,
|
||||
)
|
||||
from models.research_persona_models import ResearchPersona
|
||||
|
||||
|
||||
class UnifiedResearchAnalyzer:
|
||||
"""
|
||||
Unified AI-driven analyzer that performs:
|
||||
1. Intent inference (what user wants)
|
||||
2. Query generation (how to search)
|
||||
3. Parameter optimization (Exa/Tavily settings)
|
||||
|
||||
All in a single LLM call with justifications.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the unified analyzer."""
|
||||
logger.info("UnifiedResearchAnalyzer initialized")
|
||||
|
||||
async def analyze(
|
||||
self,
|
||||
user_input: str,
|
||||
keywords: Optional[List[str]] = None,
|
||||
research_persona: Optional[ResearchPersona] = None,
|
||||
competitor_data: Optional[List[Dict]] = None,
|
||||
industry: Optional[str] = None,
|
||||
target_audience: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Perform unified analysis of user research request.
|
||||
|
||||
Returns:
|
||||
Dict containing:
|
||||
- intent: ResearchIntent
|
||||
- queries: List[ResearchQuery]
|
||||
- exa_config: Dict with settings and justifications
|
||||
- tavily_config: Dict with settings and justifications
|
||||
- recommended_provider: str
|
||||
- provider_justification: str
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Unified analysis for: {user_input[:100]}...")
|
||||
|
||||
keywords = keywords or []
|
||||
|
||||
# Build the unified prompt
|
||||
prompt = self._build_unified_prompt(
|
||||
user_input=user_input,
|
||||
keywords=keywords,
|
||||
research_persona=research_persona,
|
||||
competitor_data=competitor_data,
|
||||
industry=industry,
|
||||
target_audience=target_audience,
|
||||
)
|
||||
|
||||
# Define the comprehensive JSON schema
|
||||
unified_schema = self._build_unified_schema()
|
||||
|
||||
# Call LLM (single call for everything)
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
|
||||
result = llm_text_gen(
|
||||
prompt=prompt,
|
||||
json_struct=unified_schema,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
logger.error(f"Unified analysis failed: {result.get('error')}")
|
||||
return self._create_fallback_response(user_input, keywords)
|
||||
|
||||
# Parse the unified result
|
||||
return self._parse_unified_result(result, user_input)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in unified analysis: {e}")
|
||||
return self._create_fallback_response(user_input, keywords or [])
|
||||
|
||||
def _build_unified_prompt(
|
||||
self,
|
||||
user_input: str,
|
||||
keywords: List[str],
|
||||
research_persona: Optional[ResearchPersona] = None,
|
||||
competitor_data: Optional[List[Dict]] = None,
|
||||
industry: Optional[str] = None,
|
||||
target_audience: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Build the unified prompt for intent + queries + parameters."""
|
||||
|
||||
# Build persona context
|
||||
persona_context = self._build_persona_context(research_persona, industry, target_audience)
|
||||
|
||||
# Build competitor context
|
||||
competitor_context = self._build_competitor_context(competitor_data)
|
||||
|
||||
prompt = f'''You are an expert AI research strategist. Analyze the user's research request and provide a complete research plan including intent understanding, search queries, and optimal API settings.
|
||||
|
||||
## USER INPUT
|
||||
"{user_input}"
|
||||
{f"KEYWORDS: {', '.join(keywords)}" if keywords else ""}
|
||||
|
||||
## USER CONTEXT
|
||||
{persona_context}
|
||||
{competitor_context}
|
||||
|
||||
## YOUR TASK: Provide a Complete Research Plan
|
||||
|
||||
### PART 1: INTENT ANALYSIS
|
||||
Understand what the user really wants from their research.
|
||||
|
||||
### PART 2: SEARCH QUERIES
|
||||
Generate 4-8 targeted search queries optimized for semantic search.
|
||||
|
||||
### PART 3: PROVIDER SETTINGS
|
||||
Configure Exa and Tavily API parameters with justifications.
|
||||
|
||||
### PART 4: GOOGLE TRENDS KEYWORDS (if trends in deliverables)
|
||||
If "trends" is in expected_deliverables OR purpose is "explore_trends":
|
||||
- Suggest 1-3 optimized keywords for Google Trends analysis
|
||||
- These may differ from research queries (trends need broader, searchable terms)
|
||||
- Consider: What keywords will show meaningful trends over time?
|
||||
- Consider: What timeframe will show relevant trends? (1 year, 12 months, etc.)
|
||||
- Consider: What geographic region is most relevant for the user?
|
||||
- Explain what insights trends will uncover for content generation:
|
||||
* Search interest trends over time (optimal publication timing)
|
||||
* Regional interest distribution (audience targeting)
|
||||
* Related topics for content expansion
|
||||
* Related queries for FAQ sections
|
||||
* Rising topics for timely content opportunities
|
||||
|
||||
---
|
||||
|
||||
## AVAILABLE PROVIDER OPTIONS
|
||||
|
||||
### EXA API OPTIONS (Semantic Search Engine)
|
||||
| Parameter | Options | Description |
|
||||
|-----------|---------|-------------|
|
||||
| type | "auto", "neural", "fast", "deep" | "neural" = semantic understanding, "deep" = comprehensive with query expansion |
|
||||
| category | "company", "research paper", "news", "github", "tweet", "personal site", "pdf", "financial report", "people" | Focus on specific content types |
|
||||
| numResults | 5-25 | Number of results (10 recommended) |
|
||||
| includeDomains | string[] | Domains to include (e.g., ["arxiv.org", "nature.com"]) |
|
||||
| excludeDomains | string[] | Domains to exclude |
|
||||
| startPublishedDate | ISO date | Filter by publish date (e.g., "2024-01-01T00:00:00.000Z") |
|
||||
| text | boolean | Include full text content |
|
||||
| highlights | boolean | Extract key highlights |
|
||||
| context | boolean | Return as single context string for RAG |
|
||||
|
||||
**WHEN TO USE EXA:**
|
||||
- Semantic understanding needed (finding similar content)
|
||||
- Academic/research papers
|
||||
- Company/competitor research
|
||||
- Deep, comprehensive results
|
||||
- Historical content
|
||||
|
||||
### TAVILY API OPTIONS (AI-Powered Search)
|
||||
| Parameter | Options | Description |
|
||||
|-----------|---------|-------------|
|
||||
| topic | "general", "news", "finance" | Search topic category |
|
||||
| search_depth | "basic", "advanced" | "advanced" = multiple semantic snippets per URL |
|
||||
| include_answer | false, true, "basic", "advanced" | AI-generated answer from results |
|
||||
| include_raw_content | false, true, "markdown", "text" | Raw page content format |
|
||||
| time_range | "day", "week", "month", "year" | Filter by recency |
|
||||
| max_results | 5-20 | Number of results |
|
||||
| include_domains | string[] | Domains to include |
|
||||
| exclude_domains | string[] | Domains to exclude |
|
||||
|
||||
**WHEN TO USE TAVILY:**
|
||||
- Real-time/current events
|
||||
- News and trending topics
|
||||
- Quick facts with AI answers
|
||||
- Financial data
|
||||
- Recent time-sensitive content
|
||||
|
||||
---
|
||||
|
||||
## OUTPUT FORMAT
|
||||
|
||||
Return a JSON object with this exact structure:
|
||||
|
||||
```json
|
||||
{{
|
||||
"intent": {{
|
||||
"input_type": "keywords|question|goal|mixed",
|
||||
"primary_question": "The main question to answer",
|
||||
"secondary_questions": ["question 1", "question 2"],
|
||||
"purpose": "learn|create_content|make_decision|compare|solve_problem|find_data|explore_trends|validate|generate_ideas",
|
||||
"content_output": "blog|podcast|video|social_post|newsletter|presentation|report|whitepaper|email|general",
|
||||
"expected_deliverables": ["key_statistics", "expert_quotes", "case_studies", "trends", "best_practices"],
|
||||
"depth": "overview|detailed|expert",
|
||||
"focus_areas": ["area1", "area2"],
|
||||
"perspective": "target perspective or null",
|
||||
"time_sensitivity": "real_time|recent|historical|evergreen",
|
||||
"confidence": 0.85,
|
||||
"confidence_reason": "Why this confidence level",
|
||||
"great_example": "Example of better input if confidence < 0.8",
|
||||
"needs_clarification": false,
|
||||
"clarifying_questions": [],
|
||||
"analysis_summary": "Brief summary of research plan"
|
||||
}},
|
||||
"queries": [
|
||||
{{
|
||||
"query": "Optimized search query string",
|
||||
"purpose": "key_statistics|expert_quotes|case_studies|trends|etc",
|
||||
"provider": "exa|tavily",
|
||||
"priority": 5,
|
||||
"expected_results": "What we expect to find",
|
||||
"justification": "Why this query and provider"
|
||||
}}
|
||||
],
|
||||
"enhanced_keywords": ["expanded", "related", "keywords"],
|
||||
"research_angles": ["Angle 1: ...", "Angle 2: ..."],
|
||||
"recommended_provider": "exa|tavily",
|
||||
"provider_justification": "Why this provider is best for this research",
|
||||
"exa_config": {{
|
||||
"enabled": true,
|
||||
"type": "auto|neural|fast|deep",
|
||||
"type_justification": "Why this search type",
|
||||
"category": "news|research paper|company|etc or null",
|
||||
"category_justification": "Why this category or null",
|
||||
"numResults": 10,
|
||||
"numResults_justification": "Why this number",
|
||||
"includeDomains": [],
|
||||
"includeDomains_justification": "Why these domains or empty",
|
||||
"startPublishedDate": "2024-01-01T00:00:00.000Z or null",
|
||||
"date_justification": "Why this date filter or null",
|
||||
"highlights": true,
|
||||
"highlights_justification": "Why enable/disable highlights",
|
||||
"context": true,
|
||||
"context_justification": "Why enable/disable context string"
|
||||
}},
|
||||
"tavily_config": {{
|
||||
"enabled": true,
|
||||
"topic": "general|news|finance",
|
||||
"topic_justification": "Why this topic",
|
||||
"search_depth": "basic|advanced",
|
||||
"search_depth_justification": "Why this depth",
|
||||
"include_answer": "true|false|basic|advanced",
|
||||
"include_answer_justification": "Why this answer mode",
|
||||
"time_range": "day|week|month|year|null",
|
||||
"time_range_justification": "Why this time range or null",
|
||||
"max_results": 10,
|
||||
"max_results_justification": "Why this number",
|
||||
"include_raw_content": "false|true|markdown|text",
|
||||
"include_raw_content_justification": "Why this content mode"
|
||||
}},
|
||||
"trends_config": {{
|
||||
"enabled": true|false,
|
||||
"keywords": ["keyword1", "keyword2"],
|
||||
"keywords_justification": "Why these keywords for trends analysis",
|
||||
"timeframe": "today 1-y|today 12-m|all",
|
||||
"timeframe_justification": "Why this timeframe",
|
||||
"geo": "US|GB|IN|etc",
|
||||
"geo_justification": "Why this geographic region",
|
||||
"expected_insights": [
|
||||
"Search interest trends over the past year",
|
||||
"Regional interest distribution",
|
||||
"Related topics for content expansion",
|
||||
"Related queries for FAQ sections",
|
||||
"Optimal publication timing based on interest peaks"
|
||||
]
|
||||
}}
|
||||
}}
|
||||
```
|
||||
|
||||
## DECISION RULES
|
||||
|
||||
1. **Provider Selection:**
|
||||
- Use EXA for: academic research, competitor analysis, deep understanding, finding similar content
|
||||
- Use TAVILY for: news, current events, quick facts, financial data, real-time info
|
||||
|
||||
2. **Query Optimization:**
|
||||
- Include relevant keywords for semantic matching
|
||||
- Add context words based on deliverables (e.g., "statistics 2024" for key_statistics)
|
||||
- Match query style to provider (natural language for Exa, keyword-rich for Tavily)
|
||||
|
||||
3. **Parameter Selection:**
|
||||
- ALWAYS provide justification for each parameter choice
|
||||
- Consider time sensitivity when setting date filters
|
||||
- Match category/topic to content type
|
||||
- Use "advanced" depth when quality matters more than speed
|
||||
|
||||
4. **Google Trends Keywords (if trends enabled):**
|
||||
- Suggest 1-3 keywords optimized for trends analysis
|
||||
- Keywords should be broader than research queries (e.g., "AI marketing" vs "AI marketing tools for small businesses")
|
||||
- Consider what will show meaningful search interest trends
|
||||
- Choose timeframe based on content type (12 months for blogs, 1 year for comprehensive)
|
||||
- Select geo based on user's target audience or industry
|
||||
- List specific insights trends will uncover
|
||||
|
||||
5. **Justifications:**
|
||||
- Keep justifications concise (1 sentence)
|
||||
- Explain the "why" not the "what"
|
||||
- Reference user's intent when relevant
|
||||
'''
|
||||
|
||||
return prompt
|
||||
|
||||
def _build_unified_schema(self) -> Dict[str, Any]:
|
||||
"""Build the JSON schema for unified response."""
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"intent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input_type": {"type": "string", "enum": ["keywords", "question", "goal", "mixed"]},
|
||||
"primary_question": {"type": "string"},
|
||||
"secondary_questions": {"type": "array", "items": {"type": "string"}},
|
||||
"purpose": {"type": "string"},
|
||||
"content_output": {"type": "string"},
|
||||
"expected_deliverables": {"type": "array", "items": {"type": "string"}},
|
||||
"depth": {"type": "string", "enum": ["overview", "detailed", "expert"]},
|
||||
"focus_areas": {"type": "array", "items": {"type": "string"}},
|
||||
"perspective": {"type": "string"},
|
||||
"time_sensitivity": {"type": "string"},
|
||||
"confidence": {"type": "number"},
|
||||
"confidence_reason": {"type": "string"},
|
||||
"great_example": {"type": "string"},
|
||||
"needs_clarification": {"type": "boolean"},
|
||||
"clarifying_questions": {"type": "array", "items": {"type": "string"}},
|
||||
"analysis_summary": {"type": "string"}
|
||||
},
|
||||
"required": ["primary_question", "purpose", "expected_deliverables", "confidence"]
|
||||
},
|
||||
"queries": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"purpose": {"type": "string"},
|
||||
"provider": {"type": "string"},
|
||||
"priority": {"type": "integer"},
|
||||
"expected_results": {"type": "string"},
|
||||
"justification": {"type": "string"}
|
||||
},
|
||||
"required": ["query", "purpose", "provider", "priority"]
|
||||
}
|
||||
},
|
||||
"enhanced_keywords": {"type": "array", "items": {"type": "string"}},
|
||||
"research_angles": {"type": "array", "items": {"type": "string"}},
|
||||
"recommended_provider": {"type": "string"},
|
||||
"provider_justification": {"type": "string"},
|
||||
"exa_config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"enabled": {"type": "boolean"},
|
||||
"type": {"type": "string"},
|
||||
"type_justification": {"type": "string"},
|
||||
"category": {"type": "string"},
|
||||
"category_justification": {"type": "string"},
|
||||
"numResults": {"type": "integer"},
|
||||
"numResults_justification": {"type": "string"},
|
||||
"includeDomains": {"type": "array", "items": {"type": "string"}},
|
||||
"includeDomains_justification": {"type": "string"},
|
||||
"startPublishedDate": {"type": "string"},
|
||||
"date_justification": {"type": "string"},
|
||||
"highlights": {"type": "boolean"},
|
||||
"highlights_justification": {"type": "string"},
|
||||
"context": {"type": "boolean"},
|
||||
"context_justification": {"type": "string"}
|
||||
}
|
||||
},
|
||||
"tavily_config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"enabled": {"type": "boolean"},
|
||||
"topic": {"type": "string"},
|
||||
"topic_justification": {"type": "string"},
|
||||
"search_depth": {"type": "string"},
|
||||
"search_depth_justification": {"type": "string"},
|
||||
"include_answer": {"type": "string"},
|
||||
"include_answer_justification": {"type": "string"},
|
||||
"time_range": {"type": "string"},
|
||||
"time_range_justification": {"type": "string"},
|
||||
"max_results": {"type": "integer"},
|
||||
"max_results_justification": {"type": "string"},
|
||||
"include_raw_content": {"type": "string"},
|
||||
"include_raw_content_justification": {"type": "string"}
|
||||
}
|
||||
},
|
||||
"trends_config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"enabled": {"type": "boolean"},
|
||||
"keywords": {"type": "array", "items": {"type": "string"}},
|
||||
"keywords_justification": {"type": "string"},
|
||||
"timeframe": {"type": "string"},
|
||||
"timeframe_justification": {"type": "string"},
|
||||
"geo": {"type": "string"},
|
||||
"geo_justification": {"type": "string"},
|
||||
"expected_insights": {"type": "array", "items": {"type": "string"}}
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["intent", "queries", "recommended_provider", "exa_config", "tavily_config"]
|
||||
}
|
||||
|
||||
def _build_persona_context(
|
||||
self,
|
||||
research_persona: Optional[ResearchPersona],
|
||||
industry: Optional[str],
|
||||
target_audience: Optional[str],
|
||||
) -> str:
|
||||
"""Build persona context section."""
|
||||
parts = []
|
||||
|
||||
if research_persona:
|
||||
if research_persona.default_industry:
|
||||
parts.append(f"Industry: {research_persona.default_industry}")
|
||||
if research_persona.default_target_audience:
|
||||
parts.append(f"Target Audience: {research_persona.default_target_audience}")
|
||||
if research_persona.research_angles:
|
||||
parts.append(f"Preferred Research Angles: {', '.join(research_persona.research_angles[:3])}")
|
||||
if research_persona.suggested_keywords:
|
||||
parts.append(f"Relevant Keywords: {', '.join(research_persona.suggested_keywords[:5])}")
|
||||
else:
|
||||
if industry:
|
||||
parts.append(f"Industry: {industry}")
|
||||
if target_audience:
|
||||
parts.append(f"Target Audience: {target_audience}")
|
||||
|
||||
if not parts:
|
||||
return "No specific user context available. Use general best practices."
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
def _build_competitor_context(self, competitor_data: Optional[List[Dict]]) -> str:
|
||||
"""Build competitor context section."""
|
||||
if not competitor_data:
|
||||
return ""
|
||||
|
||||
competitor_names = [c.get("name", c.get("url", "")) for c in competitor_data[:5]]
|
||||
if competitor_names:
|
||||
return f"\nKnown Competitors: {', '.join(competitor_names)}"
|
||||
return ""
|
||||
|
||||
def _parse_unified_result(self, result: Dict[str, Any], user_input: str) -> Dict[str, Any]:
|
||||
"""Parse the unified LLM result into structured response."""
|
||||
|
||||
intent_data = result.get("intent", {})
|
||||
|
||||
# Build ResearchIntent
|
||||
intent = ResearchIntent(
|
||||
primary_question=intent_data.get("primary_question", user_input),
|
||||
secondary_questions=intent_data.get("secondary_questions", []),
|
||||
purpose=intent_data.get("purpose", "learn"),
|
||||
content_output=intent_data.get("content_output", "general"),
|
||||
expected_deliverables=intent_data.get("expected_deliverables", ["key_statistics"]),
|
||||
depth=intent_data.get("depth", "detailed"),
|
||||
focus_areas=intent_data.get("focus_areas", []),
|
||||
perspective=intent_data.get("perspective"),
|
||||
time_sensitivity=intent_data.get("time_sensitivity"),
|
||||
input_type=intent_data.get("input_type", "keywords"),
|
||||
original_input=user_input,
|
||||
confidence=float(intent_data.get("confidence", 0.7)),
|
||||
confidence_reason=intent_data.get("confidence_reason"),
|
||||
great_example=intent_data.get("great_example"),
|
||||
needs_clarification=intent_data.get("needs_clarification", False),
|
||||
clarifying_questions=intent_data.get("clarifying_questions", []),
|
||||
)
|
||||
|
||||
# Build queries
|
||||
queries = []
|
||||
for q in result.get("queries", []):
|
||||
try:
|
||||
queries.append(ResearchQuery(
|
||||
query=q.get("query", ""),
|
||||
purpose=q.get("purpose", "key_statistics"),
|
||||
provider=q.get("provider", "exa"),
|
||||
priority=int(q.get("priority", 3)),
|
||||
expected_results=q.get("expected_results", ""),
|
||||
))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse query: {e}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"intent": intent,
|
||||
"queries": queries,
|
||||
"enhanced_keywords": result.get("enhanced_keywords", []),
|
||||
"research_angles": result.get("research_angles", []),
|
||||
"recommended_provider": result.get("recommended_provider", "exa"),
|
||||
"provider_justification": result.get("provider_justification", ""),
|
||||
"exa_config": result.get("exa_config", {}),
|
||||
"tavily_config": result.get("tavily_config", {}),
|
||||
"trends_config": result.get("trends_config", {}), # NEW: Google Trends configuration
|
||||
"analysis_summary": intent_data.get("analysis_summary", ""),
|
||||
}
|
||||
|
||||
def _create_fallback_response(self, user_input: str, keywords: List[str]) -> Dict[str, Any]:
|
||||
"""Create fallback response when analysis fails."""
|
||||
return {
|
||||
"success": False,
|
||||
"intent": ResearchIntent(
|
||||
primary_question=f"What are the key insights about: {user_input}?",
|
||||
purpose="learn",
|
||||
content_output="general",
|
||||
expected_deliverables=["key_statistics", "best_practices"],
|
||||
depth="detailed",
|
||||
original_input=user_input,
|
||||
confidence=0.5,
|
||||
),
|
||||
"queries": [
|
||||
ResearchQuery(
|
||||
query=user_input,
|
||||
purpose="key_statistics",
|
||||
provider="exa",
|
||||
priority=5,
|
||||
expected_results="General research results",
|
||||
)
|
||||
],
|
||||
"enhanced_keywords": keywords,
|
||||
"research_angles": [],
|
||||
"recommended_provider": "exa",
|
||||
"provider_justification": "Default fallback to Exa for semantic search",
|
||||
"exa_config": {
|
||||
"enabled": True,
|
||||
"type": "auto",
|
||||
"type_justification": "Auto mode for balanced results",
|
||||
"numResults": 10,
|
||||
"highlights": True,
|
||||
},
|
||||
"tavily_config": {
|
||||
"enabled": True,
|
||||
"topic": "general",
|
||||
"search_depth": "advanced",
|
||||
"include_answer": True,
|
||||
},
|
||||
"trends_config": {
|
||||
"enabled": False, # Disabled in fallback
|
||||
},
|
||||
}
|
||||
@@ -34,39 +34,81 @@ class ResearchPersonaService:
|
||||
user_id: str
|
||||
) -> Optional[ResearchPersona]:
|
||||
"""
|
||||
Get research persona for user ONLY if it exists in cache.
|
||||
This method NEVER generates - it only returns cached personas.
|
||||
Get research persona for user if it exists in database (regardless of cache validity).
|
||||
This method NEVER generates - it only returns existing personas.
|
||||
Use this for config endpoints to avoid triggering rate limit checks.
|
||||
|
||||
Note: Returns persona even if cache is expired - cache validity only matters for regeneration.
|
||||
|
||||
Args:
|
||||
user_id: User ID (Clerk string)
|
||||
|
||||
Returns:
|
||||
ResearchPersona if cached and valid, None otherwise
|
||||
ResearchPersona if exists in database, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Get persona data record
|
||||
persona_data = self._get_persona_data_record(user_id)
|
||||
|
||||
if not persona_data:
|
||||
logger.debug(f"No persona data found for user {user_id}")
|
||||
logger.debug(f"[get_cached_only] No persona data record found for user {user_id}")
|
||||
return None
|
||||
|
||||
# Only return if cache is valid and persona exists
|
||||
if self.is_cache_valid(persona_data) and persona_data.research_persona:
|
||||
# Check if research_persona field exists and is not None/empty
|
||||
# Handle cases where it might be None, empty dict {}, or empty string ""
|
||||
research_persona_raw = persona_data.research_persona
|
||||
has_persona = (
|
||||
research_persona_raw is not None
|
||||
and research_persona_raw != {}
|
||||
and research_persona_raw != ""
|
||||
and (isinstance(research_persona_raw, dict) and len(research_persona_raw) > 0)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[get_cached_only] Checking research persona for user {user_id}: "
|
||||
f"persona_data exists=True, research_persona_raw={research_persona_raw is not None}, "
|
||||
f"research_persona type={type(research_persona_raw)}, "
|
||||
f"has_persona={has_persona}, "
|
||||
f"generated_at={persona_data.research_persona_generated_at}"
|
||||
)
|
||||
|
||||
# Return persona if it exists, regardless of cache validity
|
||||
# Cache validity only matters when deciding whether to regenerate
|
||||
if has_persona:
|
||||
try:
|
||||
logger.debug(f"Returning cached research persona for user {user_id}")
|
||||
return ResearchPersona(**persona_data.research_persona)
|
||||
cache_valid = self.is_cache_valid(persona_data)
|
||||
cache_status = "valid" if cache_valid else "expired"
|
||||
logger.info(
|
||||
f"[get_cached_only] ✅ Returning research persona for user {user_id} "
|
||||
f"(cache: {cache_status}, generated_at: {persona_data.research_persona_generated_at})"
|
||||
)
|
||||
# Ensure we're passing a dict to ResearchPersona
|
||||
if not isinstance(research_persona_raw, dict):
|
||||
logger.error(f"[get_cached_only] research_persona_raw is not a dict: {type(research_persona_raw)}")
|
||||
return None
|
||||
parsed_persona = ResearchPersona(**research_persona_raw)
|
||||
logger.info(
|
||||
f"[get_cached_only] ✅ Successfully parsed persona for user {user_id}: "
|
||||
f"industry={parsed_persona.default_industry}, "
|
||||
f"target_audience={parsed_persona.default_target_audience}"
|
||||
)
|
||||
return parsed_persona
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse cached research persona: {e}")
|
||||
logger.error(f"[get_cached_only] ❌ Failed to parse research persona for user {user_id}: {e}", exc_info=True)
|
||||
logger.debug(
|
||||
f"[get_cached_only] Persona data details: "
|
||||
f"type={type(research_persona_raw)}, "
|
||||
f"is_dict={isinstance(research_persona_raw, dict)}, "
|
||||
f"value sample: {str(research_persona_raw)[:500] if research_persona_raw else 'None'}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Cache invalid or persona missing - return None (don't generate)
|
||||
logger.debug(f"No valid cached research persona for user {user_id}")
|
||||
# Persona doesn't exist in database
|
||||
logger.info(f"[get_cached_only] ⚠️ No research persona found in database for user {user_id}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting cached research persona for user {user_id}: {e}")
|
||||
logger.error(f"[get_cached_only] ❌ Error getting research persona for user {user_id}: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def get_or_generate(
|
||||
@@ -92,25 +134,40 @@ class ResearchPersonaService:
|
||||
logger.warning(f"No persona data found for user {user_id}, cannot generate research persona")
|
||||
return None
|
||||
|
||||
# Check cache if not forcing refresh
|
||||
if not force_refresh and self.is_cache_valid(persona_data):
|
||||
if persona_data.research_persona:
|
||||
# Check if persona exists in database
|
||||
if persona_data.research_persona:
|
||||
# Persona exists - check if we should return it or regenerate
|
||||
cache_valid = self.is_cache_valid(persona_data)
|
||||
|
||||
if not force_refresh and cache_valid:
|
||||
# Cache is valid - return existing persona
|
||||
logger.info(f"Using cached research persona for user {user_id}")
|
||||
try:
|
||||
return ResearchPersona(**persona_data.research_persona)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse cached research persona: {e}, regenerating...")
|
||||
# Fall through to regeneration
|
||||
# Fall through to regeneration if parsing fails
|
||||
elif not force_refresh:
|
||||
# Persona exists but cache expired - return it anyway (don't regenerate unless forced)
|
||||
logger.info(f"Research persona exists for user {user_id} but cache expired - returning existing persona (use force_refresh=true to regenerate)")
|
||||
try:
|
||||
return ResearchPersona(**persona_data.research_persona)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse existing research persona: {e}, regenerating...")
|
||||
# Fall through to regeneration if parsing fails
|
||||
else:
|
||||
logger.info(f"Research persona missing for user {user_id}, generating...")
|
||||
else:
|
||||
if force_refresh:
|
||||
# force_refresh=True - regenerate even though persona exists
|
||||
logger.info(f"Forcing refresh of research persona for user {user_id}")
|
||||
else:
|
||||
logger.info(f"Cache expired for user {user_id}, regenerating...")
|
||||
else:
|
||||
# Persona doesn't exist - generate new one
|
||||
logger.info(f"Research persona missing for user {user_id}, generating...")
|
||||
|
||||
# Generate new research persona
|
||||
# Generate new research persona (only reaches here if:
|
||||
# 1. Persona doesn't exist, OR
|
||||
# 2. force_refresh=True, OR
|
||||
# 3. Parsing of existing persona failed
|
||||
try:
|
||||
logger.info(f"Generating research persona for user {user_id}")
|
||||
research_persona = self.generate_research_persona(user_id)
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions (e.g., 429 subscription limit) so they propagate to API
|
||||
|
||||
9
backend/services/research/trends/__init__.py
Normal file
9
backend/services/research/trends/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Google Trends Research Service
|
||||
|
||||
Provides Google Trends data integration for the Research Engine.
|
||||
"""
|
||||
|
||||
from .google_trends_service import GoogleTrendsService
|
||||
|
||||
__all__ = ['GoogleTrendsService']
|
||||
380
backend/services/research/trends/google_trends_service.py
Normal file
380
backend/services/research/trends/google_trends_service.py
Normal file
@@ -0,0 +1,380 @@
|
||||
"""
|
||||
Google Trends Service
|
||||
|
||||
Provides Google Trends data integration for the Research Engine.
|
||||
Handles rate limiting, caching, error handling, and data serialization.
|
||||
|
||||
Author: ALwrity Team
|
||||
Version: 1.0
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from loguru import logger
|
||||
import pandas as pd
|
||||
|
||||
try:
|
||||
from pytrends.request import TrendReq
|
||||
PYTrends_AVAILABLE = True
|
||||
except ImportError:
|
||||
PYTrends_AVAILABLE = False
|
||||
logger.warning("pytrends not installed. Google Trends features will be unavailable.")
|
||||
|
||||
from .rate_limiter import RateLimiter
|
||||
|
||||
|
||||
class GoogleTrendsService:
|
||||
"""
|
||||
Service for fetching and analyzing Google Trends data.
|
||||
|
||||
Features:
|
||||
- Interest over time
|
||||
- Interest by region
|
||||
- Related topics
|
||||
- Related queries
|
||||
- Rate limiting (1 req/sec)
|
||||
- Caching (24-hour TTL)
|
||||
- Async support
|
||||
- Error handling with retry logic
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Google Trends service."""
|
||||
if not PYTrends_AVAILABLE:
|
||||
raise RuntimeError("pytrends library is required. Install with: pip install pytrends")
|
||||
|
||||
self.rate_limiter = RateLimiter(max_calls=1, period=1.0) # 1 request per second
|
||||
self.cache: Dict[str, Dict[str, Any]] = {} # Simple in-memory cache
|
||||
self.cache_ttl = timedelta(hours=24) # 24-hour cache
|
||||
|
||||
logger.info("GoogleTrendsService initialized")
|
||||
|
||||
async def analyze_trends(
|
||||
self,
|
||||
keywords: List[str],
|
||||
timeframe: str = "today 12-m",
|
||||
geo: str = "US",
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Comprehensive trends analysis.
|
||||
|
||||
Fetches all trends data in a single optimized call:
|
||||
- Interest over time
|
||||
- Interest by region
|
||||
- Related topics (top & rising)
|
||||
- Related queries (top & rising)
|
||||
|
||||
Args:
|
||||
keywords: List of keywords to analyze (1-5 keywords recommended)
|
||||
timeframe: Timeframe string (e.g., "today 12-m", "today 1-y", "all")
|
||||
geo: Country code (e.g., "US", "GB", "IN")
|
||||
user_id: User ID for subscription checks (optional for now)
|
||||
|
||||
Returns:
|
||||
Dict containing all trends data in serializable format
|
||||
|
||||
Raises:
|
||||
ValueError: If keywords list is empty or too long
|
||||
RuntimeError: If pytrends is not available or API fails
|
||||
"""
|
||||
if not keywords:
|
||||
raise ValueError("Keywords list cannot be empty")
|
||||
|
||||
if len(keywords) > 5:
|
||||
logger.warning(f"Too many keywords ({len(keywords)}), using first 5")
|
||||
keywords = keywords[:5]
|
||||
|
||||
# Check cache first
|
||||
cache_key = self._build_cache_key(keywords, timeframe, geo)
|
||||
cached_data = self._get_from_cache(cache_key)
|
||||
if cached_data:
|
||||
logger.info(f"Returning cached trends data for: {keywords}")
|
||||
return {**cached_data, "cached": True}
|
||||
|
||||
# Rate limit
|
||||
await self.rate_limiter.acquire()
|
||||
|
||||
try:
|
||||
logger.info(f"Fetching Google Trends data for: {keywords} (timeframe: {timeframe}, geo: {geo})")
|
||||
|
||||
# Initialize pytrends (sync operation, run in thread)
|
||||
pytrends = await asyncio.to_thread(
|
||||
self._initialize_pytrends,
|
||||
keywords,
|
||||
timeframe,
|
||||
geo
|
||||
)
|
||||
|
||||
# Fetch all data in parallel (pytrends methods are sync, so use to_thread)
|
||||
interest_over_time_task = asyncio.to_thread(
|
||||
lambda: self._safe_interest_over_time(pytrends)
|
||||
)
|
||||
interest_by_region_task = asyncio.to_thread(
|
||||
lambda: self._safe_interest_by_region(pytrends)
|
||||
)
|
||||
related_topics_task = asyncio.to_thread(
|
||||
lambda: self._safe_related_topics(pytrends, keywords)
|
||||
)
|
||||
related_queries_task = asyncio.to_thread(
|
||||
lambda: self._safe_related_queries(pytrends, keywords)
|
||||
)
|
||||
|
||||
# Wait for all tasks
|
||||
interest_over_time, interest_by_region, related_topics, related_queries = await asyncio.gather(
|
||||
interest_over_time_task,
|
||||
interest_by_region_task,
|
||||
related_topics_task,
|
||||
related_queries_task,
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
# Handle exceptions
|
||||
if isinstance(interest_over_time, Exception):
|
||||
logger.error(f"Interest over time failed: {interest_over_time}")
|
||||
interest_over_time = []
|
||||
if isinstance(interest_by_region, Exception):
|
||||
logger.error(f"Interest by region failed: {interest_by_region}")
|
||||
interest_by_region = []
|
||||
if isinstance(related_topics, Exception):
|
||||
logger.error(f"Related topics failed: {related_topics}")
|
||||
related_topics = {"top": [], "rising": []}
|
||||
if isinstance(related_queries, Exception):
|
||||
logger.error(f"Related queries failed: {related_queries}")
|
||||
related_queries = {"top": [], "rising": []}
|
||||
|
||||
# Build result
|
||||
result = {
|
||||
"interest_over_time": interest_over_time,
|
||||
"interest_by_region": interest_by_region,
|
||||
"related_topics": related_topics,
|
||||
"related_queries": related_queries,
|
||||
"timeframe": timeframe,
|
||||
"geo": geo,
|
||||
"keywords": keywords,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"cached": False
|
||||
}
|
||||
|
||||
# Cache result
|
||||
self._save_to_cache(cache_key, result)
|
||||
|
||||
logger.info(f"Google Trends data fetched successfully: {len(interest_over_time)} time points, {len(interest_by_region)} regions")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Google Trends analysis failed: {e}")
|
||||
# Return fallback response
|
||||
return self._create_fallback_response(keywords, timeframe, geo, str(e))
|
||||
|
||||
def _initialize_pytrends(
|
||||
self,
|
||||
keywords: List[str],
|
||||
timeframe: str,
|
||||
geo: str
|
||||
) -> TrendReq:
|
||||
"""Initialize pytrends and build payload (sync operation)."""
|
||||
pytrends = TrendReq(hl='en-US', tz=360)
|
||||
pytrends.build_payload(kw_list=keywords, timeframe=timeframe, geo=geo)
|
||||
return pytrends
|
||||
|
||||
def _safe_interest_over_time(self, pytrends: TrendReq) -> List[Dict[str, Any]]:
|
||||
"""Safely fetch interest over time data."""
|
||||
try:
|
||||
df = pytrends.interest_over_time()
|
||||
if df.empty:
|
||||
return []
|
||||
return self._format_dataframe(df.reset_index())
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching interest over time: {e}")
|
||||
return []
|
||||
|
||||
def _safe_interest_by_region(self, pytrends: TrendReq) -> List[Dict[str, Any]]:
|
||||
"""Safely fetch interest by region data."""
|
||||
try:
|
||||
df = pytrends.interest_by_region(resolution='COUNTRY', inc_low_vol=True, inc_geo_code=False)
|
||||
if df.empty:
|
||||
return []
|
||||
return self._format_dataframe(df.reset_index())
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching interest by region: {e}")
|
||||
return []
|
||||
|
||||
def _safe_related_topics(
|
||||
self,
|
||||
pytrends: TrendReq,
|
||||
keywords: List[str]
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""Safely fetch related topics."""
|
||||
try:
|
||||
topics_data = pytrends.related_topics()
|
||||
result = {"top": [], "rising": []}
|
||||
|
||||
for keyword in keywords:
|
||||
if keyword in topics_data and isinstance(topics_data[keyword], dict):
|
||||
keyword_topics = topics_data[keyword]
|
||||
|
||||
if "top" in keyword_topics and not keyword_topics["top"].empty:
|
||||
top_df = keyword_topics["top"]
|
||||
# Select relevant columns
|
||||
if "topic_title" in top_df.columns and "value" in top_df.columns:
|
||||
top_data = top_df[["topic_title", "value"]].to_dict('records')
|
||||
result["top"].extend(top_data)
|
||||
|
||||
if "rising" in keyword_topics and not keyword_topics["rising"].empty:
|
||||
rising_df = keyword_topics["rising"]
|
||||
if "topic_title" in rising_df.columns and "value" in rising_df.columns:
|
||||
rising_data = rising_df[["topic_title", "value"]].to_dict('records')
|
||||
result["rising"].extend(rising_data)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching related topics: {e}")
|
||||
return {"top": [], "rising": []}
|
||||
|
||||
def _safe_related_queries(
|
||||
self,
|
||||
pytrends: TrendReq,
|
||||
keywords: List[str]
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""Safely fetch related queries."""
|
||||
try:
|
||||
queries_data = pytrends.related_queries()
|
||||
result = {"top": [], "rising": []}
|
||||
|
||||
for keyword in keywords:
|
||||
if keyword in queries_data and isinstance(queries_data[keyword], dict):
|
||||
keyword_queries = queries_data[keyword]
|
||||
|
||||
if "top" in keyword_queries and not keyword_queries["top"].empty:
|
||||
top_df = keyword_queries["top"]
|
||||
result["top"].extend(top_df.to_dict('records'))
|
||||
|
||||
if "rising" in keyword_queries and not keyword_queries["rising"].empty:
|
||||
rising_df = keyword_queries["rising"]
|
||||
result["rising"].extend(rising_df.to_dict('records'))
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching related queries: {e}")
|
||||
return {"top": [], "rising": []}
|
||||
|
||||
def _format_dataframe(self, df: pd.DataFrame) -> List[Dict[str, Any]]:
|
||||
"""Convert DataFrame to list of dicts (serializable format)."""
|
||||
if df.empty:
|
||||
return []
|
||||
|
||||
# Convert datetime columns to strings
|
||||
for col in df.columns:
|
||||
if pd.api.types.is_datetime64_any_dtype(df[col]):
|
||||
df[col] = df[col].astype(str)
|
||||
|
||||
# Convert to dict records
|
||||
return df.to_dict('records')
|
||||
|
||||
def _build_cache_key(self, keywords: List[str], timeframe: str, geo: str) -> str:
|
||||
"""Build cache key from parameters."""
|
||||
keywords_str = ":".join(sorted(keywords))
|
||||
return f"google_trends:{keywords_str}:{timeframe}:{geo}"
|
||||
|
||||
def _get_from_cache(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get data from cache if not expired."""
|
||||
if cache_key not in self.cache:
|
||||
return None
|
||||
|
||||
cached_entry = self.cache[cache_key]
|
||||
cached_time = datetime.fromisoformat(cached_entry.get("timestamp", ""))
|
||||
|
||||
if datetime.utcnow() - cached_time > self.cache_ttl:
|
||||
# Expired, remove from cache
|
||||
del self.cache[cache_key]
|
||||
return None
|
||||
|
||||
# Return cached data (without cached flag)
|
||||
result = {**cached_entry}
|
||||
result.pop("cached", None)
|
||||
return result
|
||||
|
||||
def _save_to_cache(self, cache_key: str, data: Dict[str, Any]):
|
||||
"""Save data to cache."""
|
||||
# Store with timestamp
|
||||
cache_entry = {
|
||||
**data,
|
||||
"cached_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
self.cache[cache_key] = cache_entry
|
||||
|
||||
# Clean up old cache entries periodically
|
||||
if len(self.cache) > 100: # Limit cache size
|
||||
self._cleanup_cache()
|
||||
|
||||
def _cleanup_cache(self):
|
||||
"""Remove expired cache entries."""
|
||||
now = datetime.utcnow()
|
||||
expired_keys = []
|
||||
|
||||
for key, entry in self.cache.items():
|
||||
cached_time = datetime.fromisoformat(entry.get("cached_at", entry.get("timestamp", "")))
|
||||
if now - cached_time > self.cache_ttl:
|
||||
expired_keys.append(key)
|
||||
|
||||
for key in expired_keys:
|
||||
del self.cache[key]
|
||||
|
||||
logger.debug(f"Cleaned up {len(expired_keys)} expired cache entries")
|
||||
|
||||
def _create_fallback_response(
|
||||
self,
|
||||
keywords: List[str],
|
||||
timeframe: str,
|
||||
geo: str,
|
||||
error_message: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Create fallback response when trends analysis fails."""
|
||||
return {
|
||||
"interest_over_time": [],
|
||||
"interest_by_region": [],
|
||||
"related_topics": {"top": [], "rising": []},
|
||||
"related_queries": {"top": [], "rising": []},
|
||||
"timeframe": timeframe,
|
||||
"geo": geo,
|
||||
"keywords": keywords,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"cached": False,
|
||||
"error": error_message
|
||||
}
|
||||
|
||||
async def get_trending_searches(
|
||||
self,
|
||||
country: str = "united_states",
|
||||
user_id: Optional[str] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get current trending searches for a country.
|
||||
|
||||
Args:
|
||||
country: Country name (e.g., "united_states", "united_kingdom")
|
||||
user_id: User ID for subscription checks
|
||||
|
||||
Returns:
|
||||
List of trending search terms
|
||||
"""
|
||||
await self.rate_limiter.acquire()
|
||||
|
||||
try:
|
||||
pytrends = TrendReq(hl='en-US', tz=360)
|
||||
trending_df = await asyncio.to_thread(
|
||||
lambda: pytrends.trending_searches(pn=country)
|
||||
)
|
||||
|
||||
if trending_df.empty:
|
||||
return []
|
||||
|
||||
# Return as list of strings
|
||||
return trending_df[0].tolist() if len(trending_df.columns) > 0 else []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching trending searches: {e}")
|
||||
return []
|
||||
57
backend/services/research/trends/rate_limiter.py
Normal file
57
backend/services/research/trends/rate_limiter.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
Rate Limiter for Google Trends API
|
||||
|
||||
Ensures we don't exceed Google Trends rate limits (1 request per second).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from time import time
|
||||
from collections import deque
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
Simple rate limiter for Google Trends API.
|
||||
|
||||
Limits requests to max_calls per period (in seconds).
|
||||
"""
|
||||
|
||||
def __init__(self, max_calls: int = 1, period: float = 1.0):
|
||||
"""
|
||||
Initialize rate limiter.
|
||||
|
||||
Args:
|
||||
max_calls: Maximum number of calls allowed
|
||||
period: Time period in seconds
|
||||
"""
|
||||
self.max_calls = max_calls
|
||||
self.period = period
|
||||
self.calls = deque()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def acquire(self):
|
||||
"""
|
||||
Acquire permission to make a request.
|
||||
|
||||
Will wait if rate limit would be exceeded.
|
||||
"""
|
||||
async with self._lock:
|
||||
now = time()
|
||||
|
||||
# Remove old calls outside the period
|
||||
while self.calls and self.calls[0] < now - self.period:
|
||||
self.calls.popleft()
|
||||
|
||||
# If at limit, wait until oldest call expires
|
||||
if len(self.calls) >= self.max_calls:
|
||||
sleep_time = self.period - (now - self.calls[0])
|
||||
if sleep_time > 0:
|
||||
logger.debug(f"Rate limit reached, waiting {sleep_time:.2f}s")
|
||||
await asyncio.sleep(sleep_time)
|
||||
# Recursively try again after waiting
|
||||
return await self.acquire()
|
||||
|
||||
# Record this call
|
||||
self.calls.append(time())
|
||||
logger.debug(f"Rate limit check passed, {len(self.calls)}/{self.max_calls} calls in period")
|
||||
557
backend/services/video_studio/edit_service.py
Normal file
557
backend/services/video_studio/edit_service.py
Normal file
@@ -0,0 +1,557 @@
|
||||
"""
|
||||
Edit Studio Service - Video editing operations.
|
||||
|
||||
Phase 1: Basic FFmpeg operations (Trim/Cut, Speed Control, Stabilization)
|
||||
Phase 2: Text Overlay & Captions, Audio Enhancement, Noise Reduction
|
||||
Phase 3: AI Features (Background Replacement, Object Removal, Color Grading)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import subprocess
|
||||
import tempfile
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from backend.services.video_studio.video_processors import (
|
||||
trim_video,
|
||||
adjust_speed,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EditService:
|
||||
"""Service for video editing operations."""
|
||||
|
||||
def __init__(self):
|
||||
logger.info("[EditService] Service initialized")
|
||||
|
||||
def calculate_cost(self, edit_type: str, duration: float = 10.0) -> float:
|
||||
"""Calculate cost for video editing operation. FFmpeg operations are free."""
|
||||
return 0.0
|
||||
|
||||
async def trim_video(
|
||||
self,
|
||||
video_data: bytes,
|
||||
start_time: float = 0.0,
|
||||
end_time: Optional[float] = None,
|
||||
max_duration: Optional[float] = None,
|
||||
trim_mode: str = "beginning",
|
||||
user_id: str = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Trim video to specified duration or time range."""
|
||||
try:
|
||||
logger.info(f"[EditService] Video trim: user={user_id}, start={start_time}, end={end_time}")
|
||||
|
||||
processed_video_bytes = await asyncio.to_thread(
|
||||
trim_video,
|
||||
video_bytes=video_data,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
max_duration=max_duration,
|
||||
trim_mode=trim_mode,
|
||||
)
|
||||
|
||||
from backend.services.content_assets.content_asset_service import ContentAssetService
|
||||
from backend.database.database import get_db
|
||||
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
try:
|
||||
asset_service = ContentAssetService(db)
|
||||
filename = f"edited_trim_{uuid.uuid4().hex[:8]}.mp4"
|
||||
|
||||
asset_result = asset_service.save_video_asset(
|
||||
user_id=user_id,
|
||||
video_data=processed_video_bytes,
|
||||
filename=filename,
|
||||
asset_type="video_edit",
|
||||
metadata={"edit_type": "trim", "start_time": start_time, "end_time": end_time},
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"video_url": asset_result.get("url"),
|
||||
"asset_id": asset_result.get("asset_id"),
|
||||
"cost": 0.0,
|
||||
"edit_type": "trim",
|
||||
"metadata": {"start_time": start_time, "end_time": end_time},
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EditService] Video trim failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Video trimming failed: {str(e)}")
|
||||
|
||||
async def adjust_speed(
|
||||
self,
|
||||
video_data: bytes,
|
||||
speed_factor: float,
|
||||
user_id: str = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Adjust video playback speed."""
|
||||
try:
|
||||
logger.info(f"[EditService] Speed adjustment: user={user_id}, factor={speed_factor}")
|
||||
|
||||
if speed_factor <= 0:
|
||||
raise HTTPException(status_code=400, detail="Speed factor must be greater than 0")
|
||||
if speed_factor > 4.0:
|
||||
raise HTTPException(status_code=400, detail="Speed factor cannot exceed 4.0")
|
||||
|
||||
processed_video_bytes = await asyncio.to_thread(
|
||||
adjust_speed,
|
||||
video_bytes=video_data,
|
||||
speed_factor=speed_factor,
|
||||
)
|
||||
|
||||
from backend.services.content_assets.content_asset_service import ContentAssetService
|
||||
from backend.database.database import get_db
|
||||
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
try:
|
||||
asset_service = ContentAssetService(db)
|
||||
filename = f"edited_speed_{uuid.uuid4().hex[:8]}.mp4"
|
||||
|
||||
asset_result = asset_service.save_video_asset(
|
||||
user_id=user_id,
|
||||
video_data=processed_video_bytes,
|
||||
filename=filename,
|
||||
asset_type="video_edit",
|
||||
metadata={"edit_type": "speed", "speed_factor": speed_factor},
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"video_url": asset_result.get("url"),
|
||||
"asset_id": asset_result.get("asset_id"),
|
||||
"cost": 0.0,
|
||||
"edit_type": "speed",
|
||||
"metadata": {"speed_factor": speed_factor},
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EditService] Speed adjustment failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Speed adjustment failed: {str(e)}")
|
||||
|
||||
async def stabilize_video(
|
||||
self,
|
||||
video_data: bytes,
|
||||
smoothing: int = 10,
|
||||
user_id: str = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Stabilize video using FFmpeg vidstab."""
|
||||
try:
|
||||
logger.info(f"[EditService] Stabilization: user={user_id}, smoothing={smoothing}")
|
||||
|
||||
smoothing = max(1, min(100, smoothing))
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as input_file:
|
||||
input_file.write(video_data)
|
||||
input_path = input_file.name
|
||||
|
||||
transforms_file = tempfile.NamedTemporaryFile(suffix=".trf", delete=False, delete_on_close=False)
|
||||
transforms_path = transforms_file.name
|
||||
transforms_file.close()
|
||||
|
||||
output_path = None
|
||||
|
||||
try:
|
||||
detect_cmd = [
|
||||
"ffmpeg", "-i", input_path,
|
||||
"-vf", f"vidstabdetect=stepsize=6:shakiness=10:accuracy=15:result={transforms_path}",
|
||||
"-f", "null", "-"
|
||||
]
|
||||
subprocess.run(detect_cmd, capture_output=True, text=True, timeout=300)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as output_file:
|
||||
output_path = output_file.name
|
||||
|
||||
transform_cmd = [
|
||||
"ffmpeg", "-i", input_path,
|
||||
"-vf", f"vidstabtransform=input={transforms_path}:smoothing={smoothing}:zoom=1:optzoom=1",
|
||||
"-c:v", "libx264", "-preset", "medium", "-crf", "23",
|
||||
"-c:a", "copy", "-y", output_path
|
||||
]
|
||||
result = subprocess.run(transform_cmd, capture_output=True, text=True, timeout=600)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise HTTPException(status_code=500, detail=f"Stabilization failed: {result.stderr}")
|
||||
|
||||
with open(output_path, "rb") as f:
|
||||
processed_video_bytes = f.read()
|
||||
|
||||
from backend.services.content_assets.content_asset_service import ContentAssetService
|
||||
from backend.database.database import get_db
|
||||
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
try:
|
||||
asset_service = ContentAssetService(db)
|
||||
filename = f"edited_stabilized_{uuid.uuid4().hex[:8]}.mp4"
|
||||
|
||||
asset_result = asset_service.save_video_asset(
|
||||
user_id=user_id,
|
||||
video_data=processed_video_bytes,
|
||||
filename=filename,
|
||||
asset_type="video_edit",
|
||||
metadata={"edit_type": "stabilize", "smoothing": smoothing},
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"video_url": asset_result.get("url"),
|
||||
"asset_id": asset_result.get("asset_id"),
|
||||
"cost": 0.0,
|
||||
"edit_type": "stabilize",
|
||||
"metadata": {"smoothing": smoothing},
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
finally:
|
||||
Path(input_path).unlink(missing_ok=True)
|
||||
Path(transforms_path).unlink(missing_ok=True)
|
||||
if output_path:
|
||||
Path(output_path).unlink(missing_ok=True)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
raise HTTPException(status_code=504, detail="Stabilization timed out")
|
||||
except Exception as e:
|
||||
logger.error(f"[EditService] Stabilization failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Stabilization failed: {str(e)}")
|
||||
|
||||
# Phase 2: Text and Audio operations
|
||||
|
||||
async def add_text_overlay(
|
||||
self,
|
||||
video_data: bytes,
|
||||
text: str,
|
||||
position: str = "center",
|
||||
font_size: int = 48,
|
||||
font_color: str = "white",
|
||||
background_color: str = "black@0.5",
|
||||
start_time: float = 0.0,
|
||||
end_time: Optional[float] = None,
|
||||
user_id: str = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Add text overlay to video using FFmpeg drawtext filter."""
|
||||
try:
|
||||
logger.info(f"[EditService] Text overlay: user={user_id}, text='{text[:30]}...'")
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as input_file:
|
||||
input_file.write(video_data)
|
||||
input_path = input_file.name
|
||||
|
||||
output_path = None
|
||||
|
||||
try:
|
||||
position_map = {
|
||||
"top": "(w-text_w)/2:50",
|
||||
"center": "(w-text_w)/2:(h-text_h)/2",
|
||||
"bottom": "(w-text_w)/2:h-text_h-50",
|
||||
"top-left": "50:50",
|
||||
"top-right": "w-text_w-50:50",
|
||||
"bottom-left": "50:h-text_h-50",
|
||||
"bottom-right": "w-text_w-50:h-text_h-50",
|
||||
}
|
||||
pos_expr = position_map.get(position, position_map["center"])
|
||||
|
||||
escaped_text = text.replace("'", "'\\''").replace(":", "\\:")
|
||||
|
||||
drawtext_filter = (
|
||||
f"drawtext=text='{escaped_text}':"
|
||||
f"fontsize={font_size}:fontcolor={font_color}:"
|
||||
f"x={pos_expr.split(':')[0]}:y={pos_expr.split(':')[1]}:"
|
||||
f"box=1:boxcolor={background_color}:boxborderw=10"
|
||||
)
|
||||
|
||||
if start_time > 0 or end_time is not None:
|
||||
enable_expr = f"between(t,{start_time},{end_time if end_time else 9999})"
|
||||
drawtext_filter += f":enable='{enable_expr}'"
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as output_file:
|
||||
output_path = output_file.name
|
||||
|
||||
cmd = [
|
||||
"ffmpeg", "-i", input_path, "-vf", drawtext_filter,
|
||||
"-c:v", "libx264", "-preset", "medium", "-crf", "23",
|
||||
"-c:a", "copy", "-y", output_path
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise HTTPException(status_code=500, detail=f"Text overlay failed: {result.stderr}")
|
||||
|
||||
with open(output_path, "rb") as f:
|
||||
processed_video_bytes = f.read()
|
||||
|
||||
from backend.services.content_assets.content_asset_service import ContentAssetService
|
||||
from backend.database.database import get_db
|
||||
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
try:
|
||||
asset_service = ContentAssetService(db)
|
||||
filename = f"edited_text_{uuid.uuid4().hex[:8]}.mp4"
|
||||
|
||||
asset_result = asset_service.save_video_asset(
|
||||
user_id=user_id,
|
||||
video_data=processed_video_bytes,
|
||||
filename=filename,
|
||||
asset_type="video_edit",
|
||||
metadata={"edit_type": "text_overlay", "text": text[:100], "position": position},
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"video_url": asset_result.get("url"),
|
||||
"asset_id": asset_result.get("asset_id"),
|
||||
"cost": 0.0,
|
||||
"edit_type": "text_overlay",
|
||||
"metadata": {"text": text[:100], "position": position, "font_size": font_size},
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
finally:
|
||||
Path(input_path).unlink(missing_ok=True)
|
||||
if output_path:
|
||||
Path(output_path).unlink(missing_ok=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EditService] Text overlay failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Text overlay failed: {str(e)}")
|
||||
|
||||
async def adjust_volume(
|
||||
self,
|
||||
video_data: bytes,
|
||||
volume_factor: float,
|
||||
user_id: str = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Adjust video audio volume using FFmpeg."""
|
||||
try:
|
||||
logger.info(f"[EditService] Volume adjustment: user={user_id}, factor={volume_factor}")
|
||||
|
||||
if volume_factor < 0:
|
||||
raise HTTPException(status_code=400, detail="Volume factor must be non-negative")
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as input_file:
|
||||
input_file.write(video_data)
|
||||
input_path = input_file.name
|
||||
|
||||
output_path = None
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as output_file:
|
||||
output_path = output_file.name
|
||||
|
||||
cmd = [
|
||||
"ffmpeg", "-i", input_path,
|
||||
"-af", f"volume={volume_factor}",
|
||||
"-c:v", "copy", "-c:a", "aac", "-y", output_path
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise HTTPException(status_code=500, detail=f"Volume adjustment failed: {result.stderr}")
|
||||
|
||||
with open(output_path, "rb") as f:
|
||||
processed_video_bytes = f.read()
|
||||
|
||||
from backend.services.content_assets.content_asset_service import ContentAssetService
|
||||
from backend.database.database import get_db
|
||||
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
try:
|
||||
asset_service = ContentAssetService(db)
|
||||
filename = f"edited_volume_{uuid.uuid4().hex[:8]}.mp4"
|
||||
|
||||
asset_result = asset_service.save_video_asset(
|
||||
user_id=user_id,
|
||||
video_data=processed_video_bytes,
|
||||
filename=filename,
|
||||
asset_type="video_edit",
|
||||
metadata={"edit_type": "volume", "volume_factor": volume_factor},
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"video_url": asset_result.get("url"),
|
||||
"asset_id": asset_result.get("asset_id"),
|
||||
"cost": 0.0,
|
||||
"edit_type": "volume",
|
||||
"metadata": {"volume_factor": volume_factor},
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
finally:
|
||||
Path(input_path).unlink(missing_ok=True)
|
||||
if output_path:
|
||||
Path(output_path).unlink(missing_ok=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EditService] Volume adjustment failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Volume adjustment failed: {str(e)}")
|
||||
|
||||
async def normalize_audio(
|
||||
self,
|
||||
video_data: bytes,
|
||||
target_level: float = -14.0,
|
||||
user_id: str = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Normalize audio levels using FFmpeg loudnorm filter (EBU R128)."""
|
||||
try:
|
||||
logger.info(f"[EditService] Audio normalization: user={user_id}, level={target_level} LUFS")
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as input_file:
|
||||
input_file.write(video_data)
|
||||
input_path = input_file.name
|
||||
|
||||
output_path = None
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as output_file:
|
||||
output_path = output_file.name
|
||||
|
||||
cmd = [
|
||||
"ffmpeg", "-i", input_path,
|
||||
"-af", f"loudnorm=I={target_level}:TP=-1.5:LRA=11",
|
||||
"-c:v", "copy", "-c:a", "aac", "-y", output_path
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise HTTPException(status_code=500, detail=f"Audio normalization failed: {result.stderr}")
|
||||
|
||||
with open(output_path, "rb") as f:
|
||||
processed_video_bytes = f.read()
|
||||
|
||||
from backend.services.content_assets.content_asset_service import ContentAssetService
|
||||
from backend.database.database import get_db
|
||||
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
try:
|
||||
asset_service = ContentAssetService(db)
|
||||
filename = f"edited_normalized_{uuid.uuid4().hex[:8]}.mp4"
|
||||
|
||||
asset_result = asset_service.save_video_asset(
|
||||
user_id=user_id,
|
||||
video_data=processed_video_bytes,
|
||||
filename=filename,
|
||||
asset_type="video_edit",
|
||||
metadata={"edit_type": "normalize", "target_level": target_level},
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"video_url": asset_result.get("url"),
|
||||
"asset_id": asset_result.get("asset_id"),
|
||||
"cost": 0.0,
|
||||
"edit_type": "normalize",
|
||||
"metadata": {"target_level": target_level},
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
finally:
|
||||
Path(input_path).unlink(missing_ok=True)
|
||||
if output_path:
|
||||
Path(output_path).unlink(missing_ok=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EditService] Audio normalization failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Audio normalization failed: {str(e)}")
|
||||
|
||||
async def reduce_noise(
|
||||
self,
|
||||
video_data: bytes,
|
||||
noise_reduction_strength: float = 0.5,
|
||||
user_id: str = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Reduce audio noise using FFmpeg's anlmdn filter."""
|
||||
try:
|
||||
logger.info(f"[EditService] Noise reduction: user={user_id}, strength={noise_reduction_strength}")
|
||||
|
||||
strength = max(0.0, min(1.0, noise_reduction_strength))
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as input_file:
|
||||
input_file.write(video_data)
|
||||
input_path = input_file.name
|
||||
|
||||
output_path = None
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as output_file:
|
||||
output_path = output_file.name
|
||||
|
||||
sigma = 0.0001 + (strength * 0.005)
|
||||
|
||||
cmd = [
|
||||
"ffmpeg", "-i", input_path,
|
||||
"-af", f"anlmdn=s={sigma}:p=0.002:r=0.002",
|
||||
"-c:v", "copy", "-c:a", "aac", "-y", output_path
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
|
||||
|
||||
if result.returncode != 0:
|
||||
# Fallback to highpass/lowpass
|
||||
cmd = [
|
||||
"ffmpeg", "-i", input_path,
|
||||
"-af", "highpass=f=80,lowpass=f=12000",
|
||||
"-c:v", "copy", "-c:a", "aac", "-y", output_path
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise HTTPException(status_code=500, detail=f"Noise reduction failed: {result.stderr}")
|
||||
|
||||
with open(output_path, "rb") as f:
|
||||
processed_video_bytes = f.read()
|
||||
|
||||
from backend.services.content_assets.content_asset_service import ContentAssetService
|
||||
from backend.database.database import get_db
|
||||
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
try:
|
||||
asset_service = ContentAssetService(db)
|
||||
filename = f"edited_denoised_{uuid.uuid4().hex[:8]}.mp4"
|
||||
|
||||
asset_result = asset_service.save_video_asset(
|
||||
user_id=user_id,
|
||||
video_data=processed_video_bytes,
|
||||
filename=filename,
|
||||
asset_type="video_edit",
|
||||
metadata={"edit_type": "noise_reduction", "strength": strength},
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"video_url": asset_result.get("url"),
|
||||
"asset_id": asset_result.get("asset_id"),
|
||||
"cost": 0.0,
|
||||
"edit_type": "noise_reduction",
|
||||
"metadata": {"strength": strength},
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
finally:
|
||||
Path(input_path).unlink(missing_ok=True)
|
||||
if output_path:
|
||||
Path(output_path).unlink(missing_ok=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EditService] Noise reduction failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Noise reduction failed: {str(e)}")
|
||||
9
backend/services/wavespeed/generators/video/__init__.py
Normal file
9
backend/services/wavespeed/generators/video/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Video generation generator for WaveSpeed API.
|
||||
|
||||
Modular implementation with separate modules for different video operations.
|
||||
"""
|
||||
|
||||
from .generator import VideoGenerator
|
||||
|
||||
__all__ = ["VideoGenerator"]
|
||||
244
backend/services/wavespeed/generators/video/audio.py
Normal file
244
backend/services/wavespeed/generators/video/audio.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""
|
||||
Video audio generation operations.
|
||||
"""
|
||||
|
||||
import requests
|
||||
from typing import Optional, Callable
|
||||
from fastapi import HTTPException
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
from .base import VideoBase
|
||||
|
||||
logger = get_service_logger("wavespeed.generators.video.audio")
|
||||
|
||||
|
||||
class VideoAudio(VideoBase):
|
||||
"""Video audio generation operations."""
|
||||
|
||||
def hunyuan_video_foley(
|
||||
self,
|
||||
video: str, # Base64-encoded video or URL
|
||||
prompt: Optional[str] = None, # Optional text prompt describing desired sounds
|
||||
seed: int = -1, # Random seed (-1 for random)
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate realistic Foley and ambient audio from video using Hunyuan Video Foley.
|
||||
|
||||
Args:
|
||||
video: Base64-encoded video data URI or public URL (source video)
|
||||
prompt: Optional text prompt describing desired sounds (e.g., "ocean waves, seagulls")
|
||||
seed: Random seed for reproducibility (-1 for random)
|
||||
enable_sync_mode: If True, wait for result and return it directly
|
||||
timeout: Request timeout in seconds (default: 300)
|
||||
progress_callback: Optional callback function(progress: float, message: str) for progress updates
|
||||
|
||||
Returns:
|
||||
bytes: Video with generated audio
|
||||
|
||||
Raises:
|
||||
HTTPException: If the audio generation fails
|
||||
"""
|
||||
model_path = "wavespeed-ai/hunyuan-video-foley"
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
|
||||
# Build payload
|
||||
payload = {
|
||||
"video": video,
|
||||
"seed": seed,
|
||||
}
|
||||
|
||||
if prompt:
|
||||
payload["prompt"] = prompt
|
||||
|
||||
logger.info(
|
||||
f"[WaveSpeed] Hunyuan Video Foley request via {url} "
|
||||
f"(has_prompt={prompt is not None}, seed={seed})"
|
||||
)
|
||||
|
||||
# Submit the task
|
||||
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Hunyuan Video Foley submission failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed Hunyuan Video Foley submission failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
prediction_id = data.get("id")
|
||||
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed] No prediction ID in Hunyuan Video Foley response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed Hunyuan Video Foley response missing prediction id",
|
||||
)
|
||||
|
||||
logger.info(f"[WaveSpeed] Hunyuan Video Foley task submitted: {prediction_id}")
|
||||
|
||||
if enable_sync_mode:
|
||||
result = self.polling.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=timeout,
|
||||
interval_seconds=2.0,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
outputs = result.get("outputs") or []
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed Hunyuan Video Foley returned no outputs")
|
||||
|
||||
video_url = None
|
||||
if isinstance(outputs[0], str):
|
||||
video_url = outputs[0]
|
||||
elif isinstance(outputs[0], dict):
|
||||
video_url = outputs[0].get("url") or outputs[0].get("video_url")
|
||||
|
||||
if not video_url:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed Hunyuan Video Foley output format not recognized")
|
||||
|
||||
logger.info(f"[WaveSpeed] Downloading video with audio from: {video_url}")
|
||||
video_response = requests.get(video_url, timeout=timeout)
|
||||
|
||||
if video_response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Failed to download video with audio: {video_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to download video with audio from WaveSpeed",
|
||||
)
|
||||
|
||||
video_bytes = video_response.content
|
||||
logger.info(f"[WaveSpeed] Hunyuan Video Foley completed successfully (size: {len(video_bytes)} bytes)")
|
||||
|
||||
return video_bytes
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail={
|
||||
"error": "Async mode not yet implemented for Hunyuan Video Foley",
|
||||
"prediction_id": prediction_id,
|
||||
},
|
||||
)
|
||||
|
||||
def think_sound(
|
||||
self,
|
||||
video: str, # Base64-encoded video or URL
|
||||
prompt: Optional[str] = None, # Optional text prompt describing desired sounds
|
||||
seed: int = -1, # Random seed (-1 for random)
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate realistic sound effects and audio tracks from video using Think Sound.
|
||||
|
||||
Args:
|
||||
video: Base64-encoded video data URI or public URL (source video)
|
||||
prompt: Optional text prompt describing desired sounds (e.g., "engine roaring, footsteps on gravel")
|
||||
seed: Random seed for reproducibility (-1 for random)
|
||||
enable_sync_mode: If True, wait for result and return it directly
|
||||
timeout: Request timeout in seconds (default: 300)
|
||||
progress_callback: Optional callback function(progress: float, message: str) for progress updates
|
||||
|
||||
Returns:
|
||||
bytes: Video with generated audio
|
||||
|
||||
Raises:
|
||||
HTTPException: If the audio generation fails
|
||||
"""
|
||||
model_path = "wavespeed-ai/think-sound"
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
|
||||
# Build payload
|
||||
payload = {
|
||||
"video": video,
|
||||
"seed": seed,
|
||||
}
|
||||
|
||||
if prompt:
|
||||
payload["prompt"] = prompt
|
||||
|
||||
logger.info(
|
||||
f"[WaveSpeed] Think Sound request via {url} "
|
||||
f"(has_prompt={prompt is not None}, seed={seed})"
|
||||
)
|
||||
|
||||
# Submit the task
|
||||
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Think Sound submission failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed Think Sound submission failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
prediction_id = data.get("id")
|
||||
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed] No prediction ID in Think Sound response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed Think Sound response missing prediction id",
|
||||
)
|
||||
|
||||
logger.info(f"[WaveSpeed] Think Sound task submitted: {prediction_id}")
|
||||
|
||||
if enable_sync_mode:
|
||||
result = self.polling.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=timeout,
|
||||
interval_seconds=2.0,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
outputs = result.get("outputs") or []
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed Think Sound returned no outputs")
|
||||
|
||||
video_url = None
|
||||
if isinstance(outputs[0], str):
|
||||
video_url = outputs[0]
|
||||
elif isinstance(outputs[0], dict):
|
||||
video_url = outputs[0].get("url") or outputs[0].get("video_url")
|
||||
|
||||
if not video_url:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed Think Sound output format not recognized")
|
||||
|
||||
logger.info(f"[WaveSpeed] Downloading video with audio from: {video_url}")
|
||||
video_response = requests.get(video_url, timeout=timeout)
|
||||
|
||||
if video_response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Failed to download video with audio: {video_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to download video with audio from WaveSpeed",
|
||||
)
|
||||
|
||||
video_bytes = video_response.content
|
||||
logger.info(f"[WaveSpeed] Think Sound completed successfully (size: {len(video_bytes)} bytes)")
|
||||
|
||||
return video_bytes
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail={
|
||||
"error": "Async mode not yet implemented for Think Sound",
|
||||
"prediction_id": prediction_id,
|
||||
},
|
||||
)
|
||||
127
backend/services/wavespeed/generators/video/background.py
Normal file
127
backend/services/wavespeed/generators/video/background.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
Video background removal operations.
|
||||
"""
|
||||
|
||||
import requests
|
||||
from typing import Optional, Callable
|
||||
from fastapi import HTTPException
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
from .base import VideoBase
|
||||
|
||||
logger = get_service_logger("wavespeed.generators.video.background")
|
||||
|
||||
|
||||
class VideoBackground(VideoBase):
|
||||
"""Video background removal operations."""
|
||||
|
||||
def remove_background(
|
||||
self,
|
||||
video: str, # Base64-encoded video or URL
|
||||
background_image: Optional[str] = None, # Base64-encoded image or URL (optional)
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Remove or replace video background using Video Background Remover.
|
||||
|
||||
Args:
|
||||
video: Base64-encoded video data URI or public URL (source video)
|
||||
background_image: Optional base64-encoded image data URI or public URL (replacement background)
|
||||
enable_sync_mode: If True, wait for result and return it directly
|
||||
timeout: Request timeout in seconds (default: 300)
|
||||
progress_callback: Optional callback function(progress: float, message: str) for progress updates
|
||||
|
||||
Returns:
|
||||
bytes: Video with background removed/replaced
|
||||
|
||||
Raises:
|
||||
HTTPException: If the background removal fails
|
||||
"""
|
||||
model_path = "wavespeed-ai/video-background-remover"
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
|
||||
# Build payload
|
||||
payload = {
|
||||
"video": video,
|
||||
}
|
||||
|
||||
if background_image:
|
||||
payload["background_image"] = background_image
|
||||
|
||||
logger.info(
|
||||
f"[WaveSpeed] Video background removal request via {url} "
|
||||
f"(has_background={background_image is not None})"
|
||||
)
|
||||
|
||||
# Submit the task
|
||||
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Video background removal submission failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed video background removal submission failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
prediction_id = data.get("id")
|
||||
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed] No prediction ID in video background removal response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed video background removal response missing prediction id",
|
||||
)
|
||||
|
||||
logger.info(f"[WaveSpeed] Video background removal task submitted: {prediction_id}")
|
||||
|
||||
if enable_sync_mode:
|
||||
result = self.polling.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=timeout,
|
||||
interval_seconds=2.0,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
outputs = result.get("outputs") or []
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed video background removal returned no outputs")
|
||||
|
||||
video_url = None
|
||||
if isinstance(outputs[0], str):
|
||||
video_url = outputs[0]
|
||||
elif isinstance(outputs[0], dict):
|
||||
video_url = outputs[0].get("url") or outputs[0].get("video_url")
|
||||
|
||||
if not video_url:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed video background removal output format not recognized")
|
||||
|
||||
logger.info(f"[WaveSpeed] Downloading processed video from: {video_url}")
|
||||
video_response = requests.get(video_url, timeout=timeout)
|
||||
|
||||
if video_response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Failed to download processed video: {video_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to download processed video from WaveSpeed",
|
||||
)
|
||||
|
||||
video_bytes = video_response.content
|
||||
logger.info(f"[WaveSpeed] Video background removal completed successfully (size: {len(video_bytes)} bytes)")
|
||||
|
||||
return video_bytes
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail={
|
||||
"error": "Async mode not yet implemented for video background removal",
|
||||
"prediction_id": prediction_id,
|
||||
},
|
||||
)
|
||||
84
backend/services/wavespeed/generators/video/base.py
Normal file
84
backend/services/wavespeed/generators/video/base.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
Base functionality for video operations.
|
||||
|
||||
Shared utilities for HTTP requests, video download, and common operations.
|
||||
"""
|
||||
|
||||
import requests
|
||||
from typing import Optional
|
||||
from fastapi import HTTPException
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("wavespeed.generators.video.base")
|
||||
|
||||
|
||||
class VideoBase:
|
||||
"""Base class for video operations with shared functionality."""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str, polling):
|
||||
"""Initialize video base.
|
||||
|
||||
Args:
|
||||
api_key: WaveSpeed API key
|
||||
base_url: WaveSpeed API base URL
|
||||
polling: WaveSpeedPolling instance for async operations
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.polling = polling
|
||||
|
||||
def _get_headers(self) -> dict:
|
||||
"""Get HTTP headers for API requests."""
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
|
||||
def _download_video(self, video_url: str, timeout: int = 180) -> bytes:
|
||||
"""Download video from URL.
|
||||
|
||||
Args:
|
||||
video_url: URL to download video from
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
bytes: Video bytes
|
||||
|
||||
Raises:
|
||||
HTTPException: If download fails
|
||||
"""
|
||||
logger.info(f"[WaveSpeed] Downloading video from: {video_url}")
|
||||
video_response = requests.get(video_url, timeout=timeout)
|
||||
|
||||
if video_response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "Failed to download video",
|
||||
"status_code": video_response.status_code,
|
||||
"response": video_response.text[:200],
|
||||
}
|
||||
)
|
||||
|
||||
return video_response.content
|
||||
|
||||
def _extract_video_url(self, outputs: list) -> Optional[str]:
|
||||
"""Extract video URL from outputs array.
|
||||
|
||||
Args:
|
||||
outputs: Array of outputs (can be strings or dicts)
|
||||
|
||||
Returns:
|
||||
Optional[str]: Video URL if found, None otherwise
|
||||
"""
|
||||
if not outputs:
|
||||
return None
|
||||
|
||||
output = outputs[0]
|
||||
if isinstance(output, str):
|
||||
return output if output.startswith("http") else None
|
||||
elif isinstance(output, dict):
|
||||
return output.get("url") or output.get("video_url")
|
||||
|
||||
return None
|
||||
109
backend/services/wavespeed/generators/video/enhancement.py
Normal file
109
backend/services/wavespeed/generators/video/enhancement.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
Video enhancement operations (upscaling).
|
||||
"""
|
||||
|
||||
import requests
|
||||
from typing import Optional, Callable
|
||||
from fastapi import HTTPException
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
from .base import VideoBase
|
||||
|
||||
logger = get_service_logger("wavespeed.generators.video.enhancement")
|
||||
|
||||
|
||||
class VideoEnhancement(VideoBase):
|
||||
"""Video enhancement operations."""
|
||||
|
||||
def upscale_video(
|
||||
self,
|
||||
video: str, # Base64-encoded video or URL
|
||||
target_resolution: str = "1080p",
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Upscale video using FlashVSR.
|
||||
|
||||
Args:
|
||||
video: Base64-encoded video data URI or public URL
|
||||
target_resolution: Target resolution ("720p", "1080p", "2k", "4k")
|
||||
enable_sync_mode: If True, wait for result and return it directly
|
||||
timeout: Request timeout in seconds (default: 300 for long videos)
|
||||
progress_callback: Optional callback function(progress: float, message: str) for progress updates
|
||||
|
||||
Returns:
|
||||
bytes: Upscaled video bytes
|
||||
|
||||
Raises:
|
||||
HTTPException: If the upscaling fails
|
||||
"""
|
||||
model_path = "wavespeed-ai/flashvsr"
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
|
||||
payload = {
|
||||
"video": video,
|
||||
"target_resolution": target_resolution,
|
||||
}
|
||||
|
||||
logger.info(f"[WaveSpeed] Upscaling video via {url} (target={target_resolution})")
|
||||
|
||||
# Submit the task
|
||||
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] FlashVSR submission failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed FlashVSR submission failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
prediction_id = data.get("id")
|
||||
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed] No prediction ID in FlashVSR response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed FlashVSR response missing prediction id",
|
||||
)
|
||||
|
||||
logger.info(f"[WaveSpeed] FlashVSR task submitted: {prediction_id}")
|
||||
|
||||
# Poll for result
|
||||
result = self.polling.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=timeout,
|
||||
interval_seconds=2.0, # Longer interval for upscaling (slower process)
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
outputs = result.get("outputs") or []
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed FlashVSR returned no outputs")
|
||||
|
||||
video_url = outputs[0] if isinstance(outputs[0], str) else outputs[0].get("url")
|
||||
if not video_url:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed FlashVSR output format not recognized")
|
||||
|
||||
# Download the upscaled video
|
||||
logger.info(f"[WaveSpeed] Downloading upscaled video from: {video_url}")
|
||||
video_response = requests.get(video_url, timeout=timeout)
|
||||
|
||||
if video_response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Failed to download upscaled video: {video_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to download upscaled video from WaveSpeed",
|
||||
)
|
||||
|
||||
video_bytes = video_response.content
|
||||
logger.info(f"[WaveSpeed] Video upscaling completed successfully (size: {len(video_bytes)} bytes)")
|
||||
|
||||
return video_bytes
|
||||
161
backend/services/wavespeed/generators/video/extension.py
Normal file
161
backend/services/wavespeed/generators/video/extension.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""
|
||||
Video extension operations.
|
||||
"""
|
||||
|
||||
import requests
|
||||
from typing import Optional, Callable
|
||||
from fastapi import HTTPException
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
from .base import VideoBase
|
||||
|
||||
logger = get_service_logger("wavespeed.generators.video.extension")
|
||||
|
||||
|
||||
class VideoExtension(VideoBase):
|
||||
"""Video extension operations."""
|
||||
|
||||
def extend_video(
|
||||
self,
|
||||
video: str, # Base64-encoded video or URL
|
||||
prompt: str,
|
||||
model: str = "wan-2.5", # "wan-2.5", "wan-2.2-spicy", or "seedance-1.5-pro"
|
||||
audio: Optional[str] = None, # Optional audio URL (WAN 2.5 only)
|
||||
negative_prompt: Optional[str] = None, # WAN 2.5 only
|
||||
resolution: str = "720p",
|
||||
duration: int = 5,
|
||||
enable_prompt_expansion: bool = False, # WAN 2.5 only
|
||||
generate_audio: bool = True, # Seedance 1.5 Pro only
|
||||
camera_fixed: bool = False, # Seedance 1.5 Pro only
|
||||
seed: Optional[int] = None,
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Extend video duration using WAN 2.5, WAN 2.2 Spicy, or Seedance 1.5 Pro video-extend.
|
||||
|
||||
Args:
|
||||
video: Base64-encoded video data URI or public URL
|
||||
prompt: Text prompt describing how to extend the video
|
||||
model: Model to use ("wan-2.5", "wan-2.2-spicy", or "seedance-1.5-pro")
|
||||
audio: Optional audio URL to guide generation (WAN 2.5 only)
|
||||
negative_prompt: Optional negative prompt (WAN 2.5 only)
|
||||
resolution: Output resolution (varies by model)
|
||||
duration: Duration of extended video in seconds (varies by model)
|
||||
enable_prompt_expansion: Enable prompt optimizer (WAN 2.5 only)
|
||||
generate_audio: Generate audio for extended video (Seedance 1.5 Pro only)
|
||||
camera_fixed: Fix camera position (Seedance 1.5 Pro only)
|
||||
seed: Random seed for reproducibility (-1 for random)
|
||||
enable_sync_mode: If True, wait for result and return it directly
|
||||
timeout: Request timeout in seconds (default: 300)
|
||||
progress_callback: Optional callback function(progress: float, message: str) for progress updates
|
||||
|
||||
Returns:
|
||||
bytes: Extended video bytes
|
||||
|
||||
Raises:
|
||||
HTTPException: If the extension fails
|
||||
"""
|
||||
# Determine model path
|
||||
if model in ("wan-2.2-spicy", "wavespeed-ai/wan-2.2-spicy/video-extend"):
|
||||
model_path = "wavespeed-ai/wan-2.2-spicy/video-extend"
|
||||
elif model in ("seedance-1.5-pro", "bytedance/seedance-v1.5-pro/video-extend"):
|
||||
model_path = "bytedance/seedance-v1.5-pro/video-extend"
|
||||
else:
|
||||
# Default to WAN 2.5
|
||||
model_path = "alibaba/wan-2.5/video-extend"
|
||||
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
|
||||
# Base payload (common to all models)
|
||||
payload = {
|
||||
"video": video,
|
||||
"prompt": prompt,
|
||||
"resolution": resolution,
|
||||
"duration": duration,
|
||||
}
|
||||
|
||||
# Model-specific parameters
|
||||
if model_path == "alibaba/wan-2.5/video-extend":
|
||||
# WAN 2.5 specific
|
||||
payload["enable_prompt_expansion"] = enable_prompt_expansion
|
||||
if audio:
|
||||
payload["audio"] = audio
|
||||
if negative_prompt:
|
||||
payload["negative_prompt"] = negative_prompt
|
||||
elif model_path == "bytedance/seedance-v1.5-pro/video-extend":
|
||||
# Seedance 1.5 Pro specific
|
||||
payload["generate_audio"] = generate_audio
|
||||
payload["camera_fixed"] = camera_fixed
|
||||
|
||||
# Seed (all models support it)
|
||||
if seed is not None:
|
||||
payload["seed"] = seed
|
||||
|
||||
logger.info(f"[WaveSpeed] Extending video via {url} (duration={duration}s, resolution={resolution})")
|
||||
|
||||
# Submit the task
|
||||
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Video extend submission failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed video extend submission failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
prediction_id = data.get("id")
|
||||
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed] No prediction ID in video extend response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed video extend response missing prediction id",
|
||||
)
|
||||
|
||||
logger.info(f"[WaveSpeed] Video extend task submitted: {prediction_id}")
|
||||
|
||||
# Poll for result
|
||||
result = self.polling.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=timeout,
|
||||
interval_seconds=2.0,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
outputs = result.get("outputs") or []
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed video extend returned no outputs")
|
||||
|
||||
# Handle outputs - can be array of strings or array of objects
|
||||
video_url = None
|
||||
if isinstance(outputs[0], str):
|
||||
video_url = outputs[0]
|
||||
elif isinstance(outputs[0], dict):
|
||||
video_url = outputs[0].get("url") or outputs[0].get("video_url")
|
||||
|
||||
if not video_url:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed video extend output format not recognized")
|
||||
|
||||
# Download the extended video
|
||||
logger.info(f"[WaveSpeed] Downloading extended video from: {video_url}")
|
||||
video_response = requests.get(video_url, timeout=timeout)
|
||||
|
||||
if video_response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Failed to download extended video: {video_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to download extended video from WaveSpeed",
|
||||
)
|
||||
|
||||
video_bytes = video_response.content
|
||||
logger.info(f"[WaveSpeed] Video extension completed successfully (size: {len(video_bytes)} bytes)")
|
||||
|
||||
return video_bytes
|
||||
283
backend/services/wavespeed/generators/video/face_swap.py
Normal file
283
backend/services/wavespeed/generators/video/face_swap.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
Face swap operations.
|
||||
"""
|
||||
|
||||
import requests
|
||||
from typing import Optional, Callable
|
||||
from fastapi import HTTPException
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
from .base import VideoBase
|
||||
|
||||
logger = get_service_logger("wavespeed.generators.video.face_swap")
|
||||
|
||||
|
||||
class VideoFaceSwap(VideoBase):
|
||||
"""Face swap operations."""
|
||||
|
||||
def face_swap(
|
||||
self,
|
||||
image: str, # Base64-encoded image or URL
|
||||
video: str, # Base64-encoded video or URL
|
||||
prompt: Optional[str] = None,
|
||||
resolution: str = "480p",
|
||||
seed: Optional[int] = None,
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Perform face/character swap using MoCha (wavespeed-ai/wan-2.1/mocha).
|
||||
|
||||
Args:
|
||||
image: Base64-encoded image data URI or public URL (reference character)
|
||||
video: Base64-encoded video data URI or public URL (source video)
|
||||
prompt: Optional prompt to guide the swap
|
||||
resolution: Output resolution ("480p" or "720p")
|
||||
seed: Random seed for reproducibility (-1 for random)
|
||||
enable_sync_mode: If True, wait for result and return it directly
|
||||
timeout: Request timeout in seconds (default: 300)
|
||||
progress_callback: Optional callback function(progress: float, message: str) for progress updates
|
||||
|
||||
Returns:
|
||||
bytes: Face-swapped video bytes
|
||||
|
||||
Raises:
|
||||
HTTPException: If the face swap fails
|
||||
"""
|
||||
model_path = "wavespeed-ai/wan-2.1/mocha"
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
|
||||
# Build payload
|
||||
payload = {
|
||||
"image": image,
|
||||
"video": video,
|
||||
}
|
||||
|
||||
if prompt:
|
||||
payload["prompt"] = prompt
|
||||
|
||||
if resolution in ("480p", "720p"):
|
||||
payload["resolution"] = resolution
|
||||
else:
|
||||
payload["resolution"] = "480p" # Default
|
||||
|
||||
if seed is not None:
|
||||
payload["seed"] = seed
|
||||
else:
|
||||
payload["seed"] = -1 # Random seed
|
||||
|
||||
logger.info(
|
||||
f"[WaveSpeed] Face swap request via {url} "
|
||||
f"(resolution={payload['resolution']}, seed={payload['seed']})"
|
||||
)
|
||||
|
||||
# Submit the task
|
||||
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Face swap submission failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed face swap submission failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
if not data or "id" not in data:
|
||||
logger.error(f"[WaveSpeed] Unexpected face swap response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "WaveSpeed response missing prediction id"},
|
||||
)
|
||||
|
||||
prediction_id = data["id"]
|
||||
logger.info(f"[WaveSpeed] Face swap submitted: {prediction_id}")
|
||||
|
||||
if enable_sync_mode:
|
||||
# Poll until complete
|
||||
result = self.polling.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=timeout,
|
||||
interval_seconds=2.0,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
# Extract video URL from result
|
||||
outputs = result.get("outputs", [])
|
||||
if not outputs:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "Face swap completed but no output video found"},
|
||||
)
|
||||
|
||||
# Handle outputs - can be array of strings or array of objects
|
||||
video_url = None
|
||||
if isinstance(outputs[0], str):
|
||||
video_url = outputs[0]
|
||||
elif isinstance(outputs[0], dict):
|
||||
video_url = outputs[0].get("url") or outputs[0].get("video_url")
|
||||
|
||||
if not video_url:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "Face swap output format not recognized"},
|
||||
)
|
||||
|
||||
# Download video
|
||||
logger.info(f"[WaveSpeed] Downloading face-swapped video from: {video_url}")
|
||||
video_response = requests.get(video_url, timeout=timeout)
|
||||
if video_response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": f"Failed to download face-swapped video: {video_response.status_code}"},
|
||||
)
|
||||
|
||||
video_bytes = video_response.content
|
||||
logger.info(f"[WaveSpeed] Face swap completed: {len(video_bytes)} bytes")
|
||||
return video_bytes
|
||||
else:
|
||||
# Return prediction ID for async polling
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail={
|
||||
"error": "Async mode not yet implemented for face swap",
|
||||
"prediction_id": prediction_id,
|
||||
},
|
||||
)
|
||||
|
||||
def video_face_swap(
|
||||
self,
|
||||
video: str, # Base64-encoded video or URL
|
||||
face_image: str, # Base64-encoded image or URL
|
||||
target_gender: str = "all",
|
||||
target_index: int = 0,
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Perform face swap using Video Face Swap (wavespeed-ai/video-face-swap).
|
||||
|
||||
Args:
|
||||
video: Base64-encoded video data URI or public URL (source video)
|
||||
face_image: Base64-encoded image data URI or public URL (reference face)
|
||||
target_gender: Filter which faces to swap ("all", "female", "male")
|
||||
target_index: Select which face to swap (0 = largest, 1 = second largest, etc.)
|
||||
enable_sync_mode: If True, wait for result and return it directly
|
||||
timeout: Request timeout in seconds (default: 300)
|
||||
progress_callback: Optional callback function(progress: float, message: str) for progress updates
|
||||
|
||||
Returns:
|
||||
bytes: Face-swapped video bytes
|
||||
|
||||
Raises:
|
||||
HTTPException: If the face swap fails
|
||||
"""
|
||||
model_path = "wavespeed-ai/video-face-swap"
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
|
||||
# Build payload
|
||||
payload = {
|
||||
"video": video,
|
||||
"face_image": face_image,
|
||||
}
|
||||
|
||||
if target_gender in ("all", "female", "male"):
|
||||
payload["target_gender"] = target_gender
|
||||
else:
|
||||
payload["target_gender"] = "all" # Default
|
||||
|
||||
if 0 <= target_index <= 10:
|
||||
payload["target_index"] = target_index
|
||||
else:
|
||||
payload["target_index"] = 0 # Default
|
||||
|
||||
logger.info(
|
||||
f"[WaveSpeed] Video face swap request via {url} "
|
||||
f"(target_gender={payload['target_gender']}, target_index={payload['target_index']})"
|
||||
)
|
||||
|
||||
# Submit the task
|
||||
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Video face swap submission failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed video face swap submission failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
if not data or "id" not in data:
|
||||
logger.error(f"[WaveSpeed] Unexpected video face swap response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "WaveSpeed response missing prediction id"},
|
||||
)
|
||||
|
||||
prediction_id = data["id"]
|
||||
logger.info(f"[WaveSpeed] Video face swap submitted: {prediction_id}")
|
||||
|
||||
if enable_sync_mode:
|
||||
# Poll until complete
|
||||
result = self.polling.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=timeout,
|
||||
interval_seconds=2.0,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
# Extract video URL from result
|
||||
outputs = result.get("outputs", [])
|
||||
if not outputs:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "Video face swap completed but no output video found"},
|
||||
)
|
||||
|
||||
# Handle outputs - can be array of strings or array of objects
|
||||
video_url = None
|
||||
if isinstance(outputs[0], str):
|
||||
video_url = outputs[0]
|
||||
elif isinstance(outputs[0], dict):
|
||||
video_url = outputs[0].get("url") or outputs[0].get("video_url")
|
||||
|
||||
if not video_url:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "Video face swap output format not recognized"},
|
||||
)
|
||||
|
||||
# Download video
|
||||
logger.info(f"[WaveSpeed] Downloading face-swapped video from: {video_url}")
|
||||
video_response = requests.get(video_url, timeout=timeout)
|
||||
if video_response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": f"Failed to download face-swapped video: {video_response.status_code}"},
|
||||
)
|
||||
|
||||
video_bytes = video_response.content
|
||||
logger.info(f"[WaveSpeed] Video face swap completed: {len(video_bytes)} bytes")
|
||||
return video_bytes
|
||||
else:
|
||||
# Return prediction ID for async polling
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail={
|
||||
"error": "Async mode not yet implemented for video face swap",
|
||||
"prediction_id": prediction_id,
|
||||
},
|
||||
)
|
||||
333
backend/services/wavespeed/generators/video/generation.py
Normal file
333
backend/services/wavespeed/generators/video/generation.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""
|
||||
Video generation operations (text-to-video and image-to-video).
|
||||
"""
|
||||
|
||||
import requests
|
||||
from typing import Any, Dict, Optional
|
||||
from fastapi import HTTPException
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
from .base import VideoBase
|
||||
|
||||
logger = get_service_logger("wavespeed.generators.video.generation")
|
||||
|
||||
|
||||
class VideoGeneration(VideoBase):
|
||||
"""Video generation operations."""
|
||||
|
||||
def submit_image_to_video(
|
||||
self,
|
||||
model_path: str,
|
||||
payload: Dict[str, Any],
|
||||
timeout: int = 30,
|
||||
) -> str:
|
||||
"""
|
||||
Submit an image-to-video generation request.
|
||||
|
||||
Returns the prediction ID for polling.
|
||||
"""
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
logger.info(f"[WaveSpeed] Submitting request to {url}")
|
||||
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Submission failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed image-to-video submission failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
data = response.json().get("data")
|
||||
if not data or "id" not in data:
|
||||
logger.error(f"[WaveSpeed] Unexpected submission response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "WaveSpeed response missing prediction id"},
|
||||
)
|
||||
|
||||
prediction_id = data["id"]
|
||||
logger.info(f"[WaveSpeed] Submitted request: {prediction_id}")
|
||||
return prediction_id
|
||||
|
||||
def submit_text_to_video(
|
||||
self,
|
||||
model_path: str,
|
||||
payload: Dict[str, Any],
|
||||
timeout: int = 60,
|
||||
) -> str:
|
||||
"""
|
||||
Submit a text-to-video generation request to WaveSpeed.
|
||||
|
||||
Args:
|
||||
model_path: Model path (e.g., "alibaba/wan-2.5/text-to-video")
|
||||
payload: Request payload with prompt, resolution, duration, optional audio
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Prediction ID for polling
|
||||
"""
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
logger.info(f"[WaveSpeed] Submitting text-to-video request to {url}")
|
||||
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Text-to-video submission failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed text-to-video submission failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
data = response.json().get("data")
|
||||
if not data or "id" not in data:
|
||||
logger.error(f"[WaveSpeed] Unexpected text-to-video response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "WaveSpeed response missing prediction id"},
|
||||
)
|
||||
|
||||
prediction_id = data["id"]
|
||||
logger.info(f"[WaveSpeed] Submitted text-to-video request: {prediction_id}")
|
||||
return prediction_id
|
||||
|
||||
def generate_text_video(
|
||||
self,
|
||||
prompt: str,
|
||||
resolution: str = "720p", # 480p, 720p, 1080p
|
||||
duration: int = 5, # 5 or 10 seconds
|
||||
audio_base64: Optional[str] = None, # Optional audio for lip-sync
|
||||
negative_prompt: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
enable_prompt_expansion: bool = True,
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 180,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate video from text prompt using WAN 2.5 text-to-video.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt describing the video
|
||||
resolution: Output resolution (480p, 720p, 1080p)
|
||||
duration: Video duration in seconds (5 or 10)
|
||||
audio_base64: Optional audio file (wav/mp3, 3-30s, ≤15MB) for lip-sync
|
||||
negative_prompt: Optional negative prompt
|
||||
seed: Optional random seed for reproducibility
|
||||
enable_prompt_expansion: Enable prompt optimizer
|
||||
enable_sync_mode: If True, wait for result and return it directly
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Dictionary with video bytes, metadata, and cost
|
||||
"""
|
||||
model_path = "alibaba/wan-2.5/text-to-video"
|
||||
|
||||
# Validate resolution
|
||||
valid_resolutions = ["480p", "720p", "1080p"]
|
||||
if resolution not in valid_resolutions:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid resolution: {resolution}. Must be one of: {valid_resolutions}"
|
||||
)
|
||||
|
||||
# Validate duration
|
||||
if duration not in [5, 10]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Duration must be 5 or 10 seconds"
|
||||
)
|
||||
|
||||
# Build payload
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"resolution": resolution,
|
||||
"duration": duration,
|
||||
"enable_prompt_expansion": enable_prompt_expansion,
|
||||
"enable_sync_mode": enable_sync_mode,
|
||||
}
|
||||
|
||||
# Add optional audio
|
||||
if audio_base64:
|
||||
payload["audio"] = audio_base64
|
||||
|
||||
# Add optional parameters
|
||||
if negative_prompt:
|
||||
payload["negative_prompt"] = negative_prompt
|
||||
if seed is not None:
|
||||
payload["seed"] = seed
|
||||
|
||||
# Submit request
|
||||
logger.info(
|
||||
f"[WaveSpeed] Generating text-to-video: resolution={resolution}, "
|
||||
f"duration={duration}s, prompt_length={len(prompt)}, sync_mode={enable_sync_mode}"
|
||||
)
|
||||
|
||||
# For sync mode, submit and get result directly
|
||||
if enable_sync_mode:
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Text-to-video submission failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed text-to-video submission failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text[:500],
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
# Check status - if "created" or "processing", we need to poll even in sync mode
|
||||
status = data.get("status", "").lower()
|
||||
outputs = data.get("outputs") or []
|
||||
prediction_id = data.get("id")
|
||||
|
||||
logger.debug(
|
||||
f"[WaveSpeed] Sync mode response: status='{status}', outputs_count={len(outputs)}, "
|
||||
f"prediction_id={prediction_id}"
|
||||
)
|
||||
|
||||
# Handle sync mode - result should be directly in outputs
|
||||
if status == "completed" and outputs:
|
||||
# Sync mode returned completed result - use it directly
|
||||
logger.info(f"[WaveSpeed] Got immediate video results from sync mode (status: {status})")
|
||||
video_url = outputs[0]
|
||||
if not isinstance(video_url, str) or not video_url.startswith("http"):
|
||||
logger.error(f"[WaveSpeed] Invalid video URL format in sync mode: {video_url}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"Invalid video URL format: {video_url}",
|
||||
)
|
||||
|
||||
video_bytes = self._download_video(video_url)
|
||||
metadata = data.get("metadata") or {}
|
||||
else:
|
||||
# Sync mode returned "created", "processing", or incomplete status - need to poll
|
||||
if not prediction_id:
|
||||
logger.error(
|
||||
f"[WaveSpeed] Sync mode returned status '{status}' but no prediction ID. "
|
||||
f"Response: {response.text[:500]}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed text-to-video sync mode returned async response without prediction ID",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[WaveSpeed] Sync mode returned status '{status}' with {len(outputs)} output(s). "
|
||||
f"Falling back to polling (prediction_id: {prediction_id})"
|
||||
)
|
||||
|
||||
# Poll for completion
|
||||
try:
|
||||
result = self.polling.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=timeout,
|
||||
interval_seconds=2.0,
|
||||
)
|
||||
except HTTPException as e:
|
||||
detail = e.detail or {}
|
||||
if isinstance(detail, dict):
|
||||
detail.setdefault("prediction_id", prediction_id)
|
||||
detail.setdefault("resume_available", True)
|
||||
raise HTTPException(status_code=e.status_code, detail=detail)
|
||||
|
||||
outputs = result.get("outputs") or []
|
||||
if not outputs:
|
||||
logger.error(f"[WaveSpeed] Polling completed but no outputs: {result}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed text-to-video completed but returned no outputs",
|
||||
)
|
||||
|
||||
video_url = outputs[0]
|
||||
if not isinstance(video_url, str) or not video_url.startswith("http"):
|
||||
logger.error(f"[WaveSpeed] Invalid video URL format after polling: {video_url}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"Invalid video URL format: {video_url}",
|
||||
)
|
||||
|
||||
video_bytes = self._download_video(video_url)
|
||||
metadata = result.get("metadata") or {}
|
||||
else:
|
||||
# Async mode - submit and poll
|
||||
prediction_id = self.submit_text_to_video(model_path, payload, timeout=timeout)
|
||||
|
||||
# Poll for completion
|
||||
try:
|
||||
result = self.polling.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=timeout,
|
||||
interval_seconds=2.0
|
||||
)
|
||||
except HTTPException as e:
|
||||
detail = e.detail or {}
|
||||
if isinstance(detail, dict):
|
||||
detail.setdefault("prediction_id", prediction_id)
|
||||
detail.setdefault("resume_available", True)
|
||||
raise HTTPException(status_code=e.status_code, detail=detail)
|
||||
|
||||
# Extract video URL
|
||||
outputs = result.get("outputs") or []
|
||||
if not outputs:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WAN 2.5 text-to-video completed but returned no outputs"
|
||||
)
|
||||
|
||||
video_url = outputs[0]
|
||||
if not isinstance(video_url, str) or not video_url.startswith("http"):
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"Invalid video URL format: {video_url}"
|
||||
)
|
||||
|
||||
video_bytes = self._download_video(video_url)
|
||||
metadata = result.get("metadata") or {}
|
||||
# prediction_id is already set from earlier in the function
|
||||
|
||||
# Calculate cost (same pricing as image-to-video)
|
||||
pricing = {
|
||||
"480p": 0.05,
|
||||
"720p": 0.10,
|
||||
"1080p": 0.15,
|
||||
}
|
||||
cost = pricing.get(resolution, 0.10) * duration
|
||||
|
||||
# Get video dimensions
|
||||
resolution_dims = {
|
||||
"480p": (854, 480),
|
||||
"720p": (1280, 720),
|
||||
"1080p": (1920, 1080),
|
||||
}
|
||||
width, height = resolution_dims.get(resolution, (1280, 720))
|
||||
|
||||
logger.info(
|
||||
f"[WaveSpeed] ✅ Generated text-to-video: {len(video_bytes)} bytes, "
|
||||
f"resolution={resolution}, duration={duration}s, cost=${cost:.2f}"
|
||||
)
|
||||
|
||||
return {
|
||||
"video_bytes": video_bytes,
|
||||
"prompt": prompt,
|
||||
"duration": float(duration),
|
||||
"model_name": "alibaba/wan-2.5/text-to-video",
|
||||
"cost": cost,
|
||||
"provider": "wavespeed",
|
||||
"source_video_url": video_url,
|
||||
"prediction_id": prediction_id,
|
||||
"resolution": resolution,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"metadata": metadata,
|
||||
}
|
||||
263
backend/services/wavespeed/generators/video/generator.py
Normal file
263
backend/services/wavespeed/generators/video/generator.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""
|
||||
Main VideoGenerator class that composes all video operation modules.
|
||||
|
||||
This class maintains backward compatibility with the original monolithic VideoGenerator
|
||||
by delegating to specialized modules for different video operations.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Callable
|
||||
|
||||
from .base import VideoBase
|
||||
from .generation import VideoGeneration
|
||||
from .enhancement import VideoEnhancement
|
||||
from .extension import VideoExtension
|
||||
from .face_swap import VideoFaceSwap
|
||||
from .translation import VideoTranslation
|
||||
from .background import VideoBackground
|
||||
from .audio import VideoAudio
|
||||
|
||||
|
||||
class VideoGenerator(VideoBase):
|
||||
"""
|
||||
Video generation generator for WaveSpeed API.
|
||||
|
||||
This class composes multiple specialized modules to provide all video operations
|
||||
while maintaining a single unified interface for backward compatibility.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str, polling):
|
||||
"""Initialize video generator.
|
||||
|
||||
Args:
|
||||
api_key: WaveSpeed API key
|
||||
base_url: WaveSpeed API base URL
|
||||
polling: WaveSpeedPolling instance for async operations
|
||||
"""
|
||||
super().__init__(api_key, base_url, polling)
|
||||
|
||||
# Initialize specialized modules
|
||||
self._generation = VideoGeneration(api_key, base_url, polling)
|
||||
self._enhancement = VideoEnhancement(api_key, base_url, polling)
|
||||
self._extension = VideoExtension(api_key, base_url, polling)
|
||||
self._face_swap = VideoFaceSwap(api_key, base_url, polling)
|
||||
self._translation = VideoTranslation(api_key, base_url, polling)
|
||||
self._background = VideoBackground(api_key, base_url, polling)
|
||||
self._audio = VideoAudio(api_key, base_url, polling)
|
||||
|
||||
# Generation methods (delegated to VideoGeneration)
|
||||
def submit_image_to_video(
|
||||
self,
|
||||
model_path: str,
|
||||
payload: Dict[str, Any],
|
||||
timeout: int = 30,
|
||||
) -> str:
|
||||
"""Submit an image-to-video generation request."""
|
||||
return self._generation.submit_image_to_video(model_path, payload, timeout)
|
||||
|
||||
def submit_text_to_video(
|
||||
self,
|
||||
model_path: str,
|
||||
payload: Dict[str, Any],
|
||||
timeout: int = 60,
|
||||
) -> str:
|
||||
"""Submit a text-to-video generation request to WaveSpeed."""
|
||||
return self._generation.submit_text_to_video(model_path, payload, timeout)
|
||||
|
||||
def generate_text_video(
|
||||
self,
|
||||
prompt: str,
|
||||
resolution: str = "720p",
|
||||
duration: int = 5,
|
||||
audio_base64: Optional[str] = None,
|
||||
negative_prompt: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
enable_prompt_expansion: bool = True,
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 180,
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate video from text prompt using WAN 2.5 text-to-video."""
|
||||
return self._generation.generate_text_video(
|
||||
prompt=prompt,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
audio_base64=audio_base64,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=seed,
|
||||
enable_prompt_expansion=enable_prompt_expansion,
|
||||
enable_sync_mode=enable_sync_mode,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Enhancement methods (delegated to VideoEnhancement)
|
||||
def upscale_video(
|
||||
self,
|
||||
video: str,
|
||||
target_resolution: str = "1080p",
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""Upscale video using FlashVSR."""
|
||||
return self._enhancement.upscale_video(
|
||||
video=video,
|
||||
target_resolution=target_resolution,
|
||||
enable_sync_mode=enable_sync_mode,
|
||||
timeout=timeout,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
# Extension methods (delegated to VideoExtension)
|
||||
def extend_video(
|
||||
self,
|
||||
video: str,
|
||||
prompt: str,
|
||||
model: str = "wan-2.5",
|
||||
audio: Optional[str] = None,
|
||||
negative_prompt: Optional[str] = None,
|
||||
resolution: str = "720p",
|
||||
duration: int = 5,
|
||||
enable_prompt_expansion: bool = False,
|
||||
generate_audio: bool = True,
|
||||
camera_fixed: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""Extend video duration using WAN 2.5, WAN 2.2 Spicy, or Seedance 1.5 Pro video-extend."""
|
||||
return self._extension.extend_video(
|
||||
video=video,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
audio=audio,
|
||||
negative_prompt=negative_prompt,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
enable_prompt_expansion=enable_prompt_expansion,
|
||||
generate_audio=generate_audio,
|
||||
camera_fixed=camera_fixed,
|
||||
seed=seed,
|
||||
enable_sync_mode=enable_sync_mode,
|
||||
timeout=timeout,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
# Face swap methods (delegated to VideoFaceSwap)
|
||||
def face_swap(
|
||||
self,
|
||||
image: str,
|
||||
video: str,
|
||||
prompt: Optional[str] = None,
|
||||
resolution: str = "480p",
|
||||
seed: Optional[int] = None,
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""Perform face/character swap using MoCha (wavespeed-ai/wan-2.1/mocha)."""
|
||||
return self._face_swap.face_swap(
|
||||
image=image,
|
||||
video=video,
|
||||
prompt=prompt,
|
||||
resolution=resolution,
|
||||
seed=seed,
|
||||
enable_sync_mode=enable_sync_mode,
|
||||
timeout=timeout,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
def video_face_swap(
|
||||
self,
|
||||
video: str,
|
||||
face_image: str,
|
||||
target_gender: str = "all",
|
||||
target_index: int = 0,
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""Perform face swap using Video Face Swap (wavespeed-ai/video-face-swap)."""
|
||||
return self._face_swap.video_face_swap(
|
||||
video=video,
|
||||
face_image=face_image,
|
||||
target_gender=target_gender,
|
||||
target_index=target_index,
|
||||
enable_sync_mode=enable_sync_mode,
|
||||
timeout=timeout,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
# Translation methods (delegated to VideoTranslation)
|
||||
def video_translate(
|
||||
self,
|
||||
video: str,
|
||||
output_language: str = "English",
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 600,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""Translate video to target language using HeyGen Video Translate."""
|
||||
return self._translation.video_translate(
|
||||
video=video,
|
||||
output_language=output_language,
|
||||
enable_sync_mode=enable_sync_mode,
|
||||
timeout=timeout,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
# Background methods (delegated to VideoBackground)
|
||||
def remove_background(
|
||||
self,
|
||||
video: str,
|
||||
background_image: Optional[str] = None,
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""Remove or replace video background using Video Background Remover."""
|
||||
return self._background.remove_background(
|
||||
video=video,
|
||||
background_image=background_image,
|
||||
enable_sync_mode=enable_sync_mode,
|
||||
timeout=timeout,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
# Audio methods (delegated to VideoAudio)
|
||||
def hunyuan_video_foley(
|
||||
self,
|
||||
video: str,
|
||||
prompt: Optional[str] = None,
|
||||
seed: int = -1,
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""Generate realistic Foley and ambient audio from video using Hunyuan Video Foley."""
|
||||
return self._audio.hunyuan_video_foley(
|
||||
video=video,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
enable_sync_mode=enable_sync_mode,
|
||||
timeout=timeout,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
def think_sound(
|
||||
self,
|
||||
video: str,
|
||||
prompt: Optional[str] = None,
|
||||
seed: int = -1,
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""Generate realistic sound effects and audio tracks from video using Think Sound."""
|
||||
return self._audio.think_sound(
|
||||
video=video,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
enable_sync_mode=enable_sync_mode,
|
||||
timeout=timeout,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
133
backend/services/wavespeed/generators/video/translation.py
Normal file
133
backend/services/wavespeed/generators/video/translation.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
Video translation operations.
|
||||
"""
|
||||
|
||||
import requests
|
||||
from typing import Optional, Callable
|
||||
from fastapi import HTTPException
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
from .base import VideoBase
|
||||
|
||||
logger = get_service_logger("wavespeed.generators.video.translation")
|
||||
|
||||
|
||||
class VideoTranslation(VideoBase):
|
||||
"""Video translation operations."""
|
||||
|
||||
def video_translate(
|
||||
self,
|
||||
video: str, # Base64-encoded video or URL
|
||||
output_language: str = "English",
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 600,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Translate video to target language using HeyGen Video Translate.
|
||||
|
||||
Args:
|
||||
video: Base64-encoded video data URI or public URL (source video)
|
||||
output_language: Target language for translation (default: "English")
|
||||
enable_sync_mode: If True, wait for result and return it directly
|
||||
timeout: Request timeout in seconds (default: 600)
|
||||
progress_callback: Optional callback function(progress: float, message: str) for progress updates
|
||||
|
||||
Returns:
|
||||
bytes: Translated video bytes
|
||||
|
||||
Raises:
|
||||
HTTPException: If the video translation fails
|
||||
"""
|
||||
model_path = "heygen/video-translate"
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
|
||||
# Build payload
|
||||
payload = {
|
||||
"video": video,
|
||||
"output_language": output_language,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"[WaveSpeed] Video translate request via {url} "
|
||||
f"(output_language={output_language})"
|
||||
)
|
||||
|
||||
# Submit the task
|
||||
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Video translate submission failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed video translate submission failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
if not data or "id" not in data:
|
||||
logger.error(f"[WaveSpeed] Unexpected video translate response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "WaveSpeed response missing prediction id"},
|
||||
)
|
||||
|
||||
prediction_id = data["id"]
|
||||
logger.info(f"[WaveSpeed] Video translate submitted: {prediction_id}")
|
||||
|
||||
if enable_sync_mode:
|
||||
# Poll until complete
|
||||
result = self.polling.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=timeout,
|
||||
interval_seconds=2.0,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
# Extract video URL from result
|
||||
outputs = result.get("outputs", [])
|
||||
if not outputs:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "Video translate completed but no output video found"},
|
||||
)
|
||||
|
||||
# Handle outputs - can be array of strings or array of objects
|
||||
video_url = None
|
||||
if isinstance(outputs[0], str):
|
||||
video_url = outputs[0]
|
||||
elif isinstance(outputs[0], dict):
|
||||
video_url = outputs[0].get("url") or outputs[0].get("video_url")
|
||||
|
||||
if not video_url:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "Video translate output format not recognized"},
|
||||
)
|
||||
|
||||
# Download video
|
||||
logger.info(f"[WaveSpeed] Downloading translated video from: {video_url}")
|
||||
video_response = requests.get(video_url, timeout=timeout)
|
||||
if video_response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": f"Failed to download translated video: {video_response.status_code}"},
|
||||
)
|
||||
|
||||
video_bytes = video_response.content
|
||||
logger.info(f"[WaveSpeed] Video translate completed: {len(video_bytes)} bytes")
|
||||
return video_bytes
|
||||
else:
|
||||
# Return prediction ID for async polling
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail={
|
||||
"error": "Async mode not yet implemented for video translate",
|
||||
"prediction_id": prediction_id,
|
||||
},
|
||||
)
|
||||
Reference in New Issue
Block a user