AI Image Studio, AI podcast Maker, AI product Marketing
This commit is contained in:
@@ -38,7 +38,7 @@ class ContentAssetService:
|
||||
description: Optional[str] = None,
|
||||
prompt: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
asset_metadata: Optional[Dict[str, Any]] = None,
|
||||
provider: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
cost: Optional[float] = None,
|
||||
@@ -60,7 +60,7 @@ class ContentAssetService:
|
||||
description: Asset description (optional)
|
||||
prompt: Generation prompt (optional)
|
||||
tags: List of tags (optional)
|
||||
metadata: Additional metadata (optional)
|
||||
asset_metadata: Additional metadata (optional)
|
||||
provider: AI provider used (optional)
|
||||
model: Model used (optional)
|
||||
cost: Generation cost (optional)
|
||||
@@ -83,7 +83,7 @@ class ContentAssetService:
|
||||
description=description,
|
||||
prompt=prompt,
|
||||
tags=tags or [],
|
||||
metadata=metadata or {},
|
||||
asset_metadata=asset_metadata or {},
|
||||
provider=provider,
|
||||
model=model,
|
||||
cost=cost or 0.0,
|
||||
@@ -222,7 +222,7 @@ class ContentAssetService:
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
asset_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[ContentAsset]:
|
||||
"""Update asset metadata."""
|
||||
try:
|
||||
@@ -236,8 +236,8 @@ class ContentAssetService:
|
||||
asset.description = description
|
||||
if tags is not None:
|
||||
asset.tags = tags
|
||||
if metadata is not None:
|
||||
asset.metadata = {**(asset.metadata or {}), **metadata}
|
||||
if asset_metadata is not None:
|
||||
asset.asset_metadata = {**(asset.asset_metadata or {}), **asset_metadata}
|
||||
|
||||
asset.updated_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
|
||||
@@ -21,6 +21,10 @@ from models.persona_models import Base as PersonaBase
|
||||
from models.subscription_models import Base as SubscriptionBase
|
||||
from models.user_business_info import Base as UserBusinessInfoBase
|
||||
from models.content_asset_models import Base as ContentAssetBase
|
||||
# Product Marketing models use SubscriptionBase, but import to ensure models are registered
|
||||
from models.product_marketing_models import Campaign, CampaignProposal, CampaignAsset
|
||||
# Product Asset models (Product Marketing Suite - product assets, not campaigns)
|
||||
from models.product_asset_models import ProductAsset, ProductStyleTemplate, EcommerceExport
|
||||
|
||||
# Database configuration
|
||||
DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///./alwrity.db')
|
||||
@@ -73,10 +77,10 @@ def init_database():
|
||||
EnhancedStrategyBase.metadata.create_all(bind=engine)
|
||||
MonitoringBase.metadata.create_all(bind=engine)
|
||||
PersonaBase.metadata.create_all(bind=engine)
|
||||
SubscriptionBase.metadata.create_all(bind=engine)
|
||||
SubscriptionBase.metadata.create_all(bind=engine) # Includes product_marketing models
|
||||
UserBusinessInfoBase.metadata.create_all(bind=engine)
|
||||
ContentAssetBase.metadata.create_all(bind=engine)
|
||||
logger.info("Database initialized successfully with all models including subscription system, business info, and content assets")
|
||||
logger.info("Database initialized successfully with all models including subscription system, product marketing, business info, and content assets")
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error initializing database: {str(e)}")
|
||||
raise
|
||||
|
||||
@@ -6,6 +6,11 @@ from .edit_service import EditStudioService, EditStudioRequest
|
||||
from .upscale_service import UpscaleStudioService, UpscaleStudioRequest
|
||||
from .control_service import ControlStudioService, ControlStudioRequest
|
||||
from .social_optimizer_service import SocialOptimizerService, SocialOptimizerRequest
|
||||
from .transform_service import (
|
||||
TransformStudioService,
|
||||
TransformImageToVideoRequest,
|
||||
TalkingAvatarRequest,
|
||||
)
|
||||
from .templates import PlatformTemplates, TemplateManager
|
||||
|
||||
__all__ = [
|
||||
@@ -20,6 +25,9 @@ __all__ = [
|
||||
"ControlStudioRequest",
|
||||
"SocialOptimizerService",
|
||||
"SocialOptimizerRequest",
|
||||
"TransformStudioService",
|
||||
"TransformImageToVideoRequest",
|
||||
"TalkingAvatarRequest",
|
||||
"PlatformTemplates",
|
||||
"TemplateManager",
|
||||
]
|
||||
|
||||
155
backend/services/image_studio/infinitetalk_adapter.py
Normal file
155
backend/services/image_studio/infinitetalk_adapter.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""InfiniteTalk adapter for Transform Studio."""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from services.wavespeed.infinitetalk import animate_scene_with_voiceover
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("image_studio.infinitetalk")
|
||||
|
||||
|
||||
class InfiniteTalkService:
|
||||
"""Adapter for InfiniteTalk in Transform Studio context."""
|
||||
|
||||
def __init__(self, client: Optional[WaveSpeedClient] = None):
|
||||
"""Initialize InfiniteTalk service adapter."""
|
||||
self.client = client or WaveSpeedClient()
|
||||
logger.info("[InfiniteTalk Adapter] Service initialized")
|
||||
|
||||
def calculate_cost(self, resolution: str, duration: float) -> float:
|
||||
"""Calculate cost for InfiniteTalk video.
|
||||
|
||||
Args:
|
||||
resolution: Output resolution (480p or 720p)
|
||||
duration: Video duration in seconds
|
||||
|
||||
Returns:
|
||||
Cost in USD
|
||||
"""
|
||||
# InfiniteTalk pricing: $0.03/s (480p) or $0.06/s (720p)
|
||||
# Minimum charge: 5 seconds
|
||||
cost_per_second = 0.03 if resolution == "480p" else 0.06
|
||||
actual_duration = max(5.0, duration) # Minimum 5 seconds
|
||||
return cost_per_second * actual_duration
|
||||
|
||||
async def create_talking_avatar(
|
||||
self,
|
||||
image_base64: str,
|
||||
audio_base64: str,
|
||||
resolution: str = "720p",
|
||||
prompt: Optional[str] = None,
|
||||
mask_image_base64: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
user_id: str = "transform_studio",
|
||||
) -> Dict[str, Any]:
|
||||
"""Create talking avatar video using InfiniteTalk.
|
||||
|
||||
Args:
|
||||
image_base64: Person image in base64 or data URI
|
||||
audio_base64: Audio file in base64 or data URI
|
||||
resolution: Output resolution (480p or 720p)
|
||||
prompt: Optional prompt for expression/style
|
||||
mask_image_base64: Optional mask for animatable regions
|
||||
seed: Optional random seed
|
||||
user_id: User ID for tracking
|
||||
|
||||
Returns:
|
||||
Dictionary with video bytes, metadata, and cost
|
||||
"""
|
||||
# Validate resolution
|
||||
if resolution not in ["480p", "720p"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Resolution must be '480p' or '720p' for InfiniteTalk"
|
||||
)
|
||||
|
||||
# Decode image
|
||||
import base64
|
||||
try:
|
||||
if image_base64.startswith("data:"):
|
||||
if "," not in image_base64:
|
||||
raise ValueError("Invalid data URI format: missing comma separator")
|
||||
header, encoded = image_base64.split(",", 1)
|
||||
mime_parts = header.split(":")[1].split(";")[0] if ":" in header else "image/png"
|
||||
image_mime = mime_parts.strip() or "image/png"
|
||||
image_bytes = base64.b64decode(encoded)
|
||||
else:
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_mime = "image/png"
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to decode image: {str(e)}"
|
||||
)
|
||||
|
||||
# Decode audio
|
||||
try:
|
||||
if audio_base64.startswith("data:"):
|
||||
if "," not in audio_base64:
|
||||
raise ValueError("Invalid data URI format: missing comma separator")
|
||||
header, encoded = audio_base64.split(",", 1)
|
||||
mime_parts = header.split(":")[1].split(";")[0] if ":" in header else "audio/mpeg"
|
||||
audio_mime = mime_parts.strip() or "audio/mpeg"
|
||||
audio_bytes = base64.b64decode(encoded)
|
||||
else:
|
||||
audio_bytes = base64.b64decode(audio_base64)
|
||||
audio_mime = "audio/mpeg"
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to decode audio: {str(e)}"
|
||||
)
|
||||
|
||||
# Call existing InfiniteTalk function (run in thread since it's synchronous)
|
||||
# Note: We pass empty dicts for scene_data and story_context since
|
||||
# Transform Studio doesn't have story context
|
||||
try:
|
||||
result = await asyncio.to_thread(
|
||||
animate_scene_with_voiceover,
|
||||
image_bytes=image_bytes,
|
||||
audio_bytes=audio_bytes,
|
||||
scene_data={}, # Empty for Transform Studio
|
||||
story_context={}, # Empty for Transform Studio
|
||||
user_id=user_id,
|
||||
resolution=resolution,
|
||||
prompt_override=prompt,
|
||||
image_mime=image_mime,
|
||||
audio_mime=audio_mime,
|
||||
client=self.client,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[InfiniteTalk Adapter] Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"InfiniteTalk generation failed: {str(e)}"
|
||||
)
|
||||
|
||||
# Calculate actual cost based on duration
|
||||
actual_cost = self.calculate_cost(resolution, result.get("duration", 5.0))
|
||||
|
||||
# Update result with actual cost and additional metadata
|
||||
result["cost"] = actual_cost
|
||||
result["resolution"] = resolution
|
||||
|
||||
# Get video dimensions from resolution
|
||||
resolution_dims = {
|
||||
"480p": (854, 480),
|
||||
"720p": (1280, 720),
|
||||
}
|
||||
width, height = resolution_dims.get(resolution, (1280, 720))
|
||||
result["width"] = width
|
||||
result["height"] = height
|
||||
|
||||
logger.info(
|
||||
f"[InfiniteTalk Adapter] ✅ Generated talking avatar: "
|
||||
f"resolution={resolution}, duration={result.get('duration', 5.0)}s, cost=${actual_cost:.2f}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -7,6 +7,11 @@ from .edit_service import EditStudioService, EditStudioRequest
|
||||
from .upscale_service import UpscaleStudioService, UpscaleStudioRequest
|
||||
from .control_service import ControlStudioService, ControlStudioRequest
|
||||
from .social_optimizer_service import SocialOptimizerService, SocialOptimizerRequest
|
||||
from .transform_service import (
|
||||
TransformStudioService,
|
||||
TransformImageToVideoRequest,
|
||||
TalkingAvatarRequest,
|
||||
)
|
||||
from .templates import Platform, TemplateCategory, ImageTemplate
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
@@ -24,6 +29,7 @@ class ImageStudioManager:
|
||||
self.upscale_service = UpscaleStudioService()
|
||||
self.control_service = ControlStudioService()
|
||||
self.social_optimizer_service = SocialOptimizerService()
|
||||
self.transform_service = TransformStudioService()
|
||||
logger.info("[Image Studio Manager] Initialized successfully")
|
||||
|
||||
# ====================
|
||||
@@ -339,4 +345,35 @@ class ImageStudioManager:
|
||||
}
|
||||
|
||||
return specs.get(platform, {})
|
||||
|
||||
# ====================
|
||||
# TRANSFORM STUDIO
|
||||
# ====================
|
||||
|
||||
async def transform_image_to_video(
|
||||
self,
|
||||
request: TransformImageToVideoRequest,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Transform image to video using WAN 2.5."""
|
||||
logger.info("[Image Studio] Transform image-to-video request from user: %s", user_id)
|
||||
return await self.transform_service.transform_image_to_video(request, user_id=user_id or "anonymous")
|
||||
|
||||
async def create_talking_avatar(
|
||||
self,
|
||||
request: TalkingAvatarRequest,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Create talking avatar using InfiniteTalk."""
|
||||
logger.info("[Image Studio] Talking avatar request from user: %s", user_id)
|
||||
return await self.transform_service.create_talking_avatar(request, user_id=user_id or "anonymous")
|
||||
|
||||
def estimate_transform_cost(
|
||||
self,
|
||||
operation: str,
|
||||
resolution: str,
|
||||
duration: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Estimate cost for transform operation."""
|
||||
return self.transform_service.estimate_cost(operation, resolution, duration)
|
||||
|
||||
|
||||
379
backend/services/image_studio/transform_service.py
Normal file
379
backend/services/image_studio/transform_service.py
Normal file
@@ -0,0 +1,379 @@
|
||||
"""Transform Studio service for image-to-video and talking avatar generation."""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from .wan25_service import WAN25Service
|
||||
from .infinitetalk_adapter import InfiniteTalkService
|
||||
from services.llm_providers.main_video_generation import track_video_usage
|
||||
from utils.logger_utils import get_service_logger
|
||||
from utils.file_storage import save_file_safely, sanitize_filename
|
||||
|
||||
logger = get_service_logger("image_studio.transform")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformImageToVideoRequest:
|
||||
"""Request for WAN 2.5 image-to-video."""
|
||||
image_base64: str
|
||||
prompt: str
|
||||
audio_base64: Optional[str] = None
|
||||
resolution: str = "720p" # 480p, 720p, 1080p
|
||||
duration: int = 5 # 5 or 10 seconds
|
||||
negative_prompt: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
enable_prompt_expansion: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class TalkingAvatarRequest:
|
||||
"""Request for InfiniteTalk talking avatar."""
|
||||
image_base64: str
|
||||
audio_base64: str
|
||||
resolution: str = "720p" # 480p or 720p
|
||||
prompt: Optional[str] = None
|
||||
mask_image_base64: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
|
||||
class TransformStudioService:
|
||||
"""Service for Transform Studio operations."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Transform Studio service."""
|
||||
self.wan25_service = WAN25Service()
|
||||
self.infinitetalk_service = InfiniteTalkService()
|
||||
|
||||
# Video output directory
|
||||
# __file__ is: backend/services/image_studio/transform_service.py
|
||||
# We need: backend/transform_videos
|
||||
base_dir = Path(__file__).parent.parent.parent.parent
|
||||
self.output_dir = base_dir / "transform_videos"
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Verify directory was created
|
||||
if not self.output_dir.exists():
|
||||
raise RuntimeError(f"Failed to create transform_videos directory: {self.output_dir}")
|
||||
|
||||
logger.info(f"[Transform Studio] Initialized with output directory: {self.output_dir}")
|
||||
|
||||
def _save_video_file(
|
||||
self,
|
||||
video_bytes: bytes,
|
||||
operation_type: str,
|
||||
user_id: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Save video file to disk.
|
||||
|
||||
Args:
|
||||
video_bytes: Video content as bytes
|
||||
operation_type: Type of operation (e.g., "image-to-video", "talking-avatar")
|
||||
user_id: User ID for directory organization
|
||||
|
||||
Returns:
|
||||
Dictionary with filename, file_path, and file_url
|
||||
"""
|
||||
# Create user-specific directory
|
||||
user_dir = self.output_dir / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate filename
|
||||
filename = f"{operation_type}_{uuid.uuid4().hex[:8]}.mp4"
|
||||
filename = sanitize_filename(filename)
|
||||
|
||||
# Save file
|
||||
file_path, error = save_file_safely(
|
||||
content=video_bytes,
|
||||
directory=user_dir,
|
||||
filename=filename,
|
||||
max_file_size=500 * 1024 * 1024 # 500MB max for videos
|
||||
)
|
||||
|
||||
if error:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to save video file: {error}"
|
||||
)
|
||||
|
||||
file_url = f"/api/image-studio/videos/{user_id}/{filename}"
|
||||
|
||||
return {
|
||||
"filename": filename,
|
||||
"file_path": str(file_path),
|
||||
"file_url": file_url,
|
||||
"file_size": len(video_bytes),
|
||||
}
|
||||
|
||||
async def transform_image_to_video(
|
||||
self,
|
||||
request: TransformImageToVideoRequest,
|
||||
user_id: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Transform image to video using WAN 2.5.
|
||||
|
||||
Args:
|
||||
request: Transform request
|
||||
user_id: User ID for tracking and file organization
|
||||
|
||||
Returns:
|
||||
Dictionary with video URL, metadata, and cost
|
||||
"""
|
||||
logger.info(
|
||||
f"[Transform Studio] Image-to-video request from user {user_id}: "
|
||||
f"resolution={request.resolution}, duration={request.duration}s"
|
||||
)
|
||||
|
||||
# Generate video using WAN 2.5
|
||||
result = await self.wan25_service.generate_video(
|
||||
image_base64=request.image_base64,
|
||||
prompt=request.prompt,
|
||||
audio_base64=request.audio_base64,
|
||||
resolution=request.resolution,
|
||||
duration=request.duration,
|
||||
negative_prompt=request.negative_prompt,
|
||||
seed=request.seed,
|
||||
enable_prompt_expansion=request.enable_prompt_expansion,
|
||||
)
|
||||
|
||||
# Save video to disk
|
||||
save_result = self._save_video_file(
|
||||
video_bytes=result["video_bytes"],
|
||||
operation_type="image-to-video",
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Track usage
|
||||
try:
|
||||
usage_info = track_video_usage(
|
||||
user_id=user_id,
|
||||
provider=result["provider"],
|
||||
model_name=result["model_name"],
|
||||
prompt=result["prompt"],
|
||||
video_bytes=result["video_bytes"],
|
||||
cost_override=result["cost"],
|
||||
)
|
||||
logger.info(
|
||||
f"[Transform Studio] Usage tracked: {usage_info.get('current_calls', 0)} / "
|
||||
f"{usage_info.get('video_limit_display', '∞')} videos, "
|
||||
f"cost=${result['cost']:.2f}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Transform Studio] Failed to track usage: {e}")
|
||||
|
||||
# Save to asset library
|
||||
try:
|
||||
from services.database import get_db
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="video",
|
||||
source_module="image_studio",
|
||||
filename=save_result["filename"],
|
||||
file_url=save_result["file_url"],
|
||||
file_path=save_result["file_path"],
|
||||
file_size=save_result["file_size"],
|
||||
mime_type="video/mp4",
|
||||
title=f"Transform: Image-to-Video ({request.resolution})",
|
||||
description=f"Generated video using WAN 2.5: {request.prompt[:100]}",
|
||||
prompt=result["prompt"],
|
||||
tags=["image_studio", "transform", "video", "image-to-video", request.resolution],
|
||||
provider=result["provider"],
|
||||
model=result["model_name"],
|
||||
cost=result["cost"],
|
||||
asset_metadata={
|
||||
"resolution": request.resolution,
|
||||
"duration": result["duration"],
|
||||
"operation": "image-to-video",
|
||||
"width": result["width"],
|
||||
"height": result["height"],
|
||||
}
|
||||
)
|
||||
logger.info(f"[Transform Studio] Video saved to asset library")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"[Transform Studio] Failed to save to asset library: {e}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"video_url": save_result["file_url"],
|
||||
"video_base64": None, # Don't include base64 for large videos
|
||||
"duration": result["duration"],
|
||||
"resolution": result["resolution"],
|
||||
"width": result["width"],
|
||||
"height": result["height"],
|
||||
"file_size": save_result["file_size"],
|
||||
"cost": result["cost"],
|
||||
"provider": result["provider"],
|
||||
"model": result["model_name"],
|
||||
"metadata": result.get("metadata", {}),
|
||||
}
|
||||
|
||||
async def create_talking_avatar(
|
||||
self,
|
||||
request: TalkingAvatarRequest,
|
||||
user_id: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Create talking avatar using InfiniteTalk.
|
||||
|
||||
Args:
|
||||
request: Talking avatar request
|
||||
user_id: User ID for tracking and file organization
|
||||
|
||||
Returns:
|
||||
Dictionary with video URL, metadata, and cost
|
||||
"""
|
||||
logger.info(
|
||||
f"[Transform Studio] Talking avatar request from user {user_id}: "
|
||||
f"resolution={request.resolution}"
|
||||
)
|
||||
|
||||
# Generate video using InfiniteTalk
|
||||
result = await self.infinitetalk_service.create_talking_avatar(
|
||||
image_base64=request.image_base64,
|
||||
audio_base64=request.audio_base64,
|
||||
resolution=request.resolution,
|
||||
prompt=request.prompt,
|
||||
mask_image_base64=request.mask_image_base64,
|
||||
seed=request.seed,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Save video to disk
|
||||
save_result = self._save_video_file(
|
||||
video_bytes=result["video_bytes"],
|
||||
operation_type="talking-avatar",
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Track usage
|
||||
try:
|
||||
usage_info = track_video_usage(
|
||||
user_id=user_id,
|
||||
provider=result["provider"],
|
||||
model_name=result["model_name"],
|
||||
prompt=result.get("prompt", ""),
|
||||
video_bytes=result["video_bytes"],
|
||||
cost_override=result["cost"],
|
||||
)
|
||||
logger.info(
|
||||
f"[Transform Studio] Usage tracked: {usage_info.get('current_calls', 0)} / "
|
||||
f"{usage_info.get('video_limit_display', '∞')} videos, "
|
||||
f"cost=${result['cost']:.2f}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Transform Studio] Failed to track usage: {e}")
|
||||
|
||||
# Save to asset library
|
||||
try:
|
||||
from services.database import get_db
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="video",
|
||||
source_module="image_studio",
|
||||
filename=save_result["filename"],
|
||||
file_url=save_result["file_url"],
|
||||
file_path=save_result["file_path"],
|
||||
file_size=save_result["file_size"],
|
||||
mime_type="video/mp4",
|
||||
title=f"Transform: Talking Avatar ({request.resolution})",
|
||||
description="Generated talking avatar video using InfiniteTalk",
|
||||
prompt=result.get("prompt", ""),
|
||||
tags=["image_studio", "transform", "video", "talking-avatar", request.resolution],
|
||||
provider=result["provider"],
|
||||
model=result["model_name"],
|
||||
cost=result["cost"],
|
||||
asset_metadata={
|
||||
"resolution": request.resolution,
|
||||
"duration": result.get("duration", 5.0),
|
||||
"operation": "talking-avatar",
|
||||
"width": result.get("width", 1280),
|
||||
"height": result.get("height", 720),
|
||||
}
|
||||
)
|
||||
logger.info(f"[Transform Studio] Video saved to asset library")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"[Transform Studio] Failed to save to asset library: {e}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"video_url": save_result["file_url"],
|
||||
"video_base64": None, # Don't include base64 for large videos
|
||||
"duration": result.get("duration", 5.0),
|
||||
"resolution": result.get("resolution", request.resolution),
|
||||
"width": result.get("width", 1280),
|
||||
"height": result.get("height", 720),
|
||||
"file_size": save_result["file_size"],
|
||||
"cost": result["cost"],
|
||||
"provider": result["provider"],
|
||||
"model": result["model_name"],
|
||||
"metadata": result.get("metadata", {}),
|
||||
}
|
||||
|
||||
def estimate_cost(
|
||||
self,
|
||||
operation: str,
|
||||
resolution: str,
|
||||
duration: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Estimate cost for transform operation.
|
||||
|
||||
Args:
|
||||
operation: Operation type ("image-to-video" or "talking-avatar")
|
||||
resolution: Output resolution
|
||||
duration: Video duration in seconds (for image-to-video)
|
||||
|
||||
Returns:
|
||||
Cost estimation details
|
||||
"""
|
||||
if operation == "image-to-video":
|
||||
if duration is None:
|
||||
duration = 5
|
||||
cost = self.wan25_service.calculate_cost(resolution, duration)
|
||||
return {
|
||||
"estimated_cost": cost,
|
||||
"breakdown": {
|
||||
"base_cost": 0.0,
|
||||
"per_second": self.wan25_service.calculate_cost(resolution, 1),
|
||||
"duration": duration,
|
||||
"total": cost,
|
||||
},
|
||||
"currency": "USD",
|
||||
"provider": "wavespeed",
|
||||
"model": "alibaba/wan-2.5/image-to-video",
|
||||
}
|
||||
elif operation == "talking-avatar":
|
||||
# InfiniteTalk minimum is 5 seconds
|
||||
estimated_duration = duration or 5.0
|
||||
cost = self.infinitetalk_service.calculate_cost(resolution, estimated_duration)
|
||||
return {
|
||||
"estimated_cost": cost,
|
||||
"breakdown": {
|
||||
"base_cost": 0.0,
|
||||
"per_second": self.infinitetalk_service.calculate_cost(resolution, 1.0),
|
||||
"duration": estimated_duration,
|
||||
"total": cost,
|
||||
},
|
||||
"currency": "USD",
|
||||
"provider": "wavespeed",
|
||||
"model": "wavespeed-ai/infinitetalk",
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown operation: {operation}")
|
||||
|
||||
295
backend/services/image_studio/wan25_service.py
Normal file
295
backend/services/image_studio/wan25_service.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""WAN 2.5 service for Alibaba image-to-video generation via WaveSpeed."""
|
||||
|
||||
import base64
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("image_studio.wan25")
|
||||
|
||||
WAN25_MODEL_PATH = "alibaba/wan-2.5/image-to-video"
|
||||
WAN25_MODEL_NAME = "alibaba/wan-2.5/image-to-video"
|
||||
|
||||
# Pricing per second (from WaveSpeed docs)
|
||||
PRICING = {
|
||||
"480p": 0.05, # $0.05 per second
|
||||
"720p": 0.10, # $0.10 per second
|
||||
"1080p": 0.15, # $0.15 per second
|
||||
}
|
||||
|
||||
MAX_IMAGE_BYTES = 10 * 1024 * 1024 # 10MB (recommended)
|
||||
MAX_AUDIO_BYTES = 15 * 1024 * 1024 # 15MB (API limit)
|
||||
MIN_AUDIO_DURATION = 3 # seconds
|
||||
MAX_AUDIO_DURATION = 30 # seconds
|
||||
|
||||
|
||||
def _as_data_uri(content_bytes: bytes, mime_type: str) -> str:
|
||||
"""Convert bytes to data URI."""
|
||||
encoded = base64.b64encode(content_bytes).decode("utf-8")
|
||||
return f"data:{mime_type};base64,{encoded}"
|
||||
|
||||
|
||||
def _decode_base64_image(image_base64: str) -> tuple[bytes, str]:
|
||||
"""Decode base64 image, handling data URIs."""
|
||||
if image_base64.startswith("data:"):
|
||||
# Extract mime type and base64 data
|
||||
if "," not in image_base64:
|
||||
raise ValueError("Invalid data URI format: missing comma separator")
|
||||
header, encoded = image_base64.split(",", 1)
|
||||
mime_parts = header.split(":")[1].split(";")[0] if ":" in header else "image/png"
|
||||
mime_type = mime_parts.strip()
|
||||
if not mime_type:
|
||||
mime_type = "image/png"
|
||||
image_bytes = base64.b64decode(encoded)
|
||||
else:
|
||||
# Assume it's raw base64
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
mime_type = "image/png" # Default
|
||||
|
||||
return image_bytes, mime_type
|
||||
|
||||
|
||||
def _decode_base64_audio(audio_base64: str) -> tuple[bytes, str]:
|
||||
"""Decode base64 audio, handling data URIs."""
|
||||
if audio_base64.startswith("data:"):
|
||||
if "," not in audio_base64:
|
||||
raise ValueError("Invalid data URI format: missing comma separator")
|
||||
header, encoded = audio_base64.split(",", 1)
|
||||
mime_parts = header.split(":")[1].split(";")[0] if ":" in header else "audio/mpeg"
|
||||
mime_type = mime_parts.strip()
|
||||
if not mime_type:
|
||||
mime_type = "audio/mpeg"
|
||||
audio_bytes = base64.b64decode(encoded)
|
||||
else:
|
||||
audio_bytes = base64.b64decode(audio_base64)
|
||||
mime_type = "audio/mpeg" # Default
|
||||
|
||||
return audio_bytes, mime_type
|
||||
|
||||
|
||||
class WAN25Service:
|
||||
"""Service for Alibaba WAN 2.5 image-to-video generation."""
|
||||
|
||||
def __init__(self, client: Optional[WaveSpeedClient] = None):
|
||||
"""Initialize WAN 2.5 service."""
|
||||
self.client = client or WaveSpeedClient()
|
||||
logger.info("[WAN 2.5] Service initialized")
|
||||
|
||||
def calculate_cost(self, resolution: str, duration: int) -> float:
|
||||
"""Calculate cost for video generation.
|
||||
|
||||
Args:
|
||||
resolution: Output resolution (480p, 720p, 1080p)
|
||||
duration: Video duration in seconds (5 or 10)
|
||||
|
||||
Returns:
|
||||
Cost in USD
|
||||
"""
|
||||
cost_per_second = PRICING.get(resolution, PRICING["720p"])
|
||||
return cost_per_second * duration
|
||||
|
||||
async def generate_video(
|
||||
self,
|
||||
image_base64: str,
|
||||
prompt: str,
|
||||
audio_base64: Optional[str] = None,
|
||||
resolution: str = "720p",
|
||||
duration: int = 5,
|
||||
negative_prompt: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
enable_prompt_expansion: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate video using WAN 2.5.
|
||||
|
||||
Args:
|
||||
image_base64: Image in base64 or data URI format
|
||||
prompt: Text prompt describing the video
|
||||
audio_base64: Optional audio file (wav/mp3, 3-30s, ≤15MB)
|
||||
resolution: Output resolution (480p, 720p, 1080p)
|
||||
duration: Video duration in seconds (5 or 10)
|
||||
negative_prompt: Optional negative prompt
|
||||
seed: Optional random seed for reproducibility
|
||||
enable_prompt_expansion: Enable prompt optimizer
|
||||
|
||||
Returns:
|
||||
Dictionary with video bytes, metadata, and cost
|
||||
"""
|
||||
# Validate resolution
|
||||
if resolution not in PRICING:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid resolution: {resolution}. Must be one of: {list(PRICING.keys())}"
|
||||
)
|
||||
|
||||
# Validate duration
|
||||
if duration not in [5, 10]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid duration: {duration}. Must be 5 or 10 seconds"
|
||||
)
|
||||
|
||||
# Validate prompt
|
||||
if not prompt or not prompt.strip():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Prompt is required and cannot be empty"
|
||||
)
|
||||
|
||||
# Decode image
|
||||
try:
|
||||
image_bytes, image_mime = _decode_base64_image(image_base64)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to decode image: {str(e)}"
|
||||
)
|
||||
|
||||
# Validate image size
|
||||
if len(image_bytes) > MAX_IMAGE_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Image exceeds {MAX_IMAGE_BYTES / (1024*1024):.0f}MB limit"
|
||||
)
|
||||
|
||||
# Build payload
|
||||
payload = {
|
||||
"image": _as_data_uri(image_bytes, image_mime),
|
||||
"prompt": prompt,
|
||||
"resolution": resolution,
|
||||
"duration": duration,
|
||||
"enable_prompt_expansion": enable_prompt_expansion,
|
||||
}
|
||||
|
||||
# Add optional audio
|
||||
if audio_base64:
|
||||
try:
|
||||
audio_bytes, audio_mime = _decode_base64_audio(audio_base64)
|
||||
|
||||
# Validate audio size
|
||||
if len(audio_bytes) > MAX_AUDIO_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Audio exceeds {MAX_AUDIO_BYTES / (1024*1024):.0f}MB limit"
|
||||
)
|
||||
|
||||
# Note: Audio duration validation would require audio analysis
|
||||
# For now, we rely on API to handle it (API keeps first 5s/10s if longer)
|
||||
|
||||
payload["audio"] = _as_data_uri(audio_bytes, audio_mime)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to decode audio: {str(e)}"
|
||||
)
|
||||
|
||||
# Add optional parameters
|
||||
if negative_prompt:
|
||||
payload["negative_prompt"] = negative_prompt
|
||||
|
||||
if seed is not None:
|
||||
payload["seed"] = seed
|
||||
|
||||
# Submit to WaveSpeed
|
||||
logger.info(
|
||||
f"[WAN 2.5] Submitting video generation request: resolution={resolution}, duration={duration}s"
|
||||
)
|
||||
|
||||
try:
|
||||
prediction_id = self.client.submit_image_to_video(
|
||||
WAN25_MODEL_PATH,
|
||||
payload,
|
||||
timeout=60
|
||||
)
|
||||
except HTTPException as e:
|
||||
logger.error(f"[WAN 2.5] Submission failed: {e.detail}")
|
||||
raise
|
||||
|
||||
# Poll for completion
|
||||
logger.info(f"[WAN 2.5] Polling for completion: prediction_id={prediction_id}")
|
||||
|
||||
try:
|
||||
# WAN 2.5 typically takes 1-2 minutes
|
||||
result = self.client.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=180, # 3 minutes max
|
||||
interval_seconds=2.0
|
||||
)
|
||||
except HTTPException as e:
|
||||
detail = e.detail or {}
|
||||
if isinstance(detail, dict):
|
||||
detail.setdefault("prediction_id", prediction_id)
|
||||
detail.setdefault("resume_available", True)
|
||||
raise HTTPException(status_code=e.status_code, detail=detail)
|
||||
|
||||
# Extract video URL
|
||||
outputs = result.get("outputs") or []
|
||||
if not outputs:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WAN 2.5 completed but returned no outputs"
|
||||
)
|
||||
|
||||
video_url = outputs[0]
|
||||
if not isinstance(video_url, str) or not video_url.startswith("http"):
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"Invalid video URL format: {video_url}"
|
||||
)
|
||||
|
||||
# Download video (run synchronous request in thread)
|
||||
logger.info(f"[WAN 2.5] Downloading video from: {video_url}")
|
||||
video_response = await asyncio.to_thread(
|
||||
requests.get,
|
||||
video_url,
|
||||
timeout=180
|
||||
)
|
||||
|
||||
if video_response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "Failed to download WAN 2.5 video",
|
||||
"status_code": video_response.status_code,
|
||||
"response": video_response.text[:200],
|
||||
}
|
||||
)
|
||||
|
||||
video_bytes = video_response.content
|
||||
metadata = result.get("metadata") or {}
|
||||
|
||||
# Calculate cost
|
||||
cost = self.calculate_cost(resolution, duration)
|
||||
|
||||
# Get video dimensions from resolution
|
||||
resolution_dims = {
|
||||
"480p": (854, 480),
|
||||
"720p": (1280, 720),
|
||||
"1080p": (1920, 1080),
|
||||
}
|
||||
width, height = resolution_dims.get(resolution, (1280, 720))
|
||||
|
||||
logger.info(
|
||||
f"[WAN 2.5] ✅ Generated video: {len(video_bytes)} bytes, "
|
||||
f"resolution={resolution}, duration={duration}s, cost=${cost:.2f}"
|
||||
)
|
||||
|
||||
return {
|
||||
"video_bytes": video_bytes,
|
||||
"prompt": prompt,
|
||||
"duration": float(duration),
|
||||
"model_name": WAN25_MODEL_NAME,
|
||||
"cost": cost,
|
||||
"provider": "wavespeed",
|
||||
"source_video_url": video_url,
|
||||
"prediction_id": prediction_id,
|
||||
"resolution": resolution,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
20
backend/services/product_marketing/__init__.py
Normal file
20
backend/services/product_marketing/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Product Marketing Suite service package."""
|
||||
|
||||
from .orchestrator import ProductMarketingOrchestrator
|
||||
from .brand_dna_sync import BrandDNASyncService
|
||||
from .prompt_builder import ProductMarketingPromptBuilder
|
||||
from .asset_audit import AssetAuditService
|
||||
from .channel_pack import ChannelPackService
|
||||
from .campaign_storage import CampaignStorageService
|
||||
from .product_image_service import ProductImageService
|
||||
|
||||
__all__ = [
|
||||
"ProductMarketingOrchestrator",
|
||||
"BrandDNASyncService",
|
||||
"ProductMarketingPromptBuilder",
|
||||
"AssetAuditService",
|
||||
"ChannelPackService",
|
||||
"CampaignStorageService",
|
||||
"ProductImageService",
|
||||
]
|
||||
|
||||
205
backend/services/product_marketing/asset_audit.py
Normal file
205
backend/services/product_marketing/asset_audit.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""
|
||||
Asset Audit Service
|
||||
Analyzes uploaded assets and recommends enhancement operations.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from loguru import logger
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class AssetAuditService:
|
||||
"""Service to audit assets and recommend enhancements."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Asset Audit Service."""
|
||||
self.logger = logger
|
||||
logger.info("[Asset Audit] Service initialized")
|
||||
|
||||
def audit_asset(
|
||||
self,
|
||||
image_base64: str,
|
||||
asset_metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Audit an uploaded asset and recommend enhancement operations.
|
||||
|
||||
Args:
|
||||
image_base64: Base64 encoded image
|
||||
asset_metadata: Optional metadata about the asset
|
||||
|
||||
Returns:
|
||||
Audit results with recommendations
|
||||
"""
|
||||
try:
|
||||
# Decode image
|
||||
image_bytes = self._decode_base64(image_base64)
|
||||
if not image_bytes:
|
||||
raise ValueError("Invalid image data")
|
||||
|
||||
# Analyze image
|
||||
image = Image.open(BytesIO(image_bytes))
|
||||
width, height = image.size
|
||||
format_type = image.format or "PNG"
|
||||
mode = image.mode
|
||||
|
||||
# Basic quality checks
|
||||
quality_score = self._assess_quality(image, width, height)
|
||||
|
||||
# Generate recommendations
|
||||
recommendations = []
|
||||
|
||||
# Resolution recommendations
|
||||
if width < 1080 or height < 1080:
|
||||
recommendations.append({
|
||||
"operation": "upscale",
|
||||
"priority": "high",
|
||||
"reason": f"Image resolution ({width}x{height}) is below recommended 1080p for social media",
|
||||
"suggested_mode": "fast" if width < 512 else "conservative",
|
||||
})
|
||||
|
||||
# Background recommendations
|
||||
if mode == "RGBA" and self._has_transparency(image):
|
||||
recommendations.append({
|
||||
"operation": "remove_background",
|
||||
"priority": "low",
|
||||
"reason": "Image already has transparency, background removal may not be needed",
|
||||
})
|
||||
else:
|
||||
recommendations.append({
|
||||
"operation": "remove_background",
|
||||
"priority": "medium",
|
||||
"reason": "Background removal can create versatile product images",
|
||||
})
|
||||
|
||||
# Enhancement recommendations based on quality
|
||||
if quality_score < 0.7:
|
||||
recommendations.append({
|
||||
"operation": "enhance",
|
||||
"priority": "high",
|
||||
"reason": f"Image quality score ({quality_score:.2f}) suggests enhancement needed",
|
||||
"suggested_operations": ["upscale", "general_edit"],
|
||||
})
|
||||
|
||||
# Format recommendations
|
||||
if format_type not in ["PNG", "JPEG"]:
|
||||
recommendations.append({
|
||||
"operation": "convert",
|
||||
"priority": "low",
|
||||
"reason": f"Format {format_type} may not be optimal for web/social media",
|
||||
"suggested_format": "PNG" if mode == "RGBA" else "JPEG",
|
||||
})
|
||||
|
||||
audit_result = {
|
||||
"asset_info": {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"format": format_type,
|
||||
"mode": mode,
|
||||
"quality_score": quality_score,
|
||||
},
|
||||
"recommendations": recommendations,
|
||||
"status": "usable" if quality_score > 0.6 else "needs_enhancement",
|
||||
}
|
||||
|
||||
logger.info(f"[Asset Audit] Audited asset: {width}x{height}, quality: {quality_score:.2f}")
|
||||
return audit_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Asset Audit] Error auditing asset: {str(e)}")
|
||||
return {
|
||||
"asset_info": {},
|
||||
"recommendations": [],
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
def _decode_base64(self, image_base64: str) -> Optional[bytes]:
|
||||
"""Decode base64 image data."""
|
||||
try:
|
||||
if image_base64.startswith("data:"):
|
||||
_, b64data = image_base64.split(",", 1)
|
||||
else:
|
||||
b64data = image_base64
|
||||
return base64.b64decode(b64data)
|
||||
except Exception as e:
|
||||
logger.error(f"[Asset Audit] Error decoding base64: {str(e)}")
|
||||
return None
|
||||
|
||||
def _has_transparency(self, image: Image.Image) -> bool:
|
||||
"""Check if image has transparency."""
|
||||
if image.mode in ("RGBA", "LA"):
|
||||
alpha = image.split()[-1]
|
||||
return any(pixel < 255 for pixel in alpha.getdata())
|
||||
return False
|
||||
|
||||
def _assess_quality(self, image: Image.Image, width: int, height: int) -> float:
|
||||
"""
|
||||
Assess image quality score (0.0 to 1.0).
|
||||
|
||||
Simple heuristic based on resolution and format.
|
||||
"""
|
||||
score = 0.5 # Base score
|
||||
|
||||
# Resolution scoring
|
||||
min_dimension = min(width, height)
|
||||
if min_dimension >= 1080:
|
||||
score += 0.3
|
||||
elif min_dimension >= 512:
|
||||
score += 0.2
|
||||
elif min_dimension >= 256:
|
||||
score += 0.1
|
||||
|
||||
# Format scoring
|
||||
if image.format in ["PNG", "JPEG"]:
|
||||
score += 0.1
|
||||
|
||||
# Mode scoring
|
||||
if image.mode in ["RGB", "RGBA"]:
|
||||
score += 0.1
|
||||
|
||||
return min(score, 1.0)
|
||||
|
||||
def batch_audit_assets(
|
||||
self,
|
||||
assets: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Audit multiple assets in batch.
|
||||
|
||||
Args:
|
||||
assets: List of asset dictionaries with 'image_base64' and optional 'metadata'
|
||||
|
||||
Returns:
|
||||
Batch audit results
|
||||
"""
|
||||
results = []
|
||||
for asset in assets:
|
||||
audit_result = self.audit_asset(
|
||||
asset.get('image_base64'),
|
||||
asset.get('metadata')
|
||||
)
|
||||
results.append({
|
||||
"asset_id": asset.get('id'),
|
||||
"audit": audit_result,
|
||||
})
|
||||
|
||||
# Summary statistics
|
||||
total_assets = len(results)
|
||||
usable_count = sum(1 for r in results if r["audit"]["status"] == "usable")
|
||||
needs_enhancement_count = sum(
|
||||
1 for r in results if r["audit"]["status"] == "needs_enhancement"
|
||||
)
|
||||
|
||||
return {
|
||||
"results": results,
|
||||
"summary": {
|
||||
"total_assets": total_assets,
|
||||
"usable": usable_count,
|
||||
"needs_enhancement": needs_enhancement_count,
|
||||
"error": total_assets - usable_count - needs_enhancement_count,
|
||||
},
|
||||
}
|
||||
|
||||
176
backend/services/product_marketing/brand_dna_sync.py
Normal file
176
backend/services/product_marketing/brand_dna_sync.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
Brand DNA Sync Service
|
||||
Normalizes persona data and onboarding information into reusable brand tokens.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from loguru import logger
|
||||
|
||||
from services.onboarding import OnboardingDatabaseService
|
||||
from services.database import SessionLocal
|
||||
|
||||
|
||||
class BrandDNASyncService:
|
||||
"""Service to sync and normalize brand DNA from onboarding and persona data."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Brand DNA Sync Service."""
|
||||
self.logger = logger
|
||||
logger.info("[Brand DNA Sync] Service initialized")
|
||||
|
||||
def get_brand_dna_tokens(self, user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract and normalize brand DNA tokens from onboarding and persona data.
|
||||
|
||||
Args:
|
||||
user_id: User ID to fetch data for
|
||||
|
||||
Returns:
|
||||
Dictionary of brand DNA tokens ready for prompt injection
|
||||
"""
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
onboarding_db = OnboardingDatabaseService(db)
|
||||
website_analysis = onboarding_db.get_website_analysis(user_id, db)
|
||||
persona_data = onboarding_db.get_persona_data(user_id, db)
|
||||
competitor_analyses = onboarding_db.get_competitor_analysis(user_id, db)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
brand_tokens = {
|
||||
"writing_style": {},
|
||||
"target_audience": {},
|
||||
"visual_identity": {},
|
||||
"persona": {},
|
||||
"competitive_positioning": {},
|
||||
}
|
||||
|
||||
# Extract writing style from website analysis
|
||||
if website_analysis:
|
||||
writing_style = website_analysis.get('writing_style') or {}
|
||||
target_audience = website_analysis.get('target_audience') or {}
|
||||
brand_analysis = website_analysis.get('brand_analysis') or {}
|
||||
style_guidelines = website_analysis.get('style_guidelines') or {}
|
||||
|
||||
# Ensure writing_style is a dict before accessing
|
||||
if isinstance(writing_style, dict):
|
||||
brand_tokens["writing_style"] = {
|
||||
"tone": writing_style.get('tone', 'professional'),
|
||||
"voice": writing_style.get('voice', 'authoritative'),
|
||||
"complexity": writing_style.get('complexity', 'intermediate'),
|
||||
"engagement_level": writing_style.get('engagement_level', 'moderate'),
|
||||
}
|
||||
|
||||
# Ensure target_audience is a dict before accessing
|
||||
if isinstance(target_audience, dict):
|
||||
brand_tokens["target_audience"] = {
|
||||
"demographics": target_audience.get('demographics', []),
|
||||
"industry_focus": target_audience.get('industry_focus', 'general'),
|
||||
"expertise_level": target_audience.get('expertise_level', 'intermediate'),
|
||||
}
|
||||
|
||||
# Ensure brand_analysis is a dict before accessing
|
||||
if isinstance(brand_analysis, dict) and brand_analysis:
|
||||
brand_tokens["visual_identity"] = {
|
||||
"color_palette": brand_analysis.get('color_palette', []),
|
||||
"brand_values": brand_analysis.get('brand_values', []),
|
||||
"positioning": brand_analysis.get('positioning', ''),
|
||||
}
|
||||
|
||||
# Add style_guidelines if available and visual_identity exists
|
||||
if style_guidelines and isinstance(style_guidelines, dict):
|
||||
if "visual_identity" not in brand_tokens:
|
||||
brand_tokens["visual_identity"] = {}
|
||||
brand_tokens["visual_identity"]["style_guidelines"] = style_guidelines
|
||||
|
||||
# Extract persona data
|
||||
if persona_data:
|
||||
core_persona = persona_data.get('corePersona') or {}
|
||||
platform_personas = persona_data.get('platformPersonas') or {}
|
||||
|
||||
# Ensure core_persona is a dict before accessing
|
||||
if isinstance(core_persona, dict) and core_persona:
|
||||
brand_tokens["persona"] = {
|
||||
"persona_name": core_persona.get('persona_name', ''),
|
||||
"archetype": core_persona.get('archetype', ''),
|
||||
"core_belief": core_persona.get('core_belief', ''),
|
||||
"linguistic_fingerprint": core_persona.get('linguistic_fingerprint', {}),
|
||||
}
|
||||
|
||||
# Ensure persona dict exists before setting platform_personas
|
||||
if "persona" not in brand_tokens:
|
||||
brand_tokens["persona"] = {}
|
||||
|
||||
# Only set platform_personas if it's a valid dict
|
||||
if isinstance(platform_personas, dict):
|
||||
brand_tokens["persona"]["platform_personas"] = platform_personas
|
||||
|
||||
# Extract competitive positioning
|
||||
if competitor_analyses and isinstance(competitor_analyses, list) and len(competitor_analyses) > 0:
|
||||
# Extract differentiation points
|
||||
brand_tokens["competitive_positioning"] = {
|
||||
"differentiators": [],
|
||||
"unique_value_props": [],
|
||||
}
|
||||
|
||||
for competitor in competitor_analyses[:3]: # Top 3 competitors
|
||||
if not isinstance(competitor, dict):
|
||||
continue
|
||||
|
||||
analysis_data = competitor.get('analysis_data') or {}
|
||||
if isinstance(analysis_data, dict) and analysis_data:
|
||||
competitive_insights = analysis_data.get('competitive_analysis') or {}
|
||||
if isinstance(competitive_insights, dict) and competitive_insights:
|
||||
differentiators = competitive_insights.get('differentiators', [])
|
||||
if isinstance(differentiators, list) and differentiators:
|
||||
brand_tokens["competitive_positioning"]["differentiators"].extend(
|
||||
differentiators[:2]
|
||||
)
|
||||
|
||||
logger.info(f"[Brand DNA Sync] Extracted brand tokens for user {user_id}")
|
||||
return brand_tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Brand DNA Sync] Error extracting brand tokens: {str(e)}")
|
||||
return {
|
||||
"writing_style": {"tone": "professional", "voice": "authoritative"},
|
||||
"target_audience": {"demographics": [], "expertise_level": "intermediate"},
|
||||
"visual_identity": {},
|
||||
"persona": {},
|
||||
"competitive_positioning": {},
|
||||
}
|
||||
|
||||
def get_channel_specific_dna(
|
||||
self,
|
||||
user_id: str,
|
||||
channel: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get channel-specific brand DNA adaptations.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
channel: Target channel (instagram, linkedin, tiktok, etc.)
|
||||
|
||||
Returns:
|
||||
Channel-specific brand DNA tokens
|
||||
"""
|
||||
brand_tokens = self.get_brand_dna_tokens(user_id)
|
||||
channel_dna = brand_tokens.copy()
|
||||
|
||||
# Get platform-specific persona if available
|
||||
persona = brand_tokens.get("persona") or {}
|
||||
platform_personas = persona.get("platform_personas") or {}
|
||||
|
||||
if isinstance(platform_personas, dict) and channel in platform_personas:
|
||||
platform_persona = platform_personas[channel]
|
||||
if isinstance(platform_persona, dict):
|
||||
channel_dna["platform_adaptation"] = {
|
||||
"content_format_rules": platform_persona.get('content_format_rules') or {},
|
||||
"engagement_patterns": platform_persona.get('engagement_patterns') or {},
|
||||
"visual_identity": platform_persona.get('visual_identity') or {},
|
||||
}
|
||||
|
||||
return channel_dna
|
||||
|
||||
222
backend/services/product_marketing/campaign_storage.py
Normal file
222
backend/services/product_marketing/campaign_storage.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""
|
||||
Campaign Storage Service
|
||||
Handles database persistence for campaigns, proposals, and assets.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc
|
||||
|
||||
from models.product_marketing_models import Campaign, CampaignProposal, CampaignAsset, CampaignStatus
|
||||
from services.database import SessionLocal
|
||||
|
||||
|
||||
class CampaignStorageService:
|
||||
"""Service for storing and retrieving campaigns from database."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Campaign Storage Service."""
|
||||
self.logger = logger
|
||||
logger.info("[Campaign Storage] Service initialized")
|
||||
|
||||
def save_campaign(
|
||||
self,
|
||||
user_id: str,
|
||||
campaign_data: Dict[str, Any]
|
||||
) -> Campaign:
|
||||
"""
|
||||
Save campaign blueprint to database.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
campaign_data: Campaign blueprint data
|
||||
|
||||
Returns:
|
||||
Saved Campaign object
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
campaign_id = campaign_data.get('campaign_id')
|
||||
|
||||
# Check if campaign exists
|
||||
existing = db.query(Campaign).filter(
|
||||
Campaign.campaign_id == campaign_id,
|
||||
Campaign.user_id == user_id
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
# Update existing campaign
|
||||
existing.campaign_name = campaign_data.get('campaign_name', existing.campaign_name)
|
||||
existing.goal = campaign_data.get('goal', existing.goal)
|
||||
existing.kpi = campaign_data.get('kpi', existing.kpi)
|
||||
existing.status = campaign_data.get('status', existing.status)
|
||||
existing.phases = campaign_data.get('phases', existing.phases)
|
||||
existing.channels = campaign_data.get('channels', existing.channels)
|
||||
existing.asset_nodes = campaign_data.get('asset_nodes', existing.asset_nodes)
|
||||
existing.product_context = campaign_data.get('product_context', existing.product_context)
|
||||
db.commit()
|
||||
db.refresh(existing)
|
||||
logger.info(f"[Campaign Storage] Updated campaign {campaign_id}")
|
||||
return existing
|
||||
else:
|
||||
# Create new campaign
|
||||
campaign = Campaign(
|
||||
campaign_id=campaign_id,
|
||||
user_id=user_id,
|
||||
campaign_name=campaign_data.get('campaign_name'),
|
||||
goal=campaign_data.get('goal'),
|
||||
kpi=campaign_data.get('kpi'),
|
||||
status=campaign_data.get('status', 'draft'),
|
||||
phases=campaign_data.get('phases'),
|
||||
channels=campaign_data.get('channels', []),
|
||||
asset_nodes=campaign_data.get('asset_nodes', []),
|
||||
product_context=campaign_data.get('product_context'),
|
||||
)
|
||||
db.add(campaign)
|
||||
db.commit()
|
||||
db.refresh(campaign)
|
||||
logger.info(f"[Campaign Storage] Saved new campaign {campaign_id}")
|
||||
return campaign
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"[Campaign Storage] Error saving campaign: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def get_campaign(
|
||||
self,
|
||||
user_id: str,
|
||||
campaign_id: str
|
||||
) -> Optional[Campaign]:
|
||||
"""Get campaign by ID."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
campaign = db.query(Campaign).filter(
|
||||
Campaign.campaign_id == campaign_id,
|
||||
Campaign.user_id == user_id
|
||||
).first()
|
||||
return campaign
|
||||
except Exception as e:
|
||||
logger.error(f"[Campaign Storage] Error getting campaign: {str(e)}")
|
||||
return None
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def list_campaigns(
|
||||
self,
|
||||
user_id: str,
|
||||
status: Optional[str] = None,
|
||||
limit: int = 50
|
||||
) -> List[Campaign]:
|
||||
"""List campaigns for user."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
query = db.query(Campaign).filter(Campaign.user_id == user_id)
|
||||
|
||||
if status:
|
||||
query = query.filter(Campaign.status == status)
|
||||
|
||||
campaigns = query.order_by(desc(Campaign.created_at)).limit(limit).all()
|
||||
return campaigns
|
||||
except Exception as e:
|
||||
logger.error(f"[Campaign Storage] Error listing campaigns: {str(e)}")
|
||||
return []
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def save_proposals(
|
||||
self,
|
||||
user_id: str,
|
||||
campaign_id: str,
|
||||
proposals: Dict[str, Any]
|
||||
) -> List[CampaignProposal]:
|
||||
"""Save asset proposals for a campaign."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Delete existing proposals for this campaign
|
||||
db.query(CampaignProposal).filter(
|
||||
CampaignProposal.campaign_id == campaign_id,
|
||||
CampaignProposal.user_id == user_id
|
||||
).delete()
|
||||
|
||||
# Create new proposals
|
||||
saved_proposals = []
|
||||
for asset_id, proposal_data in proposals.get('proposals', {}).items():
|
||||
proposal = CampaignProposal(
|
||||
campaign_id=campaign_id,
|
||||
user_id=user_id,
|
||||
asset_node_id=asset_id,
|
||||
asset_type=proposal_data.get('asset_type'),
|
||||
channel=proposal_data.get('channel'),
|
||||
proposed_prompt=proposal_data.get('proposed_prompt'),
|
||||
recommended_template=proposal_data.get('recommended_template'),
|
||||
recommended_provider=proposal_data.get('recommended_provider'),
|
||||
recommended_model=proposal_data.get('recommended_model'),
|
||||
cost_estimate=proposal_data.get('cost_estimate', 0.0),
|
||||
concept_summary=proposal_data.get('concept_summary'),
|
||||
status='proposed',
|
||||
)
|
||||
db.add(proposal)
|
||||
saved_proposals.append(proposal)
|
||||
|
||||
db.commit()
|
||||
for proposal in saved_proposals:
|
||||
db.refresh(proposal)
|
||||
|
||||
logger.info(f"[Campaign Storage] Saved {len(saved_proposals)} proposals for campaign {campaign_id}")
|
||||
return saved_proposals
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"[Campaign Storage] Error saving proposals: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def get_proposals(
|
||||
self,
|
||||
user_id: str,
|
||||
campaign_id: str
|
||||
) -> List[CampaignProposal]:
|
||||
"""Get proposals for a campaign."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
proposals = db.query(CampaignProposal).filter(
|
||||
CampaignProposal.campaign_id == campaign_id,
|
||||
CampaignProposal.user_id == user_id
|
||||
).all()
|
||||
return proposals
|
||||
except Exception as e:
|
||||
logger.error(f"[Campaign Storage] Error getting proposals: {str(e)}")
|
||||
return []
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def update_campaign_status(
|
||||
self,
|
||||
user_id: str,
|
||||
campaign_id: str,
|
||||
status: str
|
||||
) -> bool:
|
||||
"""Update campaign status."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
campaign = db.query(Campaign).filter(
|
||||
Campaign.campaign_id == campaign_id,
|
||||
Campaign.user_id == user_id
|
||||
).first()
|
||||
|
||||
if campaign:
|
||||
campaign.status = status
|
||||
db.commit()
|
||||
logger.info(f"[Campaign Storage] Updated campaign {campaign_id} status to {status}")
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"[Campaign Storage] Error updating status: {str(e)}")
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
180
backend/services/product_marketing/channel_pack.py
Normal file
180
backend/services/product_marketing/channel_pack.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
Channel Pack Service
|
||||
Maps channels to templates, copy frameworks, and platform-specific optimizations.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from loguru import logger
|
||||
|
||||
from services.image_studio.templates import Platform, TemplateManager
|
||||
from services.image_studio.social_optimizer_service import SocialOptimizerService
|
||||
|
||||
|
||||
class ChannelPackService:
|
||||
"""Service to build channel-specific asset packs."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Channel Pack Service."""
|
||||
self.template_manager = TemplateManager()
|
||||
self.social_optimizer = SocialOptimizerService()
|
||||
self.logger = logger
|
||||
logger.info("[Channel Pack] Service initialized")
|
||||
|
||||
def get_channel_pack(
|
||||
self,
|
||||
channel: str,
|
||||
asset_type: str = "social_post"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get channel-specific pack configuration.
|
||||
|
||||
Args:
|
||||
channel: Target channel (instagram, linkedin, tiktok, facebook, twitter, pinterest, youtube)
|
||||
asset_type: Type of asset (social_post, story, reel, cover, etc.)
|
||||
|
||||
Returns:
|
||||
Channel pack configuration with templates, dimensions, copy frameworks
|
||||
"""
|
||||
try:
|
||||
# Map channel string to Platform enum
|
||||
platform_map = {
|
||||
'instagram': Platform.INSTAGRAM,
|
||||
'linkedin': Platform.LINKEDIN,
|
||||
'tiktok': Platform.TIKTOK,
|
||||
'facebook': Platform.FACEBOOK,
|
||||
'twitter': Platform.TWITTER,
|
||||
'pinterest': Platform.PINTEREST,
|
||||
'youtube': Platform.YOUTUBE,
|
||||
}
|
||||
|
||||
platform = platform_map.get(channel.lower())
|
||||
if not platform:
|
||||
raise ValueError(f"Unsupported channel: {channel}")
|
||||
|
||||
# Get templates for this platform
|
||||
templates = self.template_manager.get_platform_templates().get(platform, [])
|
||||
|
||||
# Get platform formats
|
||||
formats = self.social_optimizer.get_platform_formats(platform)
|
||||
|
||||
# Build channel pack
|
||||
pack = {
|
||||
"channel": channel,
|
||||
"platform": platform.value,
|
||||
"asset_type": asset_type,
|
||||
"templates": [
|
||||
{
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"dimensions": f"{t.aspect_ratio.width}x{t.aspect_ratio.height}",
|
||||
"aspect_ratio": t.aspect_ratio.ratio,
|
||||
"recommended_provider": t.recommended_provider,
|
||||
"quality": t.quality,
|
||||
}
|
||||
for t in templates
|
||||
],
|
||||
"formats": formats,
|
||||
"copy_framework": self._get_copy_framework(channel, asset_type),
|
||||
"optimization_tips": self._get_optimization_tips(channel),
|
||||
}
|
||||
|
||||
logger.info(f"[Channel Pack] Built pack for {channel} ({asset_type})")
|
||||
return pack
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Channel Pack] Error building pack: {str(e)}")
|
||||
return {
|
||||
"channel": channel,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
def _get_copy_framework(
|
||||
self,
|
||||
channel: str,
|
||||
asset_type: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Get copy framework for channel and asset type."""
|
||||
frameworks = {
|
||||
"instagram": {
|
||||
"social_post": {
|
||||
"caption_length": "125-150 words optimal",
|
||||
"hashtags": "5-10 relevant hashtags",
|
||||
"cta": "Clear call-to-action in first line",
|
||||
"emoji": "Use 1-3 emojis strategically",
|
||||
},
|
||||
"story": {
|
||||
"text_overlay": "Keep text minimal, readable at small size",
|
||||
"cta": "Swipe-up or link sticker",
|
||||
},
|
||||
},
|
||||
"linkedin": {
|
||||
"social_post": {
|
||||
"length": "150-300 words for maximum engagement",
|
||||
"hashtags": "3-5 professional hashtags",
|
||||
"tone": "Professional, thought-leadership focused",
|
||||
"cta": "Engage with question or call-to-action",
|
||||
},
|
||||
},
|
||||
"tiktok": {
|
||||
"video": {
|
||||
"hook": "Strong hook in first 3 seconds",
|
||||
"caption": "Short, engaging, use trending hashtags",
|
||||
"hashtags": "3-5 trending hashtags",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return frameworks.get(channel, {}).get(asset_type, {})
|
||||
|
||||
def _get_optimization_tips(self, channel: str) -> List[str]:
|
||||
"""Get optimization tips for channel."""
|
||||
tips = {
|
||||
"instagram": [
|
||||
"Use square (1:1) or portrait (4:5) for feed posts",
|
||||
"Include text overlay safe zones (15% top/bottom, 10% left/right)",
|
||||
"Optimize for mobile viewing",
|
||||
],
|
||||
"linkedin": {
|
||||
"Use landscape (1.91:1) for feed posts",
|
||||
"Professional photography style",
|
||||
"Include clear value proposition",
|
||||
},
|
||||
"tiktok": {
|
||||
"Vertical format (9:16) required",
|
||||
"Eye-catching first frame",
|
||||
"Fast-paced, engaging content",
|
||||
},
|
||||
}
|
||||
|
||||
return tips.get(channel, [])
|
||||
|
||||
def build_multi_channel_pack(
|
||||
self,
|
||||
channels: List[str],
|
||||
source_image_base64: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build optimized asset pack for multiple channels from single source.
|
||||
|
||||
Args:
|
||||
channels: List of target channels
|
||||
source_image_base64: Source image to optimize
|
||||
|
||||
Returns:
|
||||
Multi-channel pack with optimized variants
|
||||
"""
|
||||
pack_results = []
|
||||
|
||||
for channel in channels:
|
||||
pack = self.get_channel_pack(channel)
|
||||
pack_results.append({
|
||||
"channel": channel,
|
||||
"pack": pack,
|
||||
})
|
||||
|
||||
return {
|
||||
"source_image": "provided",
|
||||
"channels": pack_results,
|
||||
"total_variants": len(channels),
|
||||
}
|
||||
|
||||
469
backend/services/product_marketing/orchestrator.py
Normal file
469
backend/services/product_marketing/orchestrator.py
Normal file
@@ -0,0 +1,469 @@
|
||||
"""
|
||||
Product Marketing Orchestrator
|
||||
Main service that orchestrates campaign workflows and asset generation.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass
|
||||
from loguru import logger
|
||||
|
||||
from services.image_studio import ImageStudioManager, CreateStudioRequest
|
||||
from .prompt_builder import ProductMarketingPromptBuilder
|
||||
from .brand_dna_sync import BrandDNASyncService
|
||||
from .asset_audit import AssetAuditService
|
||||
from .channel_pack import ChannelPackService
|
||||
from services.database import SessionLocal
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_generation_operations
|
||||
|
||||
|
||||
@dataclass
|
||||
class CampaignAssetNode:
|
||||
"""Represents an asset node in the campaign graph."""
|
||||
asset_id: str
|
||||
asset_type: str # image, video, text, audio
|
||||
channel: str
|
||||
status: str # draft, generating, ready, approved
|
||||
prompt: Optional[str] = None
|
||||
template_id: Optional[str] = None
|
||||
provider: Optional[str] = None
|
||||
cost_estimate: Optional[float] = None
|
||||
generated_asset_id: Optional[int] = None # Asset Library ID
|
||||
|
||||
|
||||
@dataclass
|
||||
class CampaignBlueprint:
|
||||
"""Campaign blueprint with phases and asset nodes."""
|
||||
campaign_id: str
|
||||
campaign_name: str
|
||||
goal: str
|
||||
kpi: Optional[str] = None
|
||||
phases: List[Dict[str, Any]] = None # teaser, launch, nurture
|
||||
asset_nodes: List[CampaignAssetNode] = None
|
||||
channels: List[str] = None
|
||||
status: str = "draft" # draft, generating, ready, published
|
||||
|
||||
|
||||
class ProductMarketingOrchestrator:
|
||||
"""Main orchestrator for Product Marketing Suite."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Product Marketing Orchestrator."""
|
||||
self.image_studio = ImageStudioManager()
|
||||
self.prompt_builder = ProductMarketingPromptBuilder()
|
||||
self.brand_dna_sync = BrandDNASyncService()
|
||||
self.asset_audit = AssetAuditService()
|
||||
self.channel_pack = ChannelPackService()
|
||||
self.logger = logger
|
||||
logger.info("[Product Marketing Orchestrator] Initialized")
|
||||
|
||||
def create_campaign_blueprint(
|
||||
self,
|
||||
user_id: str,
|
||||
campaign_data: Dict[str, Any]
|
||||
) -> CampaignBlueprint:
|
||||
"""
|
||||
Create campaign blueprint from user input and onboarding data.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
campaign_data: Campaign information (name, goal, channels, etc.)
|
||||
|
||||
Returns:
|
||||
Campaign blueprint with asset nodes
|
||||
"""
|
||||
try:
|
||||
import time
|
||||
campaign_id = campaign_data.get('campaign_id') or f"campaign_{user_id}_{int(time.time())}"
|
||||
campaign_name = campaign_data.get('campaign_name', 'New Campaign')
|
||||
goal = campaign_data.get('goal', 'product_launch')
|
||||
channels = campaign_data.get('channels', [])
|
||||
|
||||
# Get brand DNA for personalization
|
||||
brand_dna = self.brand_dna_sync.get_brand_dna_tokens(user_id)
|
||||
|
||||
# Build campaign phases
|
||||
phases = self._build_campaign_phases(goal, channels)
|
||||
|
||||
# Generate asset nodes for each phase and channel
|
||||
asset_nodes = []
|
||||
for phase in phases:
|
||||
phase_name = phase.get('name')
|
||||
for channel in channels:
|
||||
# Determine required assets for this phase + channel
|
||||
required_assets = self._get_required_assets(phase_name, channel)
|
||||
|
||||
for asset_type in required_assets:
|
||||
asset_node = CampaignAssetNode(
|
||||
asset_id=f"{campaign_id}_{phase_name}_{channel}_{asset_type}",
|
||||
asset_type=asset_type,
|
||||
channel=channel,
|
||||
status="draft",
|
||||
)
|
||||
asset_nodes.append(asset_node)
|
||||
|
||||
blueprint = CampaignBlueprint(
|
||||
campaign_id=campaign_id,
|
||||
campaign_name=campaign_name,
|
||||
goal=goal,
|
||||
kpi=campaign_data.get('kpi'),
|
||||
phases=phases,
|
||||
asset_nodes=asset_nodes,
|
||||
channels=channels,
|
||||
status="draft",
|
||||
)
|
||||
|
||||
logger.info(f"[Orchestrator] Created blueprint for campaign {campaign_id} with {len(asset_nodes)} assets")
|
||||
return blueprint
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Orchestrator] Error creating blueprint: {str(e)}")
|
||||
raise
|
||||
|
||||
def generate_asset_proposals(
|
||||
self,
|
||||
user_id: str,
|
||||
blueprint: CampaignBlueprint,
|
||||
product_context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate AI proposals for each asset node in the blueprint.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
blueprint: Campaign blueprint
|
||||
product_context: Product information
|
||||
|
||||
Returns:
|
||||
Dictionary with proposals for each asset node
|
||||
"""
|
||||
try:
|
||||
proposals = {}
|
||||
|
||||
for asset_node in blueprint.asset_nodes:
|
||||
# Build specialized prompt based on asset type and channel
|
||||
if asset_node.asset_type == "image":
|
||||
base_prompt = product_context.get('product_description', 'Product image') if product_context else 'Marketing image'
|
||||
enhanced_prompt = self.prompt_builder.build_marketing_image_prompt(
|
||||
base_prompt=base_prompt,
|
||||
user_id=user_id,
|
||||
channel=asset_node.channel,
|
||||
asset_type="hero_image",
|
||||
product_context=product_context,
|
||||
)
|
||||
|
||||
# Get channel pack for template recommendations
|
||||
channel_pack = self.channel_pack.get_channel_pack(asset_node.channel)
|
||||
recommended_template = channel_pack.get('templates', [{}])[0] if channel_pack.get('templates') else None
|
||||
|
||||
# Estimate cost
|
||||
cost_estimate = self._estimate_asset_cost("image", asset_node.channel)
|
||||
|
||||
proposals[asset_node.asset_id] = {
|
||||
"asset_id": asset_node.asset_id,
|
||||
"asset_type": asset_node.asset_type,
|
||||
"channel": asset_node.channel,
|
||||
"proposed_prompt": enhanced_prompt,
|
||||
"recommended_template": recommended_template.get('id') if recommended_template else None,
|
||||
"recommended_provider": recommended_template.get('recommended_provider', 'wavespeed') if recommended_template else 'wavespeed',
|
||||
"cost_estimate": cost_estimate,
|
||||
"concept_summary": self._generate_concept_summary(enhanced_prompt),
|
||||
}
|
||||
|
||||
elif asset_node.asset_type == "text":
|
||||
base_request = f"Write {asset_node.channel} {asset_node.asset_type} for product launch"
|
||||
enhanced_prompt = self.prompt_builder.build_marketing_copy_prompt(
|
||||
base_request=base_request,
|
||||
user_id=user_id,
|
||||
channel=asset_node.channel,
|
||||
content_type="caption",
|
||||
product_context=product_context,
|
||||
)
|
||||
|
||||
proposals[asset_node.asset_id] = {
|
||||
"asset_id": asset_node.asset_id,
|
||||
"asset_type": asset_node.asset_type,
|
||||
"channel": asset_node.channel,
|
||||
"proposed_prompt": enhanced_prompt,
|
||||
"cost_estimate": 0.0, # Text generation cost is minimal
|
||||
"concept_summary": "Marketing copy optimized for channel and persona",
|
||||
}
|
||||
|
||||
logger.info(f"[Orchestrator] Generated {len(proposals)} asset proposals")
|
||||
return {"proposals": proposals, "total_assets": len(proposals)}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Orchestrator] Error generating proposals: {str(e)}")
|
||||
raise
|
||||
|
||||
async def generate_asset(
|
||||
self,
|
||||
user_id: str,
|
||||
asset_proposal: Dict[str, Any],
|
||||
product_context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a single asset using Image Studio APIs.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
asset_proposal: Asset proposal from generate_asset_proposals
|
||||
product_context: Product information
|
||||
|
||||
Returns:
|
||||
Generated asset result
|
||||
"""
|
||||
try:
|
||||
asset_type = asset_proposal.get('asset_type')
|
||||
|
||||
if asset_type == "image":
|
||||
# Build CreateStudioRequest
|
||||
create_request = CreateStudioRequest(
|
||||
prompt=asset_proposal.get('proposed_prompt'),
|
||||
template_id=asset_proposal.get('recommended_template'),
|
||||
provider=asset_proposal.get('recommended_provider', 'wavespeed'),
|
||||
quality="premium",
|
||||
enhance_prompt=True,
|
||||
use_persona=True,
|
||||
num_variations=1,
|
||||
)
|
||||
|
||||
# Generate image using Image Studio
|
||||
result = await self.image_studio.create_image(create_request, user_id=user_id)
|
||||
|
||||
# Asset is automatically tracked in Asset Library via Image Studio
|
||||
return {
|
||||
"success": True,
|
||||
"asset_type": "image",
|
||||
"result": result,
|
||||
"asset_library_ids": [
|
||||
r.get('asset_id') for r in result.get('results', [])
|
||||
if r.get('asset_id')
|
||||
],
|
||||
}
|
||||
|
||||
elif asset_type == "text":
|
||||
# Import text generation service and tracker
|
||||
import asyncio
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from utils.text_asset_tracker import save_and_track_text_content
|
||||
from services.database import SessionLocal
|
||||
|
||||
# Get enhanced prompt from proposal
|
||||
text_prompt = asset_proposal.get('proposed_prompt', '')
|
||||
channel = asset_proposal.get('channel', 'social')
|
||||
asset_id = asset_proposal.get('asset_id', '')
|
||||
|
||||
# Extract campaign_id - try from asset_proposal first, then from asset_id
|
||||
# asset_id format: {campaign_id}_{phase}_{channel}_{type}
|
||||
campaign_id = asset_proposal.get('campaign_id')
|
||||
if not campaign_id and asset_id and '_' in asset_id:
|
||||
# Try to extract: asset_id might be "campaign_user123_1234567890_teaser_instagram_text"
|
||||
# We need to find where phase_name starts (common phases: teaser, launch, nurture)
|
||||
parts = asset_id.split('_')
|
||||
# Find phase indicator (usually one of: teaser, launch, nurture)
|
||||
phase_indicators = ['teaser', 'launch', 'nurture', 'prelaunch', 'postlaunch']
|
||||
phase_idx = None
|
||||
for i, part in enumerate(parts):
|
||||
if part.lower() in phase_indicators:
|
||||
phase_idx = i
|
||||
break
|
||||
if phase_idx and phase_idx > 0:
|
||||
# Campaign ID is everything before the phase
|
||||
campaign_id = '_'.join(parts[:phase_idx])
|
||||
|
||||
# If still not found, use None (metadata will work without it)
|
||||
if not campaign_id:
|
||||
logger.warning(f"[Orchestrator] Could not extract campaign_id from asset_id: {asset_id}")
|
||||
|
||||
# Build system prompt for marketing copy
|
||||
system_prompt = f"""You are an expert marketing copywriter specializing in {channel} content.
|
||||
Generate compelling, on-brand marketing copy that:
|
||||
- Is optimized for {channel} platform best practices
|
||||
- Includes a clear call-to-action
|
||||
- Uses appropriate tone and style for the platform
|
||||
- Is concise and engaging
|
||||
- Aligns with the product marketing context provided
|
||||
|
||||
Return only the final copy text without explanations or markdown formatting."""
|
||||
|
||||
# Run synchronous llm_text_gen in thread pool
|
||||
logger.info(f"[Orchestrator] Generating text asset for channel: {channel}")
|
||||
generated_text = await asyncio.to_thread(
|
||||
llm_text_gen,
|
||||
prompt=text_prompt,
|
||||
system_prompt=system_prompt,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if not generated_text or not generated_text.strip():
|
||||
raise ValueError("Text generation returned empty content")
|
||||
|
||||
# Save to Asset Library
|
||||
db = SessionLocal()
|
||||
asset_library_id = None
|
||||
try:
|
||||
asset_library_id = save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=generated_text.strip(),
|
||||
source_module="product_marketing",
|
||||
title=f"{channel.title()} Copy: {asset_id.split('_')[-1] if '_' in asset_id else 'Marketing Copy'}",
|
||||
description=f"Marketing copy for {channel} platform generated from campaign proposal",
|
||||
prompt=text_prompt,
|
||||
tags=["product_marketing", channel.lower(), "text", "copy"],
|
||||
asset_metadata={
|
||||
"campaign_id": campaign_id,
|
||||
"asset_id": asset_id,
|
||||
"asset_type": "text",
|
||||
"channel": channel,
|
||||
"concept_summary": asset_proposal.get('concept_summary'),
|
||||
},
|
||||
subdirectory="campaigns",
|
||||
file_extension=".txt"
|
||||
)
|
||||
|
||||
if asset_library_id:
|
||||
logger.info(f"[Orchestrator] ✅ Text asset saved to library: ID={asset_library_id}")
|
||||
else:
|
||||
logger.warning(f"[Orchestrator] ⚠️ Text asset tracking returned None")
|
||||
|
||||
except Exception as save_error:
|
||||
logger.error(f"[Orchestrator] ⚠️ Failed to save text asset to library: {str(save_error)}")
|
||||
# Continue even if save fails - text is still generated
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"asset_type": "text",
|
||||
"content": generated_text.strip(),
|
||||
"asset_library_id": asset_library_id,
|
||||
"channel": channel,
|
||||
}
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported asset type: {asset_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Orchestrator] Error generating asset: {str(e)}")
|
||||
raise
|
||||
|
||||
def validate_campaign_preflight(
|
||||
self,
|
||||
user_id: str,
|
||||
blueprint: CampaignBlueprint
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate campaign blueprint against subscription limits before generation.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
blueprint: Campaign blueprint
|
||||
|
||||
Returns:
|
||||
Pre-flight validation results
|
||||
"""
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
# Count operations needed
|
||||
image_count = sum(1 for node in blueprint.asset_nodes if node.asset_type == "image")
|
||||
text_count = sum(1 for node in blueprint.asset_nodes if node.asset_type == "text")
|
||||
|
||||
# Estimate total cost
|
||||
total_cost = 0.0
|
||||
for node in blueprint.asset_nodes:
|
||||
if node.cost_estimate:
|
||||
total_cost += node.cost_estimate
|
||||
|
||||
# Validate image generation limits
|
||||
operations = []
|
||||
if image_count > 0:
|
||||
operations.append({
|
||||
'provider': 'stability', # Default provider
|
||||
'tokens_requested': 0,
|
||||
'actual_provider_name': 'wavespeed',
|
||||
'operation_type': 'image_generation',
|
||||
})
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations * image_count if operations else []
|
||||
)
|
||||
|
||||
return {
|
||||
"can_proceed": can_proceed,
|
||||
"message": message,
|
||||
"error_details": error_details,
|
||||
"summary": {
|
||||
"total_assets": len(blueprint.asset_nodes),
|
||||
"image_count": image_count,
|
||||
"text_count": text_count,
|
||||
"estimated_cost": total_cost,
|
||||
},
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Orchestrator] Error in pre-flight validation: {str(e)}")
|
||||
return {
|
||||
"can_proceed": False,
|
||||
"message": f"Validation error: {str(e)}",
|
||||
"error_details": {},
|
||||
}
|
||||
|
||||
def _build_campaign_phases(
|
||||
self,
|
||||
goal: str,
|
||||
channels: List[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Build campaign phases based on goal."""
|
||||
if goal == "product_launch":
|
||||
return [
|
||||
{"name": "teaser", "duration_days": 7, "purpose": "Build anticipation"},
|
||||
{"name": "launch", "duration_days": 3, "purpose": "Official launch"},
|
||||
{"name": "nurture", "duration_days": 14, "purpose": "Sustain engagement"},
|
||||
]
|
||||
else:
|
||||
return [
|
||||
{"name": "campaign", "duration_days": 30, "purpose": "Campaign execution"},
|
||||
]
|
||||
|
||||
def _get_required_assets(
|
||||
self,
|
||||
phase: str,
|
||||
channel: str
|
||||
) -> List[str]:
|
||||
"""Get required asset types for phase and channel."""
|
||||
# Default: image for all phases and channels
|
||||
assets = ["image"]
|
||||
|
||||
# Add text/copy for social channels
|
||||
if channel in ["instagram", "linkedin", "facebook", "twitter"]:
|
||||
assets.append("text")
|
||||
|
||||
return assets
|
||||
|
||||
def _estimate_asset_cost(
|
||||
self,
|
||||
asset_type: str,
|
||||
channel: str
|
||||
) -> float:
|
||||
"""Estimate cost for asset generation."""
|
||||
if asset_type == "image":
|
||||
# Premium quality image: ~5-6 credits
|
||||
return 5.0
|
||||
elif asset_type == "text":
|
||||
return 0.0 # Text generation is typically included
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def _generate_concept_summary(self, prompt: str) -> str:
|
||||
"""Generate a brief concept summary from prompt."""
|
||||
# Simple extraction: take first 100 chars
|
||||
return prompt[:100] + "..." if len(prompt) > 100 else prompt
|
||||
|
||||
634
backend/services/product_marketing/product_image_service.py
Normal file
634
backend/services/product_marketing/product_image_service.py
Normal file
@@ -0,0 +1,634 @@
|
||||
"""
|
||||
Product Image Service
|
||||
Specialized service for generating product-focused images using AI models.
|
||||
Optimized for e-commerce product photography, product showcases, and product marketing assets.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
import os
|
||||
import shutil
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
from services.database import SessionLocal
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
class ProductImageServiceError(Exception):
|
||||
"""Base exception for Product Image Service errors."""
|
||||
pass
|
||||
|
||||
|
||||
class ValidationError(ProductImageServiceError):
|
||||
"""Validation error for invalid requests."""
|
||||
pass
|
||||
|
||||
|
||||
class ImageGenerationError(ProductImageServiceError):
|
||||
"""Error during image generation."""
|
||||
pass
|
||||
|
||||
|
||||
class StorageError(ProductImageServiceError):
|
||||
"""Error saving image to storage."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProductImageRequest:
|
||||
"""Request for product image generation."""
|
||||
product_name: str
|
||||
product_description: str
|
||||
environment: str = "studio" # studio, lifestyle, outdoor, minimalist, luxury
|
||||
background_style: str = "white" # white, transparent, lifestyle, branded
|
||||
lighting: str = "natural" # natural, studio, dramatic, soft
|
||||
product_variant: Optional[str] = None # color, size, etc.
|
||||
angle: Optional[str] = None # front, side, top, 360, etc.
|
||||
style: str = "photorealistic" # photorealistic, minimalist, luxury, technical
|
||||
resolution: str = "1024x1024" # 1024x1024, 1280x720, etc.
|
||||
num_variations: int = 1
|
||||
brand_colors: Optional[List[str]] = None # Brand color palette
|
||||
additional_context: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProductImageResult:
|
||||
"""Result from product image generation."""
|
||||
success: bool
|
||||
product_name: str
|
||||
image_url: Optional[str] = None
|
||||
image_bytes: Optional[bytes] = None
|
||||
asset_id: Optional[int] = None # Asset Library ID
|
||||
provider: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
cost: float = 0.0
|
||||
generation_time: float = 0.0
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class ProductImageService:
|
||||
"""Service for generating product marketing images."""
|
||||
|
||||
# Product photography style presets
|
||||
ENVIRONMENT_PROMPTS = {
|
||||
"studio": "professional studio photography, clean white background, even lighting",
|
||||
"lifestyle": "lifestyle photography, product in use, natural environment, relatable setting",
|
||||
"outdoor": "outdoor photography, natural lighting, outdoor environment, dynamic setting",
|
||||
"minimalist": "minimalist product photography, simple composition, clean aesthetic",
|
||||
"luxury": "luxury product photography, premium aesthetic, sophisticated lighting, high-end",
|
||||
}
|
||||
|
||||
BACKGROUND_STYLES = {
|
||||
"white": "clean white background",
|
||||
"transparent": "transparent background, isolated product",
|
||||
"lifestyle": "lifestyle background, contextual environment",
|
||||
"branded": "branded background with brand colors",
|
||||
}
|
||||
|
||||
LIGHTING_STYLES = {
|
||||
"natural": "natural lighting, soft shadows, balanced exposure",
|
||||
"studio": "professional studio lighting, even illumination, no harsh shadows",
|
||||
"dramatic": "dramatic lighting, high contrast, artistic shadows",
|
||||
"soft": "soft diffused lighting, gentle shadows, elegant",
|
||||
}
|
||||
|
||||
# Valid values for request parameters
|
||||
VALID_ENVIRONMENTS = {"studio", "lifestyle", "outdoor", "minimalist", "luxury"}
|
||||
VALID_BACKGROUND_STYLES = {"white", "transparent", "lifestyle", "branded"}
|
||||
VALID_LIGHTING_STYLES = {"natural", "studio", "dramatic", "soft"}
|
||||
VALID_STYLES = {"photorealistic", "minimalist", "luxury", "technical"}
|
||||
VALID_ANGLES = {"front", "side", "top", "360"}
|
||||
|
||||
# Maximum values
|
||||
MAX_RESOLUTION = (4096, 4096)
|
||||
MIN_RESOLUTION = (256, 256)
|
||||
MAX_NUM_VARIATIONS = 10
|
||||
MAX_PRODUCT_NAME_LENGTH = 500
|
||||
MAX_PRODUCT_DESCRIPTION_LENGTH = 2000
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Product Image Service."""
|
||||
try:
|
||||
self.wavespeed_client = WaveSpeedClient()
|
||||
logger.info("[Product Image Service] Initialized")
|
||||
except Exception as e:
|
||||
logger.error(f"[Product Image Service] Failed to initialize WaveSpeed client: {str(e)}")
|
||||
raise ProductImageServiceError(f"Failed to initialize service: {str(e)}") from e
|
||||
|
||||
def validate_request(self, request: ProductImageRequest) -> None:
|
||||
"""
|
||||
Validate product image generation request.
|
||||
|
||||
Args:
|
||||
request: Product image generation request
|
||||
|
||||
Raises:
|
||||
ValidationError: If request is invalid
|
||||
"""
|
||||
errors = []
|
||||
|
||||
# Validate product_name
|
||||
if not request.product_name or not request.product_name.strip():
|
||||
errors.append("Product name is required")
|
||||
elif len(request.product_name) > self.MAX_PRODUCT_NAME_LENGTH:
|
||||
errors.append(f"Product name must be <= {self.MAX_PRODUCT_NAME_LENGTH} characters")
|
||||
|
||||
# Validate product_description
|
||||
if request.product_description and len(request.product_description) > self.MAX_PRODUCT_DESCRIPTION_LENGTH:
|
||||
errors.append(f"Product description must be <= {self.MAX_PRODUCT_DESCRIPTION_LENGTH} characters")
|
||||
|
||||
# Validate environment
|
||||
if request.environment not in self.VALID_ENVIRONMENTS:
|
||||
errors.append(f"Invalid environment: {request.environment}. Valid: {', '.join(self.VALID_ENVIRONMENTS)}")
|
||||
|
||||
# Validate background_style
|
||||
if request.background_style not in self.VALID_BACKGROUND_STYLES:
|
||||
errors.append(f"Invalid background_style: {request.background_style}. Valid: {', '.join(self.VALID_BACKGROUND_STYLES)}")
|
||||
|
||||
# Validate lighting
|
||||
if request.lighting not in self.VALID_LIGHTING_STYLES:
|
||||
errors.append(f"Invalid lighting: {request.lighting}. Valid: {', '.join(self.VALID_LIGHTING_STYLES)}")
|
||||
|
||||
# Validate style
|
||||
if request.style not in self.VALID_STYLES:
|
||||
errors.append(f"Invalid style: {request.style}. Valid: {', '.join(self.VALID_STYLES)}")
|
||||
|
||||
# Validate angle
|
||||
if request.angle and request.angle not in self.VALID_ANGLES:
|
||||
errors.append(f"Invalid angle: {request.angle}. Valid: {', '.join(self.VALID_ANGLES)}")
|
||||
|
||||
# Validate num_variations
|
||||
if request.num_variations < 1:
|
||||
errors.append("num_variations must be >= 1")
|
||||
elif request.num_variations > self.MAX_NUM_VARIATIONS:
|
||||
errors.append(f"num_variations must be <= {self.MAX_NUM_VARIATIONS}")
|
||||
|
||||
# Validate resolution
|
||||
try:
|
||||
width, height = self._parse_resolution(request.resolution)
|
||||
if width < self.MIN_RESOLUTION[0] or height < self.MIN_RESOLUTION[1]:
|
||||
errors.append(f"Resolution must be >= {self.MIN_RESOLUTION[0]}x{self.MIN_RESOLUTION[1]}")
|
||||
if width > self.MAX_RESOLUTION[0] or height > self.MAX_RESOLUTION[1]:
|
||||
errors.append(f"Resolution must be <= {self.MAX_RESOLUTION[0]}x{self.MAX_RESOLUTION[1]}")
|
||||
except Exception as e:
|
||||
errors.append(f"Invalid resolution format: {request.resolution}. Error: {str(e)}")
|
||||
|
||||
if errors:
|
||||
raise ValidationError(f"Validation failed: {'; '.join(errors)}")
|
||||
|
||||
def build_product_prompt(
|
||||
self,
|
||||
request: ProductImageRequest,
|
||||
brand_context: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Build optimized prompt for product image generation.
|
||||
|
||||
Args:
|
||||
request: Product image generation request
|
||||
brand_context: Optional brand DNA context for personalization
|
||||
|
||||
Returns:
|
||||
Optimized prompt string
|
||||
"""
|
||||
prompt_parts = []
|
||||
|
||||
# Base product description
|
||||
prompt_parts.append(f"Professional product photography of {request.product_name}")
|
||||
if request.product_description:
|
||||
prompt_parts.append(f": {request.product_description}")
|
||||
|
||||
# Product variant
|
||||
if request.product_variant:
|
||||
prompt_parts.append(f", {request.product_variant}")
|
||||
|
||||
# Environment and style
|
||||
env_prompt = self.ENVIRONMENT_PROMPTS.get(request.environment, self.ENVIRONMENT_PROMPTS["studio"])
|
||||
prompt_parts.append(f", {env_prompt}")
|
||||
|
||||
# Background
|
||||
bg_prompt = self.BACKGROUND_STYLES.get(request.background_style, self.BACKGROUND_STYLES["white"])
|
||||
if request.background_style == "branded" and request.brand_colors:
|
||||
bg_prompt += f", using brand colors: {', '.join(request.brand_colors)}"
|
||||
prompt_parts.append(f", {bg_prompt}")
|
||||
|
||||
# Lighting
|
||||
lighting_prompt = self.LIGHTING_STYLES.get(request.lighting, self.LIGHTING_STYLES["natural"])
|
||||
prompt_parts.append(f", {lighting_prompt}")
|
||||
|
||||
# Angle/view
|
||||
if request.angle:
|
||||
angle_map = {
|
||||
"front": "front view, centered composition",
|
||||
"side": "side profile view, showing depth",
|
||||
"top": "top-down view, flat lay style",
|
||||
"360": "3/4 angle view, showing multiple sides",
|
||||
}
|
||||
angle_prompt = angle_map.get(request.angle, request.angle)
|
||||
prompt_parts.append(f", {angle_prompt}")
|
||||
|
||||
# Style
|
||||
style_map = {
|
||||
"photorealistic": "photorealistic, highly detailed, professional photography",
|
||||
"minimalist": "minimalist aesthetic, clean composition, simple and elegant",
|
||||
"luxury": "luxury aesthetic, premium quality, sophisticated and refined",
|
||||
"technical": "technical product photography, detailed features, professional documentation style",
|
||||
}
|
||||
style_prompt = style_map.get(request.style, style_map["photorealistic"])
|
||||
prompt_parts.append(f", {style_prompt}")
|
||||
|
||||
# Additional context
|
||||
if request.additional_context:
|
||||
prompt_parts.append(f", {request.additional_context}")
|
||||
|
||||
# Brand DNA integration (if available)
|
||||
if brand_context:
|
||||
brand_tone = brand_context.get("visual_identity", {}).get("style_guidelines")
|
||||
if brand_tone:
|
||||
prompt_parts.append(f", brand style: {brand_tone}")
|
||||
|
||||
# Quality keywords
|
||||
prompt_parts.append(", high resolution, professional quality, sharp focus, commercial photography")
|
||||
|
||||
full_prompt = " ".join(prompt_parts)
|
||||
logger.debug(f"[Product Image Service] Built prompt: {full_prompt[:200]}...")
|
||||
|
||||
return full_prompt
|
||||
|
||||
def _generate_image_with_retry(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
width: int,
|
||||
height: int,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 2.0
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate image with retry logic for transient failures.
|
||||
|
||||
Args:
|
||||
model: Model to use
|
||||
prompt: Generation prompt
|
||||
width: Image width
|
||||
height: Image height
|
||||
max_retries: Maximum number of retries
|
||||
retry_delay: Delay between retries in seconds
|
||||
|
||||
Returns:
|
||||
Generated image bytes
|
||||
|
||||
Raises:
|
||||
ImageGenerationError: If generation fails after retries
|
||||
"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
logger.info(f"[Product Image Service] Image generation attempt {attempt + 1}/{max_retries}")
|
||||
|
||||
image_bytes = self.wavespeed_client.generate_image(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
enable_sync_mode=True,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
if not image_bytes:
|
||||
raise ValueError("Image generation returned empty result")
|
||||
|
||||
if len(image_bytes) < 100: # Sanity check: image should be at least 100 bytes
|
||||
raise ValueError(f"Generated image too small: {len(image_bytes)} bytes")
|
||||
|
||||
logger.info(f"[Product Image Service] ✅ Image generated successfully: {len(image_bytes)} bytes")
|
||||
return image_bytes
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
error_msg = str(e)
|
||||
logger.warning(f"[Product Image Service] Attempt {attempt + 1} failed: {error_msg}")
|
||||
|
||||
# Don't retry on validation errors or client errors (4xx)
|
||||
if "4" in error_msg or "validation" in error_msg.lower() or "invalid" in error_msg.lower():
|
||||
logger.error(f"[Product Image Service] Non-retryable error: {error_msg}")
|
||||
raise ImageGenerationError(f"Image generation failed: {error_msg}") from e
|
||||
|
||||
# Retry on transient errors
|
||||
if attempt < max_retries - 1:
|
||||
logger.info(f"[Product Image Service] Retrying in {retry_delay} seconds...")
|
||||
time.sleep(retry_delay)
|
||||
retry_delay *= 1.5 # Exponential backoff
|
||||
else:
|
||||
logger.error(f"[Product Image Service] All retry attempts failed")
|
||||
|
||||
raise ImageGenerationError(f"Image generation failed after {max_retries} attempts: {str(last_error)}") from last_error
|
||||
|
||||
async def generate_product_image(
|
||||
self,
|
||||
request: ProductImageRequest,
|
||||
user_id: str,
|
||||
brand_context: Optional[Dict[str, Any]] = None
|
||||
) -> ProductImageResult:
|
||||
"""
|
||||
Generate product image using AI models.
|
||||
|
||||
Args:
|
||||
request: Product image generation request
|
||||
user_id: User ID for tracking
|
||||
brand_context: Optional brand DNA for personalization
|
||||
|
||||
Returns:
|
||||
ProductImageResult with generated image
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Validate request
|
||||
self.validate_request(request)
|
||||
|
||||
# Validate user_id
|
||||
if not user_id or not user_id.strip():
|
||||
raise ValidationError("user_id is required")
|
||||
|
||||
# Build optimized prompt
|
||||
prompt = self.build_product_prompt(request, brand_context)
|
||||
|
||||
# Parse resolution
|
||||
width, height = self._parse_resolution(request.resolution)
|
||||
|
||||
# Select model based on style/quality needs
|
||||
model = "ideogram-v3-turbo" # Default to Ideogram V3 for photorealistic products
|
||||
if request.style == "minimalist":
|
||||
model = "ideogram-v3-turbo" # Still use Ideogram for quality
|
||||
elif request.style == "technical":
|
||||
model = "ideogram-v3-turbo"
|
||||
|
||||
logger.info(f"[Product Image Service] Generating product image for '{request.product_name}' using {model}")
|
||||
|
||||
# Generate image using WaveSpeed with retry logic
|
||||
try:
|
||||
image_bytes = self._generate_image_with_retry(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
max_retries=3,
|
||||
retry_delay=2.0
|
||||
)
|
||||
except ImageGenerationError as e:
|
||||
logger.error(f"[Product Image Service] Image generation failed: {str(e)}")
|
||||
generation_time = time.time() - start_time
|
||||
return ProductImageResult(
|
||||
success=False,
|
||||
product_name=request.product_name,
|
||||
error=f"Image generation failed: {str(e)}",
|
||||
generation_time=generation_time,
|
||||
)
|
||||
|
||||
# Save image to file and Asset Library
|
||||
asset_id = None
|
||||
image_url = None
|
||||
|
||||
try:
|
||||
asset_id, image_url = self._save_product_image(
|
||||
image_bytes=image_bytes,
|
||||
request=request,
|
||||
user_id=user_id,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
start_time=start_time
|
||||
)
|
||||
except StorageError as storage_error:
|
||||
logger.error(f"[Product Image Service] Storage failed: {str(storage_error)}", exc_info=True)
|
||||
# Continue with generation result even if storage fails
|
||||
# The image_bytes is still available in the result
|
||||
except Exception as save_error:
|
||||
logger.error(f"[Product Image Service] Unexpected error saving image: {str(save_error)}", exc_info=True)
|
||||
# Continue even if save fails
|
||||
|
||||
generation_time = time.time() - start_time
|
||||
|
||||
return ProductImageResult(
|
||||
success=True,
|
||||
product_name=request.product_name,
|
||||
image_url=image_url,
|
||||
image_bytes=image_bytes,
|
||||
asset_id=asset_id,
|
||||
provider="wavespeed",
|
||||
model=model,
|
||||
cost=0.10,
|
||||
generation_time=generation_time,
|
||||
)
|
||||
|
||||
except ValidationError as ve:
|
||||
logger.error(f"[Product Image Service] Validation error: {str(ve)}")
|
||||
generation_time = time.time() - start_time
|
||||
return ProductImageResult(
|
||||
success=False,
|
||||
product_name=request.product_name if hasattr(request, 'product_name') else "unknown",
|
||||
error=f"Validation error: {str(ve)}",
|
||||
generation_time=generation_time,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Product Image Service] ❌ Unexpected error generating product image: {str(e)}", exc_info=True)
|
||||
generation_time = time.time() - start_time
|
||||
return ProductImageResult(
|
||||
success=False,
|
||||
product_name=request.product_name if hasattr(request, 'product_name') else "unknown",
|
||||
error=f"Unexpected error: {str(e)}",
|
||||
generation_time=generation_time,
|
||||
)
|
||||
|
||||
def _save_product_image(
|
||||
self,
|
||||
image_bytes: bytes,
|
||||
request: ProductImageRequest,
|
||||
user_id: str,
|
||||
prompt: str,
|
||||
model: str,
|
||||
start_time: float
|
||||
) -> tuple[Optional[int], Optional[str]]:
|
||||
"""
|
||||
Save product image to disk and Asset Library.
|
||||
|
||||
Args:
|
||||
image_bytes: Generated image bytes
|
||||
request: Product image generation request
|
||||
user_id: User ID
|
||||
prompt: Generation prompt
|
||||
model: Model used
|
||||
start_time: Generation start time
|
||||
|
||||
Returns:
|
||||
Tuple of (asset_id, image_url)
|
||||
|
||||
Raises:
|
||||
StorageError: If saving fails
|
||||
"""
|
||||
db = None
|
||||
asset_id = None
|
||||
image_url = None
|
||||
image_path = None
|
||||
|
||||
try:
|
||||
# Generate filename
|
||||
product_hash = hashlib.md5(request.product_name.encode()).hexdigest()[:8]
|
||||
timestamp = int(start_time)
|
||||
filename = f"product_{product_hash}_{timestamp}.png"
|
||||
|
||||
# Determine base directory and create product_images folder
|
||||
base_dir = Path(__file__).parent.parent.parent
|
||||
product_images_dir = base_dir / "product_images"
|
||||
|
||||
# Create directory with error handling
|
||||
try:
|
||||
product_images_dir.mkdir(parents=True, exist_ok=True)
|
||||
except PermissionError as pe:
|
||||
raise StorageError(f"Permission denied creating directory: {str(pe)}") from pe
|
||||
except OSError as oe:
|
||||
raise StorageError(f"Failed to create directory: {str(oe)}") from oe
|
||||
|
||||
# Check disk space (rough estimate - at least 10MB free)
|
||||
try:
|
||||
stat = shutil.disk_usage(product_images_dir)
|
||||
free_space_mb = stat.free / (1024 * 1024)
|
||||
if free_space_mb < 10:
|
||||
raise StorageError(f"Insufficient disk space: {free_space_mb:.1f}MB free (need at least 10MB)")
|
||||
except OSError as oe:
|
||||
logger.warning(f"[Product Image Service] Could not check disk space: {str(oe)}")
|
||||
|
||||
# Save image to disk
|
||||
image_path = product_images_dir / filename
|
||||
try:
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
# Verify file was written
|
||||
if not image_path.exists() or image_path.stat().st_size == 0:
|
||||
raise StorageError("Image file was not written correctly")
|
||||
except PermissionError as pe:
|
||||
raise StorageError(f"Permission denied writing file: {str(pe)}") from pe
|
||||
except OSError as oe:
|
||||
raise StorageError(f"Failed to write file: {str(oe)}") from oe
|
||||
|
||||
file_size = len(image_bytes)
|
||||
image_url = f"/api/product-marketing/images/{filename}"
|
||||
|
||||
# Save to Asset Library
|
||||
db = SessionLocal()
|
||||
try:
|
||||
asset_id = save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="image",
|
||||
source_module="product_marketing",
|
||||
filename=filename,
|
||||
file_url=image_url,
|
||||
file_path=str(image_path),
|
||||
file_size=file_size,
|
||||
mime_type="image/png",
|
||||
title=f"{request.product_name} - Product Image",
|
||||
description=f"Product image: {request.product_description or request.product_name}",
|
||||
prompt=prompt,
|
||||
tags=["product_marketing", "product_image", request.environment, request.style],
|
||||
provider="wavespeed",
|
||||
model=model,
|
||||
cost=0.10, # Estimated cost for Ideogram V3
|
||||
asset_metadata={
|
||||
"product_name": request.product_name,
|
||||
"product_description": request.product_description,
|
||||
"environment": request.environment,
|
||||
"background_style": request.background_style,
|
||||
"lighting": request.lighting,
|
||||
"style": request.style,
|
||||
"variant": request.product_variant,
|
||||
"angle": request.angle,
|
||||
},
|
||||
)
|
||||
|
||||
if asset_id:
|
||||
logger.info(f"[Product Image Service] ✅ Saved product image to Asset Library: ID={asset_id}")
|
||||
else:
|
||||
logger.warning(f"[Product Image Service] ⚠️ Asset Library save returned None (file saved but not tracked)")
|
||||
|
||||
except Exception as db_error:
|
||||
logger.error(f"[Product Image Service] Database error saving to Asset Library: {str(db_error)}", exc_info=True)
|
||||
# File is saved, but database tracking failed
|
||||
# This is not critical - image is still accessible
|
||||
raise StorageError(f"Failed to save to Asset Library: {str(db_error)}") from db_error
|
||||
finally:
|
||||
if db:
|
||||
try:
|
||||
db.close()
|
||||
except Exception as close_error:
|
||||
logger.warning(f"[Product Image Service] Error closing database: {str(close_error)}")
|
||||
|
||||
return (asset_id, image_url)
|
||||
|
||||
except StorageError:
|
||||
# Clean up partial files on storage error
|
||||
if image_path and image_path.exists():
|
||||
try:
|
||||
image_path.unlink()
|
||||
logger.info(f"[Product Image Service] Cleaned up partial file: {image_path}")
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"[Product Image Service] Failed to cleanup partial file: {str(cleanup_error)}")
|
||||
raise
|
||||
|
||||
def _parse_resolution(self, resolution: str) -> tuple[int, int]:
|
||||
"""
|
||||
Parse resolution string to width, height tuple.
|
||||
|
||||
Args:
|
||||
resolution: Resolution string (e.g., "1024x1024", "square", "landscape")
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height)
|
||||
"""
|
||||
try:
|
||||
resolution = resolution.strip().lower()
|
||||
|
||||
if "x" in resolution:
|
||||
parts = resolution.split("x")
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"Invalid resolution format: {resolution}")
|
||||
width = int(parts[0].strip())
|
||||
height = int(parts[1].strip())
|
||||
|
||||
# Validate resolution values
|
||||
if width < 1 or height < 1:
|
||||
raise ValueError(f"Resolution dimensions must be positive: {width}x{height}")
|
||||
|
||||
return (width, height)
|
||||
elif resolution == "square":
|
||||
return (1024, 1024)
|
||||
elif resolution == "landscape":
|
||||
return (1280, 720)
|
||||
elif resolution == "portrait":
|
||||
return (720, 1280)
|
||||
else:
|
||||
# Try to parse as single number (assume square)
|
||||
try:
|
||||
size = int(resolution)
|
||||
return (size, size)
|
||||
except ValueError:
|
||||
# Default to square
|
||||
logger.warning(f"[Product Image Service] Could not parse resolution '{resolution}', defaulting to 1024x1024")
|
||||
return (1024, 1024)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Product Image Service] Error parsing resolution '{resolution}': {str(e)}, defaulting to 1024x1024")
|
||||
return (1024, 1024)
|
||||
|
||||
def estimate_cost(self, request: ProductImageRequest) -> float:
|
||||
"""Estimate cost for product image generation."""
|
||||
# Ideogram V3 Turbo: ~$0.10 per image
|
||||
# Multiply by number of variations
|
||||
base_cost = 0.10
|
||||
return base_cost * request.num_variations
|
||||
|
||||
304
backend/services/product_marketing/prompt_builder.py
Normal file
304
backend/services/product_marketing/prompt_builder.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Product Marketing Prompt Builder
|
||||
Extends AIPromptOptimizer with marketing-specific prompt enhancement.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from loguru import logger
|
||||
|
||||
from services.ai_prompt_optimizer import AIPromptOptimizer
|
||||
from services.onboarding import OnboardingDataService
|
||||
from services.onboarding.database_service import OnboardingDatabaseService
|
||||
from services.persona_data_service import PersonaDataService
|
||||
from services.database import SessionLocal
|
||||
|
||||
|
||||
class ProductMarketingPromptBuilder(AIPromptOptimizer):
|
||||
"""Specialized prompt builder for marketing assets with onboarding data integration."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Product Marketing Prompt Builder."""
|
||||
super().__init__()
|
||||
self.onboarding_data_service = OnboardingDataService()
|
||||
self.logger = logger
|
||||
logger.info("[Product Marketing Prompt Builder] Initialized")
|
||||
|
||||
def build_marketing_image_prompt(
|
||||
self,
|
||||
base_prompt: str,
|
||||
user_id: str,
|
||||
channel: Optional[str] = None,
|
||||
asset_type: str = "hero_image",
|
||||
product_context: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Build enhanced marketing image prompt with brand DNA and persona data.
|
||||
|
||||
Args:
|
||||
base_prompt: Base product description or image concept
|
||||
user_id: User ID to fetch onboarding data
|
||||
channel: Target channel (instagram, linkedin, tiktok, etc.)
|
||||
asset_type: Type of asset (hero_image, product_photo, lifestyle, etc.)
|
||||
product_context: Additional product information
|
||||
|
||||
Returns:
|
||||
Enhanced prompt with brand DNA, persona style, and marketing context
|
||||
"""
|
||||
try:
|
||||
# Get onboarding data
|
||||
db = SessionLocal()
|
||||
try:
|
||||
onboarding_db = OnboardingDatabaseService(db)
|
||||
website_analysis = onboarding_db.get_website_analysis(user_id, db)
|
||||
persona_data = onboarding_db.get_persona_data(user_id, db)
|
||||
competitor_analyses = onboarding_db.get_competitor_analysis(user_id, db)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Build prompt layers
|
||||
enhanced_prompt = base_prompt
|
||||
|
||||
# Layer 1: Brand DNA (from website_analysis)
|
||||
if website_analysis:
|
||||
writing_style = website_analysis.get('writing_style', {})
|
||||
target_audience = website_analysis.get('target_audience', {})
|
||||
brand_analysis = website_analysis.get('brand_analysis', {})
|
||||
style_guidelines = website_analysis.get('style_guidelines', {})
|
||||
|
||||
# Add brand tone and style
|
||||
tone = writing_style.get('tone', 'professional')
|
||||
voice = writing_style.get('voice', 'authoritative')
|
||||
brand_enhancement = f", {tone} tone, {voice} voice"
|
||||
|
||||
# Add target audience context
|
||||
demographics = target_audience.get('demographics', [])
|
||||
if demographics:
|
||||
audience_context = f", targeting {', '.join(demographics[:2])}"
|
||||
enhanced_prompt += audience_context
|
||||
|
||||
# Add brand visual identity if available
|
||||
if brand_analysis:
|
||||
color_palette = brand_analysis.get('color_palette', [])
|
||||
if color_palette:
|
||||
colors = ', '.join(color_palette[:3])
|
||||
enhanced_prompt += f", brand colors: {colors}"
|
||||
|
||||
# Layer 2: Persona Visual Style (from persona_data)
|
||||
if persona_data:
|
||||
core_persona = persona_data.get('corePersona', {})
|
||||
platform_personas = persona_data.get('platformPersonas', {})
|
||||
|
||||
if core_persona:
|
||||
persona_name = core_persona.get('persona_name', '')
|
||||
archetype = core_persona.get('archetype', '')
|
||||
if persona_name:
|
||||
enhanced_prompt += f", {persona_name} style"
|
||||
|
||||
# Channel-specific persona adaptation
|
||||
if channel and platform_personas:
|
||||
platform_persona = platform_personas.get(channel, {})
|
||||
if platform_persona:
|
||||
visual_identity = platform_persona.get('visual_identity', {})
|
||||
if visual_identity:
|
||||
aesthetic = visual_identity.get('aesthetic_preferences', '')
|
||||
if aesthetic:
|
||||
enhanced_prompt += f", {aesthetic} aesthetic"
|
||||
|
||||
# Layer 3: Channel Optimization
|
||||
channel_enhancements = {
|
||||
'instagram': ', Instagram-optimized composition, vibrant colors, engaging visual',
|
||||
'linkedin': ', professional photography, clean composition, business-focused',
|
||||
'tiktok': ', dynamic composition, eye-catching, vertical format optimized',
|
||||
'facebook': ', social media optimized, engaging, shareable visual',
|
||||
'twitter': ', Twitter card optimized, clear focal point, readable at small size',
|
||||
'pinterest': ', Pinterest-optimized, vertical format, detailed and informative',
|
||||
}
|
||||
|
||||
if channel and channel.lower() in channel_enhancements:
|
||||
enhanced_prompt += channel_enhancements[channel.lower()]
|
||||
|
||||
# Layer 4: Asset Type Specific
|
||||
asset_type_enhancements = {
|
||||
'hero_image': ', hero image style, prominent product placement, professional photography',
|
||||
'product_photo': ', product photography, clean background, detailed product showcase',
|
||||
'lifestyle': ', lifestyle photography, natural setting, authentic scene',
|
||||
'social_post': ', social media post, engaging composition, optimized for engagement',
|
||||
}
|
||||
|
||||
if asset_type in asset_type_enhancements:
|
||||
enhanced_prompt += asset_type_enhancements[asset_type]
|
||||
|
||||
# Layer 5: Competitive Differentiation
|
||||
if competitor_analyses and len(competitor_analyses) > 0:
|
||||
# Extract unique positioning from competitor analysis
|
||||
enhanced_prompt += ", unique positioning, differentiated visual style"
|
||||
|
||||
# Layer 6: Quality Descriptors
|
||||
enhanced_prompt += ", professional photography, high quality, detailed, sharp focus, natural lighting"
|
||||
|
||||
# Layer 7: Marketing Context
|
||||
if product_context:
|
||||
marketing_goal = product_context.get('marketing_goal', '')
|
||||
if marketing_goal:
|
||||
enhanced_prompt += f", {marketing_goal} focused"
|
||||
|
||||
logger.info(f"[Marketing Prompt] Enhanced prompt for user {user_id}: {enhanced_prompt[:200]}...")
|
||||
return enhanced_prompt
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Marketing Prompt] Error building prompt: {str(e)}")
|
||||
# Return base prompt with minimal enhancement if error
|
||||
return f"{base_prompt}, professional photography, high quality"
|
||||
|
||||
def build_marketing_copy_prompt(
|
||||
self,
|
||||
base_request: str,
|
||||
user_id: str,
|
||||
channel: Optional[str] = None,
|
||||
content_type: str = "caption",
|
||||
product_context: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Build enhanced marketing copy prompt with persona linguistic fingerprint.
|
||||
|
||||
Args:
|
||||
base_request: Base content request (e.g., "Write Instagram caption for product launch")
|
||||
user_id: User ID to fetch onboarding data
|
||||
channel: Target channel (instagram, linkedin, etc.)
|
||||
content_type: Type of content (caption, cta, email, ad_copy, etc.)
|
||||
product_context: Additional product information
|
||||
|
||||
Returns:
|
||||
Enhanced prompt with persona style, brand voice, and marketing context
|
||||
"""
|
||||
try:
|
||||
# Get onboarding data
|
||||
db = SessionLocal()
|
||||
try:
|
||||
onboarding_db = OnboardingDatabaseService(db)
|
||||
website_analysis = onboarding_db.get_website_analysis(user_id, db)
|
||||
persona_data = onboarding_db.get_persona_data(user_id, db)
|
||||
competitor_analyses = onboarding_db.get_competitor_analysis(user_id, db)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Build enhanced prompt
|
||||
enhanced_prompt = base_request
|
||||
|
||||
# Add persona linguistic fingerprint
|
||||
if persona_data:
|
||||
core_persona = persona_data.get('corePersona', {})
|
||||
platform_personas = persona_data.get('platformPersonas', {})
|
||||
|
||||
if core_persona:
|
||||
persona_name = core_persona.get('persona_name', '')
|
||||
linguistic_fingerprint = core_persona.get('linguistic_fingerprint', {})
|
||||
|
||||
if persona_name:
|
||||
enhanced_prompt += f"\n\nFollow {persona_name} persona style:"
|
||||
|
||||
if linguistic_fingerprint:
|
||||
sentence_metrics = linguistic_fingerprint.get('sentence_metrics', {})
|
||||
lexical_features = linguistic_fingerprint.get('lexical_features', {})
|
||||
|
||||
if sentence_metrics:
|
||||
avg_length = sentence_metrics.get('average_sentence_length_words', '')
|
||||
if avg_length:
|
||||
enhanced_prompt += f"\n- Average sentence length: {avg_length} words"
|
||||
|
||||
if lexical_features:
|
||||
go_to_words = lexical_features.get('go_to_words', [])
|
||||
avoid_words = lexical_features.get('avoid_words', [])
|
||||
vocabulary_level = lexical_features.get('vocabulary_level', '')
|
||||
|
||||
if go_to_words:
|
||||
enhanced_prompt += f"\n- Use these words: {', '.join(go_to_words[:5])}"
|
||||
if avoid_words:
|
||||
enhanced_prompt += f"\n- Avoid these words: {', '.join(avoid_words[:5])}"
|
||||
if vocabulary_level:
|
||||
enhanced_prompt += f"\n- Vocabulary level: {vocabulary_level}"
|
||||
|
||||
# Channel-specific persona adaptation
|
||||
if channel and platform_personas:
|
||||
platform_persona = platform_personas.get(channel, {})
|
||||
if platform_persona:
|
||||
content_format_rules = platform_persona.get('content_format_rules', {})
|
||||
engagement_patterns = platform_persona.get('engagement_patterns', {})
|
||||
|
||||
if content_format_rules:
|
||||
char_limit = content_format_rules.get('character_limit', '')
|
||||
hashtag_strategy = content_format_rules.get('hashtag_strategy', '')
|
||||
|
||||
if char_limit:
|
||||
enhanced_prompt += f"\n- Character limit: {char_limit}"
|
||||
if hashtag_strategy:
|
||||
enhanced_prompt += f"\n- Hashtag strategy: {hashtag_strategy}"
|
||||
|
||||
# Add brand voice
|
||||
if website_analysis:
|
||||
writing_style = website_analysis.get('writing_style', {})
|
||||
target_audience = website_analysis.get('target_audience', {})
|
||||
|
||||
tone = writing_style.get('tone', 'professional')
|
||||
voice = writing_style.get('voice', 'authoritative')
|
||||
enhanced_prompt += f"\n- Brand tone: {tone}, Brand voice: {voice}"
|
||||
|
||||
demographics = target_audience.get('demographics', [])
|
||||
expertise_level = target_audience.get('expertise_level', 'intermediate')
|
||||
if demographics:
|
||||
enhanced_prompt += f"\n- Target audience: {', '.join(demographics[:2])}, {expertise_level} level"
|
||||
|
||||
# Add competitive positioning
|
||||
if competitor_analyses and len(competitor_analyses) > 0:
|
||||
enhanced_prompt += "\n- Differentiate from competitors, highlight unique value propositions"
|
||||
|
||||
# Add marketing context
|
||||
if product_context:
|
||||
marketing_goal = product_context.get('marketing_goal', '')
|
||||
if marketing_goal:
|
||||
enhanced_prompt += f"\n- Marketing goal: {marketing_goal}"
|
||||
|
||||
logger.info(f"[Marketing Copy Prompt] Enhanced for user {user_id}: {enhanced_prompt[:200]}...")
|
||||
return enhanced_prompt
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Marketing Copy Prompt] Error building prompt: {str(e)}")
|
||||
return base_request
|
||||
|
||||
def optimize_marketing_prompt(
|
||||
self,
|
||||
prompt_type: str,
|
||||
base_prompt: str,
|
||||
user_id: str,
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Main entry point for marketing prompt optimization.
|
||||
|
||||
Args:
|
||||
prompt_type: Type of prompt (image, copy, video_script, etc.)
|
||||
base_prompt: Base prompt to enhance
|
||||
user_id: User ID for personalization
|
||||
context: Additional context (channel, asset_type, product_context, etc.)
|
||||
|
||||
Returns:
|
||||
Optimized marketing prompt
|
||||
"""
|
||||
context = context or {}
|
||||
channel = context.get('channel')
|
||||
asset_type = context.get('asset_type', 'hero_image')
|
||||
content_type = context.get('content_type', 'caption')
|
||||
product_context = context.get('product_context')
|
||||
|
||||
if prompt_type == 'image':
|
||||
return self.build_marketing_image_prompt(
|
||||
base_prompt, user_id, channel, asset_type, product_context
|
||||
)
|
||||
elif prompt_type in ['copy', 'caption', 'cta', 'email', 'ad_copy']:
|
||||
return self.build_marketing_copy_prompt(
|
||||
base_prompt, user_id, channel, content_type, product_context
|
||||
)
|
||||
else:
|
||||
# Default: minimal enhancement
|
||||
return f"{base_prompt}, professional quality, marketing optimized"
|
||||
|
||||
Reference in New Issue
Block a user