AI Analysis and Content Strategy fixes. Enhanced Strategy Routes refactoring.
This commit is contained in:
@@ -448,7 +448,7 @@ Format as structured JSON with detailed assessment and optimization guidance.
|
||||
}
|
||||
}
|
||||
|
||||
async def _execute_ai_call(self, service_type: AIServiceType, prompt: str, schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def _execute_ai_call(self, service_type: AIServiceType, prompt: str, schema: Dict[str, Any], user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute AI call with comprehensive error handling and monitoring.
|
||||
|
||||
@@ -456,26 +456,35 @@ Format as structured JSON with detailed assessment and optimization guidance.
|
||||
service_type: Type of AI service being called
|
||||
prompt: The prompt to send to AI
|
||||
schema: Expected response schema
|
||||
user_id: Clerk user ID for subscription checking (REQUIRED - no fallback)
|
||||
|
||||
Returns:
|
||||
Dictionary with AI response or error information
|
||||
|
||||
Raises:
|
||||
RuntimeError: If user_id is not provided
|
||||
"""
|
||||
if not user_id:
|
||||
raise RuntimeError("user_id is required for subscription checking. All AI calls must be authenticated.")
|
||||
|
||||
start_time = datetime.utcnow()
|
||||
success = False
|
||||
error_message = None
|
||||
|
||||
try:
|
||||
logger.info(f"🤖 Executing AI call for {service_type.value}")
|
||||
logger.info(f"🤖 Executing AI call for {service_type.value}, user_id={user_id}")
|
||||
|
||||
# Emit educational content for frontend
|
||||
await self._emit_educational_content(service_type, "start")
|
||||
|
||||
# Execute the AI call
|
||||
# Execute the AI call through llm_text_gen for subscription checks
|
||||
# Use llm_text_gen which has subscription checks and usage tracking
|
||||
response = await asyncio.wait_for(
|
||||
asyncio.to_thread(
|
||||
self._call_gemini_structured,
|
||||
self._call_llm_with_checks,
|
||||
prompt,
|
||||
schema,
|
||||
user_id,
|
||||
),
|
||||
timeout=self.config['timeout_seconds']
|
||||
)
|
||||
@@ -531,9 +540,48 @@ Format as structured JSON with detailed assessment and optimization guidance.
|
||||
"success": False
|
||||
}
|
||||
|
||||
def _call_llm_with_checks(self, prompt: str, schema: Dict[str, Any], user_id: str):
|
||||
"""
|
||||
Call LLM through main_text_generation with subscription checks.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send to AI
|
||||
schema: Expected response schema
|
||||
user_id: Clerk user ID for subscription checking (required)
|
||||
|
||||
Returns:
|
||||
Dictionary with AI response
|
||||
"""
|
||||
if not user_id:
|
||||
raise RuntimeError("user_id is required for subscription checking")
|
||||
|
||||
# Use llm_text_gen which has subscription checks and usage tracking
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
|
||||
logger.info(f"[AIServiceManager] Calling llm_text_gen with user_id={user_id} for subscription checks")
|
||||
|
||||
# Call through main_text_generation for subscription checks
|
||||
result = llm_text_gen(
|
||||
prompt=prompt,
|
||||
json_struct=schema,
|
||||
user_id=user_id # Pass user_id for subscription checks
|
||||
)
|
||||
|
||||
# llm_text_gen returns string or dict, ensure we return dict
|
||||
if isinstance(result, str):
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"[AIServiceManager] Failed to parse JSON from llm_text_gen response")
|
||||
return {"error": "Failed to parse AI response", "raw_response": result}
|
||||
|
||||
return result if isinstance(result, dict) else {"data": result}
|
||||
|
||||
def _call_gemini_structured(self, prompt: str, schema: Dict[str, Any]):
|
||||
"""Call gemini structured JSON with flexible signature support.
|
||||
Tries extended signature first; falls back to minimal signature to avoid TypeError.
|
||||
"""
|
||||
Call gemini structured JSON directly (backward compatibility only).
|
||||
|
||||
⚠️ WARNING: This bypasses subscription checks. Use _call_llm_with_checks() instead.
|
||||
"""
|
||||
try:
|
||||
# Attempt extended signature (temperature/top_p/top_k/max_tokens/system_prompt)
|
||||
@@ -550,9 +598,25 @@ Format as structured JSON with detailed assessment and optimization guidance.
|
||||
logger.debug("Falling back to base gemini provider signature (prompt, schema)")
|
||||
return _gemini_fn(prompt, schema)
|
||||
|
||||
async def execute_structured_json_call(self, service_type: AIServiceType, prompt: str, schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Public wrapper to execute a structured JSON AI call with a provided schema."""
|
||||
return await self._execute_ai_call(service_type, prompt, schema)
|
||||
async def execute_structured_json_call(self, service_type: AIServiceType, prompt: str, schema: Dict[str, Any], user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Public wrapper to execute a structured JSON AI call with a provided schema.
|
||||
|
||||
Args:
|
||||
service_type: Type of AI service being called
|
||||
prompt: The prompt to send to AI
|
||||
schema: Expected response schema
|
||||
user_id: Clerk user ID for subscription checking (REQUIRED - no fallback)
|
||||
|
||||
Returns:
|
||||
Dictionary with AI response or error information
|
||||
|
||||
Raises:
|
||||
RuntimeError: If user_id is not provided
|
||||
"""
|
||||
if not user_id:
|
||||
raise RuntimeError("user_id is required for subscription checking. All AI calls must be authenticated.")
|
||||
return await self._execute_ai_call(service_type, prompt, schema, user_id=user_id)
|
||||
|
||||
async def generate_content_gap_analysis(self, analysis_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
|
||||
@@ -35,7 +35,7 @@ blog_writer/
|
||||
- Delegates to specialized modules for specific functionality
|
||||
|
||||
### Research Module (`research/`)
|
||||
- **`ResearchService`**: Orchestrates comprehensive research using Google Search grounding
|
||||
- **`ResearchService`**: Orchestrates comprehensive research using Exa neural search (currently Exa-only for testing)
|
||||
- **`KeywordAnalyzer`**: AI-powered keyword analysis and extraction
|
||||
- **`CompetitorAnalyzer`**: Competitor intelligence and market analysis
|
||||
- **`ContentAngleGenerator`**: Strategic content angle discovery
|
||||
|
||||
@@ -2,10 +2,12 @@
|
||||
Research module for AI Blog Writer.
|
||||
|
||||
This module handles all research-related functionality including:
|
||||
- Google Search grounding integration
|
||||
- Exa neural search integration (primary provider for testing)
|
||||
- Keyword analysis and competitor research
|
||||
- Content angle discovery
|
||||
- Research caching and optimization
|
||||
|
||||
Note: Currently Exa-only for testing. Google Search grounding code preserved for future use.
|
||||
"""
|
||||
|
||||
from .research_service import ResearchService
|
||||
|
||||
@@ -29,10 +29,15 @@ class ExaResearchProvider(BaseProvider):
|
||||
# Determine category: use exa_category if set, otherwise map from source_types
|
||||
category = config.exa_category if config.exa_category else self._map_source_type_to_category(config.source_types)
|
||||
|
||||
# Use exa_num_results if available, otherwise fallback to max_sources
|
||||
num_results = config.exa_num_results if hasattr(config, 'exa_num_results') and config.exa_num_results else min(config.max_sources, 25)
|
||||
# Cap at 100 as per Exa API limits
|
||||
num_results = min(num_results, 100)
|
||||
|
||||
# Build search kwargs - use correct Exa API format
|
||||
search_kwargs = {
|
||||
'type': config.exa_search_type or "auto",
|
||||
'num_results': min(config.max_sources, 25),
|
||||
'num_results': num_results,
|
||||
'text': {'max_characters': 1000},
|
||||
'summary': {'query': f"Key insights about {topic}"},
|
||||
'highlights': {
|
||||
@@ -49,37 +54,133 @@ class ExaResearchProvider(BaseProvider):
|
||||
if config.exa_exclude_domains:
|
||||
search_kwargs['exclude_domains'] = config.exa_exclude_domains
|
||||
|
||||
# Add date filters if configured
|
||||
if hasattr(config, 'exa_date_filter') and config.exa_date_filter:
|
||||
search_kwargs['start_published_date'] = config.exa_date_filter
|
||||
if hasattr(config, 'exa_end_published_date') and config.exa_end_published_date:
|
||||
search_kwargs['end_published_date'] = config.exa_end_published_date
|
||||
if hasattr(config, 'exa_start_crawl_date') and config.exa_start_crawl_date:
|
||||
search_kwargs['start_crawl_date'] = config.exa_start_crawl_date
|
||||
if hasattr(config, 'exa_end_crawl_date') and config.exa_end_crawl_date:
|
||||
search_kwargs['end_crawl_date'] = config.exa_end_crawl_date
|
||||
|
||||
# Add context if configured (supports boolean or object with maxCharacters)
|
||||
if hasattr(config, 'exa_context') and config.exa_context is not None:
|
||||
if config.exa_context:
|
||||
if hasattr(config, 'exa_context_max_characters') and config.exa_context_max_characters:
|
||||
search_kwargs['context'] = {'maxCharacters': config.exa_context_max_characters}
|
||||
else:
|
||||
search_kwargs['context'] = True
|
||||
# If False, don't add context parameter (default behavior)
|
||||
|
||||
# Add text filters if configured
|
||||
if hasattr(config, 'exa_include_text') and config.exa_include_text:
|
||||
search_kwargs['include_text'] = config.exa_include_text
|
||||
if hasattr(config, 'exa_exclude_text') and config.exa_exclude_text:
|
||||
search_kwargs['exclude_text'] = config.exa_exclude_text
|
||||
|
||||
logger.info(f"[Exa Research] Executing search: {query}")
|
||||
|
||||
# Execute Exa search - pass contents parameters directly, not nested
|
||||
try:
|
||||
# Build optional parameters dict
|
||||
optional_params = {}
|
||||
if category:
|
||||
optional_params['category'] = category
|
||||
if config.exa_include_domains:
|
||||
optional_params['include_domains'] = config.exa_include_domains
|
||||
if config.exa_exclude_domains:
|
||||
optional_params['exclude_domains'] = config.exa_exclude_domains
|
||||
if hasattr(config, 'exa_date_filter') and config.exa_date_filter:
|
||||
optional_params['start_published_date'] = config.exa_date_filter
|
||||
if hasattr(config, 'exa_end_published_date') and config.exa_end_published_date:
|
||||
optional_params['end_published_date'] = config.exa_end_published_date
|
||||
if hasattr(config, 'exa_start_crawl_date') and config.exa_start_crawl_date:
|
||||
optional_params['start_crawl_date'] = config.exa_start_crawl_date
|
||||
if hasattr(config, 'exa_end_crawl_date') and config.exa_end_crawl_date:
|
||||
optional_params['end_crawl_date'] = config.exa_end_crawl_date
|
||||
# Add context if configured (supports boolean or object with maxCharacters)
|
||||
if hasattr(config, 'exa_context') and config.exa_context:
|
||||
if hasattr(config, 'exa_context_max_characters') and config.exa_context_max_characters:
|
||||
optional_params['context'] = {'maxCharacters': config.exa_context_max_characters}
|
||||
else:
|
||||
optional_params['context'] = True
|
||||
|
||||
# Add text filters if configured
|
||||
if hasattr(config, 'exa_include_text') and config.exa_include_text:
|
||||
optional_params['include_text'] = config.exa_include_text
|
||||
if hasattr(config, 'exa_exclude_text') and config.exa_exclude_text:
|
||||
optional_params['exclude_text'] = config.exa_exclude_text
|
||||
|
||||
# Add additional_queries for Deep search (only works with type="deep")
|
||||
if config.exa_search_type == 'deep' and hasattr(config, 'exa_additional_queries') and config.exa_additional_queries:
|
||||
optional_params['additional_queries'] = config.exa_additional_queries
|
||||
|
||||
# Build contents parameters (text, summary, highlights)
|
||||
text_params = {}
|
||||
if hasattr(config, 'exa_text_max_characters') and config.exa_text_max_characters:
|
||||
text_params['max_characters'] = config.exa_text_max_characters
|
||||
else:
|
||||
text_params['max_characters'] = 1000 # Default
|
||||
|
||||
summary_params = {}
|
||||
if hasattr(config, 'exa_summary_query') and config.exa_summary_query:
|
||||
summary_params['query'] = config.exa_summary_query
|
||||
else:
|
||||
summary_params['query'] = f"Key insights about {topic}" # Default
|
||||
|
||||
highlights_params = {}
|
||||
if hasattr(config, 'exa_highlights') and config.exa_highlights:
|
||||
if hasattr(config, 'exa_highlights_num_sentences') and config.exa_highlights_num_sentences:
|
||||
highlights_params['num_sentences'] = config.exa_highlights_num_sentences
|
||||
else:
|
||||
highlights_params['num_sentences'] = 2 # Default
|
||||
|
||||
if hasattr(config, 'exa_highlights_per_url') and config.exa_highlights_per_url:
|
||||
highlights_params['highlights_per_url'] = config.exa_highlights_per_url
|
||||
else:
|
||||
highlights_params['highlights_per_url'] = 3 # Default
|
||||
|
||||
results = self.exa.search_and_contents(
|
||||
query,
|
||||
text={'max_characters': 1000},
|
||||
summary={'query': f"Key insights about {topic}"},
|
||||
highlights={'num_sentences': 2, 'highlights_per_url': 3},
|
||||
text=text_params,
|
||||
summary=summary_params,
|
||||
highlights=highlights_params if highlights_params else None,
|
||||
type=config.exa_search_type or "auto",
|
||||
num_results=min(config.max_sources, 25),
|
||||
**({k: v for k, v in {
|
||||
'category': category,
|
||||
'include_domains': config.exa_include_domains,
|
||||
'exclude_domains': config.exa_exclude_domains
|
||||
}.items() if v})
|
||||
num_results=num_results,
|
||||
**optional_params
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Exa Research] API call failed: {e}")
|
||||
# Try simpler call without contents if the above fails
|
||||
try:
|
||||
logger.info("[Exa Research] Retrying with simplified parameters")
|
||||
# Build minimal optional parameters for retry
|
||||
optional_params = {}
|
||||
if category:
|
||||
optional_params['category'] = category
|
||||
if config.exa_include_domains:
|
||||
optional_params['include_domains'] = config.exa_include_domains
|
||||
if config.exa_exclude_domains:
|
||||
optional_params['exclude_domains'] = config.exa_exclude_domains
|
||||
if hasattr(config, 'exa_date_filter') and config.exa_date_filter:
|
||||
optional_params['start_published_date'] = config.exa_date_filter
|
||||
if hasattr(config, 'exa_end_published_date') and config.exa_end_published_date:
|
||||
optional_params['end_published_date'] = config.exa_end_published_date
|
||||
if hasattr(config, 'exa_start_crawl_date') and config.exa_start_crawl_date:
|
||||
optional_params['start_crawl_date'] = config.exa_start_crawl_date
|
||||
if hasattr(config, 'exa_end_crawl_date') and config.exa_end_crawl_date:
|
||||
optional_params['end_crawl_date'] = config.exa_end_crawl_date
|
||||
|
||||
# Add additional_queries for Deep search (only works with type="deep")
|
||||
if config.exa_search_type == 'deep' and hasattr(config, 'exa_additional_queries') and config.exa_additional_queries:
|
||||
optional_params['additional_queries'] = config.exa_additional_queries
|
||||
|
||||
results = self.exa.search_and_contents(
|
||||
query,
|
||||
type=config.exa_search_type or "auto",
|
||||
num_results=min(config.max_sources, 25),
|
||||
**({k: v for k, v in {
|
||||
'category': category,
|
||||
'include_domains': config.exa_include_domains,
|
||||
'exclude_domains': config.exa_exclude_domains
|
||||
}.items() if v})
|
||||
num_results=num_results,
|
||||
**optional_params
|
||||
)
|
||||
except Exception as retry_error:
|
||||
logger.error(f"[Exa Research] Retry also failed: {retry_error}")
|
||||
|
||||
@@ -31,7 +31,11 @@ from .research_strategies import get_strategy_for_mode
|
||||
|
||||
|
||||
class ResearchService:
|
||||
"""Service for conducting comprehensive research using Google Search grounding."""
|
||||
"""Service for conducting comprehensive research using Exa neural search.
|
||||
|
||||
Currently supports Exa as the primary and only provider for testing and debugging.
|
||||
Google Search grounding code is preserved for future use.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.keyword_analyzer = KeywordAnalyzer()
|
||||
@@ -43,9 +47,11 @@ class ResearchService:
|
||||
async def research(self, request: BlogResearchRequest, user_id: str) -> BlogResearchResponse:
|
||||
"""
|
||||
Stage 1: Research & Strategy (AI Orchestration)
|
||||
Uses ONLY Gemini's native Google Search grounding - ONE API call for everything.
|
||||
Uses Exa neural search as the primary research provider.
|
||||
Follows LinkedIn service pattern for efficiency and cost optimization.
|
||||
Includes intelligent caching for exact keyword matches.
|
||||
|
||||
Note: Currently Exa-only for testing. Failures will raise errors instead of falling back.
|
||||
"""
|
||||
try:
|
||||
from services.cache.research_cache import research_cache
|
||||
@@ -88,7 +94,7 @@ class ResearchService:
|
||||
|
||||
# Determine research mode and get appropriate strategy
|
||||
research_mode = request.research_mode or ResearchMode.BASIC
|
||||
config = request.config or ResearchConfig(mode=research_mode, provider=ResearchProvider.GOOGLE)
|
||||
config = request.config or ResearchConfig(mode=research_mode, provider=ResearchProvider.EXA)
|
||||
strategy = get_strategy_for_mode(research_mode)
|
||||
|
||||
logger.info(f"Research: mode={research_mode.value}, provider={config.provider.value}")
|
||||
@@ -96,7 +102,11 @@ class ResearchService:
|
||||
# Build research prompt based on strategy
|
||||
research_prompt = strategy.build_research_prompt(topic, industry, target_audience, config)
|
||||
|
||||
# Route to appropriate provider
|
||||
# Currently Exa-only for testing - fail if other providers are requested
|
||||
if config.provider != ResearchProvider.EXA:
|
||||
raise ValueError(f"Only Exa provider is currently supported for testing. Requested provider: {config.provider.value}")
|
||||
|
||||
# Route to Exa provider
|
||||
if config.provider == ResearchProvider.EXA:
|
||||
# Exa research workflow
|
||||
from .exa_provider import ExaResearchProvider
|
||||
@@ -145,13 +155,9 @@ class ResearchService:
|
||||
grounding_metadata = None # Exa doesn't provide grounding metadata
|
||||
|
||||
except RuntimeError as e:
|
||||
if "EXA_API_KEY not configured" in str(e):
|
||||
logger.warning("Exa not configured, falling back to Google")
|
||||
config.provider = ResearchProvider.GOOGLE
|
||||
# Continue to Google flow below
|
||||
raw_result = None
|
||||
else:
|
||||
raise
|
||||
# Fail fast - no fallback for testing/debugging
|
||||
logger.error(f"Exa research failed: {e}")
|
||||
raise RuntimeError(f"Exa research failed: {e}. Please ensure EXA_API_KEY is configured.") from e
|
||||
|
||||
elif config.provider == ResearchProvider.TAVILY:
|
||||
# Tavily research workflow
|
||||
@@ -231,41 +237,13 @@ class ResearchService:
|
||||
grounding_metadata = None # Tavily doesn't provide grounding metadata
|
||||
|
||||
except RuntimeError as e:
|
||||
if "TAVILY_API_KEY not configured" in str(e):
|
||||
logger.warning("Tavily not configured, falling back to Google")
|
||||
config.provider = ResearchProvider.GOOGLE
|
||||
# Continue to Google flow below
|
||||
raw_result = None
|
||||
else:
|
||||
raise
|
||||
|
||||
if config.provider not in [ResearchProvider.EXA, ResearchProvider.TAVILY]:
|
||||
# Google research (existing flow) or fallback from Exa
|
||||
from .google_provider import GoogleResearchProvider
|
||||
import time
|
||||
|
||||
api_start_time = time.time()
|
||||
google_provider = GoogleResearchProvider()
|
||||
gemini_result = await google_provider.search(
|
||||
research_prompt, topic, industry, target_audience, config, user_id
|
||||
)
|
||||
api_duration_ms = (time.time() - api_start_time) * 1000
|
||||
|
||||
# Log API call performance
|
||||
blog_writer_logger.log_api_call(
|
||||
"gemini_grounded",
|
||||
"generate_grounded_content",
|
||||
api_duration_ms,
|
||||
token_usage=gemini_result.get("token_usage", {}),
|
||||
content_length=len(gemini_result.get("content", ""))
|
||||
)
|
||||
|
||||
# Extract sources and content
|
||||
sources = self._extract_sources_from_grounding(gemini_result)
|
||||
content = gemini_result.get("content", "")
|
||||
search_widget = gemini_result.get("search_widget", "") or ""
|
||||
search_queries = gemini_result.get("search_queries", []) or []
|
||||
grounding_metadata = self._extract_grounding_metadata(gemini_result)
|
||||
# Fail fast - no fallback for testing/debugging
|
||||
logger.error(f"Tavily research failed: {e}")
|
||||
raise RuntimeError(f"Tavily research failed: {e}. Please ensure TAVILY_API_KEY is configured.") from e
|
||||
|
||||
# Validate that we have content and sources before proceeding
|
||||
if 'content' not in locals() or 'sources' not in locals():
|
||||
raise RuntimeError(f"{config.provider.value} research did not return content or sources. Research failed.")
|
||||
|
||||
# Continue with common analysis (same for both providers)
|
||||
keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords, user_id=user_id)
|
||||
@@ -434,7 +412,7 @@ class ResearchService:
|
||||
|
||||
# Determine research mode and get appropriate strategy
|
||||
research_mode = request.research_mode or ResearchMode.BASIC
|
||||
config = request.config or ResearchConfig(mode=research_mode, provider=ResearchProvider.GOOGLE)
|
||||
config = request.config or ResearchConfig(mode=research_mode, provider=ResearchProvider.EXA)
|
||||
strategy = get_strategy_for_mode(research_mode)
|
||||
|
||||
logger.info(f"Research: mode={research_mode.value}, provider={config.provider.value}")
|
||||
@@ -442,7 +420,11 @@ class ResearchService:
|
||||
# Build research prompt based on strategy
|
||||
research_prompt = strategy.build_research_prompt(topic, industry, target_audience, config)
|
||||
|
||||
# Route to appropriate provider
|
||||
# Currently Exa-only for testing - fail if other providers are requested
|
||||
if config.provider != ResearchProvider.EXA:
|
||||
raise ValueError(f"Only Exa provider is currently supported for testing. Requested provider: {config.provider.value}")
|
||||
|
||||
# Route to Exa provider
|
||||
if config.provider == ResearchProvider.EXA:
|
||||
# Exa research workflow
|
||||
from .exa_provider import ExaResearchProvider
|
||||
@@ -495,13 +477,10 @@ class ResearchService:
|
||||
grounding_metadata = None # Exa doesn't provide grounding metadata
|
||||
|
||||
except RuntimeError as e:
|
||||
if "EXA_API_KEY not configured" in str(e):
|
||||
logger.warning("Exa not configured, falling back to Google")
|
||||
await task_manager.update_progress(task_id, "⚠️ Exa not configured, falling back to Google Search")
|
||||
config.provider = ResearchProvider.GOOGLE
|
||||
# Continue to Google flow below
|
||||
else:
|
||||
raise
|
||||
# Fail fast - no fallback for testing/debugging
|
||||
logger.error(f"Exa research failed: {e}")
|
||||
await task_manager.update_progress(task_id, f"❌ Exa research failed: {str(e)}")
|
||||
raise RuntimeError(f"Exa research failed: {e}. Please ensure EXA_API_KEY is configured.") from e
|
||||
|
||||
elif config.provider == ResearchProvider.TAVILY:
|
||||
# Tavily research workflow
|
||||
@@ -581,43 +560,18 @@ class ResearchService:
|
||||
grounding_metadata = None # Tavily doesn't provide grounding metadata
|
||||
|
||||
except RuntimeError as e:
|
||||
if "TAVILY_API_KEY not configured" in str(e):
|
||||
logger.warning("Tavily not configured, falling back to Google")
|
||||
await task_manager.update_progress(task_id, "⚠️ Tavily not configured, falling back to Google Search")
|
||||
config.provider = ResearchProvider.GOOGLE
|
||||
# Continue to Google flow below
|
||||
else:
|
||||
raise
|
||||
|
||||
if config.provider not in [ResearchProvider.EXA, ResearchProvider.TAVILY]:
|
||||
# Google research (existing flow)
|
||||
from .google_provider import GoogleResearchProvider
|
||||
|
||||
await task_manager.update_progress(task_id, "🌐 Connecting to Google Search grounding...")
|
||||
google_provider = GoogleResearchProvider()
|
||||
|
||||
await task_manager.update_progress(task_id, "🤖 Making AI request to Gemini with Google Search grounding...")
|
||||
try:
|
||||
gemini_result = await google_provider.search(
|
||||
research_prompt, topic, industry, target_audience, config, user_id
|
||||
)
|
||||
except HTTPException as http_error:
|
||||
logger.error(f"Subscription limit exceeded for Google research: {http_error.detail}")
|
||||
await task_manager.update_progress(task_id, f"❌ Subscription limit exceeded: {http_error.detail.get('message', str(http_error.detail)) if isinstance(http_error.detail, dict) else str(http_error.detail)}")
|
||||
raise
|
||||
|
||||
await task_manager.update_progress(task_id, "📊 Processing research results and extracting insights...")
|
||||
# Extract sources and content
|
||||
# Handle None result case
|
||||
if gemini_result is None:
|
||||
logger.error("gemini_result is None after search - this should not happen if HTTPException was raised")
|
||||
raise ValueError("Research result is None - search operation failed unexpectedly")
|
||||
|
||||
sources = self._extract_sources_from_grounding(gemini_result)
|
||||
content = gemini_result.get("content", "") if isinstance(gemini_result, dict) else ""
|
||||
search_widget = gemini_result.get("search_widget", "") or "" if isinstance(gemini_result, dict) else ""
|
||||
search_queries = gemini_result.get("search_queries", []) or [] if isinstance(gemini_result, dict) else []
|
||||
grounding_metadata = self._extract_grounding_metadata(gemini_result)
|
||||
# Fail fast - no fallback for testing/debugging
|
||||
logger.error(f"Tavily research failed: {e}")
|
||||
await task_manager.update_progress(task_id, f"❌ Tavily research failed: {str(e)}")
|
||||
raise RuntimeError(f"Tavily research failed: {e}. Please ensure TAVILY_API_KEY is configured.") from e
|
||||
|
||||
# Validate that we have content and sources before proceeding
|
||||
if config.provider == ResearchProvider.EXA and ('content' not in locals() or 'sources' not in locals()):
|
||||
await task_manager.update_progress(task_id, "❌ Exa research did not return content or sources")
|
||||
raise RuntimeError("Exa research did not return content or sources. Research failed.")
|
||||
elif config.provider == ResearchProvider.TAVILY and ('content' not in locals() or 'sources' not in locals()):
|
||||
await task_manager.update_progress(task_id, "❌ Tavily research did not return content or sources")
|
||||
raise RuntimeError("Tavily research did not return content or sources. Research failed.")
|
||||
|
||||
# Continue with common analysis (same for both providers)
|
||||
await task_manager.update_progress(task_id, "🔍 Analyzing keywords and content angles...")
|
||||
|
||||
17
backend/services/campaign_creator/__init__.py
Normal file
17
backend/services/campaign_creator/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Campaign Creator service package."""
|
||||
|
||||
from .orchestrator import CampaignOrchestrator, CampaignBlueprint, CampaignAssetNode
|
||||
from .campaign_storage import CampaignStorageService
|
||||
from .channel_pack import ChannelPackService
|
||||
from .asset_audit import AssetAuditService
|
||||
from .prompt_builder import CampaignPromptBuilder
|
||||
|
||||
__all__ = [
|
||||
"CampaignOrchestrator",
|
||||
"CampaignBlueprint",
|
||||
"CampaignAssetNode",
|
||||
"CampaignStorageService",
|
||||
"ChannelPackService",
|
||||
"AssetAuditService",
|
||||
"CampaignPromptBuilder",
|
||||
]
|
||||
204
backend/services/campaign_creator/asset_audit.py
Normal file
204
backend/services/campaign_creator/asset_audit.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
Asset Audit Service
|
||||
Analyzes uploaded assets and recommends enhancement operations.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from loguru import logger
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class AssetAuditService:
|
||||
"""Service to audit assets and recommend enhancements."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Asset Audit Service."""
|
||||
self.logger = logger
|
||||
logger.info("[Asset Audit] Service initialized")
|
||||
|
||||
def audit_asset(
|
||||
self,
|
||||
image_base64: str,
|
||||
asset_metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Audit an uploaded asset and recommend enhancement operations.
|
||||
|
||||
Args:
|
||||
image_base64: Base64 encoded image
|
||||
asset_metadata: Optional metadata about the asset
|
||||
|
||||
Returns:
|
||||
Audit results with recommendations
|
||||
"""
|
||||
try:
|
||||
# Decode image
|
||||
image_bytes = self._decode_base64(image_base64)
|
||||
if not image_bytes:
|
||||
raise ValueError("Invalid image data")
|
||||
|
||||
# Analyze image
|
||||
image = Image.open(BytesIO(image_bytes))
|
||||
width, height = image.size
|
||||
format_type = image.format or "PNG"
|
||||
mode = image.mode
|
||||
|
||||
# Basic quality checks
|
||||
quality_score = self._assess_quality(image, width, height)
|
||||
|
||||
# Generate recommendations
|
||||
recommendations = []
|
||||
|
||||
# Resolution recommendations
|
||||
if width < 1080 or height < 1080:
|
||||
recommendations.append({
|
||||
"operation": "upscale",
|
||||
"priority": "high",
|
||||
"reason": f"Image resolution ({width}x{height}) is below recommended 1080p for social media",
|
||||
"suggested_mode": "fast" if width < 512 else "conservative",
|
||||
})
|
||||
|
||||
# Background recommendations
|
||||
if mode == "RGBA" and self._has_transparency(image):
|
||||
recommendations.append({
|
||||
"operation": "remove_background",
|
||||
"priority": "low",
|
||||
"reason": "Image already has transparency, background removal may not be needed",
|
||||
})
|
||||
else:
|
||||
recommendations.append({
|
||||
"operation": "remove_background",
|
||||
"priority": "medium",
|
||||
"reason": "Background removal can create versatile product images",
|
||||
})
|
||||
|
||||
# Enhancement recommendations based on quality
|
||||
if quality_score < 0.7:
|
||||
recommendations.append({
|
||||
"operation": "enhance",
|
||||
"priority": "high",
|
||||
"reason": f"Image quality score ({quality_score:.2f}) suggests enhancement needed",
|
||||
"suggested_operations": ["upscale", "general_edit"],
|
||||
})
|
||||
|
||||
# Format recommendations
|
||||
if format_type not in ["PNG", "JPEG"]:
|
||||
recommendations.append({
|
||||
"operation": "convert",
|
||||
"priority": "low",
|
||||
"reason": f"Format {format_type} may not be optimal for web/social media",
|
||||
"suggested_format": "PNG" if mode == "RGBA" else "JPEG",
|
||||
})
|
||||
|
||||
audit_result = {
|
||||
"asset_info": {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"format": format_type,
|
||||
"mode": mode,
|
||||
"quality_score": quality_score,
|
||||
},
|
||||
"recommendations": recommendations,
|
||||
"status": "usable" if quality_score > 0.6 else "needs_enhancement",
|
||||
}
|
||||
|
||||
logger.info(f"[Asset Audit] Audited asset: {width}x{height}, quality: {quality_score:.2f}")
|
||||
return audit_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Asset Audit] Error auditing asset: {str(e)}")
|
||||
return {
|
||||
"asset_info": {},
|
||||
"recommendations": [],
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
def _decode_base64(self, image_base64: str) -> Optional[bytes]:
|
||||
"""Decode base64 image data."""
|
||||
try:
|
||||
if image_base64.startswith("data:"):
|
||||
_, b64data = image_base64.split(",", 1)
|
||||
else:
|
||||
b64data = image_base64
|
||||
return base64.b64decode(b64data)
|
||||
except Exception as e:
|
||||
logger.error(f"[Asset Audit] Error decoding base64: {str(e)}")
|
||||
return None
|
||||
|
||||
def _has_transparency(self, image: Image.Image) -> bool:
|
||||
"""Check if image has transparency."""
|
||||
if image.mode in ("RGBA", "LA"):
|
||||
alpha = image.split()[-1]
|
||||
return any(pixel < 255 for pixel in alpha.getdata())
|
||||
return False
|
||||
|
||||
def _assess_quality(self, image: Image.Image, width: int, height: int) -> float:
|
||||
"""
|
||||
Assess image quality score (0.0 to 1.0).
|
||||
|
||||
Simple heuristic based on resolution and format.
|
||||
"""
|
||||
score = 0.5 # Base score
|
||||
|
||||
# Resolution scoring
|
||||
min_dimension = min(width, height)
|
||||
if min_dimension >= 1080:
|
||||
score += 0.3
|
||||
elif min_dimension >= 512:
|
||||
score += 0.2
|
||||
elif min_dimension >= 256:
|
||||
score += 0.1
|
||||
|
||||
# Format scoring
|
||||
if image.format in ["PNG", "JPEG"]:
|
||||
score += 0.1
|
||||
|
||||
# Mode scoring
|
||||
if image.mode in ["RGB", "RGBA"]:
|
||||
score += 0.1
|
||||
|
||||
return min(score, 1.0)
|
||||
|
||||
def batch_audit_assets(
|
||||
self,
|
||||
assets: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Audit multiple assets in batch.
|
||||
|
||||
Args:
|
||||
assets: List of asset dictionaries with 'image_base64' and optional 'metadata'
|
||||
|
||||
Returns:
|
||||
Batch audit results
|
||||
"""
|
||||
results = []
|
||||
for asset in assets:
|
||||
audit_result = self.audit_asset(
|
||||
asset.get('image_base64'),
|
||||
asset.get('metadata')
|
||||
)
|
||||
results.append({
|
||||
"asset_id": asset.get('id'),
|
||||
"audit": audit_result,
|
||||
})
|
||||
|
||||
# Summary statistics
|
||||
total_assets = len(results)
|
||||
usable_count = sum(1 for r in results if r["audit"]["status"] == "usable")
|
||||
needs_enhancement_count = sum(
|
||||
1 for r in results if r["audit"]["status"] == "needs_enhancement"
|
||||
)
|
||||
|
||||
return {
|
||||
"results": results,
|
||||
"summary": {
|
||||
"total_assets": total_assets,
|
||||
"usable": usable_count,
|
||||
"needs_enhancement": needs_enhancement_count,
|
||||
"error": total_assets - usable_count - needs_enhancement_count,
|
||||
},
|
||||
}
|
||||
295
backend/services/campaign_creator/campaign_storage.py
Normal file
295
backend/services/campaign_creator/campaign_storage.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""
|
||||
Campaign Storage Service
|
||||
Handles database persistence for campaigns, proposals, and assets.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc
|
||||
|
||||
from models.product_marketing_models import Campaign, CampaignProposal, CampaignAsset, CampaignStatus
|
||||
from services.database import SessionLocal
|
||||
|
||||
|
||||
class CampaignStorageService:
|
||||
"""Service for storing and retrieving campaigns from database."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Campaign Storage Service."""
|
||||
self.logger = logger
|
||||
logger.info("[Campaign Storage] Service initialized")
|
||||
|
||||
def save_campaign(
|
||||
self,
|
||||
user_id: str,
|
||||
campaign_data: Dict[str, Any]
|
||||
) -> Campaign:
|
||||
"""
|
||||
Save campaign blueprint to database.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
campaign_data: Campaign blueprint data
|
||||
|
||||
Returns:
|
||||
Saved Campaign object
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
campaign_id = campaign_data.get('campaign_id')
|
||||
|
||||
# Check if campaign exists
|
||||
existing = db.query(Campaign).filter(
|
||||
Campaign.campaign_id == campaign_id,
|
||||
Campaign.user_id == user_id
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
# Update existing campaign
|
||||
existing.campaign_name = campaign_data.get('campaign_name', existing.campaign_name)
|
||||
existing.goal = campaign_data.get('goal', existing.goal)
|
||||
existing.kpi = campaign_data.get('kpi', existing.kpi)
|
||||
existing.status = campaign_data.get('status', existing.status)
|
||||
existing.phases = campaign_data.get('phases', existing.phases)
|
||||
existing.channels = campaign_data.get('channels', existing.channels)
|
||||
existing.asset_nodes = campaign_data.get('asset_nodes', existing.asset_nodes)
|
||||
existing.product_context = campaign_data.get('product_context', existing.product_context)
|
||||
db.commit()
|
||||
db.refresh(existing)
|
||||
logger.info(f"[Campaign Storage] Updated campaign {campaign_id}")
|
||||
return existing
|
||||
else:
|
||||
# Create new campaign
|
||||
campaign = Campaign(
|
||||
campaign_id=campaign_id,
|
||||
user_id=user_id,
|
||||
campaign_name=campaign_data.get('campaign_name'),
|
||||
goal=campaign_data.get('goal'),
|
||||
kpi=campaign_data.get('kpi'),
|
||||
status=campaign_data.get('status', 'draft'),
|
||||
phases=campaign_data.get('phases'),
|
||||
channels=campaign_data.get('channels', []),
|
||||
asset_nodes=campaign_data.get('asset_nodes', []),
|
||||
product_context=campaign_data.get('product_context'),
|
||||
)
|
||||
db.add(campaign)
|
||||
db.commit()
|
||||
db.refresh(campaign)
|
||||
logger.info(f"[Campaign Storage] Saved new campaign {campaign_id}")
|
||||
return campaign
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"[Campaign Storage] Error saving campaign: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def get_campaign(
|
||||
self,
|
||||
user_id: str,
|
||||
campaign_id: str
|
||||
) -> Optional[Campaign]:
|
||||
"""Get campaign by ID."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
campaign = db.query(Campaign).filter(
|
||||
Campaign.campaign_id == campaign_id,
|
||||
Campaign.user_id == user_id
|
||||
).first()
|
||||
return campaign
|
||||
except Exception as e:
|
||||
logger.error(f"[Campaign Storage] Error getting campaign: {str(e)}")
|
||||
return None
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def list_campaigns(
|
||||
self,
|
||||
user_id: str,
|
||||
status: Optional[str] = None,
|
||||
limit: int = 50
|
||||
) -> List[Campaign]:
|
||||
"""List campaigns for user."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
query = db.query(Campaign).filter(Campaign.user_id == user_id)
|
||||
|
||||
if status:
|
||||
query = query.filter(Campaign.status == status)
|
||||
|
||||
campaigns = query.order_by(desc(Campaign.created_at)).limit(limit).all()
|
||||
return campaigns
|
||||
except Exception as e:
|
||||
logger.error(f"[Campaign Storage] Error listing campaigns: {str(e)}")
|
||||
return []
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def save_proposals(
|
||||
self,
|
||||
user_id: str,
|
||||
campaign_id: str,
|
||||
proposals: Dict[str, Any]
|
||||
) -> List[CampaignProposal]:
|
||||
"""Save asset proposals for a campaign."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Delete existing proposals for this campaign
|
||||
db.query(CampaignProposal).filter(
|
||||
CampaignProposal.campaign_id == campaign_id,
|
||||
CampaignProposal.user_id == user_id
|
||||
).delete()
|
||||
|
||||
# Create new proposals
|
||||
saved_proposals = []
|
||||
for asset_id, proposal_data in proposals.get('proposals', {}).items():
|
||||
proposal = CampaignProposal(
|
||||
campaign_id=campaign_id,
|
||||
user_id=user_id,
|
||||
asset_node_id=asset_id,
|
||||
asset_type=proposal_data.get('asset_type'),
|
||||
channel=proposal_data.get('channel'),
|
||||
proposed_prompt=proposal_data.get('proposed_prompt'),
|
||||
recommended_template=proposal_data.get('recommended_template'),
|
||||
recommended_provider=proposal_data.get('recommended_provider'),
|
||||
recommended_model=proposal_data.get('recommended_model'),
|
||||
cost_estimate=proposal_data.get('cost_estimate', 0.0),
|
||||
concept_summary=proposal_data.get('concept_summary'),
|
||||
status='proposed',
|
||||
)
|
||||
db.add(proposal)
|
||||
saved_proposals.append(proposal)
|
||||
|
||||
db.commit()
|
||||
for proposal in saved_proposals:
|
||||
db.refresh(proposal)
|
||||
|
||||
logger.info(f"[Campaign Storage] Saved {len(saved_proposals)} proposals for campaign {campaign_id}")
|
||||
return saved_proposals
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"[Campaign Storage] Error saving proposals: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def get_proposals(
|
||||
self,
|
||||
user_id: str,
|
||||
campaign_id: str
|
||||
) -> List[CampaignProposal]:
|
||||
"""Get proposals for a campaign."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
proposals = db.query(CampaignProposal).filter(
|
||||
CampaignProposal.campaign_id == campaign_id,
|
||||
CampaignProposal.user_id == user_id
|
||||
).all()
|
||||
return proposals
|
||||
except Exception as e:
|
||||
logger.error(f"[Campaign Storage] Error getting proposals: {str(e)}")
|
||||
return []
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def update_campaign_status(
|
||||
self,
|
||||
user_id: str,
|
||||
campaign_id: str,
|
||||
status: str
|
||||
) -> bool:
|
||||
"""Update campaign status."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
campaign = db.query(Campaign).filter(
|
||||
Campaign.campaign_id == campaign_id,
|
||||
Campaign.user_id == user_id
|
||||
).first()
|
||||
|
||||
if campaign:
|
||||
campaign.status = status
|
||||
db.commit()
|
||||
logger.info(f"[Campaign Storage] Updated campaign {campaign_id} status to {status}")
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"[Campaign Storage] Error updating status: {str(e)}")
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def update_asset_status(
|
||||
self,
|
||||
user_id: str,
|
||||
campaign_id: str,
|
||||
asset_id: str,
|
||||
status: str,
|
||||
generated_asset_id: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Update status of a campaign asset and its proposal.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
campaign_id: Campaign ID
|
||||
asset_id: Asset node ID
|
||||
status: New status (generating, ready, approved, rejected)
|
||||
generated_asset_id: Optional Asset Library ID
|
||||
|
||||
Returns:
|
||||
True if updated successfully
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Update proposal status
|
||||
proposal = db.query(CampaignProposal).filter(
|
||||
CampaignProposal.campaign_id == campaign_id,
|
||||
CampaignProposal.user_id == user_id,
|
||||
CampaignProposal.asset_node_id == asset_id
|
||||
).first()
|
||||
|
||||
if proposal:
|
||||
proposal.status = status
|
||||
if generated_asset_id:
|
||||
proposal.generated_asset_id = generated_asset_id
|
||||
db.commit()
|
||||
logger.info(f"[Campaign Storage] Updated proposal {asset_id} status to {status}")
|
||||
|
||||
# Update or create campaign asset
|
||||
campaign_asset = db.query(CampaignAsset).filter(
|
||||
CampaignAsset.campaign_id == campaign_id,
|
||||
CampaignAsset.user_id == user_id,
|
||||
CampaignAsset.asset_node_id == asset_id
|
||||
).first()
|
||||
|
||||
if campaign_asset:
|
||||
campaign_asset.status = status
|
||||
if generated_asset_id:
|
||||
campaign_asset.generated_asset_id = generated_asset_id
|
||||
db.commit()
|
||||
logger.info(f"[Campaign Storage] Updated campaign asset {asset_id} status to {status}")
|
||||
else:
|
||||
# Create new campaign asset if it doesn't exist
|
||||
if proposal:
|
||||
campaign_asset = CampaignAsset(
|
||||
campaign_id=campaign_id,
|
||||
user_id=user_id,
|
||||
asset_node_id=asset_id,
|
||||
asset_type=proposal.asset_type,
|
||||
channel=proposal.channel,
|
||||
status=status,
|
||||
generated_asset_id=generated_asset_id,
|
||||
)
|
||||
db.add(campaign_asset)
|
||||
db.commit()
|
||||
logger.info(f"[Campaign Storage] Created campaign asset {asset_id}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"[Campaign Storage] Error updating asset status: {str(e)}")
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
179
backend/services/campaign_creator/channel_pack.py
Normal file
179
backend/services/campaign_creator/channel_pack.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
Channel Pack Service
|
||||
Maps channels to templates, copy frameworks, and platform-specific optimizations.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from loguru import logger
|
||||
|
||||
from services.image_studio.templates import Platform, TemplateManager
|
||||
from services.image_studio.social_optimizer_service import SocialOptimizerService
|
||||
|
||||
|
||||
class ChannelPackService:
|
||||
"""Service to build channel-specific asset packs."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Channel Pack Service."""
|
||||
self.template_manager = TemplateManager()
|
||||
self.social_optimizer = SocialOptimizerService()
|
||||
self.logger = logger
|
||||
logger.info("[Channel Pack] Service initialized")
|
||||
|
||||
def get_channel_pack(
|
||||
self,
|
||||
channel: str,
|
||||
asset_type: str = "social_post"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get channel-specific pack configuration.
|
||||
|
||||
Args:
|
||||
channel: Target channel (instagram, linkedin, tiktok, facebook, twitter, pinterest, youtube)
|
||||
asset_type: Type of asset (social_post, story, reel, cover, etc.)
|
||||
|
||||
Returns:
|
||||
Channel pack configuration with templates, dimensions, copy frameworks
|
||||
"""
|
||||
try:
|
||||
# Map channel string to Platform enum
|
||||
platform_map = {
|
||||
'instagram': Platform.INSTAGRAM,
|
||||
'linkedin': Platform.LINKEDIN,
|
||||
'tiktok': Platform.TIKTOK,
|
||||
'facebook': Platform.FACEBOOK,
|
||||
'twitter': Platform.TWITTER,
|
||||
'pinterest': Platform.PINTEREST,
|
||||
'youtube': Platform.YOUTUBE,
|
||||
}
|
||||
|
||||
platform = platform_map.get(channel.lower())
|
||||
if not platform:
|
||||
raise ValueError(f"Unsupported channel: {channel}")
|
||||
|
||||
# Get templates for this platform
|
||||
templates = self.template_manager.get_platform_templates().get(platform, [])
|
||||
|
||||
# Get platform formats
|
||||
formats = self.social_optimizer.get_platform_formats(platform)
|
||||
|
||||
# Build channel pack
|
||||
pack = {
|
||||
"channel": channel,
|
||||
"platform": platform.value,
|
||||
"asset_type": asset_type,
|
||||
"templates": [
|
||||
{
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"dimensions": f"{t.aspect_ratio.width}x{t.aspect_ratio.height}",
|
||||
"aspect_ratio": t.aspect_ratio.ratio,
|
||||
"recommended_provider": t.recommended_provider,
|
||||
"quality": t.quality,
|
||||
}
|
||||
for t in templates
|
||||
],
|
||||
"formats": formats,
|
||||
"copy_framework": self._get_copy_framework(channel, asset_type),
|
||||
"optimization_tips": self._get_optimization_tips(channel),
|
||||
}
|
||||
|
||||
logger.info(f"[Channel Pack] Built pack for {channel} ({asset_type})")
|
||||
return pack
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Channel Pack] Error building pack: {str(e)}")
|
||||
return {
|
||||
"channel": channel,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
def _get_copy_framework(
|
||||
self,
|
||||
channel: str,
|
||||
asset_type: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Get copy framework for channel and asset type."""
|
||||
frameworks = {
|
||||
"instagram": {
|
||||
"social_post": {
|
||||
"caption_length": "125-150 words optimal",
|
||||
"hashtags": "5-10 relevant hashtags",
|
||||
"cta": "Clear call-to-action in first line",
|
||||
"emoji": "Use 1-3 emojis strategically",
|
||||
},
|
||||
"story": {
|
||||
"text_overlay": "Keep text minimal, readable at small size",
|
||||
"cta": "Swipe-up or link sticker",
|
||||
},
|
||||
},
|
||||
"linkedin": {
|
||||
"social_post": {
|
||||
"length": "150-300 words for maximum engagement",
|
||||
"hashtags": "3-5 professional hashtags",
|
||||
"tone": "Professional, thought-leadership focused",
|
||||
"cta": "Engage with question or call-to-action",
|
||||
},
|
||||
},
|
||||
"tiktok": {
|
||||
"video": {
|
||||
"hook": "Strong hook in first 3 seconds",
|
||||
"caption": "Short, engaging, use trending hashtags",
|
||||
"hashtags": "3-5 trending hashtags",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return frameworks.get(channel, {}).get(asset_type, {})
|
||||
|
||||
def _get_optimization_tips(self, channel: str) -> List[str]:
|
||||
"""Get optimization tips for channel."""
|
||||
tips = {
|
||||
"instagram": [
|
||||
"Use square (1:1) or portrait (4:5) for feed posts",
|
||||
"Include text overlay safe zones (15% top/bottom, 10% left/right)",
|
||||
"Optimize for mobile viewing",
|
||||
],
|
||||
"linkedin": [
|
||||
"Use landscape (1.91:1) for feed posts",
|
||||
"Professional photography style",
|
||||
"Include clear value proposition",
|
||||
],
|
||||
"tiktok": [
|
||||
"Vertical format (9:16) required",
|
||||
"Eye-catching first frame",
|
||||
"Fast-paced, engaging content",
|
||||
],
|
||||
}
|
||||
|
||||
return tips.get(channel, [])
|
||||
|
||||
def build_multi_channel_pack(
|
||||
self,
|
||||
channels: List[str],
|
||||
source_image_base64: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build optimized asset pack for multiple channels from single source.
|
||||
|
||||
Args:
|
||||
channels: List of target channels
|
||||
source_image_base64: Source image to optimize
|
||||
|
||||
Returns:
|
||||
Multi-channel pack with optimized variants
|
||||
"""
|
||||
pack_results = []
|
||||
|
||||
for channel in channels:
|
||||
pack = self.get_channel_pack(channel)
|
||||
pack_results.append({
|
||||
"channel": channel,
|
||||
"pack": pack,
|
||||
})
|
||||
|
||||
return {
|
||||
"source_image": "provided",
|
||||
"channels": pack_results,
|
||||
"total_variants": len(channels),
|
||||
}
|
||||
653
backend/services/campaign_creator/orchestrator.py
Normal file
653
backend/services/campaign_creator/orchestrator.py
Normal file
@@ -0,0 +1,653 @@
|
||||
"""
|
||||
Campaign Creator Orchestrator
|
||||
Main service that orchestrates campaign workflows and asset generation.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass
|
||||
from loguru import logger
|
||||
|
||||
from services.image_studio import ImageStudioManager, CreateStudioRequest
|
||||
from .prompt_builder import CampaignPromptBuilder
|
||||
from services.product_marketing.brand_dna_sync import BrandDNASyncService
|
||||
from .asset_audit import AssetAuditService
|
||||
from .channel_pack import ChannelPackService
|
||||
from services.database import SessionLocal
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_generation_operations
|
||||
|
||||
|
||||
@dataclass
|
||||
class CampaignAssetNode:
|
||||
"""Represents an asset node in the campaign graph."""
|
||||
asset_id: str
|
||||
asset_type: str # image, video, text, audio
|
||||
channel: str
|
||||
status: str # draft, generating, ready, approved
|
||||
prompt: Optional[str] = None
|
||||
template_id: Optional[str] = None
|
||||
provider: Optional[str] = None
|
||||
cost_estimate: Optional[float] = None
|
||||
generated_asset_id: Optional[int] = None # Asset Library ID
|
||||
|
||||
|
||||
@dataclass
|
||||
class CampaignBlueprint:
|
||||
"""Campaign blueprint with phases and asset nodes."""
|
||||
campaign_id: str
|
||||
campaign_name: str
|
||||
goal: str
|
||||
kpi: Optional[str] = None
|
||||
phases: List[Dict[str, Any]] = None # teaser, launch, nurture
|
||||
asset_nodes: List[CampaignAssetNode] = None
|
||||
channels: List[str] = None
|
||||
status: str = "draft" # draft, generating, ready, published
|
||||
|
||||
|
||||
class CampaignOrchestrator:
|
||||
"""Main orchestrator for Campaign Creator."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Campaign Orchestrator."""
|
||||
self.image_studio = ImageStudioManager()
|
||||
self.prompt_builder = CampaignPromptBuilder()
|
||||
self.brand_dna_sync = BrandDNASyncService()
|
||||
self.asset_audit = AssetAuditService()
|
||||
self.channel_pack = ChannelPackService()
|
||||
self.logger = logger
|
||||
logger.info("[Campaign Orchestrator] Initialized")
|
||||
|
||||
def create_campaign_blueprint(
|
||||
self,
|
||||
user_id: str,
|
||||
campaign_data: Dict[str, Any]
|
||||
) -> CampaignBlueprint:
|
||||
"""
|
||||
Create campaign blueprint from user input and onboarding data.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
campaign_data: Campaign information (name, goal, channels, etc.)
|
||||
|
||||
Returns:
|
||||
Campaign blueprint with asset nodes
|
||||
"""
|
||||
try:
|
||||
import time
|
||||
campaign_id = campaign_data.get('campaign_id') or f"campaign_{user_id}_{int(time.time())}"
|
||||
campaign_name = campaign_data.get('campaign_name', 'New Campaign')
|
||||
goal = campaign_data.get('goal', 'product_launch')
|
||||
channels = campaign_data.get('channels', [])
|
||||
|
||||
# Get brand DNA for personalization
|
||||
brand_dna = self.brand_dna_sync.get_brand_dna_tokens(user_id)
|
||||
|
||||
# Build campaign phases
|
||||
phases = self._build_campaign_phases(goal, channels)
|
||||
|
||||
# Generate asset nodes for each phase and channel
|
||||
asset_nodes = []
|
||||
for phase in phases:
|
||||
phase_name = phase.get('name')
|
||||
for channel in channels:
|
||||
# Determine required assets for this phase + channel
|
||||
required_assets = self._get_required_assets(phase_name, channel)
|
||||
|
||||
for asset_type in required_assets:
|
||||
asset_node = CampaignAssetNode(
|
||||
asset_id=f"{campaign_id}_{phase_name}_{channel}_{asset_type}",
|
||||
asset_type=asset_type,
|
||||
channel=channel,
|
||||
status="draft",
|
||||
)
|
||||
asset_nodes.append(asset_node)
|
||||
|
||||
blueprint = CampaignBlueprint(
|
||||
campaign_id=campaign_id,
|
||||
campaign_name=campaign_name,
|
||||
goal=goal,
|
||||
kpi=campaign_data.get('kpi'),
|
||||
phases=phases,
|
||||
asset_nodes=asset_nodes,
|
||||
channels=channels,
|
||||
status="draft",
|
||||
)
|
||||
|
||||
logger.info(f"[Orchestrator] Created blueprint for campaign {campaign_id} with {len(asset_nodes)} assets")
|
||||
return blueprint
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Orchestrator] Error creating blueprint: {str(e)}")
|
||||
raise
|
||||
|
||||
def generate_asset_proposals(
|
||||
self,
|
||||
user_id: str,
|
||||
blueprint: CampaignBlueprint,
|
||||
product_context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate AI proposals for each asset node in the blueprint.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
blueprint: Campaign blueprint
|
||||
product_context: Product information
|
||||
|
||||
Returns:
|
||||
Dictionary with proposals for each asset node
|
||||
"""
|
||||
try:
|
||||
proposals = {}
|
||||
|
||||
for asset_node in blueprint.asset_nodes:
|
||||
# Build specialized prompt based on asset type and channel
|
||||
if asset_node.asset_type == "image":
|
||||
base_prompt = product_context.get('product_description', 'Product image') if product_context else 'Marketing image'
|
||||
enhanced_prompt = self.prompt_builder.build_marketing_image_prompt(
|
||||
base_prompt=base_prompt,
|
||||
user_id=user_id,
|
||||
channel=asset_node.channel,
|
||||
asset_type="hero_image",
|
||||
product_context=product_context,
|
||||
)
|
||||
|
||||
# Get channel pack for template recommendations
|
||||
channel_pack = self.channel_pack.get_channel_pack(asset_node.channel)
|
||||
recommended_template = channel_pack.get('templates', [{}])[0] if channel_pack.get('templates') else None
|
||||
|
||||
# Estimate cost
|
||||
cost_estimate = self._estimate_asset_cost("image", asset_node.channel)
|
||||
|
||||
proposals[asset_node.asset_id] = {
|
||||
"asset_id": asset_node.asset_id,
|
||||
"asset_type": asset_node.asset_type,
|
||||
"channel": asset_node.channel,
|
||||
"campaign_id": blueprint.campaign_id, # Include campaign_id for tracking
|
||||
"proposed_prompt": enhanced_prompt,
|
||||
"recommended_template": recommended_template.get('id') if recommended_template else None,
|
||||
"recommended_provider": recommended_template.get('recommended_provider', 'wavespeed') if recommended_template else 'wavespeed',
|
||||
"cost_estimate": cost_estimate,
|
||||
"concept_summary": self._generate_concept_summary(enhanced_prompt),
|
||||
}
|
||||
|
||||
elif asset_node.asset_type == "video":
|
||||
# Video asset proposals - determine if animation (image-to-video) or demo (text-to-video)
|
||||
# Default to animation if we have product image, otherwise demo
|
||||
video_subtype = asset_proposal.get('video_subtype', 'animation') if 'asset_proposal' in locals() else 'demo'
|
||||
|
||||
# For demo videos (text-to-video), we need product description
|
||||
if video_subtype == "demo" or not product_context or not product_context.get('product_image_base64'):
|
||||
# Text-to-video demo video
|
||||
video_type = "demo" # Default, can be customized
|
||||
if asset_node.channel in ["tiktok", "instagram"]:
|
||||
video_type = "storytelling" # Storytelling for social media
|
||||
elif asset_node.channel in ["linkedin", "youtube"]:
|
||||
video_type = "feature_highlight" # Feature highlights for professional
|
||||
|
||||
# Estimate cost for text-to-video (WAN 2.5: $0.05-$0.15/second)
|
||||
duration = 10 # Default 10s for demo videos
|
||||
resolution = "720p" # Default
|
||||
cost_per_second = 0.10 if resolution == "720p" else (0.15 if resolution == "1080p" else 0.05)
|
||||
cost_estimate = duration * cost_per_second
|
||||
|
||||
proposals[asset_node.asset_id] = {
|
||||
"asset_id": asset_node.asset_id,
|
||||
"asset_type": asset_node.asset_type,
|
||||
"video_subtype": "demo", # Text-to-video
|
||||
"channel": asset_node.channel,
|
||||
"campaign_id": blueprint.campaign_id,
|
||||
"video_type": video_type,
|
||||
"duration": duration,
|
||||
"resolution": resolution,
|
||||
"cost_estimate": cost_estimate,
|
||||
"concept_summary": f"Product {video_type} video optimized for {asset_node.channel}",
|
||||
"note": "Text-to-video demo - requires product description",
|
||||
}
|
||||
else:
|
||||
# Image-to-video animation
|
||||
animation_type = "reveal" # Default
|
||||
if asset_node.channel in ["tiktok", "instagram", "youtube"]:
|
||||
animation_type = "demo" # Demo animations for social media
|
||||
elif asset_node.channel in ["linkedin", "facebook"]:
|
||||
animation_type = "reveal" # Professional reveal for B2B
|
||||
|
||||
# Estimate cost for image-to-video (WAN 2.5: $0.05-$0.15/second)
|
||||
duration = 5 # Default 5s for animations
|
||||
resolution = "720p" # Default
|
||||
cost_per_second = 0.10 if resolution == "720p" else (0.15 if resolution == "1080p" else 0.05)
|
||||
cost_estimate = duration * cost_per_second
|
||||
|
||||
proposals[asset_node.asset_id] = {
|
||||
"asset_id": asset_node.asset_id,
|
||||
"asset_type": asset_node.asset_type,
|
||||
"video_subtype": "animation", # Image-to-video
|
||||
"channel": asset_node.channel,
|
||||
"campaign_id": blueprint.campaign_id,
|
||||
"animation_type": animation_type,
|
||||
"duration": duration,
|
||||
"resolution": resolution,
|
||||
"cost_estimate": cost_estimate,
|
||||
"concept_summary": f"Product {animation_type} animation optimized for {asset_node.channel}",
|
||||
"note": "Requires product image - will be provided during generation",
|
||||
}
|
||||
|
||||
elif asset_node.asset_type == "text":
|
||||
base_request = f"Write {asset_node.channel} {asset_node.asset_type} for product launch"
|
||||
enhanced_prompt = self.prompt_builder.build_marketing_copy_prompt(
|
||||
base_request=base_request,
|
||||
user_id=user_id,
|
||||
channel=asset_node.channel,
|
||||
content_type="caption",
|
||||
product_context=product_context,
|
||||
)
|
||||
|
||||
proposals[asset_node.asset_id] = {
|
||||
"asset_id": asset_node.asset_id,
|
||||
"asset_type": asset_node.asset_type,
|
||||
"channel": asset_node.channel,
|
||||
"campaign_id": blueprint.campaign_id, # Include campaign_id for tracking
|
||||
"proposed_prompt": enhanced_prompt,
|
||||
"cost_estimate": 0.0, # Text generation cost is minimal
|
||||
"concept_summary": "Marketing copy optimized for channel and persona",
|
||||
}
|
||||
|
||||
logger.info(f"[Orchestrator] Generated {len(proposals)} asset proposals")
|
||||
return {"proposals": proposals, "total_assets": len(proposals)}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Orchestrator] Error generating proposals: {str(e)}")
|
||||
raise
|
||||
|
||||
async def generate_asset(
|
||||
self,
|
||||
user_id: str,
|
||||
asset_proposal: Dict[str, Any],
|
||||
product_context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a single asset using Image Studio APIs.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
asset_proposal: Asset proposal from generate_asset_proposals
|
||||
product_context: Product information
|
||||
|
||||
Returns:
|
||||
Generated asset result
|
||||
"""
|
||||
try:
|
||||
asset_type = asset_proposal.get('asset_type')
|
||||
|
||||
if asset_type == "image":
|
||||
# Build CreateStudioRequest
|
||||
create_request = CreateStudioRequest(
|
||||
prompt=asset_proposal.get('proposed_prompt'),
|
||||
template_id=asset_proposal.get('recommended_template'),
|
||||
provider=asset_proposal.get('recommended_provider', 'wavespeed'),
|
||||
quality="premium",
|
||||
enhance_prompt=True,
|
||||
use_persona=True,
|
||||
num_variations=1,
|
||||
)
|
||||
|
||||
# Generate image using Image Studio
|
||||
result = await self.image_studio.create_image(create_request, user_id=user_id)
|
||||
|
||||
# Asset is automatically tracked in Asset Library via Image Studio
|
||||
return {
|
||||
"success": True,
|
||||
"asset_type": "image",
|
||||
"result": result,
|
||||
"asset_library_ids": [
|
||||
r.get('asset_id') for r in result.get('results', [])
|
||||
if r.get('asset_id')
|
||||
],
|
||||
}
|
||||
|
||||
elif asset_type == "video":
|
||||
# Check video subtype: "animation" (image-to-video) or "demo" (text-to-video)
|
||||
video_subtype = asset_proposal.get('video_subtype', 'animation')
|
||||
|
||||
if video_subtype == "demo":
|
||||
# Text-to-video: Product demo video from description
|
||||
from services.product_marketing.product_video_service import ProductVideoService, ProductVideoRequest
|
||||
|
||||
# Get product info from context
|
||||
product_name = product_context.get('product_name', 'Product') if product_context else 'Product'
|
||||
product_description = product_context.get('product_description', '') if product_context else ''
|
||||
|
||||
if not product_description:
|
||||
raise ValueError("Product description required for text-to-video demo generation")
|
||||
|
||||
# Get brand context
|
||||
brand_dna = self.brand_dna_sync.get_brand_dna_tokens(user_id)
|
||||
brand_context = {
|
||||
"visual_identity": brand_dna.get("visual_identity", {}),
|
||||
"persona": brand_dna.get("persona", {}),
|
||||
}
|
||||
|
||||
# Get video type from proposal or default
|
||||
video_type = asset_proposal.get('video_type', 'demo')
|
||||
|
||||
# Create video service
|
||||
video_service = ProductVideoService()
|
||||
|
||||
# Create video request
|
||||
video_request = ProductVideoRequest(
|
||||
product_name=product_name,
|
||||
product_description=product_description,
|
||||
video_type=video_type,
|
||||
resolution=asset_proposal.get('resolution', '720p'),
|
||||
duration=asset_proposal.get('duration', 10),
|
||||
audio_base64=asset_proposal.get('audio_base64'),
|
||||
brand_context=brand_context,
|
||||
additional_context=asset_proposal.get('additional_context'),
|
||||
)
|
||||
|
||||
# Generate video using unified ai_video_generate()
|
||||
result = await video_service.generate_product_video(video_request, user_id)
|
||||
|
||||
# Extract campaign_id for metadata
|
||||
campaign_id = asset_proposal.get('campaign_id')
|
||||
asset_id = asset_proposal.get('asset_id', '')
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"asset_type": "video",
|
||||
"video_subtype": "demo",
|
||||
"video_url": result.get('file_url'),
|
||||
"video_filename": result.get('filename'),
|
||||
"cost": result.get('cost', 0.0),
|
||||
"video_type": video_type,
|
||||
"campaign_id": campaign_id,
|
||||
"asset_id": asset_id,
|
||||
}
|
||||
|
||||
else:
|
||||
# Image-to-video: Product animation
|
||||
from services.product_marketing.product_animation_service import ProductAnimationService, ProductAnimationRequest
|
||||
|
||||
# Get product image from proposal or product context
|
||||
product_image_base64 = asset_proposal.get('product_image_base64')
|
||||
if not product_image_base64 and product_context:
|
||||
product_image_base64 = product_context.get('product_image_base64')
|
||||
|
||||
if not product_image_base64:
|
||||
raise ValueError("Product image required for image-to-video animation generation")
|
||||
|
||||
# Get animation type from proposal or default to "reveal"
|
||||
animation_type = asset_proposal.get('animation_type', 'reveal')
|
||||
product_name = product_context.get('product_name', 'Product') if product_context else 'Product'
|
||||
product_description = product_context.get('product_description') if product_context else None
|
||||
|
||||
# Get brand context
|
||||
brand_dna = self.brand_dna_sync.get_brand_dna_tokens(user_id)
|
||||
brand_context = {
|
||||
"visual_identity": brand_dna.get("visual_identity", {}),
|
||||
"persona": brand_dna.get("persona", {}),
|
||||
}
|
||||
|
||||
# Create animation service
|
||||
animation_service = ProductAnimationService()
|
||||
|
||||
# Create animation request
|
||||
animation_request = ProductAnimationRequest(
|
||||
product_image_base64=product_image_base64,
|
||||
animation_type=animation_type,
|
||||
product_name=product_name,
|
||||
product_description=product_description,
|
||||
resolution=asset_proposal.get('resolution', '720p'),
|
||||
duration=asset_proposal.get('duration', 5),
|
||||
audio_base64=asset_proposal.get('audio_base64'),
|
||||
brand_context=brand_context,
|
||||
additional_context=asset_proposal.get('additional_context'),
|
||||
)
|
||||
|
||||
# Generate video
|
||||
result = await animation_service.animate_product(animation_request, user_id)
|
||||
|
||||
# Extract campaign_id for metadata
|
||||
campaign_id = asset_proposal.get('campaign_id')
|
||||
asset_id = asset_proposal.get('asset_id', '')
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"asset_type": "video",
|
||||
"video_subtype": "animation",
|
||||
"video_url": result.get('video_url'),
|
||||
"video_filename": result.get('filename'),
|
||||
"cost": result.get('cost', 0.0),
|
||||
"animation_type": animation_type,
|
||||
"campaign_id": campaign_id,
|
||||
"asset_id": asset_id,
|
||||
}
|
||||
|
||||
elif asset_type == "text":
|
||||
# Import text generation service and tracker
|
||||
import asyncio
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from utils.text_asset_tracker import save_and_track_text_content
|
||||
from services.database import SessionLocal
|
||||
|
||||
# Get enhanced prompt from proposal
|
||||
text_prompt = asset_proposal.get('proposed_prompt', '')
|
||||
channel = asset_proposal.get('channel', 'social')
|
||||
asset_id = asset_proposal.get('asset_id', '')
|
||||
|
||||
# Extract campaign_id - try from asset_proposal first, then from asset_id
|
||||
# asset_id format: {campaign_id}_{phase}_{channel}_{type}
|
||||
campaign_id = asset_proposal.get('campaign_id')
|
||||
if not campaign_id and asset_id and '_' in asset_id:
|
||||
# Try to extract: asset_id might be "campaign_user123_1234567890_teaser_instagram_text"
|
||||
# We need to find where phase_name starts (common phases: teaser, launch, nurture)
|
||||
parts = asset_id.split('_')
|
||||
# Find phase indicator (usually one of: teaser, launch, nurture)
|
||||
phase_indicators = ['teaser', 'launch', 'nurture', 'prelaunch', 'postlaunch']
|
||||
phase_idx = None
|
||||
for i, part in enumerate(parts):
|
||||
if part.lower() in phase_indicators:
|
||||
phase_idx = i
|
||||
break
|
||||
if phase_idx and phase_idx > 0:
|
||||
# Campaign ID is everything before the phase
|
||||
campaign_id = '_'.join(parts[:phase_idx])
|
||||
|
||||
# If still not found, use None (metadata will work without it)
|
||||
if not campaign_id:
|
||||
logger.warning(f"[Orchestrator] Could not extract campaign_id from asset_id: {asset_id}")
|
||||
|
||||
# Build system prompt for marketing copy
|
||||
system_prompt = f"""You are an expert marketing copywriter specializing in {channel} content.
|
||||
Generate compelling, on-brand marketing copy that:
|
||||
- Is optimized for {channel} platform best practices
|
||||
- Includes a clear call-to-action
|
||||
- Uses appropriate tone and style for the platform
|
||||
- Is concise and engaging
|
||||
- Aligns with the product marketing context provided
|
||||
|
||||
Return only the final copy text without explanations or markdown formatting."""
|
||||
|
||||
# Run synchronous llm_text_gen in thread pool
|
||||
logger.info(f"[Orchestrator] Generating text asset for channel: {channel}")
|
||||
generated_text = await asyncio.to_thread(
|
||||
llm_text_gen,
|
||||
prompt=text_prompt,
|
||||
system_prompt=system_prompt,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if not generated_text or not generated_text.strip():
|
||||
raise ValueError("Text generation returned empty content")
|
||||
|
||||
# Save to Asset Library
|
||||
db = SessionLocal()
|
||||
asset_library_id = None
|
||||
try:
|
||||
asset_library_id = save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=generated_text.strip(),
|
||||
source_module="campaign_creator",
|
||||
title=f"{channel.title()} Copy: {asset_id.split('_')[-1] if '_' in asset_id else 'Marketing Copy'}",
|
||||
description=f"Marketing copy for {channel} platform generated from campaign proposal",
|
||||
prompt=text_prompt,
|
||||
tags=["campaign_creator", channel.lower(), "text", "copy"],
|
||||
asset_metadata={
|
||||
"campaign_id": campaign_id,
|
||||
"asset_id": asset_id,
|
||||
"asset_type": "text",
|
||||
"channel": channel,
|
||||
"concept_summary": asset_proposal.get('concept_summary'),
|
||||
},
|
||||
subdirectory="campaigns",
|
||||
file_extension=".txt"
|
||||
)
|
||||
|
||||
if asset_library_id:
|
||||
logger.info(f"[Orchestrator] ✅ Text asset saved to library: ID={asset_library_id}")
|
||||
else:
|
||||
logger.warning(f"[Orchestrator] ⚠️ Text asset tracking returned None")
|
||||
|
||||
except Exception as save_error:
|
||||
logger.error(f"[Orchestrator] ⚠️ Failed to save text asset to library: {str(save_error)}")
|
||||
# Continue even if save fails - text is still generated
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"asset_type": "text",
|
||||
"content": generated_text.strip(),
|
||||
"asset_library_id": asset_library_id,
|
||||
"channel": channel,
|
||||
}
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported asset type: {asset_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Orchestrator] Error generating asset: {str(e)}")
|
||||
raise
|
||||
|
||||
def validate_campaign_preflight(
|
||||
self,
|
||||
user_id: str,
|
||||
blueprint: CampaignBlueprint
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate campaign blueprint against subscription limits before generation.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
blueprint: Campaign blueprint
|
||||
|
||||
Returns:
|
||||
Pre-flight validation results
|
||||
"""
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
# Count operations needed
|
||||
image_count = sum(1 for node in blueprint.asset_nodes if node.asset_type == "image")
|
||||
text_count = sum(1 for node in blueprint.asset_nodes if node.asset_type == "text")
|
||||
|
||||
# Estimate total cost
|
||||
total_cost = 0.0
|
||||
for node in blueprint.asset_nodes:
|
||||
if node.cost_estimate:
|
||||
total_cost += node.cost_estimate
|
||||
|
||||
# Validate image generation limits
|
||||
operations = []
|
||||
if image_count > 0:
|
||||
operations.append({
|
||||
'provider': 'stability', # Default provider
|
||||
'tokens_requested': 0,
|
||||
'actual_provider_name': 'wavespeed',
|
||||
'operation_type': 'image_generation',
|
||||
})
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations * image_count if operations else []
|
||||
)
|
||||
|
||||
return {
|
||||
"can_proceed": can_proceed,
|
||||
"message": message,
|
||||
"error_details": error_details,
|
||||
"summary": {
|
||||
"total_assets": len(blueprint.asset_nodes),
|
||||
"image_count": image_count,
|
||||
"text_count": text_count,
|
||||
"estimated_cost": total_cost,
|
||||
},
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Orchestrator] Error in pre-flight validation: {str(e)}")
|
||||
return {
|
||||
"can_proceed": False,
|
||||
"message": f"Validation error: {str(e)}",
|
||||
"error_details": {},
|
||||
}
|
||||
|
||||
def _build_campaign_phases(
|
||||
self,
|
||||
goal: str,
|
||||
channels: List[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Build campaign phases based on goal."""
|
||||
if goal == "product_launch":
|
||||
return [
|
||||
{"name": "teaser", "duration_days": 7, "purpose": "Build anticipation"},
|
||||
{"name": "launch", "duration_days": 3, "purpose": "Official launch"},
|
||||
{"name": "nurture", "duration_days": 14, "purpose": "Sustain engagement"},
|
||||
]
|
||||
else:
|
||||
return [
|
||||
{"name": "campaign", "duration_days": 30, "purpose": "Campaign execution"},
|
||||
]
|
||||
|
||||
def _get_required_assets(
|
||||
self,
|
||||
phase: str,
|
||||
channel: str
|
||||
) -> List[str]:
|
||||
"""Get required asset types for phase and channel."""
|
||||
# Default: image for all phases and channels
|
||||
assets = ["image"]
|
||||
|
||||
# Add text/copy for social channels
|
||||
if channel in ["instagram", "linkedin", "facebook", "twitter"]:
|
||||
assets.append("text")
|
||||
|
||||
return assets
|
||||
|
||||
def _estimate_asset_cost(
|
||||
self,
|
||||
asset_type: str,
|
||||
channel: str
|
||||
) -> float:
|
||||
"""Estimate cost for asset generation."""
|
||||
if asset_type == "image":
|
||||
# Premium quality image: ~5-6 credits
|
||||
return 5.0
|
||||
elif asset_type == "video":
|
||||
# WAN 2.5 Image-to-Video: $0.05-$0.15/second
|
||||
# Default: 5 seconds at 720p = $0.50
|
||||
return 0.50
|
||||
elif asset_type == "text":
|
||||
return 0.0 # Text generation is typically included
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def _generate_concept_summary(self, prompt: str) -> str:
|
||||
"""Generate a brief concept summary from prompt."""
|
||||
# Simple extraction: take first 100 chars
|
||||
return prompt[:100] + "..." if len(prompt) > 100 else prompt
|
||||
303
backend/services/campaign_creator/prompt_builder.py
Normal file
303
backend/services/campaign_creator/prompt_builder.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
Campaign Creator Prompt Builder
|
||||
Extends AIPromptOptimizer with campaign-specific prompt enhancement.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from loguru import logger
|
||||
|
||||
from services.ai_prompt_optimizer import AIPromptOptimizer
|
||||
from services.onboarding import OnboardingDataService
|
||||
from services.onboarding.database_service import OnboardingDatabaseService
|
||||
from services.persona_data_service import PersonaDataService
|
||||
from services.database import SessionLocal
|
||||
|
||||
|
||||
class CampaignPromptBuilder(AIPromptOptimizer):
|
||||
"""Specialized prompt builder for campaign assets with onboarding data integration."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Campaign Prompt Builder."""
|
||||
super().__init__()
|
||||
self.onboarding_data_service = OnboardingDataService()
|
||||
self.logger = logger
|
||||
logger.info("[Campaign Prompt Builder] Initialized")
|
||||
|
||||
def build_marketing_image_prompt(
|
||||
self,
|
||||
base_prompt: str,
|
||||
user_id: str,
|
||||
channel: Optional[str] = None,
|
||||
asset_type: str = "hero_image",
|
||||
product_context: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Build enhanced marketing image prompt with brand DNA and persona data.
|
||||
|
||||
Args:
|
||||
base_prompt: Base product description or image concept
|
||||
user_id: User ID to fetch onboarding data
|
||||
channel: Target channel (instagram, linkedin, tiktok, etc.)
|
||||
asset_type: Type of asset (hero_image, product_photo, lifestyle, etc.)
|
||||
product_context: Additional product information
|
||||
|
||||
Returns:
|
||||
Enhanced prompt with brand DNA, persona style, and marketing context
|
||||
"""
|
||||
try:
|
||||
# Get onboarding data
|
||||
db = SessionLocal()
|
||||
try:
|
||||
onboarding_db = OnboardingDatabaseService(db)
|
||||
website_analysis = onboarding_db.get_website_analysis(user_id, db)
|
||||
persona_data = onboarding_db.get_persona_data(user_id, db)
|
||||
competitor_analyses = onboarding_db.get_competitor_analysis(user_id, db)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Build prompt layers
|
||||
enhanced_prompt = base_prompt
|
||||
|
||||
# Layer 1: Brand DNA (from website_analysis)
|
||||
if website_analysis:
|
||||
writing_style = website_analysis.get('writing_style', {})
|
||||
target_audience = website_analysis.get('target_audience', {})
|
||||
brand_analysis = website_analysis.get('brand_analysis', {})
|
||||
style_guidelines = website_analysis.get('style_guidelines', {})
|
||||
|
||||
# Add brand tone and style
|
||||
tone = writing_style.get('tone', 'professional')
|
||||
voice = writing_style.get('voice', 'authoritative')
|
||||
brand_enhancement = f", {tone} tone, {voice} voice"
|
||||
|
||||
# Add target audience context
|
||||
demographics = target_audience.get('demographics', [])
|
||||
if demographics:
|
||||
audience_context = f", targeting {', '.join(demographics[:2])}"
|
||||
enhanced_prompt += audience_context
|
||||
|
||||
# Add brand visual identity if available
|
||||
if brand_analysis:
|
||||
color_palette = brand_analysis.get('color_palette', [])
|
||||
if color_palette:
|
||||
colors = ', '.join(color_palette[:3])
|
||||
enhanced_prompt += f", brand colors: {colors}"
|
||||
|
||||
# Layer 2: Persona Visual Style (from persona_data)
|
||||
if persona_data:
|
||||
core_persona = persona_data.get('corePersona', {})
|
||||
platform_personas = persona_data.get('platformPersonas', {})
|
||||
|
||||
if core_persona:
|
||||
persona_name = core_persona.get('persona_name', '')
|
||||
archetype = core_persona.get('archetype', '')
|
||||
if persona_name:
|
||||
enhanced_prompt += f", {persona_name} style"
|
||||
|
||||
# Channel-specific persona adaptation
|
||||
if channel and platform_personas:
|
||||
platform_persona = platform_personas.get(channel, {})
|
||||
if platform_persona:
|
||||
visual_identity = platform_persona.get('visual_identity', {})
|
||||
if visual_identity:
|
||||
aesthetic = visual_identity.get('aesthetic_preferences', '')
|
||||
if aesthetic:
|
||||
enhanced_prompt += f", {aesthetic} aesthetic"
|
||||
|
||||
# Layer 3: Channel Optimization
|
||||
channel_enhancements = {
|
||||
'instagram': ', Instagram-optimized composition, vibrant colors, engaging visual',
|
||||
'linkedin': ', professional photography, clean composition, business-focused',
|
||||
'tiktok': ', dynamic composition, eye-catching, vertical format optimized',
|
||||
'facebook': ', social media optimized, engaging, shareable visual',
|
||||
'twitter': ', Twitter card optimized, clear focal point, readable at small size',
|
||||
'pinterest': ', Pinterest-optimized, vertical format, detailed and informative',
|
||||
}
|
||||
|
||||
if channel and channel.lower() in channel_enhancements:
|
||||
enhanced_prompt += channel_enhancements[channel.lower()]
|
||||
|
||||
# Layer 4: Asset Type Specific
|
||||
asset_type_enhancements = {
|
||||
'hero_image': ', hero image style, prominent product placement, professional photography',
|
||||
'product_photo': ', product photography, clean background, detailed product showcase',
|
||||
'lifestyle': ', lifestyle photography, natural setting, authentic scene',
|
||||
'social_post': ', social media post, engaging composition, optimized for engagement',
|
||||
}
|
||||
|
||||
if asset_type in asset_type_enhancements:
|
||||
enhanced_prompt += asset_type_enhancements[asset_type]
|
||||
|
||||
# Layer 5: Competitive Differentiation
|
||||
if competitor_analyses and len(competitor_analyses) > 0:
|
||||
# Extract unique positioning from competitor analysis
|
||||
enhanced_prompt += ", unique positioning, differentiated visual style"
|
||||
|
||||
# Layer 6: Quality Descriptors
|
||||
enhanced_prompt += ", professional photography, high quality, detailed, sharp focus, natural lighting"
|
||||
|
||||
# Layer 7: Marketing Context
|
||||
if product_context:
|
||||
marketing_goal = product_context.get('marketing_goal', '')
|
||||
if marketing_goal:
|
||||
enhanced_prompt += f", {marketing_goal} focused"
|
||||
|
||||
logger.info(f"[Campaign Prompt] Enhanced prompt for user {user_id}: {enhanced_prompt[:200]}...")
|
||||
return enhanced_prompt
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Campaign Prompt] Error building prompt: {str(e)}")
|
||||
# Return base prompt with minimal enhancement if error
|
||||
return f"{base_prompt}, professional photography, high quality"
|
||||
|
||||
def build_marketing_copy_prompt(
|
||||
self,
|
||||
base_request: str,
|
||||
user_id: str,
|
||||
channel: Optional[str] = None,
|
||||
content_type: str = "caption",
|
||||
product_context: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Build enhanced marketing copy prompt with persona linguistic fingerprint.
|
||||
|
||||
Args:
|
||||
base_request: Base content request (e.g., "Write Instagram caption for product launch")
|
||||
user_id: User ID to fetch onboarding data
|
||||
channel: Target channel (instagram, linkedin, etc.)
|
||||
content_type: Type of content (caption, cta, email, ad_copy, etc.)
|
||||
product_context: Additional product information
|
||||
|
||||
Returns:
|
||||
Enhanced prompt with persona style, brand voice, and marketing context
|
||||
"""
|
||||
try:
|
||||
# Get onboarding data
|
||||
db = SessionLocal()
|
||||
try:
|
||||
onboarding_db = OnboardingDatabaseService(db)
|
||||
website_analysis = onboarding_db.get_website_analysis(user_id, db)
|
||||
persona_data = onboarding_db.get_persona_data(user_id, db)
|
||||
competitor_analyses = onboarding_db.get_competitor_analysis(user_id, db)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Build enhanced prompt
|
||||
enhanced_prompt = base_request
|
||||
|
||||
# Add persona linguistic fingerprint
|
||||
if persona_data:
|
||||
core_persona = persona_data.get('corePersona', {})
|
||||
platform_personas = persona_data.get('platformPersonas', {})
|
||||
|
||||
if core_persona:
|
||||
persona_name = core_persona.get('persona_name', '')
|
||||
linguistic_fingerprint = core_persona.get('linguistic_fingerprint', {})
|
||||
|
||||
if persona_name:
|
||||
enhanced_prompt += f"\n\nFollow {persona_name} persona style:"
|
||||
|
||||
if linguistic_fingerprint:
|
||||
sentence_metrics = linguistic_fingerprint.get('sentence_metrics', {})
|
||||
lexical_features = linguistic_fingerprint.get('lexical_features', {})
|
||||
|
||||
if sentence_metrics:
|
||||
avg_length = sentence_metrics.get('average_sentence_length_words', '')
|
||||
if avg_length:
|
||||
enhanced_prompt += f"\n- Average sentence length: {avg_length} words"
|
||||
|
||||
if lexical_features:
|
||||
go_to_words = lexical_features.get('go_to_words', [])
|
||||
avoid_words = lexical_features.get('avoid_words', [])
|
||||
vocabulary_level = lexical_features.get('vocabulary_level', '')
|
||||
|
||||
if go_to_words:
|
||||
enhanced_prompt += f"\n- Use these words: {', '.join(go_to_words[:5])}"
|
||||
if avoid_words:
|
||||
enhanced_prompt += f"\n- Avoid these words: {', '.join(avoid_words[:5])}"
|
||||
if vocabulary_level:
|
||||
enhanced_prompt += f"\n- Vocabulary level: {vocabulary_level}"
|
||||
|
||||
# Channel-specific persona adaptation
|
||||
if channel and platform_personas:
|
||||
platform_persona = platform_personas.get(channel, {})
|
||||
if platform_persona:
|
||||
content_format_rules = platform_persona.get('content_format_rules', {})
|
||||
engagement_patterns = platform_persona.get('engagement_patterns', {})
|
||||
|
||||
if content_format_rules:
|
||||
char_limit = content_format_rules.get('character_limit', '')
|
||||
hashtag_strategy = content_format_rules.get('hashtag_strategy', '')
|
||||
|
||||
if char_limit:
|
||||
enhanced_prompt += f"\n- Character limit: {char_limit}"
|
||||
if hashtag_strategy:
|
||||
enhanced_prompt += f"\n- Hashtag strategy: {hashtag_strategy}"
|
||||
|
||||
# Add brand voice
|
||||
if website_analysis:
|
||||
writing_style = website_analysis.get('writing_style', {})
|
||||
target_audience = website_analysis.get('target_audience', {})
|
||||
|
||||
tone = writing_style.get('tone', 'professional')
|
||||
voice = writing_style.get('voice', 'authoritative')
|
||||
enhanced_prompt += f"\n- Brand tone: {tone}, Brand voice: {voice}"
|
||||
|
||||
demographics = target_audience.get('demographics', [])
|
||||
expertise_level = target_audience.get('expertise_level', 'intermediate')
|
||||
if demographics:
|
||||
enhanced_prompt += f"\n- Target audience: {', '.join(demographics[:2])}, {expertise_level} level"
|
||||
|
||||
# Add competitive positioning
|
||||
if competitor_analyses and len(competitor_analyses) > 0:
|
||||
enhanced_prompt += "\n- Differentiate from competitors, highlight unique value propositions"
|
||||
|
||||
# Add marketing context
|
||||
if product_context:
|
||||
marketing_goal = product_context.get('marketing_goal', '')
|
||||
if marketing_goal:
|
||||
enhanced_prompt += f"\n- Marketing goal: {marketing_goal}"
|
||||
|
||||
logger.info(f"[Campaign Copy Prompt] Enhanced for user {user_id}: {enhanced_prompt[:200]}...")
|
||||
return enhanced_prompt
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Campaign Copy Prompt] Error building prompt: {str(e)}")
|
||||
return base_request
|
||||
|
||||
def optimize_marketing_prompt(
|
||||
self,
|
||||
prompt_type: str,
|
||||
base_prompt: str,
|
||||
user_id: str,
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Main entry point for marketing prompt optimization.
|
||||
|
||||
Args:
|
||||
prompt_type: Type of prompt (image, copy, video_script, etc.)
|
||||
base_prompt: Base prompt to enhance
|
||||
user_id: User ID for personalization
|
||||
context: Additional context (channel, asset_type, product_context, etc.)
|
||||
|
||||
Returns:
|
||||
Optimized marketing prompt
|
||||
"""
|
||||
context = context or {}
|
||||
channel = context.get('channel')
|
||||
asset_type = context.get('asset_type', 'hero_image')
|
||||
content_type = context.get('content_type', 'caption')
|
||||
product_context = context.get('product_context')
|
||||
|
||||
if prompt_type == 'image':
|
||||
return self.build_marketing_image_prompt(
|
||||
base_prompt, user_id, channel, asset_type, product_context
|
||||
)
|
||||
elif prompt_type in ['copy', 'caption', 'cta', 'email', 'ad_copy']:
|
||||
return self.build_marketing_copy_prompt(
|
||||
base_prompt, user_id, channel, content_type, product_context
|
||||
)
|
||||
else:
|
||||
# Default: minimal enhancement
|
||||
return f"{base_prompt}, professional quality, marketing optimized"
|
||||
@@ -56,11 +56,11 @@ class CreateStudioService:
|
||||
}
|
||||
}
|
||||
|
||||
# Quality-to-provider mapping
|
||||
# Quality-to-provider mapping (OSS-focused defaults)
|
||||
QUALITY_PROVIDERS = {
|
||||
"draft": ["huggingface", "wavespeed:qwen-image"], # Fast, low cost
|
||||
"standard": ["stability:core", "wavespeed:ideogram-v3-turbo"], # Balanced
|
||||
"premium": ["wavespeed:ideogram-v3-turbo", "stability:ultra"], # Best quality
|
||||
"draft": ["wavespeed:qwen-image", "huggingface"], # OSS: Qwen Image ($0.03) - Fast, low cost
|
||||
"standard": ["wavespeed:qwen-image", "stability:core"], # OSS: Qwen Image default
|
||||
"premium": ["wavespeed:ideogram-v3-turbo", "stability:ultra"], # OSS: Ideogram V3 Turbo ($0.05)
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -30,6 +30,13 @@ class WaveSpeedImageProvider(ImageGenerationProvider):
|
||||
"cost_per_image": 0.05, # Estimated, adjust based on actual pricing
|
||||
"max_resolution": (1024, 1024),
|
||||
"default_steps": 15,
|
||||
},
|
||||
"flux-kontext-pro": {
|
||||
"name": "FLUX Kontext Pro",
|
||||
"description": "Professional typography and text rendering with improved prompt adherence",
|
||||
"cost_per_image": 0.04, # $0.04 per image
|
||||
"max_resolution": (1024, 1024),
|
||||
"default_steps": 20,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -177,6 +184,55 @@ class WaveSpeedImageProvider(ImageGenerationProvider):
|
||||
logger.error("[Qwen Image] ❌ Error generating image: %s", str(e), exc_info=True)
|
||||
raise RuntimeError(f"Qwen Image generation failed: {str(e)}")
|
||||
|
||||
def _generate_flux_kontext_pro(self, options: ImageGenerationOptions) -> bytes:
|
||||
"""Generate image using FLUX Kontext Pro.
|
||||
|
||||
Args:
|
||||
options: Image generation options
|
||||
|
||||
Returns:
|
||||
Image bytes
|
||||
"""
|
||||
logger.info("[FLUX Kontext Pro] Starting image generation: %s", options.prompt[:100])
|
||||
|
||||
try:
|
||||
# Prepare parameters for WaveSpeed FLUX Kontext Pro API
|
||||
params = {
|
||||
"model": "flux-kontext-pro",
|
||||
"prompt": options.prompt,
|
||||
"width": options.width,
|
||||
"height": options.height,
|
||||
"num_inference_steps": options.steps or self.SUPPORTED_MODELS["flux-kontext-pro"]["default_steps"],
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if options.negative_prompt:
|
||||
params["negative_prompt"] = options.negative_prompt
|
||||
|
||||
if options.guidance_scale:
|
||||
params["guidance_scale"] = options.guidance_scale
|
||||
|
||||
if options.seed:
|
||||
params["seed"] = options.seed
|
||||
|
||||
# Call WaveSpeed API
|
||||
result = self.client.generate_image(**params)
|
||||
|
||||
# Extract image bytes from result
|
||||
if isinstance(result, bytes):
|
||||
image_bytes = result
|
||||
elif isinstance(result, dict) and "image" in result:
|
||||
image_bytes = result["image"]
|
||||
else:
|
||||
raise ValueError(f"Unexpected response format from WaveSpeed API: {type(result)}")
|
||||
|
||||
logger.info("[FLUX Kontext Pro] ✅ Successfully generated image: %d bytes", len(image_bytes))
|
||||
return image_bytes
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[FLUX Kontext Pro] ❌ Error generating image: %s", str(e), exc_info=True)
|
||||
raise RuntimeError(f"FLUX Kontext Pro generation failed: {str(e)}")
|
||||
|
||||
def generate(self, options: ImageGenerationOptions) -> ImageGenerationResult:
|
||||
"""Generate image using WaveSpeed AI models.
|
||||
|
||||
@@ -201,6 +257,8 @@ class WaveSpeedImageProvider(ImageGenerationProvider):
|
||||
image_bytes = self._generate_ideogram_v3(options)
|
||||
elif model == "qwen-image":
|
||||
image_bytes = self._generate_qwen_image(options)
|
||||
elif model == "flux-kontext-pro":
|
||||
image_bytes = self._generate_flux_kontext_pro(options)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model: {model}")
|
||||
|
||||
|
||||
@@ -144,6 +144,9 @@ def generate_audio(
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
logger.info(f"[audio_gen] Filtered kwargs (removed None values): {filtered_kwargs}")
|
||||
|
||||
# Track response time
|
||||
import time
|
||||
start_time = time.time()
|
||||
client = WaveSpeedClient()
|
||||
audio_bytes = client.generate_speech(
|
||||
text=text,
|
||||
@@ -155,8 +158,9 @@ def generate_audio(
|
||||
enable_sync_mode=enable_sync_mode,
|
||||
**filtered_kwargs
|
||||
)
|
||||
response_time = time.time() - start_time
|
||||
|
||||
logger.info(f"[audio_gen] ✅ API call successful, generated {len(audio_bytes)} bytes")
|
||||
logger.info(f"[audio_gen] ✅ API call successful, generated {len(audio_bytes)} bytes in {response_time:.2f}s")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -228,19 +232,29 @@ def generate_audio(
|
||||
# Create usage log
|
||||
# Store the text parameter in a local variable before any imports to prevent shadowing
|
||||
text_param = text # Capture function parameter before any potential shadowing
|
||||
|
||||
# Detect actual provider name (WaveSpeed, Google, OpenAI, etc.)
|
||||
from services.subscription.provider_detection import detect_actual_provider
|
||||
actual_provider = detect_actual_provider(
|
||||
provider_enum=APIProvider.AUDIO,
|
||||
model_name="minimax/speech-02-hd",
|
||||
endpoint="/audio-generation/wavespeed"
|
||||
)
|
||||
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.AUDIO,
|
||||
endpoint="/audio-generation/wavespeed",
|
||||
method="POST",
|
||||
model_used="minimax/speech-02-hd",
|
||||
actual_provider_name=actual_provider, # Track actual provider (WaveSpeed, etc.)
|
||||
tokens_input=character_count,
|
||||
tokens_output=0,
|
||||
tokens_total=character_count,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=estimated_cost,
|
||||
response_time=0.0,
|
||||
response_time=response_time, # Use actual response time
|
||||
status_code=200,
|
||||
request_size=len(text_param.encode("utf-8")), # Use captured parameter
|
||||
response_size=len(audio_bytes),
|
||||
|
||||
@@ -138,7 +138,8 @@ def _track_image_operation_usage(
|
||||
prompt: Optional[str] = None,
|
||||
endpoint: str = "/image-generation",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
log_prefix: str = "[Image Generation]"
|
||||
log_prefix: str = "[Image Generation]",
|
||||
response_time: float = 0.0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Reusable usage tracking helper for all image operations.
|
||||
@@ -165,6 +166,7 @@ def _track_image_operation_usage(
|
||||
db_track = next(get_db_track())
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription.provider_detection import detect_actual_provider
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
@@ -215,6 +217,13 @@ def _track_image_operation_usage(
|
||||
# Determine API provider based on actual provider
|
||||
api_provider = APIProvider.STABILITY # Default for image generation
|
||||
|
||||
# Detect actual provider name (WaveSpeed, Stability, HuggingFace, etc.)
|
||||
actual_provider = detect_actual_provider(
|
||||
provider_enum=api_provider,
|
||||
model_name=model,
|
||||
endpoint=endpoint
|
||||
)
|
||||
|
||||
# Create usage log
|
||||
request_size = len(prompt.encode("utf-8")) if prompt else 0
|
||||
usage_log = APIUsageLog(
|
||||
@@ -223,13 +232,14 @@ def _track_image_operation_usage(
|
||||
endpoint=endpoint,
|
||||
method="POST",
|
||||
model_used=model or "unknown",
|
||||
actual_provider_name=actual_provider, # Track actual provider (WaveSpeed, Stability, etc.)
|
||||
tokens_input=0,
|
||||
tokens_output=0,
|
||||
tokens_total=0,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=cost,
|
||||
response_time=0.0,
|
||||
response_time=response_time, # Use actual response time
|
||||
status_code=200,
|
||||
request_size=request_size,
|
||||
response_size=len(result_bytes),
|
||||
@@ -327,21 +337,39 @@ def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_i
|
||||
|
||||
# Normalize obvious model/provider mismatches
|
||||
model_lower = (image_options.model or "").lower()
|
||||
|
||||
# Detect Wavespeed models and remap provider if needed
|
||||
wavespeed_models = ["qwen-image", "ideogram-v3-turbo", "flux-kontext-pro"]
|
||||
if model_lower in wavespeed_models and provider_name != "wavespeed":
|
||||
logger.info("Remapping provider to wavespeed for model=%s", image_options.model)
|
||||
provider_name = "wavespeed"
|
||||
|
||||
# Detect HuggingFace models and remap provider if needed
|
||||
if provider_name == "stability" and (model_lower.startswith("black-forest-labs/") or model_lower.startswith("runwayml/") or model_lower.startswith("stabilityai/flux")):
|
||||
logger.info("Remapping provider to huggingface for model=%s", image_options.model)
|
||||
provider_name = "huggingface"
|
||||
|
||||
# Detect HuggingFace models when provider is not explicitly set
|
||||
if not opts.get("provider") and (model_lower.startswith("black-forest-labs/") or model_lower.startswith("runwayml/") or model_lower.startswith("stabilityai/flux")):
|
||||
logger.info("Auto-detecting provider as huggingface for model=%s", image_options.model)
|
||||
provider_name = "huggingface"
|
||||
|
||||
if provider_name == "huggingface" and not image_options.model:
|
||||
# Provide a sensible default HF model if none specified
|
||||
image_options.model = "black-forest-labs/FLUX.1-Krea-dev"
|
||||
|
||||
if provider_name == "wavespeed" and not image_options.model:
|
||||
# Provide a sensible default WaveSpeed model if none specified
|
||||
image_options.model = "ideogram-v3-turbo"
|
||||
# Default to cost-effective model: Qwen Image ($0.05/image, optimized for blog images)
|
||||
image_options.model = "qwen-image"
|
||||
|
||||
logger.info("Generating image via provider=%s model=%s", provider_name, image_options.model)
|
||||
provider = _get_provider(provider_name)
|
||||
|
||||
# Track response time
|
||||
import time
|
||||
start_time = time.time()
|
||||
result = provider.generate(image_options)
|
||||
response_time = time.time() - start_time
|
||||
|
||||
# TRACK USAGE after successful API call - Reuse extracted helper
|
||||
if user_id and result and result.image_bytes:
|
||||
@@ -352,12 +380,14 @@ def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_i
|
||||
if result.metadata and "estimated_cost" in result.metadata:
|
||||
estimated_cost = float(result.metadata["estimated_cost"])
|
||||
else:
|
||||
# Fallback: estimate based on provider/model
|
||||
# Fallback: estimate based on provider/model (OSS-focused pricing)
|
||||
if provider_name == "wavespeed":
|
||||
if result.model and "qwen" in result.model.lower():
|
||||
estimated_cost = 0.05
|
||||
estimated_cost = 0.05 # Qwen Image: $0.05/image
|
||||
elif result.model and "ideogram" in result.model.lower():
|
||||
estimated_cost = 0.10 # Ideogram V3 Turbo: $0.10/image
|
||||
else:
|
||||
estimated_cost = 0.10 # ideogram-v3-turbo default
|
||||
estimated_cost = 0.05 # Default to Qwen Image pricing
|
||||
elif provider_name == "stability":
|
||||
estimated_cost = 0.04
|
||||
else:
|
||||
@@ -374,7 +404,8 @@ def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_i
|
||||
prompt=prompt,
|
||||
endpoint="/image-generation",
|
||||
metadata=result.metadata,
|
||||
log_prefix="[Image Generation]"
|
||||
log_prefix="[Image Generation]",
|
||||
response_time=response_time
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Image Generation] ⚠️ Skipping usage tracking: user_id={user_id}, image_bytes={len(result.image_bytes) if result.image_bytes else 0} bytes")
|
||||
|
||||
@@ -27,6 +27,7 @@ except ImportError:
|
||||
|
||||
from ..onboarding.api_key_manager import APIKeyManager
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.provider_detection import detect_actual_provider
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("video_generation_service")
|
||||
@@ -508,6 +509,11 @@ async def ai_video_generate(
|
||||
|
||||
# Generate video based on operation type
|
||||
model_name = kwargs.get("model", _get_default_model(operation_type, provider))
|
||||
|
||||
# Track response time for video generation
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
if operation_type == "text-to-video":
|
||||
if provider == "huggingface":
|
||||
@@ -620,6 +626,7 @@ async def ai_video_generate(
|
||||
|
||||
# Track usage (same pattern as text generation)
|
||||
# Use cost from result_dict if available, otherwise calculate
|
||||
response_time = time.time() - start_time
|
||||
cost_override = result_dict.get("cost") if operation_type == "image-to-video" else kwargs.get("cost_override")
|
||||
track_video_usage(
|
||||
user_id=user_id,
|
||||
@@ -628,6 +635,7 @@ async def ai_video_generate(
|
||||
prompt=result_dict.get("prompt", prompt or ""),
|
||||
video_bytes=video_bytes,
|
||||
cost_override=cost_override,
|
||||
response_time=response_time,
|
||||
)
|
||||
|
||||
# Progress callback: Complete
|
||||
@@ -662,6 +670,7 @@ def track_video_usage(
|
||||
prompt: str,
|
||||
video_bytes: bytes,
|
||||
cost_override: Optional[float] = None,
|
||||
response_time: float = 0.0,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Track subscription usage for any video generation (text-to-video or image-to-video).
|
||||
@@ -732,19 +741,27 @@ def track_video_usage(
|
||||
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
|
||||
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
# Detect actual provider name (WaveSpeed, HuggingFace, Google, etc.)
|
||||
actual_provider = detect_actual_provider(
|
||||
provider_enum=APIProvider.VIDEO,
|
||||
model_name=model_name,
|
||||
endpoint=f"/video-generation/{provider}"
|
||||
)
|
||||
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.VIDEO,
|
||||
endpoint=f"/video-generation/{provider}",
|
||||
method="POST",
|
||||
model_used=model_name,
|
||||
actual_provider_name=actual_provider, # Track actual provider (WaveSpeed, HuggingFace, etc.)
|
||||
tokens_input=0,
|
||||
tokens_output=0,
|
||||
tokens_total=0,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=cost_per_video,
|
||||
response_time=0.0,
|
||||
response_time=response_time, # Use actual response time
|
||||
status_code=200,
|
||||
request_size=len((prompt or "").encode("utf-8")),
|
||||
response_size=len(video_bytes),
|
||||
|
||||
@@ -1,23 +1,15 @@
|
||||
"""Product Marketing Suite service package."""
|
||||
"""Product Marketing Suite service package - Product asset creation only."""
|
||||
|
||||
from .orchestrator import ProductMarketingOrchestrator
|
||||
from .brand_dna_sync import BrandDNASyncService
|
||||
from .prompt_builder import ProductMarketingPromptBuilder
|
||||
from .asset_audit import AssetAuditService
|
||||
from .channel_pack import ChannelPackService
|
||||
from .campaign_storage import CampaignStorageService
|
||||
from .product_image_service import ProductImageService
|
||||
from .product_animation_service import ProductAnimationService, ProductAnimationRequest
|
||||
from .product_video_service import ProductVideoService, ProductVideoRequest
|
||||
from .product_avatar_service import ProductAvatarService, ProductAvatarRequest
|
||||
from .intelligent_prompt_builder import IntelligentPromptBuilder
|
||||
from .personalization_service import PersonalizationService
|
||||
|
||||
__all__ = [
|
||||
"ProductMarketingOrchestrator",
|
||||
"BrandDNASyncService",
|
||||
"ProductMarketingPromptBuilder",
|
||||
"AssetAuditService",
|
||||
"ChannelPackService",
|
||||
"CampaignStorageService",
|
||||
"ProductImageService",
|
||||
"ProductAnimationService",
|
||||
"ProductAnimationRequest",
|
||||
@@ -25,5 +17,7 @@ __all__ = [
|
||||
"ProductVideoRequest",
|
||||
"ProductAvatarService",
|
||||
"ProductAvatarRequest",
|
||||
"IntelligentPromptBuilder",
|
||||
"PersonalizationService",
|
||||
]
|
||||
|
||||
|
||||
454
backend/services/product_marketing/intelligent_prompt_builder.py
Normal file
454
backend/services/product_marketing/intelligent_prompt_builder.py
Normal file
@@ -0,0 +1,454 @@
|
||||
"""
|
||||
Intelligent Prompt Builder
|
||||
Infers complete requirements from minimal user input using onboarding data.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
from loguru import logger
|
||||
import json
|
||||
|
||||
from services.onboarding.database_service import OnboardingDatabaseService
|
||||
from services.database import SessionLocal
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from .product_marketing_templates import (
|
||||
ProductMarketingTemplates,
|
||||
TemplateCategory,
|
||||
ProductImageTemplate,
|
||||
ProductVideoTemplate,
|
||||
ProductAvatarTemplate,
|
||||
)
|
||||
|
||||
|
||||
class IntelligentPromptBuilder:
|
||||
"""
|
||||
Intelligent prompt builder that infers requirements from minimal user input.
|
||||
|
||||
Example:
|
||||
Input: "iPhone case for my store"
|
||||
Output: Complete configuration with all fields pre-filled
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Intelligent Prompt Builder."""
|
||||
self.logger = logger
|
||||
logger.info("[Intelligent Prompt Builder] Initialized")
|
||||
|
||||
def infer_requirements(
|
||||
self,
|
||||
user_input: str,
|
||||
user_id: str,
|
||||
asset_type: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Infer complete requirements from minimal user input.
|
||||
|
||||
Args:
|
||||
user_input: Minimal user input (e.g., "iPhone case for my store")
|
||||
user_id: User ID to fetch onboarding data
|
||||
asset_type: Optional asset type hint (image, video, animation, avatar)
|
||||
|
||||
Returns:
|
||||
Complete configuration dictionary with all fields pre-filled
|
||||
"""
|
||||
try:
|
||||
# 1. Parse user input
|
||||
parsed_input = self._parse_user_input(user_input, asset_type)
|
||||
|
||||
# 2. Get onboarding data
|
||||
onboarding_data = self._get_onboarding_data(user_id)
|
||||
|
||||
# 3. Infer requirements from context
|
||||
requirements = self._infer_from_context(parsed_input, onboarding_data, asset_type)
|
||||
|
||||
# 4. Match template
|
||||
template = self._match_template(requirements, asset_type)
|
||||
|
||||
# 5. Generate smart defaults
|
||||
defaults = self._generate_defaults(requirements, template, onboarding_data)
|
||||
|
||||
logger.info(f"[Intelligent Prompt Builder] Inferred requirements: {defaults.get('product_name', 'Unknown')}")
|
||||
return defaults
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Intelligent Prompt Builder] Error inferring requirements: {str(e)}", exc_info=True)
|
||||
# Return basic defaults on error
|
||||
return self._get_basic_defaults(user_input, asset_type)
|
||||
|
||||
def _parse_user_input(
|
||||
self,
|
||||
user_input: str,
|
||||
asset_type: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse minimal user input to extract entities.
|
||||
|
||||
Uses LLM with few-shot prompting to extract:
|
||||
- Product name
|
||||
- Product type
|
||||
- Use case (e-commerce, marketing, social media, etc.)
|
||||
- Platform hints (store, Instagram, Shopify, Amazon, etc.)
|
||||
- Style preferences
|
||||
"""
|
||||
try:
|
||||
# Build system prompt for entity extraction
|
||||
system_prompt = """You are an expert at parsing product marketing requests.
|
||||
Extract key information from user input and return structured JSON.
|
||||
|
||||
Extract:
|
||||
- product_name: The product name or description
|
||||
- product_type: Type of product (phone_case, clothing, electronics, food, etc.)
|
||||
- use_case: Primary use case (ecommerce, social_media, marketing_campaign, documentation, etc.)
|
||||
- platform_hints: Platforms mentioned (shopify, amazon, instagram, facebook, etc.)
|
||||
- style_hints: Style preferences mentioned (professional, casual, luxury, minimalist, etc.)
|
||||
- asset_type_hint: Type of asset needed (image, video, animation, avatar) if mentioned
|
||||
|
||||
Return JSON only, no explanations."""
|
||||
|
||||
# Few-shot examples
|
||||
examples = """
|
||||
Examples:
|
||||
Input: "iPhone case for my store"
|
||||
Output: {"product_name": "iPhone case", "product_type": "phone_case", "use_case": "ecommerce", "platform_hints": ["shopify"], "style_hints": [], "asset_type_hint": "image"}
|
||||
|
||||
Input: "Create a video for my new product launch on Instagram"
|
||||
Output: {"product_name": "new product", "product_type": "unknown", "use_case": "social_media", "platform_hints": ["instagram"], "style_hints": [], "asset_type_hint": "video"}
|
||||
|
||||
Input: "Luxury watch photoshoot"
|
||||
Output: {"product_name": "luxury watch", "product_type": "watch", "use_case": "marketing_campaign", "platform_hints": [], "style_hints": ["luxury"], "asset_type_hint": "image"}
|
||||
"""
|
||||
|
||||
prompt = f"{examples}\n\nInput: {user_input}\nOutput:"
|
||||
|
||||
# Call LLM for parsing
|
||||
json_struct = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"product_name": {"type": "string"},
|
||||
"product_type": {"type": "string"},
|
||||
"use_case": {"type": "string"},
|
||||
"platform_hints": {"type": "array", "items": {"type": "string"}},
|
||||
"style_hints": {"type": "array", "items": {"type": "string"}},
|
||||
"asset_type_hint": {"type": "string"}
|
||||
},
|
||||
"required": ["product_name", "use_case"]
|
||||
}
|
||||
|
||||
# Call LLM synchronously (llm_text_gen is synchronous)
|
||||
result_text = llm_text_gen(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
json_struct=json_struct,
|
||||
user_id=None # No user_id needed for parsing
|
||||
)
|
||||
|
||||
# Parse JSON response
|
||||
try:
|
||||
parsed = json.loads(result_text) if isinstance(result_text, str) else result_text
|
||||
except json.JSONDecodeError:
|
||||
# Fallback: try to extract JSON from text
|
||||
import re
|
||||
json_match = re.search(r'\{[^}]+\}', result_text)
|
||||
if json_match:
|
||||
parsed = json.loads(json_match.group())
|
||||
else:
|
||||
# Ultimate fallback: basic extraction
|
||||
parsed = {
|
||||
"product_name": user_input,
|
||||
"product_type": "unknown",
|
||||
"use_case": "marketing_campaign",
|
||||
"platform_hints": [],
|
||||
"style_hints": [],
|
||||
"asset_type_hint": asset_type or "image"
|
||||
}
|
||||
|
||||
# Override asset_type_hint if provided
|
||||
if asset_type:
|
||||
parsed["asset_type_hint"] = asset_type
|
||||
|
||||
logger.info(f"[Intelligent Prompt Builder] Parsed input: {parsed}")
|
||||
return parsed
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Intelligent Prompt Builder] Error parsing input: {str(e)}")
|
||||
# Fallback: basic extraction
|
||||
return {
|
||||
"product_name": user_input,
|
||||
"product_type": "unknown",
|
||||
"use_case": "marketing_campaign",
|
||||
"platform_hints": [],
|
||||
"style_hints": [],
|
||||
"asset_type_hint": asset_type or "image"
|
||||
}
|
||||
|
||||
def _get_onboarding_data(self, user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get all onboarding data for user.
|
||||
|
||||
Returns:
|
||||
Dictionary with website_analysis, persona_data, competitor_analyses
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
onboarding_db = OnboardingDatabaseService(db)
|
||||
website_analysis = onboarding_db.get_website_analysis(user_id, db)
|
||||
persona_data = onboarding_db.get_persona_data(user_id, db)
|
||||
competitor_analyses = onboarding_db.get_competitor_analysis(user_id, db)
|
||||
|
||||
return {
|
||||
"website_analysis": website_analysis or {},
|
||||
"persona_data": persona_data or {},
|
||||
"competitor_analyses": competitor_analyses or [],
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[Intelligent Prompt Builder] Error getting onboarding data: {str(e)}")
|
||||
return {
|
||||
"website_analysis": {},
|
||||
"persona_data": {},
|
||||
"competitor_analyses": [],
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _infer_from_context(
|
||||
self,
|
||||
parsed_input: Dict[str, Any],
|
||||
onboarding_data: Dict[str, Any],
|
||||
asset_type: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Infer requirements from parsed input and onboarding context.
|
||||
|
||||
Uses onboarding data to fill in missing information:
|
||||
- Platform from onboarding (if user has e-commerce setup)
|
||||
- Style from brand DNA
|
||||
- Target audience from onboarding
|
||||
"""
|
||||
requirements = parsed_input.copy()
|
||||
|
||||
website_analysis = onboarding_data.get("website_analysis", {})
|
||||
persona_data = onboarding_data.get("persona_data", {})
|
||||
|
||||
# Infer platform from onboarding
|
||||
if not requirements.get("platform_hints"):
|
||||
# Check if user has e-commerce setup (from website analysis)
|
||||
brand_analysis = website_analysis.get("brand_analysis", {})
|
||||
# Try to infer platform from website URL or other hints
|
||||
# For now, default to e-commerce if no hints
|
||||
if requirements.get("use_case") == "ecommerce":
|
||||
requirements["platform_hints"] = ["shopify"] # Default e-commerce platform
|
||||
|
||||
# Infer style from brand DNA
|
||||
if not requirements.get("style_hints"):
|
||||
if brand_analysis:
|
||||
style_guidelines = brand_analysis.get("style_guidelines", {})
|
||||
aesthetic = style_guidelines.get("aesthetic", "")
|
||||
if aesthetic:
|
||||
requirements["style_hints"] = [aesthetic.lower()]
|
||||
|
||||
# Infer target audience from onboarding
|
||||
target_audience = website_analysis.get("target_audience", {})
|
||||
if target_audience:
|
||||
requirements["target_audience"] = target_audience
|
||||
|
||||
# Infer brand colors
|
||||
if brand_analysis:
|
||||
color_palette = brand_analysis.get("color_palette", [])
|
||||
if color_palette:
|
||||
requirements["brand_colors"] = color_palette[:5] # Top 5 colors
|
||||
|
||||
# Infer writing style
|
||||
writing_style = website_analysis.get("writing_style", {})
|
||||
if writing_style:
|
||||
requirements["tone"] = writing_style.get("tone", "professional")
|
||||
requirements["voice"] = writing_style.get("voice", "authoritative")
|
||||
|
||||
return requirements
|
||||
|
||||
def _match_template(
|
||||
self,
|
||||
requirements: Dict[str, Any],
|
||||
asset_type: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Match requirements to appropriate template.
|
||||
|
||||
Returns:
|
||||
Template dictionary or None
|
||||
"""
|
||||
asset_type_hint = asset_type or requirements.get("asset_type_hint", "image")
|
||||
use_case = requirements.get("use_case", "marketing_campaign")
|
||||
style_hints = requirements.get("style_hints", [])
|
||||
|
||||
if asset_type_hint == "image":
|
||||
templates = ProductMarketingTemplates.get_product_image_templates()
|
||||
|
||||
# Match by use case
|
||||
if use_case == "ecommerce":
|
||||
# Match e-commerce template
|
||||
for template in templates:
|
||||
if "ecommerce" in template.id.lower() or "e-commerce" in template.name.lower():
|
||||
return {
|
||||
"id": template.id,
|
||||
"name": template.name,
|
||||
"category": template.category.value,
|
||||
"environment": template.environment,
|
||||
"background_style": template.background_style,
|
||||
"lighting": template.lighting,
|
||||
"style": template.style,
|
||||
"angle": template.angle,
|
||||
"recommended_resolution": template.recommended_resolution,
|
||||
}
|
||||
|
||||
# Match by style
|
||||
if style_hints:
|
||||
style_lower = style_hints[0].lower()
|
||||
for template in templates:
|
||||
if style_lower in template.style.lower() or style_lower in template.name.lower():
|
||||
return {
|
||||
"id": template.id,
|
||||
"name": template.name,
|
||||
"category": template.category.value,
|
||||
"environment": template.environment,
|
||||
"background_style": template.background_style,
|
||||
"lighting": template.lighting,
|
||||
"style": template.style,
|
||||
"angle": template.angle,
|
||||
"recommended_resolution": template.recommended_resolution,
|
||||
}
|
||||
|
||||
# Default: e-commerce product shot
|
||||
default_template = templates[0] # ecommerce_product_shot
|
||||
return {
|
||||
"id": default_template.id,
|
||||
"name": default_template.name,
|
||||
"category": default_template.category.value,
|
||||
"environment": default_template.environment,
|
||||
"background_style": default_template.background_style,
|
||||
"lighting": default_template.lighting,
|
||||
"style": default_template.style,
|
||||
"angle": default_template.angle,
|
||||
"recommended_resolution": default_template.recommended_resolution,
|
||||
}
|
||||
|
||||
elif asset_type_hint == "video":
|
||||
templates = ProductMarketingTemplates.get_product_video_templates()
|
||||
# Default: product demo video
|
||||
default_template = templates[0]
|
||||
return {
|
||||
"id": default_template.id,
|
||||
"name": default_template.name,
|
||||
"category": default_template.category.value,
|
||||
"video_type": default_template.video_type,
|
||||
"resolution": default_template.resolution,
|
||||
"duration": default_template.duration,
|
||||
}
|
||||
|
||||
elif asset_type_hint == "avatar":
|
||||
templates = ProductMarketingTemplates.get_product_avatar_templates()
|
||||
# Default: product overview
|
||||
default_template = templates[0]
|
||||
return {
|
||||
"id": default_template.id,
|
||||
"name": default_template.name,
|
||||
"category": default_template.category.value,
|
||||
"explainer_type": default_template.explainer_type,
|
||||
"resolution": default_template.resolution,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def _generate_defaults(
|
||||
self,
|
||||
requirements: Dict[str, Any],
|
||||
template: Optional[Dict[str, Any]],
|
||||
onboarding_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate complete configuration with smart defaults.
|
||||
|
||||
Combines:
|
||||
- Parsed requirements
|
||||
- Matched template
|
||||
- Onboarding data
|
||||
"""
|
||||
defaults = {}
|
||||
|
||||
# Product information
|
||||
defaults["product_name"] = requirements.get("product_name", "Product")
|
||||
defaults["product_description"] = requirements.get("product_description", f"Professional {requirements.get('product_name', 'product')}")
|
||||
|
||||
# Asset type
|
||||
asset_type = requirements.get("asset_type_hint", "image")
|
||||
defaults["asset_type"] = asset_type
|
||||
|
||||
# Template information
|
||||
if template:
|
||||
defaults["template_id"] = template.get("id")
|
||||
defaults["template_name"] = template.get("name")
|
||||
|
||||
# Image-specific defaults
|
||||
if asset_type == "image" and template:
|
||||
defaults["environment"] = template.get("environment", "studio")
|
||||
defaults["background_style"] = template.get("background_style", "white")
|
||||
defaults["lighting"] = template.get("lighting", "studio")
|
||||
defaults["style"] = template.get("style", "photorealistic")
|
||||
defaults["angle"] = template.get("angle", "front")
|
||||
defaults["resolution"] = template.get("recommended_resolution", "1024x1024")
|
||||
defaults["num_variations"] = 1
|
||||
|
||||
# Override with style hints if available
|
||||
if requirements.get("style_hints"):
|
||||
style_hint = requirements["style_hints"][0].lower()
|
||||
if "luxury" in style_hint:
|
||||
defaults["style"] = "luxury"
|
||||
defaults["lighting"] = "dramatic"
|
||||
elif "minimalist" in style_hint:
|
||||
defaults["style"] = "minimalist"
|
||||
defaults["background_style"] = "white"
|
||||
elif "lifestyle" in style_hint:
|
||||
defaults["environment"] = "lifestyle"
|
||||
defaults["background_style"] = "lifestyle"
|
||||
|
||||
# Video-specific defaults
|
||||
elif asset_type == "video" and template:
|
||||
defaults["video_type"] = template.get("video_type", "demo")
|
||||
defaults["resolution"] = template.get("resolution", "720p")
|
||||
defaults["duration"] = template.get("duration", 10)
|
||||
|
||||
# Avatar-specific defaults
|
||||
elif asset_type == "avatar" and template:
|
||||
defaults["explainer_type"] = template.get("explainer_type", "product_overview")
|
||||
defaults["resolution"] = template.get("resolution", "720p")
|
||||
|
||||
# Brand colors from onboarding
|
||||
if requirements.get("brand_colors"):
|
||||
defaults["brand_colors"] = requirements["brand_colors"]
|
||||
|
||||
# Additional context
|
||||
defaults["additional_context"] = requirements.get("additional_context", "")
|
||||
|
||||
# Confidence score (how well we matched)
|
||||
defaults["confidence"] = 0.8 if template else 0.5
|
||||
defaults["inferred_fields"] = list(defaults.keys())
|
||||
|
||||
return defaults
|
||||
|
||||
def _get_basic_defaults(
|
||||
self,
|
||||
user_input: str,
|
||||
asset_type: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get basic defaults when parsing fails."""
|
||||
return {
|
||||
"product_name": user_input,
|
||||
"product_description": f"Professional {user_input}",
|
||||
"asset_type": asset_type or "image",
|
||||
"environment": "studio",
|
||||
"background_style": "white",
|
||||
"lighting": "studio",
|
||||
"style": "photorealistic",
|
||||
"resolution": "1024x1024",
|
||||
"num_variations": 1,
|
||||
"confidence": 0.3,
|
||||
"inferred_fields": ["product_name", "product_description"],
|
||||
}
|
||||
413
backend/services/product_marketing/personalization_service.py
Normal file
413
backend/services/product_marketing/personalization_service.py
Normal file
@@ -0,0 +1,413 @@
|
||||
"""
|
||||
Personalization Service
|
||||
Extracts ALL onboarding data and provides personalized defaults for forms and recommendations.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
from loguru import logger
|
||||
|
||||
from services.onboarding.database_service import OnboardingDatabaseService
|
||||
from services.database import SessionLocal
|
||||
|
||||
|
||||
class PersonalizationService:
|
||||
"""
|
||||
Service for extracting user preferences from onboarding data
|
||||
and providing personalized defaults and recommendations.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Personalization Service."""
|
||||
self.logger = logger
|
||||
logger.info("[Personalization Service] Initialized")
|
||||
|
||||
def get_user_preferences(self, user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get comprehensive user preferences from ALL onboarding data.
|
||||
|
||||
Returns:
|
||||
Dictionary with personalized preferences:
|
||||
- industry: User's industry
|
||||
- target_audience: Demographics, expertise level
|
||||
- platform_preferences: Preferred platforms from persona data
|
||||
- content_preferences: Preferred content types
|
||||
- style_preferences: Visual style, tone, voice
|
||||
- brand_colors: Brand color palette
|
||||
- templates: Recommended templates for user's industry
|
||||
- channels: Recommended channels based on platform personas
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
onboarding_db = OnboardingDatabaseService(db)
|
||||
website_analysis = onboarding_db.get_website_analysis(user_id, db)
|
||||
persona_data = onboarding_db.get_persona_data(user_id, db)
|
||||
competitor_analyses = onboarding_db.get_competitor_analysis(user_id, db)
|
||||
|
||||
preferences = {
|
||||
"industry": None,
|
||||
"target_audience": {},
|
||||
"platform_preferences": [],
|
||||
"content_preferences": [],
|
||||
"style_preferences": {},
|
||||
"brand_colors": [],
|
||||
"recommended_templates": [],
|
||||
"recommended_channels": [],
|
||||
"writing_style": {},
|
||||
"brand_values": [],
|
||||
}
|
||||
|
||||
# Extract from website_analysis
|
||||
if website_analysis:
|
||||
# Industry
|
||||
target_audience = website_analysis.get("target_audience", {})
|
||||
preferences["industry"] = target_audience.get("industry_focus")
|
||||
|
||||
# Target audience
|
||||
preferences["target_audience"] = {
|
||||
"demographics": target_audience.get("demographics", []),
|
||||
"expertise_level": target_audience.get("expertise_level", "intermediate"),
|
||||
"industry_focus": target_audience.get("industry_focus"),
|
||||
}
|
||||
|
||||
# Writing style
|
||||
writing_style = website_analysis.get("writing_style", {})
|
||||
preferences["writing_style"] = {
|
||||
"tone": writing_style.get("tone", "professional"),
|
||||
"voice": writing_style.get("voice", "authoritative"),
|
||||
"complexity": writing_style.get("complexity", "intermediate"),
|
||||
"engagement_level": writing_style.get("engagement_level", "moderate"),
|
||||
}
|
||||
|
||||
# Brand colors
|
||||
brand_analysis = website_analysis.get("brand_analysis", {})
|
||||
if brand_analysis:
|
||||
preferences["brand_colors"] = brand_analysis.get("color_palette", [])
|
||||
preferences["brand_values"] = brand_analysis.get("brand_values", [])
|
||||
|
||||
# Style preferences
|
||||
style_guidelines = website_analysis.get("style_guidelines", {})
|
||||
if style_guidelines:
|
||||
preferences["style_preferences"] = {
|
||||
"aesthetic": style_guidelines.get("aesthetic", "modern"),
|
||||
"visual_style": style_guidelines.get("visual_style", "clean"),
|
||||
}
|
||||
|
||||
# Extract from persona_data
|
||||
if persona_data:
|
||||
core_persona = persona_data.get("corePersona", {})
|
||||
platform_personas = persona_data.get("platformPersonas", {})
|
||||
selected_platforms = persona_data.get("selectedPlatforms", [])
|
||||
|
||||
# Platform preferences from selected platforms
|
||||
if selected_platforms:
|
||||
preferences["platform_preferences"] = selected_platforms
|
||||
elif platform_personas:
|
||||
# Extract platforms from platform personas
|
||||
preferences["platform_preferences"] = list(platform_personas.keys())
|
||||
|
||||
# Recommended channels based on platform personas
|
||||
if platform_personas:
|
||||
# Prioritize platforms with active personas
|
||||
preferences["recommended_channels"] = list(platform_personas.keys())[:5] # Top 5
|
||||
|
||||
# Content preferences from persona
|
||||
if core_persona:
|
||||
content_format_rules = core_persona.get("content_format_rules", {})
|
||||
if content_format_rules:
|
||||
preferred_formats = content_format_rules.get("preferred_formats", [])
|
||||
preferences["content_preferences"] = preferred_formats
|
||||
|
||||
# Infer content preferences from industry
|
||||
if preferences["industry"]:
|
||||
industry_content_map = {
|
||||
"ecommerce": ["product_images", "product_videos", "lifestyle_content"],
|
||||
"saas": ["feature_highlights", "tutorials", "demo_videos"],
|
||||
"education": ["tutorials", "educational_content", "explainer_videos"],
|
||||
"healthcare": ["informational_content", "patient_stories", "educational_videos"],
|
||||
"finance": ["informational_content", "trust_building", "expert_content"],
|
||||
"fashion": ["lifestyle_images", "fashion_shows", "style_guides"],
|
||||
"food": ["food_photography", "recipe_videos", "lifestyle_content"],
|
||||
}
|
||||
industry_lower = preferences["industry"].lower()
|
||||
for key, content_types in industry_content_map.items():
|
||||
if key in industry_lower:
|
||||
preferences["content_preferences"] = content_types
|
||||
break
|
||||
|
||||
# Recommend templates based on industry
|
||||
preferences["recommended_templates"] = self._get_recommended_templates(
|
||||
preferences.get("industry"),
|
||||
preferences.get("style_preferences", {}).get("aesthetic")
|
||||
)
|
||||
|
||||
# Recommend channels if not already set
|
||||
if not preferences["recommended_channels"]:
|
||||
preferences["recommended_channels"] = self._get_recommended_channels(
|
||||
preferences.get("industry"),
|
||||
preferences.get("target_audience", {}).get("demographics", [])
|
||||
)
|
||||
|
||||
logger.info(f"[Personalization] Extracted preferences for user {user_id}: industry={preferences.get('industry')}")
|
||||
return preferences
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Personalization] Error getting user preferences: {str(e)}", exc_info=True)
|
||||
return self._get_default_preferences()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def get_personalized_defaults(
|
||||
self,
|
||||
user_id: str,
|
||||
form_type: str = "product_photoshoot"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get personalized defaults for a specific form.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
form_type: Type of form (product_photoshoot, campaign_creator, product_video, etc.)
|
||||
|
||||
Returns:
|
||||
Dictionary with pre-filled form values
|
||||
"""
|
||||
preferences = self.get_user_preferences(user_id)
|
||||
defaults = {}
|
||||
|
||||
if form_type == "product_photoshoot":
|
||||
defaults = {
|
||||
"environment": self._infer_environment(preferences),
|
||||
"background_style": self._infer_background_style(preferences),
|
||||
"lighting": self._infer_lighting(preferences),
|
||||
"style": self._infer_style(preferences),
|
||||
"resolution": "1024x1024",
|
||||
"num_variations": 1,
|
||||
"brand_colors": preferences.get("brand_colors", []),
|
||||
}
|
||||
|
||||
elif form_type == "campaign_creator":
|
||||
defaults = {
|
||||
"channels": preferences.get("recommended_channels", ["instagram", "linkedin"]),
|
||||
"goal": self._infer_campaign_goal(preferences),
|
||||
}
|
||||
|
||||
elif form_type == "product_video":
|
||||
defaults = {
|
||||
"video_type": self._infer_video_type(preferences),
|
||||
"resolution": "720p",
|
||||
"duration": 10,
|
||||
}
|
||||
|
||||
elif form_type == "product_avatar":
|
||||
defaults = {
|
||||
"explainer_type": self._infer_explainer_type(preferences),
|
||||
"resolution": "720p",
|
||||
}
|
||||
|
||||
return defaults
|
||||
|
||||
def get_recommendations(self, user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get personalized recommendations for user.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- recommended_templates: Templates matching user's industry
|
||||
- recommended_channels: Channels matching user's platform personas
|
||||
- recommended_asset_types: Asset types matching user's content preferences
|
||||
"""
|
||||
preferences = self.get_user_preferences(user_id)
|
||||
|
||||
return {
|
||||
"templates": preferences.get("recommended_templates", []),
|
||||
"channels": preferences.get("recommended_channels", []),
|
||||
"asset_types": preferences.get("content_preferences", []),
|
||||
"industry": preferences.get("industry"),
|
||||
"reasoning": self._generate_recommendation_reasoning(preferences),
|
||||
}
|
||||
|
||||
def _get_recommended_templates(
|
||||
self,
|
||||
industry: Optional[str],
|
||||
aesthetic: Optional[str] = None
|
||||
) -> List[str]:
|
||||
"""Get recommended template IDs based on industry and aesthetic."""
|
||||
templates = []
|
||||
|
||||
if not industry:
|
||||
return ["ecommerce_product_shot", "lifestyle_product"]
|
||||
|
||||
industry_lower = industry.lower() if industry else ""
|
||||
|
||||
# Industry-based template recommendations
|
||||
if "ecommerce" in industry_lower or "retail" in industry_lower:
|
||||
templates.extend(["ecommerce_product_shot", "lifestyle_product"])
|
||||
elif "saas" in industry_lower or "tech" in industry_lower:
|
||||
templates.extend(["technical_product_detail", "lifestyle_product"])
|
||||
elif "luxury" in industry_lower or "premium" in industry_lower:
|
||||
templates.extend(["luxury_product_showcase", "lifestyle_product"])
|
||||
else:
|
||||
templates.extend(["ecommerce_product_shot", "lifestyle_product"])
|
||||
|
||||
# Aesthetic-based adjustments
|
||||
if aesthetic:
|
||||
aesthetic_lower = aesthetic.lower()
|
||||
if "luxury" in aesthetic_lower or "premium" in aesthetic_lower:
|
||||
templates.insert(0, "luxury_product_showcase")
|
||||
elif "minimalist" in aesthetic_lower or "clean" in aesthetic_lower:
|
||||
templates.insert(0, "ecommerce_product_shot")
|
||||
|
||||
return templates[:3] # Return top 3
|
||||
|
||||
def _get_recommended_channels(
|
||||
self,
|
||||
industry: Optional[str],
|
||||
demographics: List[str]
|
||||
) -> List[str]:
|
||||
"""Get recommended channels based on industry and demographics."""
|
||||
channels = []
|
||||
|
||||
if not industry:
|
||||
return ["instagram", "linkedin"]
|
||||
|
||||
industry_lower = industry.lower() if industry else ""
|
||||
|
||||
# Industry-based channel recommendations
|
||||
if "b2b" in industry_lower or "saas" in industry_lower or "enterprise" in industry_lower:
|
||||
channels.extend(["linkedin", "twitter", "youtube"])
|
||||
elif "b2c" in industry_lower or "ecommerce" in industry_lower or "retail" in industry_lower:
|
||||
channels.extend(["instagram", "facebook", "pinterest", "tiktok"])
|
||||
elif "fashion" in industry_lower or "lifestyle" in industry_lower:
|
||||
channels.extend(["instagram", "pinterest", "tiktok"])
|
||||
elif "education" in industry_lower:
|
||||
channels.extend(["youtube", "linkedin", "facebook"])
|
||||
else:
|
||||
channels.extend(["instagram", "linkedin", "facebook"])
|
||||
|
||||
# Demographics-based adjustments
|
||||
if demographics:
|
||||
demographics_str = " ".join(demographics).lower()
|
||||
if "young" in demographics_str or "millennial" in demographics_str or "gen z" in demographics_str:
|
||||
if "tiktok" not in channels:
|
||||
channels.insert(0, "tiktok")
|
||||
if "professional" in demographics_str or "business" in demographics_str:
|
||||
if "linkedin" not in channels:
|
||||
channels.insert(0, "linkedin")
|
||||
|
||||
return channels[:5] # Return top 5
|
||||
|
||||
def _infer_environment(self, preferences: Dict[str, Any]) -> str:
|
||||
"""Infer environment setting from preferences."""
|
||||
industry = preferences.get("industry", "").lower() if preferences.get("industry") else ""
|
||||
aesthetic = preferences.get("style_preferences", {}).get("aesthetic", "").lower()
|
||||
|
||||
if "luxury" in aesthetic or "premium" in industry:
|
||||
return "studio"
|
||||
elif "ecommerce" in industry or "retail" in industry:
|
||||
return "studio"
|
||||
elif "lifestyle" in aesthetic:
|
||||
return "lifestyle"
|
||||
else:
|
||||
return "studio"
|
||||
|
||||
def _infer_background_style(self, preferences: Dict[str, Any]) -> str:
|
||||
"""Infer background style from preferences."""
|
||||
industry = preferences.get("industry", "").lower() if preferences.get("industry") else ""
|
||||
aesthetic = preferences.get("style_preferences", {}).get("aesthetic", "").lower()
|
||||
|
||||
if "ecommerce" in industry or "retail" in industry:
|
||||
return "white"
|
||||
elif "luxury" in aesthetic:
|
||||
return "minimalist"
|
||||
elif "lifestyle" in aesthetic:
|
||||
return "lifestyle"
|
||||
else:
|
||||
return "white"
|
||||
|
||||
def _infer_lighting(self, preferences: Dict[str, Any]) -> str:
|
||||
"""Infer lighting style from preferences."""
|
||||
aesthetic = preferences.get("style_preferences", {}).get("aesthetic", "").lower()
|
||||
|
||||
if "luxury" in aesthetic or "dramatic" in aesthetic:
|
||||
return "dramatic"
|
||||
elif "natural" in aesthetic:
|
||||
return "natural"
|
||||
else:
|
||||
return "studio"
|
||||
|
||||
def _infer_style(self, preferences: Dict[str, Any]) -> str:
|
||||
"""Infer image style from preferences."""
|
||||
aesthetic = preferences.get("style_preferences", {}).get("aesthetic", "").lower()
|
||||
industry = preferences.get("industry", "").lower() if preferences.get("industry") else ""
|
||||
|
||||
if "luxury" in aesthetic or "premium" in industry:
|
||||
return "luxury"
|
||||
elif "minimalist" in aesthetic:
|
||||
return "minimalist"
|
||||
elif "technical" in industry or "saas" in industry:
|
||||
return "technical"
|
||||
else:
|
||||
return "photorealistic"
|
||||
|
||||
def _infer_campaign_goal(self, preferences: Dict[str, Any]) -> str:
|
||||
"""Infer campaign goal from preferences."""
|
||||
industry = preferences.get("industry", "").lower() if preferences.get("industry") else ""
|
||||
|
||||
if "saas" in industry or "tech" in industry:
|
||||
return "conversion"
|
||||
elif "ecommerce" in industry or "retail" in industry:
|
||||
return "conversion"
|
||||
else:
|
||||
return "awareness"
|
||||
|
||||
def _infer_video_type(self, preferences: Dict[str, Any]) -> str:
|
||||
"""Infer video type from preferences."""
|
||||
content_prefs = preferences.get("content_preferences", [])
|
||||
|
||||
if "demo" in str(content_prefs).lower():
|
||||
return "demo"
|
||||
elif "tutorial" in str(content_prefs).lower():
|
||||
return "feature_highlight"
|
||||
else:
|
||||
return "demo"
|
||||
|
||||
def _infer_explainer_type(self, preferences: Dict[str, Any]) -> str:
|
||||
"""Infer explainer type from preferences."""
|
||||
content_prefs = preferences.get("content_preferences", [])
|
||||
|
||||
if "tutorial" in str(content_prefs).lower():
|
||||
return "tutorial"
|
||||
elif "feature" in str(content_prefs).lower():
|
||||
return "feature_explainer"
|
||||
else:
|
||||
return "product_overview"
|
||||
|
||||
def _generate_recommendation_reasoning(self, preferences: Dict[str, Any]) -> str:
|
||||
"""Generate human-readable reasoning for recommendations."""
|
||||
industry = preferences.get("industry", "your industry")
|
||||
channels = preferences.get("recommended_channels", [])
|
||||
|
||||
reasoning = f"Based on your {industry} industry"
|
||||
if channels:
|
||||
reasoning += f" and platform preferences, we recommend focusing on {', '.join(channels[:3])}"
|
||||
reasoning += "."
|
||||
|
||||
return reasoning
|
||||
|
||||
def _get_default_preferences(self) -> Dict[str, Any]:
|
||||
"""Get default preferences when onboarding data is unavailable."""
|
||||
return {
|
||||
"industry": None,
|
||||
"target_audience": {},
|
||||
"platform_preferences": ["instagram", "linkedin"],
|
||||
"content_preferences": [],
|
||||
"style_preferences": {},
|
||||
"brand_colors": [],
|
||||
"recommended_templates": ["ecommerce_product_shot", "lifestyle_product"],
|
||||
"recommended_channels": ["instagram", "linkedin", "facebook"],
|
||||
"writing_style": {
|
||||
"tone": "professional",
|
||||
"voice": "authoritative",
|
||||
},
|
||||
"brand_values": [],
|
||||
}
|
||||
@@ -10,6 +10,8 @@ from dataclasses import dataclass
|
||||
from services.image_studio.transform_service import TransformStudioService, TransformImageToVideoRequest
|
||||
from services.image_studio.studio_manager import ImageStudioManager
|
||||
from utils.logger_utils import get_service_logger
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
from services.database import SessionLocal
|
||||
|
||||
logger = get_service_logger("product_marketing.animation")
|
||||
|
||||
@@ -141,6 +143,63 @@ class ProductAnimationService:
|
||||
result["animation_type"] = request.animation_type
|
||||
result["source_module"] = "product_marketing"
|
||||
|
||||
# Save to Asset Library
|
||||
if result.get("file_url") and result.get("filename"):
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Build animation prompt for metadata
|
||||
animation_prompt = self._build_animation_prompt(
|
||||
animation_type=request.animation_type,
|
||||
product_name=request.product_name,
|
||||
product_description=request.product_description,
|
||||
brand_context=request.brand_context,
|
||||
additional_context=request.additional_context
|
||||
)
|
||||
|
||||
asset_id = save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="video",
|
||||
source_module="product_marketing",
|
||||
filename=result.get("filename"),
|
||||
file_url=result.get("file_url"),
|
||||
file_path=result.get("file_path"),
|
||||
file_size=result.get("file_size"),
|
||||
mime_type="video/mp4",
|
||||
title=f"{request.product_name} - {request.animation_type.title()} Animation",
|
||||
description=f"Product animation: {request.product_description or request.product_name}",
|
||||
prompt=animation_prompt,
|
||||
tags=["product_marketing", "product_animation", request.animation_type, request.resolution],
|
||||
provider=result.get("provider", "wavespeed"),
|
||||
model=result.get("model_name", "alibaba/wan-2.5/image-to-video"),
|
||||
cost=result.get("cost", 0.0),
|
||||
generation_time=result.get("generation_time"),
|
||||
asset_metadata={
|
||||
"product_name": request.product_name,
|
||||
"product_description": request.product_description,
|
||||
"animation_type": request.animation_type,
|
||||
"resolution": request.resolution,
|
||||
"duration": request.duration,
|
||||
"width": result.get("width"),
|
||||
"height": result.get("height"),
|
||||
},
|
||||
)
|
||||
|
||||
if asset_id:
|
||||
logger.info(f"[Product Animation] ✅ Saved animation to Asset Library: ID={asset_id}")
|
||||
else:
|
||||
logger.warning(f"[Product Animation] ⚠️ Asset Library save returned None")
|
||||
|
||||
except Exception as db_error:
|
||||
logger.error(f"[Product Animation] Database error saving to Asset Library: {str(db_error)}", exc_info=True)
|
||||
# Video is saved, but database tracking failed - not critical
|
||||
finally:
|
||||
if db:
|
||||
try:
|
||||
db.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(
|
||||
f"[Product Animation] ✅ Product animation completed: "
|
||||
f"cost=${result.get('cost', 0):.2f}, video_url={result.get('video_url', 'N/A')}"
|
||||
|
||||
@@ -14,6 +14,8 @@ import base64
|
||||
from services.image_studio.infinitetalk_adapter import InfiniteTalkService
|
||||
from services.story_writer.audio_generation_service import StoryAudioGenerationService
|
||||
from utils.logger_utils import get_service_logger
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
from services.database import SessionLocal
|
||||
|
||||
logger = get_service_logger("product_marketing.avatar")
|
||||
|
||||
@@ -271,6 +273,65 @@ class ProductAvatarService:
|
||||
result["file_size"] = file_size
|
||||
result["duration"] = result.get("duration", 0.0)
|
||||
|
||||
# Save to Asset Library
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Build avatar prompt for metadata
|
||||
avatar_prompt = request.prompt
|
||||
if not avatar_prompt:
|
||||
avatar_prompt = self._build_avatar_prompt(
|
||||
explainer_type=request.explainer_type,
|
||||
product_name=request.product_name,
|
||||
product_description=request.product_description,
|
||||
brand_context=request.brand_context,
|
||||
additional_context=request.additional_context
|
||||
)
|
||||
|
||||
asset_id = save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="video",
|
||||
source_module="product_marketing",
|
||||
filename=filename,
|
||||
file_url=file_url,
|
||||
file_path=str(file_path),
|
||||
file_size=file_size,
|
||||
mime_type="video/mp4",
|
||||
title=f"{request.product_name} - {request.explainer_type.replace('_', ' ').title()} Explainer",
|
||||
description=f"Product explainer: {request.product_description or request.product_name}",
|
||||
prompt=avatar_prompt,
|
||||
tags=["product_marketing", "product_avatar", "explainer", request.explainer_type, request.resolution],
|
||||
provider=result.get("provider", "infinitetalk"),
|
||||
model=result.get("model_name", "infinitetalk"),
|
||||
cost=result.get("cost", 0.0),
|
||||
generation_time=result.get("generation_time"),
|
||||
asset_metadata={
|
||||
"product_name": request.product_name,
|
||||
"product_description": request.product_description,
|
||||
"explainer_type": request.explainer_type,
|
||||
"resolution": request.resolution,
|
||||
"duration": result.get("duration", 0.0),
|
||||
"script_text": request.script_text,
|
||||
"width": result.get("width"),
|
||||
"height": result.get("height"),
|
||||
},
|
||||
)
|
||||
|
||||
if asset_id:
|
||||
logger.info(f"[Product Avatar] ✅ Saved explainer video to Asset Library: ID={asset_id}")
|
||||
else:
|
||||
logger.warning(f"[Product Avatar] ⚠️ Asset Library save returned None")
|
||||
|
||||
except Exception as db_error:
|
||||
logger.error(f"[Product Avatar] Database error saving to Asset Library: {str(db_error)}", exc_info=True)
|
||||
# Video is saved, but database tracking failed - not critical
|
||||
finally:
|
||||
if db:
|
||||
try:
|
||||
db.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(
|
||||
f"[Product Avatar] ✅ Product explainer video generated successfully: "
|
||||
f"cost=${result.get('cost', 0):.2f}, duration={result.get('duration', 0):.1f}s, "
|
||||
|
||||
@@ -0,0 +1,390 @@
|
||||
"""
|
||||
Product Marketing Templates Library
|
||||
Pre-built templates for common product marketing use cases.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class TemplateCategory(str, Enum):
|
||||
"""Template categories."""
|
||||
PRODUCT_IMAGE = "product_image"
|
||||
PRODUCT_VIDEO = "product_video"
|
||||
PRODUCT_AVATAR = "product_avatar"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProductImageTemplate:
|
||||
"""Product image generation template."""
|
||||
id: str
|
||||
name: str
|
||||
category: TemplateCategory
|
||||
description: str
|
||||
environment: str # studio, lifestyle, outdoor, minimalist
|
||||
background_style: str # white, transparent, lifestyle, branded
|
||||
lighting: str # natural, studio, dramatic, soft
|
||||
style: str # photorealistic, minimalist, luxury, technical
|
||||
angle: str # front, side, top, 45_degree, 360
|
||||
use_cases: List[str]
|
||||
prompt_template: Optional[str] = None
|
||||
recommended_resolution: str = "1024x1024"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProductVideoTemplate:
|
||||
"""Product video generation template."""
|
||||
id: str
|
||||
name: str
|
||||
category: TemplateCategory
|
||||
description: str
|
||||
video_type: str # demo, storytelling, feature_highlight, launch
|
||||
resolution: str # 480p, 720p, 1080p
|
||||
duration: int # 5 or 10 seconds
|
||||
use_cases: List[str]
|
||||
prompt_template: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProductAvatarTemplate:
|
||||
"""Product avatar/explainer video template."""
|
||||
id: str
|
||||
name: str
|
||||
category: TemplateCategory
|
||||
description: str
|
||||
explainer_type: str # product_overview, feature_explainer, tutorial, brand_message
|
||||
resolution: str # 480p, 720p
|
||||
use_cases: List[str]
|
||||
script_template: Optional[str] = None
|
||||
prompt_template: Optional[str] = None
|
||||
|
||||
|
||||
class ProductMarketingTemplates:
|
||||
"""Product Marketing template definitions."""
|
||||
|
||||
@classmethod
|
||||
def get_product_image_templates(cls) -> List[ProductImageTemplate]:
|
||||
"""Get all product image templates."""
|
||||
return [
|
||||
ProductImageTemplate(
|
||||
id="ecommerce_product_shot",
|
||||
name="E-commerce Product Shot",
|
||||
category=TemplateCategory.PRODUCT_IMAGE,
|
||||
description="Professional product photography for e-commerce listings. Clean white background, studio lighting, front angle.",
|
||||
environment="studio",
|
||||
background_style="white",
|
||||
lighting="studio",
|
||||
style="photorealistic",
|
||||
angle="front",
|
||||
use_cases=["E-commerce listings", "Product catalogs", "Amazon/Shopify"],
|
||||
prompt_template="{product_name} on white background, professional product photography, studio lighting, clean and minimalist, high quality, e-commerce style",
|
||||
recommended_resolution="1024x1024",
|
||||
),
|
||||
ProductImageTemplate(
|
||||
id="lifestyle_product",
|
||||
name="Lifestyle Product Image",
|
||||
category=TemplateCategory.PRODUCT_IMAGE,
|
||||
description="Product in realistic lifestyle setting. Natural environment, authentic use case.",
|
||||
environment="lifestyle",
|
||||
background_style="lifestyle",
|
||||
lighting="natural",
|
||||
style="photorealistic",
|
||||
angle="45_degree",
|
||||
use_cases=["Social media", "Marketing campaigns", "Brand storytelling"],
|
||||
prompt_template="{product_name} in realistic lifestyle setting, natural environment, authentic use case, relatable scenario, professional photography",
|
||||
recommended_resolution="1024x1024",
|
||||
),
|
||||
ProductImageTemplate(
|
||||
id="luxury_product_showcase",
|
||||
name="Luxury Product Showcase",
|
||||
category=TemplateCategory.PRODUCT_IMAGE,
|
||||
description="Premium product presentation. Dramatic lighting, elegant composition, luxury aesthetic.",
|
||||
environment="studio",
|
||||
background_style="minimalist",
|
||||
lighting="dramatic",
|
||||
style="luxury",
|
||||
angle="45_degree",
|
||||
use_cases=["Premium brands", "Luxury products", "High-end marketing"],
|
||||
prompt_template="{product_name} luxury product showcase, dramatic lighting, elegant composition, premium aesthetic, sophisticated, high-end",
|
||||
recommended_resolution="1024x1024",
|
||||
),
|
||||
ProductImageTemplate(
|
||||
id="technical_product_detail",
|
||||
name="Technical Product Detail",
|
||||
category=TemplateCategory.PRODUCT_IMAGE,
|
||||
description="Technical product photography. Focus on details, specifications, features.",
|
||||
environment="studio",
|
||||
background_style="white",
|
||||
lighting="studio",
|
||||
style="technical",
|
||||
angle="front",
|
||||
use_cases=["Technical products", "Specification sheets", "Product documentation"],
|
||||
prompt_template="{product_name} technical product photography, detailed features visible, clean background, professional technical documentation style",
|
||||
recommended_resolution="1024x1024",
|
||||
),
|
||||
ProductImageTemplate(
|
||||
id="social_media_product",
|
||||
name="Social Media Product Post",
|
||||
category=TemplateCategory.PRODUCT_IMAGE,
|
||||
description="Product image optimized for social media. Eye-catching, shareable, engaging.",
|
||||
environment="lifestyle",
|
||||
background_style="lifestyle",
|
||||
lighting="natural",
|
||||
style="photorealistic",
|
||||
angle="45_degree",
|
||||
use_cases=["Instagram", "Facebook", "TikTok", "Pinterest"],
|
||||
prompt_template="{product_name} social media product post, eye-catching, shareable, engaging, modern aesthetic, social media optimized",
|
||||
recommended_resolution="1024x1024",
|
||||
),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_product_video_templates(cls) -> List[ProductVideoTemplate]:
|
||||
"""Get all product video templates."""
|
||||
return [
|
||||
ProductVideoTemplate(
|
||||
id="product_demo_video",
|
||||
name="Product Demo Video",
|
||||
category=TemplateCategory.PRODUCT_VIDEO,
|
||||
description="Product demonstration video showing product in use, showcasing key features and benefits.",
|
||||
video_type="demo",
|
||||
resolution="720p",
|
||||
duration=10,
|
||||
use_cases=["Product launches", "Feature showcases", "Marketing campaigns"],
|
||||
prompt_template="{product_name} being demonstrated in use, showcasing key features and benefits, professional product demonstration, dynamic camera movement, engaging presentation",
|
||||
),
|
||||
ProductVideoTemplate(
|
||||
id="product_storytelling",
|
||||
name="Product Storytelling Video",
|
||||
category=TemplateCategory.PRODUCT_VIDEO,
|
||||
description="Narrative-driven product showcase. Emotional connection, compelling visual story.",
|
||||
video_type="storytelling",
|
||||
resolution="1080p",
|
||||
duration=10,
|
||||
use_cases=["Brand storytelling", "Emotional marketing", "Campaign videos"],
|
||||
prompt_template="Story of {product_name}, narrative-driven product showcase, emotional connection, cinematic storytelling, compelling visual narrative",
|
||||
),
|
||||
ProductVideoTemplate(
|
||||
id="feature_highlight_video",
|
||||
name="Feature Highlight Video",
|
||||
category=TemplateCategory.PRODUCT_VIDEO,
|
||||
description="Close-up shots highlighting key product features. Feature-focused presentation.",
|
||||
video_type="feature_highlight",
|
||||
resolution="720p",
|
||||
duration=10,
|
||||
use_cases=["Feature announcements", "Product updates", "Technical showcases"],
|
||||
prompt_template="{product_name} highlighting key features, close-up shots of important details, feature-focused presentation, professional product photography",
|
||||
),
|
||||
ProductVideoTemplate(
|
||||
id="product_launch_video",
|
||||
name="Product Launch Video",
|
||||
category=TemplateCategory.PRODUCT_VIDEO,
|
||||
description="Exciting product launch reveal. Dynamic presentation, launch event aesthetic.",
|
||||
video_type="launch",
|
||||
resolution="1080p",
|
||||
duration=10,
|
||||
use_cases=["Product launches", "Announcements", "Launch events"],
|
||||
prompt_template="{product_name} product launch reveal, exciting unveiling, dynamic presentation, professional product showcase, launch event aesthetic",
|
||||
),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_product_avatar_templates(cls) -> List[ProductAvatarTemplate]:
|
||||
"""Get all product avatar/explainer templates."""
|
||||
return [
|
||||
ProductAvatarTemplate(
|
||||
id="product_overview_explainer",
|
||||
name="Product Overview Explainer",
|
||||
category=TemplateCategory.PRODUCT_AVATAR,
|
||||
description="Comprehensive product overview. Engaging and informative presentation.",
|
||||
explainer_type="product_overview",
|
||||
resolution="720p",
|
||||
use_cases=["Product introductions", "Landing pages", "Sales presentations"],
|
||||
script_template="Welcome! Today I'm excited to introduce {product_name}. {product_description}. This innovative product offers [key benefits]. Let me show you what makes it special...",
|
||||
prompt_template="Professional product presentation of {product_name}, engaging and informative, clear communication, confident expression",
|
||||
),
|
||||
ProductAvatarTemplate(
|
||||
id="feature_explainer",
|
||||
name="Feature Explainer Video",
|
||||
category=TemplateCategory.PRODUCT_AVATAR,
|
||||
description="Detailed feature explanation. Pointing gestures, clear visual communication.",
|
||||
explainer_type="feature_explainer",
|
||||
resolution="720p",
|
||||
use_cases=["Feature announcements", "Product tutorials", "How-to guides"],
|
||||
script_template="Let me show you the key features of {product_name}. First, [feature 1] - this allows you to [benefit]. Next, [feature 2] - which enables [benefit]. Finally, [feature 3] - giving you [benefit]...",
|
||||
prompt_template="Demonstrating features of {product_name}, detailed explanation, pointing gestures, clear visual communication",
|
||||
),
|
||||
ProductAvatarTemplate(
|
||||
id="product_tutorial",
|
||||
name="Product Tutorial Video",
|
||||
category=TemplateCategory.PRODUCT_AVATAR,
|
||||
description="Step-by-step product tutorial. Instructional and clear, friendly approach.",
|
||||
explainer_type="tutorial",
|
||||
resolution="720p",
|
||||
use_cases=["User guides", "Onboarding", "Training materials"],
|
||||
script_template="Welcome to this tutorial on {product_name}. Today I'll walk you through how to use it. Step 1: [instruction]. Step 2: [instruction]. Step 3: [instruction]...",
|
||||
prompt_template="Tutorial presentation for {product_name}, step-by-step explanation, instructional and clear, friendly and approachable",
|
||||
),
|
||||
ProductAvatarTemplate(
|
||||
id="brand_message_video",
|
||||
name="Brand Message Video",
|
||||
category=TemplateCategory.PRODUCT_AVATAR,
|
||||
description="Brand message delivery. Authentic and compelling brand storytelling.",
|
||||
explainer_type="brand_message",
|
||||
resolution="720p",
|
||||
use_cases=["Brand campaigns", "Mission statements", "Company values"],
|
||||
script_template="At [Brand Name], we believe in {product_name} because [brand values]. Our mission is [mission statement]. This product represents [brand message]...",
|
||||
prompt_template="Brand message delivery for {product_name}, authentic and compelling, brand storytelling, emotional connection",
|
||||
),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_template_by_id(cls, template_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a specific template by ID."""
|
||||
# Search in all template types
|
||||
for template in cls.get_product_image_templates():
|
||||
if template.id == template_id:
|
||||
return {
|
||||
"id": template.id,
|
||||
"name": template.name,
|
||||
"category": template.category.value,
|
||||
"description": template.description,
|
||||
"template_data": {
|
||||
"environment": template.environment,
|
||||
"background_style": template.background_style,
|
||||
"lighting": template.lighting,
|
||||
"style": template.style,
|
||||
"angle": template.angle,
|
||||
"recommended_resolution": template.recommended_resolution,
|
||||
},
|
||||
"use_cases": template.use_cases,
|
||||
"prompt_template": template.prompt_template,
|
||||
}
|
||||
|
||||
for template in cls.get_product_video_templates():
|
||||
if template.id == template_id:
|
||||
return {
|
||||
"id": template.id,
|
||||
"name": template.name,
|
||||
"category": template.category.value,
|
||||
"description": template.description,
|
||||
"template_data": {
|
||||
"video_type": template.video_type,
|
||||
"resolution": template.resolution,
|
||||
"duration": template.duration,
|
||||
},
|
||||
"use_cases": template.use_cases,
|
||||
"prompt_template": template.prompt_template,
|
||||
}
|
||||
|
||||
for template in cls.get_product_avatar_templates():
|
||||
if template.id == template_id:
|
||||
return {
|
||||
"id": template.id,
|
||||
"name": template.name,
|
||||
"category": template.category.value,
|
||||
"description": template.description,
|
||||
"template_data": {
|
||||
"explainer_type": template.explainer_type,
|
||||
"resolution": template.resolution,
|
||||
},
|
||||
"use_cases": template.use_cases,
|
||||
"script_template": template.script_template,
|
||||
"prompt_template": template.prompt_template,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_templates_by_category(cls, category: TemplateCategory) -> List[Dict[str, Any]]:
|
||||
"""Get all templates for a specific category."""
|
||||
if category == TemplateCategory.PRODUCT_IMAGE:
|
||||
return [
|
||||
{
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"description": t.description,
|
||||
"environment": t.environment,
|
||||
"background_style": t.background_style,
|
||||
"lighting": t.lighting,
|
||||
"style": t.style,
|
||||
"angle": t.angle,
|
||||
"use_cases": t.use_cases,
|
||||
"prompt_template": t.prompt_template,
|
||||
"recommended_resolution": t.recommended_resolution,
|
||||
}
|
||||
for t in cls.get_product_image_templates()
|
||||
]
|
||||
elif category == TemplateCategory.PRODUCT_VIDEO:
|
||||
return [
|
||||
{
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"description": t.description,
|
||||
"video_type": t.video_type,
|
||||
"resolution": t.resolution,
|
||||
"duration": t.duration,
|
||||
"use_cases": t.use_cases,
|
||||
"prompt_template": t.prompt_template,
|
||||
}
|
||||
for t in cls.get_product_video_templates()
|
||||
]
|
||||
elif category == TemplateCategory.PRODUCT_AVATAR:
|
||||
return [
|
||||
{
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"description": t.description,
|
||||
"explainer_type": t.explainer_type,
|
||||
"resolution": t.resolution,
|
||||
"use_cases": t.use_cases,
|
||||
"script_template": t.script_template,
|
||||
"prompt_template": t.prompt_template,
|
||||
}
|
||||
for t in cls.get_product_avatar_templates()
|
||||
]
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def apply_template(
|
||||
cls,
|
||||
template_id: str,
|
||||
product_name: str,
|
||||
product_description: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Apply a template to product data.
|
||||
|
||||
Args:
|
||||
template_id: Template ID to apply
|
||||
product_name: Product name
|
||||
product_description: Product description (optional)
|
||||
**kwargs: Additional template-specific parameters
|
||||
|
||||
Returns:
|
||||
Template configuration ready for use
|
||||
"""
|
||||
template = cls.get_template_by_id(template_id)
|
||||
if not template:
|
||||
raise ValueError(f"Template not found: {template_id}")
|
||||
|
||||
# Format prompt/script templates with product data
|
||||
result = template.copy()
|
||||
|
||||
if result.get("prompt_template"):
|
||||
result["prompt"] = result["prompt_template"].format(
|
||||
product_name=product_name,
|
||||
product_description=product_description or product_name,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if result.get("script_template"):
|
||||
result["script"] = result["script_template"].format(
|
||||
product_name=product_name,
|
||||
product_description=product_description or product_name,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -9,6 +9,8 @@ from dataclasses import dataclass
|
||||
|
||||
from services.llm_providers.main_video_generation import ai_video_generate
|
||||
from utils.logger_utils import get_service_logger
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
from services.database import SessionLocal
|
||||
|
||||
logger = get_service_logger("product_marketing.video")
|
||||
|
||||
@@ -212,6 +214,62 @@ class ProductVideoService:
|
||||
result["file_url"] = file_url
|
||||
result["file_size"] = len(video_bytes)
|
||||
|
||||
# Save to Asset Library
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Build video prompt for metadata
|
||||
video_prompt = self._build_video_prompt(
|
||||
video_type=request.video_type,
|
||||
product_name=request.product_name,
|
||||
product_description=request.product_description,
|
||||
brand_context=request.brand_context,
|
||||
additional_context=request.additional_context
|
||||
)
|
||||
|
||||
asset_id = save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="video",
|
||||
source_module="product_marketing",
|
||||
filename=filename,
|
||||
file_url=file_url,
|
||||
file_path=str(file_path),
|
||||
file_size=len(video_bytes),
|
||||
mime_type="video/mp4",
|
||||
title=f"{request.product_name} - {request.video_type.replace('_', ' ').title()} Video",
|
||||
description=f"Product video: {request.product_description or request.product_name}",
|
||||
prompt=video_prompt,
|
||||
tags=["product_marketing", "product_video", request.video_type, request.resolution],
|
||||
provider=result.get("provider", "wavespeed"),
|
||||
model=result.get("model_name", "alibaba/wan-2.5/text-to-video"),
|
||||
cost=result.get("cost", 0.0),
|
||||
generation_time=result.get("generation_time"),
|
||||
asset_metadata={
|
||||
"product_name": request.product_name,
|
||||
"product_description": request.product_description,
|
||||
"video_type": request.video_type,
|
||||
"resolution": request.resolution,
|
||||
"duration": request.duration,
|
||||
"width": result.get("width"),
|
||||
"height": result.get("height"),
|
||||
},
|
||||
)
|
||||
|
||||
if asset_id:
|
||||
logger.info(f"[Product Video] ✅ Saved video to Asset Library: ID={asset_id}")
|
||||
else:
|
||||
logger.warning(f"[Product Video] ⚠️ Asset Library save returned None")
|
||||
|
||||
except Exception as db_error:
|
||||
logger.error(f"[Product Video] Database error saving to Asset Library: {str(db_error)}", exc_info=True)
|
||||
# Video is saved, but database tracking failed - not critical
|
||||
finally:
|
||||
if db:
|
||||
try:
|
||||
db.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(
|
||||
f"[Product Video] ✅ Product video generated successfully: "
|
||||
f"cost=${result.get('cost', 0):.2f}, video_url={file_url}"
|
||||
|
||||
@@ -154,7 +154,17 @@ class IntentAwareAnalyzer:
|
||||
"primary_answer": {"type": "string"},
|
||||
"secondary_answers": {
|
||||
"type": "object",
|
||||
"additionalProperties": {"type": "string"}
|
||||
"additionalProperties": {"oneOf": [{"type": "string"}, {"type": "null"}]}
|
||||
},
|
||||
"focus_areas_coverage": {
|
||||
"type": "object",
|
||||
"additionalProperties": {"oneOf": [{"type": "string"}, {"type": "null"}]},
|
||||
"description": "Summary of what was found for each focus area, or null if not covered"
|
||||
},
|
||||
"also_answering_coverage": {
|
||||
"type": "object",
|
||||
"additionalProperties": {"oneOf": [{"type": "string"}, {"type": "null"}]},
|
||||
"description": "Information found about each 'also answering' topic, or null if not found"
|
||||
},
|
||||
"executive_summary": {"type": "string"},
|
||||
"key_takeaways": {
|
||||
@@ -469,10 +479,21 @@ class IntentAwareAnalyzer:
|
||||
if not sources:
|
||||
sources = self._extract_sources_from_raw(raw_results)
|
||||
|
||||
# Parse coverage fields (handle null values)
|
||||
focus_areas_coverage = {}
|
||||
for area, coverage in result.get("focus_areas_coverage", {}).items():
|
||||
focus_areas_coverage[area] = coverage if coverage else None
|
||||
|
||||
also_answering_coverage = {}
|
||||
for topic, coverage in result.get("also_answering_coverage", {}).items():
|
||||
also_answering_coverage[topic] = coverage if coverage else None
|
||||
|
||||
return IntentDrivenResearchResult(
|
||||
success=True,
|
||||
primary_answer=result.get("primary_answer", ""),
|
||||
secondary_answers=result.get("secondary_answers", {}),
|
||||
focus_areas_coverage=focus_areas_coverage,
|
||||
also_answering_coverage=also_answering_coverage,
|
||||
statistics=statistics,
|
||||
expert_quotes=expert_quotes,
|
||||
case_studies=case_studies,
|
||||
@@ -534,6 +555,8 @@ class IntentAwareAnalyzer:
|
||||
success=True,
|
||||
primary_answer=f"Research findings for: {intent.primary_question}",
|
||||
secondary_answers={},
|
||||
focus_areas_coverage={area: None for area in intent.focus_areas} if intent.focus_areas else {},
|
||||
also_answering_coverage={topic: None for topic in intent.also_answering} if intent.also_answering else {},
|
||||
executive_summary=content[:300] if content else "Research completed",
|
||||
key_takeaways=key_takeaways,
|
||||
sources=sources,
|
||||
|
||||
@@ -11,6 +11,7 @@ Version: 1.0
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List, Optional
|
||||
from loguru import logger
|
||||
|
||||
@@ -27,6 +28,14 @@ from models.research_persona_models import ResearchPersona
|
||||
class IntentPromptBuilder:
|
||||
"""Builds prompts for intent-driven research."""
|
||||
|
||||
def _get_current_date_context(self) -> str:
|
||||
"""Get current date/time context for prompts."""
|
||||
now = datetime.now()
|
||||
current_year = now.year
|
||||
current_month = now.strftime("%B") # Full month name
|
||||
current_date = now.strftime("%Y-%m-%d")
|
||||
return f"CURRENT DATE: {current_date} ({current_month} {current_year})\nCURRENT YEAR: {current_year}"
|
||||
|
||||
# Purpose explanations for the AI
|
||||
PURPOSE_EXPLANATIONS = {
|
||||
ResearchPurpose.LEARN: "User wants to understand a topic for personal knowledge",
|
||||
@@ -74,6 +83,11 @@ class IntentPromptBuilder:
|
||||
- What specific deliverables they need
|
||||
"""
|
||||
|
||||
# Get current date context
|
||||
date_context = self._get_current_date_context()
|
||||
now = datetime.now()
|
||||
current_year = now.year
|
||||
|
||||
# Build persona context
|
||||
persona_context = self._build_persona_context(research_persona, industry, target_audience)
|
||||
|
||||
@@ -82,6 +96,11 @@ class IntentPromptBuilder:
|
||||
|
||||
prompt = f"""You are an expert research intent analyzer. Your job is to understand what a content creator REALLY needs from their research.
|
||||
|
||||
## CURRENT DATE/TIME CONTEXT
|
||||
{date_context}
|
||||
|
||||
**NOTE**: When user mentions time-sensitive terms (latest, current, recent, trends, predictions), prioritize {current_year} data.
|
||||
|
||||
## USER INPUT
|
||||
"{user_input}"
|
||||
|
||||
@@ -97,7 +116,7 @@ class IntentPromptBuilder:
|
||||
Analyze the user's input and infer their research intent. Determine:
|
||||
|
||||
1. **INPUT TYPE**: Is this:
|
||||
- "keywords": Simple topic keywords (e.g., "AI healthcare 2025")
|
||||
- "keywords": Simple topic keywords (e.g., "AI healthcare {current_year}")
|
||||
- "question": A specific question (e.g., "What are the best AI tools for healthcare?")
|
||||
- "goal": A goal statement (e.g., "I need to write a blog about AI in healthcare")
|
||||
- "mixed": Combination of above
|
||||
@@ -210,8 +229,25 @@ Return a JSON object:
|
||||
if research_persona and research_persona.suggested_keywords:
|
||||
persona_keywords = f"\nSUGGESTED KEYWORDS FROM PERSONA: {', '.join(research_persona.suggested_keywords[:10])}"
|
||||
|
||||
# Get current date context
|
||||
date_context = self._get_current_date_context()
|
||||
now = datetime.now()
|
||||
current_year = now.year
|
||||
next_year = current_year + 1
|
||||
current_month_year = now.strftime("%B %Y")
|
||||
|
||||
prompt = f"""You are a research query optimizer. Generate multiple targeted search queries based on the user's research intent.
|
||||
|
||||
## CURRENT DATE/TIME CONTEXT
|
||||
{date_context}
|
||||
|
||||
**CRITICAL**: When generating queries:
|
||||
- ALWAYS use the CURRENT YEAR ({current_year}) for time-sensitive queries
|
||||
- For trends, predictions, or future-looking queries, use {current_year} or {next_year}
|
||||
- For recent/real-time queries, use current month/year: {current_month_year}
|
||||
- NEVER use outdated years from training data (e.g., 2024, 2025 if we're past those dates)
|
||||
- When user mentions "latest", "current", "recent", or time-sensitive terms, prioritize {current_year} data
|
||||
|
||||
## RESEARCH INTENT
|
||||
|
||||
PRIMARY QUESTION: {intent.primary_question}
|
||||
@@ -256,14 +292,14 @@ Return a JSON object:
|
||||
{{
|
||||
"queries": [
|
||||
{{
|
||||
"query": "Healthcare AI adoption statistics 2025 hospitals implementation data",
|
||||
"query": "Healthcare AI adoption statistics {current_year} hospitals implementation data",
|
||||
"purpose": "key_statistics",
|
||||
"provider": "exa",
|
||||
"priority": 5,
|
||||
"expected_results": "Statistics on hospital AI adoption rates"
|
||||
}},
|
||||
{{
|
||||
"query": "AI healthcare trends predictions future outlook 2025 2026",
|
||||
"query": "AI healthcare trends predictions future outlook {current_year} {next_year}",
|
||||
"purpose": "trends",
|
||||
"provider": "tavily",
|
||||
"priority": 4,
|
||||
@@ -280,13 +316,14 @@ Return a JSON object:
|
||||
|
||||
## QUERY OPTIMIZATION RULES
|
||||
|
||||
1. For STATISTICS: Include words like "statistics", "data", "percentage", "report", "study"
|
||||
1. For STATISTICS: Include words like "statistics", "data", "percentage", "report", "study", and CURRENT YEAR ({current_year})
|
||||
2. For CASE STUDIES: Include "case study", "success story", "implementation", "example"
|
||||
3. For TRENDS: Include "trends", "future", "predictions", "emerging", year numbers
|
||||
3. For TRENDS: Include "trends", "future", "predictions", "emerging", and CURRENT YEAR ({current_year}) or {next_year}
|
||||
4. For EXPERT QUOTES: Include expert names if known, or "expert opinion", "interview"
|
||||
5. For COMPARISONS: Include "vs", "compare", "comparison", "alternative"
|
||||
6. For NEWS/REAL-TIME: Use Tavily, include recent year/month
|
||||
6. For NEWS/REAL-TIME: Use Tavily, include CURRENT YEAR ({current_year}) and current month/year ({current_month_year})
|
||||
7. For ACADEMIC/DEEP: Use Exa with neural search
|
||||
8. **CRITICAL**: Always use {current_year} (not outdated years) for time-sensitive queries
|
||||
"""
|
||||
|
||||
return prompt
|
||||
@@ -314,23 +351,43 @@ Return a JSON object:
|
||||
if intent.perspective:
|
||||
perspective_instruction = f"\n**PERSPECTIVE**: Analyze results from the viewpoint of: {intent.perspective}"
|
||||
|
||||
# Get current date context
|
||||
date_context = self._get_current_date_context()
|
||||
now = datetime.now()
|
||||
current_year = now.year
|
||||
|
||||
prompt = f"""You are a research analyst helping a content creator find exactly what they need. Your job is to analyze raw research results and extract precisely what the user is looking for.
|
||||
|
||||
## CURRENT DATE/TIME CONTEXT
|
||||
{date_context}
|
||||
|
||||
**CRITICAL**: When analyzing results:
|
||||
- Prioritize data from CURRENT YEAR ({current_year}) or recent dates
|
||||
- If statistics/quotes mention outdated years, note the recency in context
|
||||
- For trends/predictions, ensure timelines reference {current_year} or future years
|
||||
- NEVER present outdated data as "current" or "latest" - always check dates
|
||||
|
||||
## USER'S RESEARCH INTENT
|
||||
|
||||
PRIMARY QUESTION: {intent.primary_question}
|
||||
**PRIMARY QUESTION**: {intent.primary_question}
|
||||
|
||||
SECONDARY QUESTIONS:
|
||||
**SECONDARY QUESTIONS TO ANSWER**:
|
||||
{chr(10).join(f'- {q}' for q in intent.secondary_questions) if intent.secondary_questions else 'None specified'}
|
||||
|
||||
PURPOSE: {intent.purpose}
|
||||
**FOCUS AREAS** (prioritize information related to these):
|
||||
{', '.join(intent.focus_areas) if intent.focus_areas else 'General - no specific focus areas'}
|
||||
|
||||
**ALSO ANSWERING** (address these topics if found in results):
|
||||
{', '.join(intent.also_answering) if intent.also_answering else 'None specified'}
|
||||
|
||||
**PURPOSE**: {intent.purpose}
|
||||
→ {purpose_explanation}
|
||||
|
||||
CONTENT OUTPUT: {intent.content_output}
|
||||
**CONTENT OUTPUT**: {intent.content_output}
|
||||
|
||||
EXPECTED DELIVERABLES: {', '.join(intent.expected_deliverables)}
|
||||
**EXPECTED DELIVERABLES**: {', '.join(intent.expected_deliverables)}
|
||||
|
||||
FOCUS AREAS: {', '.join(intent.focus_areas) if intent.focus_areas else 'General'}
|
||||
**PERSPECTIVE**: {intent.perspective or 'General audience'}
|
||||
{perspective_instruction}
|
||||
|
||||
## RAW RESEARCH RESULTS
|
||||
@@ -339,7 +396,33 @@ FOCUS AREAS: {', '.join(intent.focus_areas) if intent.focus_areas else 'General'
|
||||
|
||||
## YOUR TASK
|
||||
|
||||
Analyze the raw research results and extract EXACTLY what the user needs.
|
||||
Analyze the raw research results and extract EXACTLY what the user needs. Use a **generalized approach** - don't over-optimize for specific fields, but ensure all intent aspects are considered naturally.
|
||||
|
||||
### ANALYSIS GUIDELINES:
|
||||
|
||||
1. **PRIMARY QUESTION**: Always provide a direct, clear answer to the primary question in 2-3 sentences.
|
||||
|
||||
2. **SECONDARY QUESTIONS**: For each secondary question, provide an answer if information is available in the results. If not available, note it in gaps_identified. Don't force answers - only include what's actually in the results.
|
||||
|
||||
3. **FOCUS AREAS**: When extracting deliverables, prioritize information that relates to the focus areas. If focus areas are specified:
|
||||
- Weight relevance scores higher for sources/content matching focus areas
|
||||
- Include focus area context in extracted statistics, quotes, case studies
|
||||
- If results don't address focus areas, note this in gaps_identified
|
||||
- Provide a brief summary of what was found for each focus area in focus_areas_coverage
|
||||
|
||||
4. **ALSO ANSWERING**: If results contain information about "also answering" topics, include it naturally in the analysis. Don't create separate sections unless the information is substantial. Provide a brief summary of what was found for each topic in also_answering_coverage.
|
||||
|
||||
5. **GENERALIZED EXTRACTION**:
|
||||
- Extract deliverables based on expected_deliverables
|
||||
- Use perspective to frame information appropriately
|
||||
- Consider content_output when structuring results
|
||||
- Don't over-optimize - let the results guide what's extracted
|
||||
|
||||
6. **CONTEXTUAL LINKING**: When extracting information, consider:
|
||||
- How it relates to the primary question
|
||||
- Which secondary questions it answers
|
||||
- Which focus areas it addresses
|
||||
- This helps create a cohesive research result
|
||||
|
||||
{deliverables_instructions}
|
||||
|
||||
@@ -351,8 +434,16 @@ Provide results in this JSON structure:
|
||||
{{
|
||||
"primary_answer": "Direct 2-3 sentence answer to the primary question",
|
||||
"secondary_answers": {{
|
||||
"Question 1?": "Answer to question 1",
|
||||
"Question 2?": "Answer to question 2"
|
||||
"Secondary Question 1?": "Answer if found in results, or null if not available",
|
||||
"Secondary Question 2?": "Answer if found in results, or null if not available"
|
||||
}},
|
||||
"focus_areas_coverage": {{
|
||||
"Focus Area 1": "Brief summary of what was found related to this focus area, or null if not covered",
|
||||
"Focus Area 2": "Brief summary of what was found related to this focus area, or null if not covered"
|
||||
}},
|
||||
"also_answering_coverage": {{
|
||||
"Topic 1": "Information found about this topic, or null if not found",
|
||||
"Topic 2": "Information found about this topic, or null if not found"
|
||||
}},
|
||||
"executive_summary": "2-3 sentence executive summary of all findings",
|
||||
"key_takeaways": [
|
||||
@@ -364,13 +455,13 @@ Provide results in this JSON structure:
|
||||
],
|
||||
"statistics": [
|
||||
{{
|
||||
"statistic": "72% of hospitals plan to adopt AI by 2025",
|
||||
"statistic": "72% of hospitals plan to adopt AI by {current_year}",
|
||||
"value": "72%",
|
||||
"context": "Survey of 500 US hospitals in 2024",
|
||||
"source": "Healthcare AI Report 2024",
|
||||
"context": "Survey of 500 US hospitals in {current_year}",
|
||||
"source": "Healthcare AI Report {current_year}",
|
||||
"url": "https://example.com/report",
|
||||
"credibility": 0.9,
|
||||
"recency": "2024"
|
||||
"recency": "{current_year}"
|
||||
}}
|
||||
],
|
||||
"expert_quotes": [
|
||||
@@ -401,7 +492,7 @@ Provide results in this JSON structure:
|
||||
"direction": "growing",
|
||||
"evidence": ["25% YoY growth", "Major hospital chains investing"],
|
||||
"impact": "Could reduce misdiagnosis by 30%",
|
||||
"timeline": "Expected mainstream by 2027",
|
||||
"timeline": "Expected mainstream by {current_year + 2}",
|
||||
"sources": ["url1", "url2"]
|
||||
}}
|
||||
],
|
||||
@@ -442,7 +533,7 @@ Provide results in this JSON structure:
|
||||
"Example: Hospital X reduced readmissions by 25% using predictive AI"
|
||||
],
|
||||
"predictions": [
|
||||
"By 2030, AI will assist in 80% of initial diagnoses"
|
||||
"By {current_year + 5}, AI will assist in 80% of initial diagnoses"
|
||||
],
|
||||
"suggested_outline": [
|
||||
"1. Introduction: The AI Healthcare Revolution",
|
||||
@@ -454,7 +545,7 @@ Provide results in this JSON structure:
|
||||
],
|
||||
"sources": [
|
||||
{{
|
||||
"title": "Healthcare AI Report 2024",
|
||||
"title": "Healthcare AI Report {current_year}",
|
||||
"url": "https://example.com",
|
||||
"relevance_score": 0.95,
|
||||
"relevance_reason": "Directly addresses adoption statistics",
|
||||
@@ -468,7 +559,7 @@ Provide results in this JSON structure:
|
||||
"Limited information on regulatory challenges"
|
||||
],
|
||||
"follow_up_queries": [
|
||||
"AI healthcare regulations FDA 2025",
|
||||
"AI healthcare regulations FDA {current_year}",
|
||||
"Small clinic AI implementation costs"
|
||||
]
|
||||
}}
|
||||
@@ -486,6 +577,8 @@ Provide results in this JSON structure:
|
||||
8. **Suggest follow_up_queries** for gaps or incomplete areas
|
||||
9. **Rate confidence** based on how well results match the user's intent
|
||||
10. **Include deliverables ONLY if they are in expected_deliverables** or critical to the question
|
||||
11. **Don't over-optimize** - use a natural, generalized approach that considers all intent fields without forcing connections
|
||||
12. **For focus_areas_coverage and also_answering_coverage**: Only include entries for focus areas/topics that actually have information in the results. Use null for areas/topics not covered.
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
@@ -137,6 +137,11 @@ class IntentQueryGenerator:
|
||||
provider=q.get("provider", "exa"),
|
||||
priority=min(max(int(q.get("priority", 3)), 1), 5), # Clamp 1-5
|
||||
expected_results=q.get("expected_results", ""),
|
||||
addresses_primary_question=q.get("addresses_primary_question", False),
|
||||
addresses_secondary_questions=q.get("addresses_secondary_questions", []),
|
||||
targets_focus_areas=q.get("targets_focus_areas", []),
|
||||
covers_also_answering=q.get("covers_also_answering", []),
|
||||
justification=q.get("justification"),
|
||||
)
|
||||
queries.append(query)
|
||||
except Exception as e:
|
||||
@@ -266,6 +271,10 @@ class IntentQueryGenerator:
|
||||
provider=template["provider"],
|
||||
priority=template["priority"],
|
||||
expected_results=template["expected"],
|
||||
addresses_primary_question=False,
|
||||
addresses_secondary_questions=[],
|
||||
targets_focus_areas=[],
|
||||
covers_also_answering=[],
|
||||
)
|
||||
|
||||
def _create_fallback_queries(self, intent: ResearchIntent) -> Dict[str, Any]:
|
||||
@@ -287,6 +296,10 @@ class IntentQueryGenerator:
|
||||
provider="exa",
|
||||
priority=5,
|
||||
expected_results="General information and insights",
|
||||
addresses_primary_question=True,
|
||||
addresses_secondary_questions=[],
|
||||
targets_focus_areas=[],
|
||||
covers_also_answering=[],
|
||||
))
|
||||
|
||||
return {
|
||||
@@ -357,10 +370,17 @@ class QueryOptimizer:
|
||||
if ExpectedDeliverable.TRENDS.value in deliverables:
|
||||
topic = "news"
|
||||
|
||||
# Determine search depth
|
||||
search_depth = "basic"
|
||||
if intent.depth in ["detailed", "expert"]:
|
||||
search_depth = "advanced"
|
||||
# Determine search depth based on depth and time sensitivity
|
||||
# advanced = 2 credits (best quality), basic/fast/ultra-fast = 1 credit
|
||||
search_depth = "basic" # Default: balanced
|
||||
if intent.depth == "expert":
|
||||
search_depth = "advanced" # Best quality for expert research
|
||||
elif intent.depth == "detailed":
|
||||
search_depth = "advanced" # Better snippets for detailed research
|
||||
elif intent.time_sensitivity == "real_time":
|
||||
search_depth = "ultra-fast" # Minimize latency for real-time
|
||||
elif intent.time_sensitivity == "recent":
|
||||
search_depth = "fast" # Good balance for recent content
|
||||
|
||||
# Include answer for factual queries
|
||||
include_answer = False
|
||||
|
||||
121
backend/services/research/intent/query_deduplicator.py
Normal file
121
backend/services/research/intent/query_deduplicator.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
Query deduplication logic for unified research analyzer.
|
||||
|
||||
Removes redundant queries that would return similar results
|
||||
and ensures queries are linked to intent fields.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from loguru import logger
|
||||
|
||||
from models.research_intent_models import ResearchIntent, ResearchQuery
|
||||
|
||||
|
||||
def deduplicate_queries(
|
||||
queries: List[ResearchQuery],
|
||||
intent: ResearchIntent
|
||||
) -> List[ResearchQuery]:
|
||||
"""
|
||||
Remove redundant queries that would return similar results.
|
||||
|
||||
Rules:
|
||||
1. If two queries are semantically very similar (same keywords, same purpose), merge them
|
||||
2. If a query can answer multiple secondary questions, combine them
|
||||
3. If focus areas overlap significantly, don't create separate queries
|
||||
4. Maximum 8 queries - prioritize by importance
|
||||
5. Always keep the primary query (addresses_primary_question=True)
|
||||
"""
|
||||
if len(queries) <= 8:
|
||||
# Still check for exact duplicates
|
||||
seen_queries = set()
|
||||
deduplicated = []
|
||||
for query in queries:
|
||||
query_key = (query.query.lower().strip(), query.provider)
|
||||
if query_key not in seen_queries:
|
||||
seen_queries.add(query_key)
|
||||
deduplicated.append(query)
|
||||
return deduplicated
|
||||
|
||||
# Sort by priority (highest first)
|
||||
queries.sort(key=lambda q: q.priority, reverse=True)
|
||||
|
||||
# Always keep primary query
|
||||
primary_queries = [q for q in queries if q.addresses_primary_question]
|
||||
other_queries = [q for q in queries if not q.addresses_primary_question]
|
||||
|
||||
deduplicated = []
|
||||
seen_keywords = set()
|
||||
|
||||
# Add primary queries first (should be only one, but handle multiple)
|
||||
for query in primary_queries:
|
||||
query_key = (query.query.lower().strip(), query.provider)
|
||||
if query_key not in seen_keywords:
|
||||
seen_keywords.add(query_key)
|
||||
deduplicated.append(query)
|
||||
|
||||
# Process other queries with similarity checking
|
||||
for query in other_queries:
|
||||
query_key = (query.query.lower().strip(), query.provider)
|
||||
|
||||
# Check for exact duplicate
|
||||
if query_key in seen_keywords:
|
||||
continue
|
||||
|
||||
# Check for semantic similarity with existing queries
|
||||
query_words = set(query.query.lower().split())
|
||||
is_duplicate = False
|
||||
|
||||
for existing in deduplicated:
|
||||
existing_words = set(existing.query.lower().split())
|
||||
|
||||
# Calculate Jaccard similarity (intersection over union)
|
||||
intersection = query_words & existing_words
|
||||
union = query_words | existing_words
|
||||
similarity = len(intersection) / len(union) if union else 0
|
||||
|
||||
# CRITICAL: Don't merge queries that target different focus areas or also_answering topics
|
||||
# These should remain separate even if they're similar
|
||||
query_focus_areas = set(query.targets_focus_areas)
|
||||
existing_focus_areas = set(existing.targets_focus_areas)
|
||||
query_also_answering = set(query.covers_also_answering)
|
||||
existing_also_answering = set(existing.covers_also_answering)
|
||||
|
||||
# If queries target different focus areas, keep them separate
|
||||
if query_focus_areas and existing_focus_areas and query_focus_areas != existing_focus_areas:
|
||||
continue # Keep separate - different focus areas
|
||||
|
||||
# If queries cover different also_answering topics, keep them separate
|
||||
if query_also_answering and existing_also_answering and query_also_answering != existing_also_answering:
|
||||
continue # Keep separate - different also_answering topics
|
||||
|
||||
# Only consider duplicate if >90% similarity (increased from 80%) AND same purpose/provider AND same focus/also_answering
|
||||
# This is more strict to avoid over-deduplication
|
||||
if similarity > 0.9 and query.purpose == existing.purpose and query.provider == existing.provider:
|
||||
# Only merge if they truly target the same things
|
||||
if query_focus_areas == existing_focus_areas and query_also_answering == existing_also_answering:
|
||||
is_duplicate = True
|
||||
# Merge: update existing query's linking arrays
|
||||
existing.addresses_secondary_questions = list(set(
|
||||
existing.addresses_secondary_questions + query.addresses_secondary_questions
|
||||
))
|
||||
existing.targets_focus_areas = list(set(
|
||||
existing.targets_focus_areas + query.targets_focus_areas
|
||||
))
|
||||
existing.covers_also_answering = list(set(
|
||||
existing.covers_also_answering + query.covers_also_answering
|
||||
))
|
||||
# Update expected_results to reflect merged coverage
|
||||
if query.expected_results and query.expected_results not in existing.expected_results:
|
||||
existing.expected_results += f" Also covers: {query.expected_results}"
|
||||
break
|
||||
|
||||
if not is_duplicate:
|
||||
deduplicated.append(query)
|
||||
seen_keywords.add(query_key)
|
||||
|
||||
# Limit to 8 queries total
|
||||
if len(deduplicated) >= 8:
|
||||
break
|
||||
|
||||
logger.info(f"Deduplicated queries: {len(queries)} -> {len(deduplicated)}")
|
||||
return deduplicated
|
||||
112
backend/services/research/intent/unified_analyzer_utils.py
Normal file
112
backend/services/research/intent/unified_analyzer_utils.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""
|
||||
Utility functions for unified research analyzer.
|
||||
|
||||
Provides helper functions for date context, persona context,
|
||||
competitor context, and fallback response creation.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from models.research_intent_models import ResearchIntent, ResearchQuery
|
||||
from models.research_persona_models import ResearchPersona
|
||||
|
||||
|
||||
def get_current_date_context() -> str:
|
||||
"""Get current date/time context for prompts."""
|
||||
now = datetime.now()
|
||||
current_year = now.year
|
||||
current_month = now.strftime("%B") # Full month name
|
||||
current_date = now.strftime("%Y-%m-%d")
|
||||
return f"CURRENT DATE: {current_date} ({current_month} {current_year})\nCURRENT YEAR: {current_year}"
|
||||
|
||||
|
||||
def build_persona_context(
|
||||
research_persona: Optional[ResearchPersona],
|
||||
industry: Optional[str],
|
||||
target_audience: Optional[str],
|
||||
) -> str:
|
||||
"""Build persona context section."""
|
||||
parts = []
|
||||
|
||||
if research_persona:
|
||||
if research_persona.default_industry:
|
||||
parts.append(f"Industry: {research_persona.default_industry}")
|
||||
if research_persona.default_target_audience:
|
||||
parts.append(f"Target Audience: {research_persona.default_target_audience}")
|
||||
if research_persona.research_angles:
|
||||
parts.append(f"Preferred Research Angles: {', '.join(research_persona.research_angles[:3])}")
|
||||
if research_persona.suggested_keywords:
|
||||
parts.append(f"Relevant Keywords: {', '.join(research_persona.suggested_keywords[:5])}")
|
||||
else:
|
||||
if industry:
|
||||
parts.append(f"Industry: {industry}")
|
||||
if target_audience:
|
||||
parts.append(f"Target Audience: {target_audience}")
|
||||
|
||||
if not parts:
|
||||
return "No specific user context available. Use general best practices."
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def build_competitor_context(competitor_data: Optional[List[Dict]]) -> str:
|
||||
"""Build competitor context section."""
|
||||
if not competitor_data:
|
||||
return ""
|
||||
|
||||
competitor_names = [c.get("name", c.get("url", "")) for c in competitor_data[:5]]
|
||||
if competitor_names:
|
||||
return f"\nKnown Competitors: {', '.join(competitor_names)}"
|
||||
return ""
|
||||
|
||||
|
||||
def create_fallback_response(user_input: str, keywords: List[str]) -> Dict[str, Any]:
|
||||
"""Create fallback response when analysis fails."""
|
||||
return {
|
||||
"success": False,
|
||||
"intent": ResearchIntent(
|
||||
primary_question=f"What are the key insights about: {user_input}?",
|
||||
purpose="learn",
|
||||
content_output="general",
|
||||
expected_deliverables=["key_statistics", "best_practices"],
|
||||
depth="detailed",
|
||||
focus_areas=[],
|
||||
also_answering=[],
|
||||
original_input=user_input,
|
||||
confidence=0.5,
|
||||
),
|
||||
"queries": [
|
||||
ResearchQuery(
|
||||
query=user_input,
|
||||
purpose="key_statistics",
|
||||
provider="exa",
|
||||
priority=5,
|
||||
expected_results="General research results",
|
||||
addresses_primary_question=True,
|
||||
addresses_secondary_questions=[],
|
||||
targets_focus_areas=[],
|
||||
covers_also_answering=[],
|
||||
)
|
||||
],
|
||||
"enhanced_keywords": keywords,
|
||||
"research_angles": [],
|
||||
"recommended_provider": "exa",
|
||||
"provider_justification": "Default fallback to Exa for semantic search",
|
||||
"exa_config": {
|
||||
"enabled": True,
|
||||
"type": "auto",
|
||||
"type_justification": "Auto mode for balanced results",
|
||||
"numResults": 10,
|
||||
"highlights": True,
|
||||
},
|
||||
"tavily_config": {
|
||||
"enabled": True,
|
||||
"topic": "general",
|
||||
"search_depth": "advanced",
|
||||
"include_answer": True,
|
||||
},
|
||||
"trends_config": {
|
||||
"enabled": False, # Disabled in fallback
|
||||
},
|
||||
}
|
||||
277
backend/services/research/intent/unified_prompt_builder.py
Normal file
277
backend/services/research/intent/unified_prompt_builder.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""
|
||||
Prompt builder for unified research analyzer.
|
||||
|
||||
Builds the comprehensive LLM prompt that guides intent inference,
|
||||
query generation, and parameter optimization in a single call.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from models.research_persona_models import ResearchPersona
|
||||
from .unified_analyzer_utils import (
|
||||
get_current_date_context,
|
||||
build_persona_context,
|
||||
build_competitor_context,
|
||||
)
|
||||
|
||||
|
||||
def build_unified_prompt(
|
||||
user_input: str,
|
||||
keywords: List[str],
|
||||
research_persona: Optional[ResearchPersona] = None,
|
||||
competitor_data: Optional[List[Dict]] = None,
|
||||
industry: Optional[str] = None,
|
||||
target_audience: Optional[str] = None,
|
||||
user_provided_purpose: Optional[str] = None,
|
||||
user_provided_content_output: Optional[str] = None,
|
||||
user_provided_depth: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build the unified prompt for intent + queries + parameters.
|
||||
|
||||
This prompt guides the LLM to:
|
||||
1. Infer research intent (or use user-provided purpose/content_output/depth)
|
||||
2. Generate targeted queries linked to intent fields
|
||||
3. Optimize provider settings based on queries and intent
|
||||
"""
|
||||
# Get current date context
|
||||
date_context = get_current_date_context()
|
||||
now = datetime.now()
|
||||
current_year = now.year
|
||||
next_year = current_year + 1
|
||||
current_month_year = now.strftime("%B %Y")
|
||||
|
||||
# Build persona context
|
||||
persona_context = build_persona_context(research_persona, industry, target_audience)
|
||||
|
||||
# Build competitor context
|
||||
competitor_context = build_competitor_context(competitor_data)
|
||||
|
||||
prompt = f'''You are an expert AI research strategist. Analyze the user's research request and provide a complete research plan including intent understanding, search queries, and optimal API settings.
|
||||
|
||||
## CURRENT DATE/TIME CONTEXT
|
||||
{date_context}
|
||||
|
||||
**NOTE**: When user mentions time-sensitive terms (latest, current, recent, trends, predictions), prioritize {current_year} data.
|
||||
|
||||
## USER INPUT
|
||||
"{user_input}"
|
||||
{f"KEYWORDS: {', '.join(keywords)}" if keywords else ""}
|
||||
|
||||
## USER CONTEXT
|
||||
{persona_context}
|
||||
{competitor_context}
|
||||
{f'''
|
||||
## USER-PROVIDED INTENT SETTINGS
|
||||
The user has explicitly selected these settings - USE THESE VALUES, do NOT infer different ones:
|
||||
- purpose: {user_provided_purpose} (USE THIS EXACT VALUE)
|
||||
- content_output: {user_provided_content_output} (USE THIS EXACT VALUE)
|
||||
- depth: {user_provided_depth} (USE THIS EXACT VALUE)
|
||||
|
||||
IMPORTANT: Since the user has explicitly selected these, you should:
|
||||
1. Use the provided purpose, content_output, and depth values exactly as given
|
||||
2. Still infer secondary_questions, focus_areas, also_answering, and expected_deliverables based on the user input and these provided settings
|
||||
3. Generate queries that align with the user's explicit selections
|
||||
''' if (user_provided_purpose or user_provided_content_output or user_provided_depth) else ''}
|
||||
|
||||
## YOUR TASK: Provide a Complete Research Plan
|
||||
|
||||
### PART 1: INTENT ANALYSIS
|
||||
{f"Use the user-provided settings above. For fields not provided, infer what the user really wants from their research." if (user_provided_purpose or user_provided_content_output or user_provided_depth) else "Understand what the user really wants from their research."}
|
||||
|
||||
**CRITICAL: Use EXACT enum values - do NOT return descriptive strings.**
|
||||
- purpose: Must be one of: "learn", "create_content", "make_decision", "compare", "solve_problem", "find_data", "explore_trends", "validate", "generate_ideas"
|
||||
{f"**USER PROVIDED: {user_provided_purpose} - USE THIS EXACT VALUE**" if user_provided_purpose else "- Infer from user input"}
|
||||
- content_output: Must be one of: "blog", "podcast", "video", "social_post", "newsletter", "presentation", "report", "whitepaper", "email", "general"
|
||||
{f"**USER PROVIDED: {user_provided_content_output} - USE THIS EXACT VALUE**" if user_provided_content_output else "- Infer from user input"}
|
||||
- depth: Must be one of: "overview", "detailed", "expert"
|
||||
{f"**USER PROVIDED: {user_provided_depth} - USE THIS EXACT VALUE**" if user_provided_depth else "- Infer from user input"}
|
||||
- expected_deliverables: Must be an array of exact values: "key_statistics", "expert_quotes", "case_studies", "comparisons", "trends", "best_practices", "step_by_step", "pros_cons", "definitions", "citations", "examples", "predictions"
|
||||
- Infer based on purpose, content_output, and user input
|
||||
|
||||
**CRITICAL: ALWAYS generate focus_areas and also_answering fields:**
|
||||
- focus_areas: Generate 2-5 specific focus areas based on user input (e.g., "academic research", "industry trends", "company analysis", "practical applications", "safety considerations")
|
||||
- also_answering: Generate 2-4 related topics or questions that should also be addressed (e.g., "benefits and drawbacks", "alternatives", "implementation steps", "cost considerations")
|
||||
- These fields are REQUIRED and MUST be populated - do NOT leave them empty
|
||||
- Think about what additional aspects of the topic would be valuable to cover
|
||||
|
||||
### PART 2: SEARCH QUERIES
|
||||
Generate 4-8 targeted, diverse search queries optimized for semantic search.
|
||||
|
||||
**CRITICAL: Generate MULTIPLE DIVERSE queries (minimum 4, maximum 8). Do NOT generate just one query.**
|
||||
|
||||
**QUERY GENERATION RULES:**
|
||||
|
||||
1. **PRIMARY QUERY**: Generate 1 query that directly addresses the primary_question
|
||||
- This should be the highest priority (priority: 5)
|
||||
- Should comprehensively cover the main research goal
|
||||
- Set addresses_primary_question: true
|
||||
|
||||
2. **SECONDARY QUERY MAPPING**: For EACH secondary_question, generate a SEPARATE query that addresses it
|
||||
- Link each query to its corresponding secondary_question in addresses_secondary_questions array
|
||||
- Priority: 4 (high but secondary to primary)
|
||||
- **CRITICAL**: Create SEPARATE queries for each secondary question UNLESS they are extremely similar (same keywords, same search intent)
|
||||
- Only merge if queries would return identical results
|
||||
|
||||
3. **FOCUS AREA QUERIES**: Generate SEPARATE queries for EACH focus_area
|
||||
- **CRITICAL**: If focus_areas exist, generate AT LEAST one query per focus_area
|
||||
- Add each focus area to targets_focus_areas array for its corresponding query
|
||||
- Priority: 3-4 depending on importance
|
||||
- **CRITICAL**: Create SEPARATE queries for each focus_area UNLESS they are extremely similar (same search intent, same keywords)
|
||||
- Each focus area should have its own dedicated query to ensure comprehensive coverage
|
||||
|
||||
4. **ALSO ANSWERING QUERIES**: Generate queries for EACH also_answering topic
|
||||
- **CRITICAL**: Generate at least one query per also_answering topic that is NOT covered by primary/secondary queries
|
||||
- Lower priority (priority: 2-3)
|
||||
- Add each topic to covers_also_answering array for its corresponding query
|
||||
- Only skip if the topic is already fully covered by existing queries
|
||||
|
||||
5. **QUERY DIVERSITY RULES** (IMPORTANT):
|
||||
- **CRITICAL**: Ensure queries are DISTINCT and target DIFFERENT aspects
|
||||
- Vary search terms: use synonyms, related terms, different angles
|
||||
- Vary query structure: some specific, some broader
|
||||
- Vary providers: mix Exa and Tavily when appropriate
|
||||
- Target different content types: academic, news, practical guides, etc.
|
||||
- **DO NOT** create queries that are just slight variations of each other
|
||||
- **DO NOT** merge queries that target different focus areas or also_answering topics
|
||||
|
||||
6. **MINIMUM QUERY REQUIREMENTS**:
|
||||
- **ALWAYS generate at least 4 queries** (even for simple topics)
|
||||
- If you have: 1 primary + 1 secondary + 2 focus areas = generate at least 4 queries
|
||||
- If you have: 1 primary + 3 secondary + 2 focus areas + 2 also_answering = generate 6-8 queries
|
||||
- **If focus_areas or also_answering are empty, generate queries covering different angles/aspects of the primary question**
|
||||
|
||||
7. **QUERY-TO-INTENT LINKING**: For each query, specify:
|
||||
- addresses_primary_question: true/false (only one query should be true)
|
||||
- addresses_secondary_questions: array of secondary question strings (can be empty, or contain one/multiple)
|
||||
- targets_focus_areas: array of focus area strings (should match focus_areas when relevant)
|
||||
- covers_also_answering: array of also_answering topic strings (should match also_answering when relevant)
|
||||
- justification: brief explanation explaining how this query differs from others and what it will find
|
||||
|
||||
**OUTPUT FORMAT FOR QUERIES:**
|
||||
Each query must include these linking fields. Ensure queries are DIVERSE and target different aspects, not just variations of the same search.
|
||||
|
||||
### PART 3: PROVIDER SETTINGS
|
||||
Configure Exa and Tavily API parameters with justifications.
|
||||
|
||||
**Provider settings should be optimized based on:**
|
||||
1. **Primary query characteristics** (most important - this is what will be executed)
|
||||
2. **Secondary questions** (if they require different settings for comprehensive coverage)
|
||||
3. **Focus areas** (if they need specific content types or sources)
|
||||
4. **Also answering topics** (if they need different time ranges or sources)
|
||||
5. **Time sensitivity** from intent (real_time, recent, historical, evergreen)
|
||||
6. **Depth requirements** from intent (overview, detailed, expert)
|
||||
|
||||
**SETTING OPTIMIZATION RULES:**
|
||||
|
||||
1. **Time Sensitivity Based on Intent**:
|
||||
- If time_sensitivity = "real_time" OR any secondary_question/focus_area needs recent data:
|
||||
- Tavily: time_range = "day" or "week", topic = "news"
|
||||
- Exa: startPublishedDate = current year, type = "auto" or "fast"
|
||||
- If time_sensitivity = "historical":
|
||||
- Exa: No date filters, use historical content, type = "deep" or "neural"
|
||||
- Tavily: time_range = "year" or null, topic = "general"
|
||||
- If time_sensitivity = "recent":
|
||||
- Exa: startPublishedDate = current year or last 6 months
|
||||
- Tavily: time_range = "month" or "week"
|
||||
- If time_sensitivity = "evergreen":
|
||||
- Exa: No date filters, type = "deep" for comprehensive coverage
|
||||
- Tavily: time_range = null, topic = "general"
|
||||
|
||||
2. **Content Type Based on Focus Areas**:
|
||||
- If focus_areas include "academic" or "research" or "studies":
|
||||
- Exa: category = "research paper", includeDomains = ["arxiv.org", "nature.com", "pubmed.ncbi.nlm.nih.gov"]
|
||||
- Exa: type = "deep" or "neural" for comprehensive academic coverage
|
||||
- If focus_areas include "companies" or "competitors" or "business":
|
||||
- Exa: category = "company"
|
||||
- Exa: type = "auto" or "deep" for company research
|
||||
- If focus_areas include "news" or "trends" or "current events":
|
||||
- Tavily: topic = "news", search_depth = "advanced"
|
||||
- Exa: category = "news" (if using Exa for news)
|
||||
- If focus_areas include "social" or "twitter" or "social media":
|
||||
- Exa: category = "tweet"
|
||||
- If focus_areas include "github" or "code" or "technical":
|
||||
- Exa: category = "github"
|
||||
|
||||
3. **Depth Based on Intent Depth and Secondary Questions**:
|
||||
- If depth = "expert" OR secondary_questions require detailed analysis:
|
||||
- Exa: type = "deep", context = true, contextMaxCharacters = 15000+, numResults = 20-50
|
||||
- Tavily: search_depth = "advanced", chunks_per_source = 3, max_results = 15-20
|
||||
- If depth = "detailed":
|
||||
- Exa: type = "auto" or "deep", context = true, contextMaxCharacters = 10000+, numResults = 10-20
|
||||
- Tavily: search_depth = "advanced" or "basic", chunks_per_source = 3, max_results = 10-15
|
||||
- If depth = "overview":
|
||||
- Exa: type = "auto" or "fast", numResults = 5-10
|
||||
- Tavily: search_depth = "basic" or "fast", max_results = 5-10
|
||||
|
||||
4. **Query-Specific Settings (Primary Query Focus)**:
|
||||
- If primary query needs comprehensive results (addresses multiple secondary questions or focus areas):
|
||||
- Exa: type = "deep", context = true, contextMaxCharacters = 15000+
|
||||
- Tavily: search_depth = "advanced", chunks_per_source = 3
|
||||
- If primary query needs speed (simple factual answer):
|
||||
- Exa: type = "fast", numResults = 5-10
|
||||
- Tavily: search_depth = "ultra-fast", max_results = 5
|
||||
- If primary query targets specific content type:
|
||||
- Match Exa category or Tavily topic to content type
|
||||
- If primary query is time-sensitive:
|
||||
- Apply time filters based on urgency
|
||||
|
||||
5. **Also Answering Topics Considerations**:
|
||||
- If also_answering topics need different time ranges:
|
||||
- Use broader time_range in Tavily (e.g., "year" instead of "month")
|
||||
- Don't apply strict date filters in Exa
|
||||
- If also_answering topics need different sources:
|
||||
- Consider including additional domains in includeDomains
|
||||
- Use more comprehensive search (type = "deep" in Exa)
|
||||
|
||||
6. **Provider Selection Based on Intent**:
|
||||
- Use EXA when:
|
||||
* Primary query needs semantic understanding
|
||||
* Focus areas include "academic", "research", "companies"
|
||||
* Depth = "expert" or "detailed"
|
||||
* Need comprehensive context (context = true)
|
||||
- Use TAVILY when:
|
||||
* Time sensitivity = "real_time" or "recent"
|
||||
* Focus areas include "news", "trends", "current events"
|
||||
* Need quick AI-generated answers
|
||||
* Primary query is about recent developments
|
||||
|
||||
**NOTE**: Since we're executing only the PRIMARY query initially, optimize settings for the primary query, but ensure settings can accommodate secondary questions and focus areas in the results. The settings should be comprehensive enough to capture information relevant to all intent aspects.
|
||||
|
||||
### PART 4: GOOGLE TRENDS KEYWORDS (if trends in deliverables)
|
||||
If "trends" is in expected_deliverables OR purpose is "explore_trends":
|
||||
- Suggest 1-3 optimized keywords for Google Trends analysis
|
||||
- These may differ from research queries (trends need broader, searchable terms)
|
||||
- Consider: What keywords will show meaningful trends over time?
|
||||
- Consider: What timeframe will show relevant trends? (1 year, 12 months, etc.)
|
||||
- Consider: What geographic region is most relevant for the user?
|
||||
- Explain what insights trends will uncover for content generation:
|
||||
* Search interest trends over time (optimal publication timing)
|
||||
* Regional interest distribution (audience targeting)
|
||||
* Related topics for content expansion
|
||||
* Related queries for FAQ sections
|
||||
* Rising topics for timely content opportunities
|
||||
|
||||
---
|
||||
|
||||
## PROVIDER OPTIONS
|
||||
|
||||
**EXA**: type (auto/fast/deep/neural/keyword), category (company/research paper/news/etc), numResults (1-100), includeDomains, startPublishedDate, highlights, context (required for deep). Best for: academic, companies, deep analysis.
|
||||
|
||||
**TAVILY**: topic (general/news/finance), search_depth (advanced/basic/fast/ultra-fast), time_range, max_results (0-20), chunks_per_source (1-3). Best for: news, real-time, quick facts.
|
||||
|
||||
---
|
||||
|
||||
## OUTPUT FORMAT
|
||||
|
||||
Return JSON with: intent (all fields), queries (with linking fields), enhanced_keywords, research_angles, recommended_provider, provider_justification, exa_config (enabled, type, category, numResults, includeDomains, excludeDomains, startPublishedDate, highlights, context, contextMaxCharacters, and justifications), tavily_config (enabled, topic, search_depth, include_answer, time_range, max_results, chunks_per_source, and justifications), trends_config (if trends enabled).
|
||||
|
||||
**Key Requirements:**
|
||||
- Provide brief justifications (1 sentence) for all config parameters
|
||||
- Reference intent fields (depth, time_sensitivity, focus_areas) in justifications
|
||||
- Include current year ({current_year}) in time-sensitive queries
|
||||
- Use EXA for academic/companies/deep analysis, TAVILY for news/real-time
|
||||
'''
|
||||
|
||||
return prompt
|
||||
@@ -8,24 +8,17 @@ This reduces 2 LLM calls to 1, improves coherence, and provides
|
||||
user-friendly justifications for all settings.
|
||||
|
||||
Author: ALwrity Team
|
||||
Version: 1.0
|
||||
Version: 2.0 (Refactored)
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from typing import Dict, Any, List, Optional
|
||||
from loguru import logger
|
||||
|
||||
from models.research_intent_models import (
|
||||
ResearchIntent,
|
||||
ResearchQuery,
|
||||
IntentInferenceResponse,
|
||||
ResearchPurpose,
|
||||
ContentOutput,
|
||||
ExpectedDeliverable,
|
||||
ResearchDepthLevel,
|
||||
InputType,
|
||||
)
|
||||
from models.research_persona_models import ResearchPersona
|
||||
from .unified_prompt_builder import build_unified_prompt
|
||||
from .unified_schema_builder import build_unified_schema
|
||||
from .unified_result_parser import parse_unified_result
|
||||
from .unified_analyzer_utils import create_fallback_response
|
||||
|
||||
|
||||
class UnifiedResearchAnalyzer:
|
||||
@@ -36,6 +29,13 @@ class UnifiedResearchAnalyzer:
|
||||
3. Parameter optimization (Exa/Tavily settings)
|
||||
|
||||
All in a single LLM call with justifications.
|
||||
|
||||
Refactored to use modular components for better maintainability:
|
||||
- unified_prompt_builder: Builds the comprehensive LLM prompt
|
||||
- unified_schema_builder: Defines the JSON schema for structured output
|
||||
- unified_result_parser: Parses LLM response into structured models
|
||||
- unified_analyzer_utils: Utility functions for context and fallback
|
||||
- query_deduplicator: Removes redundant queries (used by parser)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -51,36 +51,56 @@ class UnifiedResearchAnalyzer:
|
||||
industry: Optional[str] = None,
|
||||
target_audience: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
user_provided_purpose: Optional[str] = None,
|
||||
user_provided_content_output: Optional[str] = None,
|
||||
user_provided_depth: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Perform unified analysis of user research request.
|
||||
|
||||
Args:
|
||||
user_input: The user's research input (keywords, question, etc.)
|
||||
keywords: Optional list of keywords
|
||||
research_persona: Optional research persona for personalization
|
||||
competitor_data: Optional competitor analysis data
|
||||
industry: Optional industry context
|
||||
target_audience: Optional target audience context
|
||||
user_id: User ID for subscription checks (required)
|
||||
|
||||
Returns:
|
||||
Dict containing:
|
||||
- success: bool
|
||||
- intent: ResearchIntent
|
||||
- queries: List[ResearchQuery]
|
||||
- exa_config: Dict with settings and justifications
|
||||
- tavily_config: Dict with settings and justifications
|
||||
- recommended_provider: str
|
||||
- provider_justification: str
|
||||
- trends_config: Dict with Google Trends settings (optional)
|
||||
- enhanced_keywords: List[str]
|
||||
- research_angles: List[str]
|
||||
- analysis_summary: str
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Unified analysis for: {user_input[:100]}...")
|
||||
|
||||
keywords = keywords or []
|
||||
|
||||
# Build the unified prompt
|
||||
prompt = self._build_unified_prompt(
|
||||
# Build the unified prompt using the prompt builder module
|
||||
prompt = build_unified_prompt(
|
||||
user_input=user_input,
|
||||
keywords=keywords,
|
||||
research_persona=research_persona,
|
||||
competitor_data=competitor_data,
|
||||
industry=industry,
|
||||
target_audience=target_audience,
|
||||
user_provided_purpose=user_provided_purpose,
|
||||
user_provided_content_output=user_provided_content_output,
|
||||
user_provided_depth=user_provided_depth,
|
||||
)
|
||||
|
||||
# Define the comprehensive JSON schema
|
||||
unified_schema = self._build_unified_schema()
|
||||
# Define the comprehensive JSON schema using the schema builder module
|
||||
unified_schema = build_unified_schema()
|
||||
|
||||
# Call LLM (single call for everything)
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
@@ -93,467 +113,11 @@ class UnifiedResearchAnalyzer:
|
||||
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
logger.error(f"Unified analysis failed: {result.get('error')}")
|
||||
return self._create_fallback_response(user_input, keywords)
|
||||
return create_fallback_response(user_input, keywords)
|
||||
|
||||
# Parse the unified result
|
||||
return self._parse_unified_result(result, user_input)
|
||||
# Parse the unified result using the result parser module
|
||||
return parse_unified_result(result, user_input)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in unified analysis: {e}")
|
||||
return self._create_fallback_response(user_input, keywords or [])
|
||||
|
||||
def _build_unified_prompt(
|
||||
self,
|
||||
user_input: str,
|
||||
keywords: List[str],
|
||||
research_persona: Optional[ResearchPersona] = None,
|
||||
competitor_data: Optional[List[Dict]] = None,
|
||||
industry: Optional[str] = None,
|
||||
target_audience: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Build the unified prompt for intent + queries + parameters."""
|
||||
|
||||
# Build persona context
|
||||
persona_context = self._build_persona_context(research_persona, industry, target_audience)
|
||||
|
||||
# Build competitor context
|
||||
competitor_context = self._build_competitor_context(competitor_data)
|
||||
|
||||
prompt = f'''You are an expert AI research strategist. Analyze the user's research request and provide a complete research plan including intent understanding, search queries, and optimal API settings.
|
||||
|
||||
## USER INPUT
|
||||
"{user_input}"
|
||||
{f"KEYWORDS: {', '.join(keywords)}" if keywords else ""}
|
||||
|
||||
## USER CONTEXT
|
||||
{persona_context}
|
||||
{competitor_context}
|
||||
|
||||
## YOUR TASK: Provide a Complete Research Plan
|
||||
|
||||
### PART 1: INTENT ANALYSIS
|
||||
Understand what the user really wants from their research.
|
||||
|
||||
### PART 2: SEARCH QUERIES
|
||||
Generate 4-8 targeted search queries optimized for semantic search.
|
||||
|
||||
### PART 3: PROVIDER SETTINGS
|
||||
Configure Exa and Tavily API parameters with justifications.
|
||||
|
||||
### PART 4: GOOGLE TRENDS KEYWORDS (if trends in deliverables)
|
||||
If "trends" is in expected_deliverables OR purpose is "explore_trends":
|
||||
- Suggest 1-3 optimized keywords for Google Trends analysis
|
||||
- These may differ from research queries (trends need broader, searchable terms)
|
||||
- Consider: What keywords will show meaningful trends over time?
|
||||
- Consider: What timeframe will show relevant trends? (1 year, 12 months, etc.)
|
||||
- Consider: What geographic region is most relevant for the user?
|
||||
- Explain what insights trends will uncover for content generation:
|
||||
* Search interest trends over time (optimal publication timing)
|
||||
* Regional interest distribution (audience targeting)
|
||||
* Related topics for content expansion
|
||||
* Related queries for FAQ sections
|
||||
* Rising topics for timely content opportunities
|
||||
|
||||
---
|
||||
|
||||
## AVAILABLE PROVIDER OPTIONS
|
||||
|
||||
### EXA API OPTIONS (Semantic Search Engine)
|
||||
| Parameter | Options | Description |
|
||||
|-----------|---------|-------------|
|
||||
| type | "auto", "neural", "fast", "deep" | "neural" = semantic understanding, "deep" = comprehensive with query expansion |
|
||||
| category | "company", "research paper", "news", "github", "tweet", "personal site", "pdf", "financial report", "people" | Focus on specific content types |
|
||||
| numResults | 5-25 | Number of results (10 recommended) |
|
||||
| includeDomains | string[] | Domains to include (e.g., ["arxiv.org", "nature.com"]) |
|
||||
| excludeDomains | string[] | Domains to exclude |
|
||||
| startPublishedDate | ISO date | Filter by publish date (e.g., "2024-01-01T00:00:00.000Z") |
|
||||
| text | boolean | Include full text content |
|
||||
| highlights | boolean | Extract key highlights |
|
||||
| context | boolean | Return as single context string for RAG |
|
||||
|
||||
**WHEN TO USE EXA:**
|
||||
- Semantic understanding needed (finding similar content)
|
||||
- Academic/research papers
|
||||
- Company/competitor research
|
||||
- Deep, comprehensive results
|
||||
- Historical content
|
||||
|
||||
### TAVILY API OPTIONS (AI-Powered Search)
|
||||
| Parameter | Options | Description |
|
||||
|-----------|---------|-------------|
|
||||
| topic | "general", "news", "finance" | Search topic category |
|
||||
| search_depth | "basic", "advanced" | "advanced" = multiple semantic snippets per URL |
|
||||
| include_answer | false, true, "basic", "advanced" | AI-generated answer from results |
|
||||
| include_raw_content | false, true, "markdown", "text" | Raw page content format |
|
||||
| time_range | "day", "week", "month", "year" | Filter by recency |
|
||||
| max_results | 5-20 | Number of results |
|
||||
| include_domains | string[] | Domains to include |
|
||||
| exclude_domains | string[] | Domains to exclude |
|
||||
|
||||
**WHEN TO USE TAVILY:**
|
||||
- Real-time/current events
|
||||
- News and trending topics
|
||||
- Quick facts with AI answers
|
||||
- Financial data
|
||||
- Recent time-sensitive content
|
||||
|
||||
---
|
||||
|
||||
## OUTPUT FORMAT
|
||||
|
||||
Return a JSON object with this exact structure:
|
||||
|
||||
```json
|
||||
{{
|
||||
"intent": {{
|
||||
"input_type": "keywords|question|goal|mixed",
|
||||
"primary_question": "The main question to answer",
|
||||
"secondary_questions": ["question 1", "question 2"],
|
||||
"purpose": "learn|create_content|make_decision|compare|solve_problem|find_data|explore_trends|validate|generate_ideas",
|
||||
"content_output": "blog|podcast|video|social_post|newsletter|presentation|report|whitepaper|email|general",
|
||||
"expected_deliverables": ["key_statistics", "expert_quotes", "case_studies", "trends", "best_practices"],
|
||||
"depth": "overview|detailed|expert",
|
||||
"focus_areas": ["area1", "area2"],
|
||||
"perspective": "target perspective or null",
|
||||
"time_sensitivity": "real_time|recent|historical|evergreen",
|
||||
"confidence": 0.85,
|
||||
"confidence_reason": "Why this confidence level",
|
||||
"great_example": "Example of better input if confidence < 0.8",
|
||||
"needs_clarification": false,
|
||||
"clarifying_questions": [],
|
||||
"analysis_summary": "Brief summary of research plan"
|
||||
}},
|
||||
"queries": [
|
||||
{{
|
||||
"query": "Optimized search query string",
|
||||
"purpose": "key_statistics|expert_quotes|case_studies|trends|etc",
|
||||
"provider": "exa|tavily",
|
||||
"priority": 5,
|
||||
"expected_results": "What we expect to find",
|
||||
"justification": "Why this query and provider"
|
||||
}}
|
||||
],
|
||||
"enhanced_keywords": ["expanded", "related", "keywords"],
|
||||
"research_angles": ["Angle 1: ...", "Angle 2: ..."],
|
||||
"recommended_provider": "exa|tavily",
|
||||
"provider_justification": "Why this provider is best for this research",
|
||||
"exa_config": {{
|
||||
"enabled": true,
|
||||
"type": "auto|neural|fast|deep",
|
||||
"type_justification": "Why this search type",
|
||||
"category": "news|research paper|company|etc or null",
|
||||
"category_justification": "Why this category or null",
|
||||
"numResults": 10,
|
||||
"numResults_justification": "Why this number",
|
||||
"includeDomains": [],
|
||||
"includeDomains_justification": "Why these domains or empty",
|
||||
"startPublishedDate": "2024-01-01T00:00:00.000Z or null",
|
||||
"date_justification": "Why this date filter or null",
|
||||
"highlights": true,
|
||||
"highlights_justification": "Why enable/disable highlights",
|
||||
"context": true,
|
||||
"context_justification": "Why enable/disable context string"
|
||||
}},
|
||||
"tavily_config": {{
|
||||
"enabled": true,
|
||||
"topic": "general|news|finance",
|
||||
"topic_justification": "Why this topic",
|
||||
"search_depth": "basic|advanced",
|
||||
"search_depth_justification": "Why this depth",
|
||||
"include_answer": "true|false|basic|advanced",
|
||||
"include_answer_justification": "Why this answer mode",
|
||||
"time_range": "day|week|month|year|null",
|
||||
"time_range_justification": "Why this time range or null",
|
||||
"max_results": 10,
|
||||
"max_results_justification": "Why this number",
|
||||
"include_raw_content": "false|true|markdown|text",
|
||||
"include_raw_content_justification": "Why this content mode"
|
||||
}},
|
||||
"trends_config": {{
|
||||
"enabled": true|false,
|
||||
"keywords": ["keyword1", "keyword2"],
|
||||
"keywords_justification": "Why these keywords for trends analysis",
|
||||
"timeframe": "today 1-y|today 12-m|all",
|
||||
"timeframe_justification": "Why this timeframe",
|
||||
"geo": "US|GB|IN|etc",
|
||||
"geo_justification": "Why this geographic region",
|
||||
"expected_insights": [
|
||||
"Search interest trends over the past year",
|
||||
"Regional interest distribution",
|
||||
"Related topics for content expansion",
|
||||
"Related queries for FAQ sections",
|
||||
"Optimal publication timing based on interest peaks"
|
||||
]
|
||||
}}
|
||||
}}
|
||||
```
|
||||
|
||||
## DECISION RULES
|
||||
|
||||
1. **Provider Selection:**
|
||||
- Use EXA for: academic research, competitor analysis, deep understanding, finding similar content
|
||||
- Use TAVILY for: news, current events, quick facts, financial data, real-time info
|
||||
|
||||
2. **Query Optimization:**
|
||||
- Include relevant keywords for semantic matching
|
||||
- Add context words based on deliverables (e.g., "statistics 2024" for key_statistics)
|
||||
- Match query style to provider (natural language for Exa, keyword-rich for Tavily)
|
||||
|
||||
3. **Parameter Selection:**
|
||||
- ALWAYS provide justification for each parameter choice
|
||||
- Consider time sensitivity when setting date filters
|
||||
- Match category/topic to content type
|
||||
- Use "advanced" depth when quality matters more than speed
|
||||
|
||||
4. **Google Trends Keywords (if trends enabled):**
|
||||
- Suggest 1-3 keywords optimized for trends analysis
|
||||
- Keywords should be broader than research queries (e.g., "AI marketing" vs "AI marketing tools for small businesses")
|
||||
- Consider what will show meaningful search interest trends
|
||||
- Choose timeframe based on content type (12 months for blogs, 1 year for comprehensive)
|
||||
- Select geo based on user's target audience or industry
|
||||
- List specific insights trends will uncover
|
||||
|
||||
5. **Justifications:**
|
||||
- Keep justifications concise (1 sentence)
|
||||
- Explain the "why" not the "what"
|
||||
- Reference user's intent when relevant
|
||||
'''
|
||||
|
||||
return prompt
|
||||
|
||||
def _build_unified_schema(self) -> Dict[str, Any]:
|
||||
"""Build the JSON schema for unified response."""
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"intent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input_type": {"type": "string", "enum": ["keywords", "question", "goal", "mixed"]},
|
||||
"primary_question": {"type": "string"},
|
||||
"secondary_questions": {"type": "array", "items": {"type": "string"}},
|
||||
"purpose": {"type": "string"},
|
||||
"content_output": {"type": "string"},
|
||||
"expected_deliverables": {"type": "array", "items": {"type": "string"}},
|
||||
"depth": {"type": "string", "enum": ["overview", "detailed", "expert"]},
|
||||
"focus_areas": {"type": "array", "items": {"type": "string"}},
|
||||
"perspective": {"type": "string"},
|
||||
"time_sensitivity": {"type": "string"},
|
||||
"confidence": {"type": "number"},
|
||||
"confidence_reason": {"type": "string"},
|
||||
"great_example": {"type": "string"},
|
||||
"needs_clarification": {"type": "boolean"},
|
||||
"clarifying_questions": {"type": "array", "items": {"type": "string"}},
|
||||
"analysis_summary": {"type": "string"}
|
||||
},
|
||||
"required": ["primary_question", "purpose", "expected_deliverables", "confidence"]
|
||||
},
|
||||
"queries": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"purpose": {"type": "string"},
|
||||
"provider": {"type": "string"},
|
||||
"priority": {"type": "integer"},
|
||||
"expected_results": {"type": "string"},
|
||||
"justification": {"type": "string"}
|
||||
},
|
||||
"required": ["query", "purpose", "provider", "priority"]
|
||||
}
|
||||
},
|
||||
"enhanced_keywords": {"type": "array", "items": {"type": "string"}},
|
||||
"research_angles": {"type": "array", "items": {"type": "string"}},
|
||||
"recommended_provider": {"type": "string"},
|
||||
"provider_justification": {"type": "string"},
|
||||
"exa_config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"enabled": {"type": "boolean"},
|
||||
"type": {"type": "string"},
|
||||
"type_justification": {"type": "string"},
|
||||
"category": {"type": "string"},
|
||||
"category_justification": {"type": "string"},
|
||||
"numResults": {"type": "integer"},
|
||||
"numResults_justification": {"type": "string"},
|
||||
"includeDomains": {"type": "array", "items": {"type": "string"}},
|
||||
"includeDomains_justification": {"type": "string"},
|
||||
"startPublishedDate": {"type": "string"},
|
||||
"date_justification": {"type": "string"},
|
||||
"highlights": {"type": "boolean"},
|
||||
"highlights_justification": {"type": "string"},
|
||||
"context": {"type": "boolean"},
|
||||
"context_justification": {"type": "string"}
|
||||
}
|
||||
},
|
||||
"tavily_config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"enabled": {"type": "boolean"},
|
||||
"topic": {"type": "string"},
|
||||
"topic_justification": {"type": "string"},
|
||||
"search_depth": {"type": "string"},
|
||||
"search_depth_justification": {"type": "string"},
|
||||
"include_answer": {"type": "string"},
|
||||
"include_answer_justification": {"type": "string"},
|
||||
"time_range": {"type": "string"},
|
||||
"time_range_justification": {"type": "string"},
|
||||
"max_results": {"type": "integer"},
|
||||
"max_results_justification": {"type": "string"},
|
||||
"include_raw_content": {"type": "string"},
|
||||
"include_raw_content_justification": {"type": "string"}
|
||||
}
|
||||
},
|
||||
"trends_config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"enabled": {"type": "boolean"},
|
||||
"keywords": {"type": "array", "items": {"type": "string"}},
|
||||
"keywords_justification": {"type": "string"},
|
||||
"timeframe": {"type": "string"},
|
||||
"timeframe_justification": {"type": "string"},
|
||||
"geo": {"type": "string"},
|
||||
"geo_justification": {"type": "string"},
|
||||
"expected_insights": {"type": "array", "items": {"type": "string"}}
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["intent", "queries", "recommended_provider", "exa_config", "tavily_config"]
|
||||
}
|
||||
|
||||
def _build_persona_context(
|
||||
self,
|
||||
research_persona: Optional[ResearchPersona],
|
||||
industry: Optional[str],
|
||||
target_audience: Optional[str],
|
||||
) -> str:
|
||||
"""Build persona context section."""
|
||||
parts = []
|
||||
|
||||
if research_persona:
|
||||
if research_persona.default_industry:
|
||||
parts.append(f"Industry: {research_persona.default_industry}")
|
||||
if research_persona.default_target_audience:
|
||||
parts.append(f"Target Audience: {research_persona.default_target_audience}")
|
||||
if research_persona.research_angles:
|
||||
parts.append(f"Preferred Research Angles: {', '.join(research_persona.research_angles[:3])}")
|
||||
if research_persona.suggested_keywords:
|
||||
parts.append(f"Relevant Keywords: {', '.join(research_persona.suggested_keywords[:5])}")
|
||||
else:
|
||||
if industry:
|
||||
parts.append(f"Industry: {industry}")
|
||||
if target_audience:
|
||||
parts.append(f"Target Audience: {target_audience}")
|
||||
|
||||
if not parts:
|
||||
return "No specific user context available. Use general best practices."
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
def _build_competitor_context(self, competitor_data: Optional[List[Dict]]) -> str:
|
||||
"""Build competitor context section."""
|
||||
if not competitor_data:
|
||||
return ""
|
||||
|
||||
competitor_names = [c.get("name", c.get("url", "")) for c in competitor_data[:5]]
|
||||
if competitor_names:
|
||||
return f"\nKnown Competitors: {', '.join(competitor_names)}"
|
||||
return ""
|
||||
|
||||
def _parse_unified_result(self, result: Dict[str, Any], user_input: str) -> Dict[str, Any]:
|
||||
"""Parse the unified LLM result into structured response."""
|
||||
|
||||
intent_data = result.get("intent", {})
|
||||
|
||||
# Build ResearchIntent
|
||||
intent = ResearchIntent(
|
||||
primary_question=intent_data.get("primary_question", user_input),
|
||||
secondary_questions=intent_data.get("secondary_questions", []),
|
||||
purpose=intent_data.get("purpose", "learn"),
|
||||
content_output=intent_data.get("content_output", "general"),
|
||||
expected_deliverables=intent_data.get("expected_deliverables", ["key_statistics"]),
|
||||
depth=intent_data.get("depth", "detailed"),
|
||||
focus_areas=intent_data.get("focus_areas", []),
|
||||
perspective=intent_data.get("perspective"),
|
||||
time_sensitivity=intent_data.get("time_sensitivity"),
|
||||
input_type=intent_data.get("input_type", "keywords"),
|
||||
original_input=user_input,
|
||||
confidence=float(intent_data.get("confidence", 0.7)),
|
||||
confidence_reason=intent_data.get("confidence_reason"),
|
||||
great_example=intent_data.get("great_example"),
|
||||
needs_clarification=intent_data.get("needs_clarification", False),
|
||||
clarifying_questions=intent_data.get("clarifying_questions", []),
|
||||
)
|
||||
|
||||
# Build queries
|
||||
queries = []
|
||||
for q in result.get("queries", []):
|
||||
try:
|
||||
queries.append(ResearchQuery(
|
||||
query=q.get("query", ""),
|
||||
purpose=q.get("purpose", "key_statistics"),
|
||||
provider=q.get("provider", "exa"),
|
||||
priority=int(q.get("priority", 3)),
|
||||
expected_results=q.get("expected_results", ""),
|
||||
))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse query: {e}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"intent": intent,
|
||||
"queries": queries,
|
||||
"enhanced_keywords": result.get("enhanced_keywords", []),
|
||||
"research_angles": result.get("research_angles", []),
|
||||
"recommended_provider": result.get("recommended_provider", "exa"),
|
||||
"provider_justification": result.get("provider_justification", ""),
|
||||
"exa_config": result.get("exa_config", {}),
|
||||
"tavily_config": result.get("tavily_config", {}),
|
||||
"trends_config": result.get("trends_config", {}), # NEW: Google Trends configuration
|
||||
"analysis_summary": intent_data.get("analysis_summary", ""),
|
||||
}
|
||||
|
||||
def _create_fallback_response(self, user_input: str, keywords: List[str]) -> Dict[str, Any]:
|
||||
"""Create fallback response when analysis fails."""
|
||||
return {
|
||||
"success": False,
|
||||
"intent": ResearchIntent(
|
||||
primary_question=f"What are the key insights about: {user_input}?",
|
||||
purpose="learn",
|
||||
content_output="general",
|
||||
expected_deliverables=["key_statistics", "best_practices"],
|
||||
depth="detailed",
|
||||
original_input=user_input,
|
||||
confidence=0.5,
|
||||
),
|
||||
"queries": [
|
||||
ResearchQuery(
|
||||
query=user_input,
|
||||
purpose="key_statistics",
|
||||
provider="exa",
|
||||
priority=5,
|
||||
expected_results="General research results",
|
||||
)
|
||||
],
|
||||
"enhanced_keywords": keywords,
|
||||
"research_angles": [],
|
||||
"recommended_provider": "exa",
|
||||
"provider_justification": "Default fallback to Exa for semantic search",
|
||||
"exa_config": {
|
||||
"enabled": True,
|
||||
"type": "auto",
|
||||
"type_justification": "Auto mode for balanced results",
|
||||
"numResults": 10,
|
||||
"highlights": True,
|
||||
},
|
||||
"tavily_config": {
|
||||
"enabled": True,
|
||||
"topic": "general",
|
||||
"search_depth": "advanced",
|
||||
"include_answer": True,
|
||||
},
|
||||
"trends_config": {
|
||||
"enabled": False, # Disabled in fallback
|
||||
},
|
||||
}
|
||||
return create_fallback_response(user_input, keywords or [])
|
||||
|
||||
209
backend/services/research/intent/unified_result_parser.py
Normal file
209
backend/services/research/intent/unified_result_parser.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""
|
||||
Result parsing logic for unified research analyzer.
|
||||
|
||||
Parses LLM response into structured ResearchIntent, ResearchQuery,
|
||||
and configuration dictionaries.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List
|
||||
from loguru import logger
|
||||
|
||||
from models.research_intent_models import (
|
||||
ResearchIntent, ResearchQuery,
|
||||
ResearchPurpose, ContentOutput, ExpectedDeliverable,
|
||||
ResearchDepthLevel, InputType
|
||||
)
|
||||
from .query_deduplicator import deduplicate_queries
|
||||
|
||||
|
||||
def _normalize_purpose(value: str) -> str:
|
||||
"""Normalize purpose value to enum."""
|
||||
if not value or not isinstance(value, str):
|
||||
return "learn"
|
||||
value_lower = value.lower()
|
||||
# Check for exact match
|
||||
for purpose in ResearchPurpose:
|
||||
if value_lower == purpose.value or value_lower == purpose.name.lower():
|
||||
return purpose.value
|
||||
# Check for keywords in description
|
||||
if "content" in value_lower or "write" in value_lower or "create" in value_lower or "blog" in value_lower:
|
||||
return "create_content"
|
||||
elif "compare" in value_lower or "comparison" in value_lower:
|
||||
return "compare"
|
||||
elif "decision" in value_lower or "choose" in value_lower:
|
||||
return "make_decision"
|
||||
elif "problem" in value_lower or "solve" in value_lower:
|
||||
return "solve_problem"
|
||||
elif "data" in value_lower or "statistic" in value_lower or "fact" in value_lower:
|
||||
return "find_data"
|
||||
elif "trend" in value_lower:
|
||||
return "explore_trends"
|
||||
elif "validat" in value_lower or "verify" in value_lower:
|
||||
return "validate"
|
||||
elif "idea" in value_lower or "brainstorm" in value_lower:
|
||||
return "generate_ideas"
|
||||
return "learn"
|
||||
|
||||
|
||||
def _normalize_content_output(value: str) -> str:
|
||||
"""Normalize content_output value to enum."""
|
||||
if not value or not isinstance(value, str):
|
||||
return "general"
|
||||
value_lower = value.lower()
|
||||
# Check for exact match
|
||||
for output in ContentOutput:
|
||||
if value_lower == output.value or value_lower == output.name.lower():
|
||||
return output.value
|
||||
# Check for keywords
|
||||
if "blog" in value_lower or "article" in value_lower:
|
||||
return "blog"
|
||||
elif "podcast" in value_lower:
|
||||
return "podcast"
|
||||
elif "video" in value_lower:
|
||||
return "video"
|
||||
elif "social" in value_lower or "post" in value_lower:
|
||||
return "social_post"
|
||||
elif "newsletter" in value_lower:
|
||||
return "newsletter"
|
||||
elif "presentation" in value_lower or "slide" in value_lower:
|
||||
return "presentation"
|
||||
elif "report" in value_lower:
|
||||
return "report"
|
||||
elif "whitepaper" in value_lower or "white paper" in value_lower:
|
||||
return "whitepaper"
|
||||
elif "email" in value_lower:
|
||||
return "email"
|
||||
return "general"
|
||||
|
||||
|
||||
def _normalize_deliverable(value: str) -> str:
|
||||
"""Normalize deliverable value to enum."""
|
||||
if not value or not isinstance(value, str):
|
||||
return "key_statistics"
|
||||
value_lower = value.lower().strip()
|
||||
# Check for exact match first
|
||||
for deliverable in ExpectedDeliverable:
|
||||
if value_lower == deliverable.value or value_lower == deliverable.name.lower():
|
||||
return deliverable.value
|
||||
# Check for keywords (more aggressive matching)
|
||||
if "statistic" in value_lower or "data" in value_lower or "number" in value_lower or "metric" in value_lower or "report" in value_lower:
|
||||
return "key_statistics"
|
||||
elif "quote" in value_lower or "expert" in value_lower:
|
||||
return "expert_quotes"
|
||||
elif "case" in value_lower or "study" in value_lower:
|
||||
return "case_studies"
|
||||
elif "compar" in value_lower or "compare" in value_lower or "landscape" in value_lower or "matrix" in value_lower:
|
||||
return "comparisons"
|
||||
elif "trend" in value_lower or "keyword" in value_lower or "seo" in value_lower:
|
||||
return "trends"
|
||||
elif "practice" in value_lower or "best" in value_lower or "guideline" in value_lower or "recommendation" in value_lower or "calendar" in value_lower:
|
||||
return "best_practices"
|
||||
elif "step" in value_lower or "how" in value_lower or "process" in value_lower or "guide" in value_lower or "outline" in value_lower or "heading" in value_lower:
|
||||
return "step_by_step"
|
||||
elif ("pro" in value_lower and "con" in value_lower) or "advantage" in value_lower or "disadvantage" in value_lower:
|
||||
return "pros_cons"
|
||||
elif "defin" in value_lower or "explain" in value_lower:
|
||||
return "definitions"
|
||||
elif "citation" in value_lower or "source" in value_lower or "reference" in value_lower:
|
||||
return "citations"
|
||||
elif "example" in value_lower or "sample" in value_lower:
|
||||
return "examples"
|
||||
elif "prediction" in value_lower or "future" in value_lower or "outlook" in value_lower:
|
||||
return "predictions"
|
||||
# Default fallback
|
||||
return "key_statistics"
|
||||
|
||||
|
||||
def parse_unified_result(result: Dict[str, Any], user_input: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse the unified LLM result into structured response.
|
||||
|
||||
Args:
|
||||
result: Raw LLM response dictionary
|
||||
user_input: Original user input for fallback values
|
||||
|
||||
Returns:
|
||||
Structured response with intent, queries, configs, etc.
|
||||
"""
|
||||
intent_data = result.get("intent", {})
|
||||
|
||||
# Normalize enum values
|
||||
purpose_value = _normalize_purpose(intent_data.get("purpose", "learn"))
|
||||
content_output_value = _normalize_content_output(intent_data.get("content_output", "general"))
|
||||
|
||||
# Normalize deliverables list
|
||||
deliverables_raw = intent_data.get("expected_deliverables", ["key_statistics"])
|
||||
if not isinstance(deliverables_raw, list):
|
||||
deliverables_raw = [deliverables_raw] if deliverables_raw else ["key_statistics"]
|
||||
normalized_deliverables = [_normalize_deliverable(d) for d in deliverables_raw if d]
|
||||
if not normalized_deliverables:
|
||||
normalized_deliverables = ["key_statistics"]
|
||||
|
||||
# Build ResearchIntent
|
||||
try:
|
||||
intent = ResearchIntent(
|
||||
primary_question=intent_data.get("primary_question", user_input),
|
||||
secondary_questions=intent_data.get("secondary_questions", []),
|
||||
purpose=purpose_value,
|
||||
content_output=content_output_value,
|
||||
expected_deliverables=normalized_deliverables,
|
||||
depth=intent_data.get("depth", "detailed"),
|
||||
focus_areas=intent_data.get("focus_areas", []),
|
||||
also_answering=intent_data.get("also_answering", []),
|
||||
perspective=intent_data.get("perspective"),
|
||||
time_sensitivity=intent_data.get("time_sensitivity"),
|
||||
input_type=intent_data.get("input_type", "keywords"),
|
||||
original_input=user_input,
|
||||
confidence=float(intent_data.get("confidence", 0.7)),
|
||||
confidence_reason=intent_data.get("confidence_reason"),
|
||||
great_example=intent_data.get("great_example"),
|
||||
needs_clarification=intent_data.get("needs_clarification", False),
|
||||
clarifying_questions=intent_data.get("clarifying_questions", []),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse intent: {e}, intent_data: {intent_data}")
|
||||
# Return fallback intent
|
||||
from .unified_analyzer_utils import create_fallback_response
|
||||
return create_fallback_response(user_input, [])
|
||||
|
||||
# Build queries
|
||||
queries = []
|
||||
for q in result.get("queries", []):
|
||||
try:
|
||||
# Normalize query purpose
|
||||
query_purpose = _normalize_deliverable(q.get("purpose", "key_statistics"))
|
||||
queries.append(ResearchQuery(
|
||||
query=q.get("query", ""),
|
||||
purpose=query_purpose,
|
||||
provider=q.get("provider", "exa"),
|
||||
priority=int(q.get("priority", 3)),
|
||||
expected_results=q.get("expected_results", ""),
|
||||
addresses_primary_question=q.get("addresses_primary_question", False),
|
||||
addresses_secondary_questions=q.get("addresses_secondary_questions", []),
|
||||
targets_focus_areas=q.get("targets_focus_areas", []),
|
||||
covers_also_answering=q.get("covers_also_answering", []),
|
||||
justification=q.get("justification"),
|
||||
))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse query: {e}, query: {q}")
|
||||
|
||||
# Deduplicate queries to avoid redundant API calls
|
||||
queries = deduplicate_queries(queries, intent)
|
||||
|
||||
# Log warning if no queries after parsing
|
||||
if not queries:
|
||||
logger.warning("No valid queries parsed from LLM response")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"intent": intent,
|
||||
"queries": queries,
|
||||
"enhanced_keywords": result.get("enhanced_keywords", []),
|
||||
"research_angles": result.get("research_angles", []),
|
||||
"recommended_provider": result.get("recommended_provider", "exa"),
|
||||
"provider_justification": result.get("provider_justification", ""),
|
||||
"exa_config": result.get("exa_config", {}),
|
||||
"tavily_config": result.get("tavily_config", {}),
|
||||
"trends_config": result.get("trends_config", {}), # Google Trends configuration
|
||||
"analysis_summary": intent_data.get("analysis_summary", ""),
|
||||
}
|
||||
140
backend/services/research/intent/unified_schema_builder.py
Normal file
140
backend/services/research/intent/unified_schema_builder.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
JSON schema builder for unified research analyzer.
|
||||
|
||||
Defines the structured JSON schema that the LLM must return
|
||||
for intent analysis, query generation, and parameter optimization.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
def build_unified_schema() -> Dict[str, Any]:
|
||||
"""
|
||||
Build the JSON schema for unified response.
|
||||
|
||||
This schema defines the structure expected from the LLM
|
||||
for intent + queries + provider settings.
|
||||
"""
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"intent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input_type": {"type": "string", "enum": ["keywords", "question", "goal", "mixed"]},
|
||||
"primary_question": {"type": "string"},
|
||||
"secondary_questions": {"type": "array", "items": {"type": "string"}},
|
||||
"purpose": {"type": "string"},
|
||||
"content_output": {"type": "string"},
|
||||
"expected_deliverables": {"type": "array", "items": {"type": "string"}},
|
||||
"depth": {"type": "string", "enum": ["overview", "detailed", "expert"]},
|
||||
"focus_areas": {"type": "array", "items": {"type": "string"}},
|
||||
"also_answering": {"type": "array", "items": {"type": "string"}},
|
||||
"perspective": {"type": "string"},
|
||||
"time_sensitivity": {"type": "string"},
|
||||
"confidence": {"type": "number"},
|
||||
"confidence_reason": {"type": "string"},
|
||||
"great_example": {"type": "string"},
|
||||
"needs_clarification": {"type": "boolean"},
|
||||
"clarifying_questions": {"type": "array", "items": {"type": "string"}},
|
||||
"analysis_summary": {"type": "string"}
|
||||
},
|
||||
"required": ["primary_question", "purpose", "expected_deliverables", "confidence"]
|
||||
},
|
||||
"queries": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"purpose": {"type": "string"},
|
||||
"provider": {"type": "string"},
|
||||
"priority": {"type": "integer"},
|
||||
"expected_results": {"type": "string"},
|
||||
"justification": {"type": "string"},
|
||||
"addresses_primary_question": {"type": "boolean"},
|
||||
"addresses_secondary_questions": {"type": "array", "items": {"type": "string"}},
|
||||
"targets_focus_areas": {"type": "array", "items": {"type": "string"}},
|
||||
"covers_also_answering": {"type": "array", "items": {"type": "string"}}
|
||||
},
|
||||
"required": ["query", "purpose", "provider", "priority", "addresses_primary_question"]
|
||||
}
|
||||
},
|
||||
"enhanced_keywords": {"type": "array", "items": {"type": "string"}},
|
||||
"research_angles": {"type": "array", "items": {"type": "string"}},
|
||||
"recommended_provider": {"type": "string"},
|
||||
"provider_justification": {"type": "string"},
|
||||
"exa_config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"enabled": {"type": "boolean"},
|
||||
"type": {"type": "string"},
|
||||
"type_justification": {"type": "string"},
|
||||
"category": {"type": "string"},
|
||||
"category_justification": {"type": "string"},
|
||||
"numResults": {"type": "integer"},
|
||||
"numResults_justification": {"type": "string"},
|
||||
"includeDomains": {"type": "array", "items": {"type": "string"}},
|
||||
"includeDomains_justification": {"type": "string"},
|
||||
"startPublishedDate": {"type": "string"},
|
||||
"date_justification": {"type": "string"},
|
||||
"highlights": {"type": "boolean"},
|
||||
"highlights_justification": {"type": "string"},
|
||||
"context": {"type": "boolean"},
|
||||
"context_justification": {"type": "string"},
|
||||
"additionalQueries": {"type": "array", "items": {"type": "string"}},
|
||||
"additionalQueries_justification": {"type": "string"},
|
||||
"livecrawl": {"type": "string"},
|
||||
"livecrawl_justification": {"type": "string"}
|
||||
}
|
||||
},
|
||||
"tavily_config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"enabled": {"type": "boolean"},
|
||||
"topic": {"type": "string"},
|
||||
"topic_justification": {"type": "string"},
|
||||
"search_depth": {"type": "string"},
|
||||
"search_depth_justification": {"type": "string"},
|
||||
"include_answer": {"oneOf": [{"type": "string"}, {"type": "boolean"}]},
|
||||
"include_answer_justification": {"type": "string"},
|
||||
"time_range": {"oneOf": [{"type": "string"}, {"type": "null"}]},
|
||||
"time_range_justification": {"type": "string"},
|
||||
"start_date": {"oneOf": [{"type": "string"}, {"type": "null"}]},
|
||||
"start_date_justification": {"type": "string"},
|
||||
"end_date": {"oneOf": [{"type": "string"}, {"type": "null"}]},
|
||||
"end_date_justification": {"type": "string"},
|
||||
"max_results": {"type": "integer"},
|
||||
"max_results_justification": {"type": "string"},
|
||||
"chunks_per_source": {"type": "integer"},
|
||||
"chunks_per_source_justification": {"type": "string"},
|
||||
"include_raw_content": {"oneOf": [{"type": "string"}, {"type": "boolean"}]},
|
||||
"include_raw_content_justification": {"type": "string"},
|
||||
"country": {"oneOf": [{"type": "string"}, {"type": "null"}]},
|
||||
"country_justification": {"type": "string"},
|
||||
"include_images": {"type": "boolean"},
|
||||
"include_images_justification": {"type": "string"},
|
||||
"include_image_descriptions": {"type": "boolean"},
|
||||
"include_image_descriptions_justification": {"type": "string"},
|
||||
"include_favicon": {"type": "boolean"},
|
||||
"include_favicon_justification": {"type": "string"},
|
||||
"auto_parameters": {"type": "boolean"},
|
||||
"auto_parameters_justification": {"type": "string"}
|
||||
}
|
||||
},
|
||||
"trends_config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"enabled": {"type": "boolean"},
|
||||
"keywords": {"type": "array", "items": {"type": "string"}},
|
||||
"keywords_justification": {"type": "string"},
|
||||
"timeframe": {"type": "string"},
|
||||
"timeframe_justification": {"type": "string"},
|
||||
"geo": {"type": "string"},
|
||||
"geo_justification": {"type": "string"},
|
||||
"expected_insights": {"type": "array", "items": {"type": "string"}}
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["intent", "queries", "recommended_provider", "exa_config", "tavily_config"]
|
||||
}
|
||||
@@ -92,21 +92,21 @@ class TavilyService:
|
||||
Args:
|
||||
query: The search query to execute
|
||||
topic: Category of search (general, news, finance)
|
||||
search_depth: Depth of search (basic, advanced) - basic costs 1 credit, advanced costs 2
|
||||
max_results: Maximum number of results to return (0-20)
|
||||
include_domains: List of domains to specifically include
|
||||
exclude_domains: List of domains to specifically exclude
|
||||
search_depth: Depth of search (advanced=2 credits, basic/fast/ultra-fast=1 credit)
|
||||
max_results: Maximum number of results to return (0-20, default: 5)
|
||||
include_domains: List of domains to specifically include (max 300)
|
||||
exclude_domains: List of domains to specifically exclude (max 150)
|
||||
include_answer: Include LLM-generated answer (basic/advanced/true/false)
|
||||
include_raw_content: Include raw HTML content (markdown/text/true/false)
|
||||
include_images: Include image search results
|
||||
include_image_descriptions: Include image descriptions
|
||||
include_image_descriptions: Include image descriptions (requires include_images)
|
||||
include_favicon: Include favicon URLs
|
||||
time_range: Time range filter (day, week, month, year, d, w, m, y)
|
||||
start_date: Start date filter (YYYY-MM-DD)
|
||||
end_date: End date filter (YYYY-MM-DD)
|
||||
country: Country filter (boost results from specific country)
|
||||
chunks_per_source: Maximum chunks per source (1-3, only for advanced search)
|
||||
auto_parameters: Auto-configure parameters based on query
|
||||
country: Country filter (lowercase full country name, e.g., "united states" not "US")
|
||||
chunks_per_source: Maximum chunks per source (1-3, only for advanced/fast search, default: 3)
|
||||
auto_parameters: Auto-configure parameters based on query (costs 2 credits)
|
||||
|
||||
Returns:
|
||||
Dictionary containing search results
|
||||
@@ -159,7 +159,8 @@ class TavilyService:
|
||||
if country and topic == "general":
|
||||
payload["country"] = country
|
||||
|
||||
if search_depth == "advanced" and 1 <= chunks_per_source <= 3:
|
||||
# chunks_per_source only available for advanced and fast search_depth
|
||||
if search_depth in ["advanced", "fast"] and 1 <= chunks_per_source <= 3:
|
||||
payload["chunks_per_source"] = chunks_per_source
|
||||
|
||||
if auto_parameters:
|
||||
|
||||
113
backend/services/research_service.py
Normal file
113
backend/services/research_service.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""
|
||||
Research Service
|
||||
|
||||
Service layer for managing research project persistence.
|
||||
Similar to PodcastService, but for research projects.
|
||||
"""
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc, and_, or_
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from models.research_models import ResearchProject
|
||||
|
||||
|
||||
class ResearchService:
|
||||
"""Service for managing research projects."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def create_project(
|
||||
self,
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
keywords: List[str],
|
||||
industry: Optional[str] = None,
|
||||
target_audience: Optional[str] = None,
|
||||
research_mode: Optional[str] = "comprehensive",
|
||||
**kwargs
|
||||
) -> ResearchProject:
|
||||
"""Create a new research project."""
|
||||
# Extract current_step and status from kwargs to avoid conflicts
|
||||
current_step = kwargs.pop("current_step", 1)
|
||||
status = kwargs.pop("status", "draft")
|
||||
|
||||
project = ResearchProject(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
keywords=keywords,
|
||||
industry=industry,
|
||||
target_audience=target_audience,
|
||||
research_mode=research_mode,
|
||||
status=status,
|
||||
current_step=current_step,
|
||||
**kwargs
|
||||
)
|
||||
self.db.add(project)
|
||||
self.db.commit()
|
||||
self.db.refresh(project)
|
||||
return project
|
||||
|
||||
def get_project(self, user_id: str, project_id: str) -> Optional[ResearchProject]:
|
||||
"""Get a project by ID, ensuring user ownership."""
|
||||
return self.db.query(ResearchProject).filter(
|
||||
and_(
|
||||
ResearchProject.project_id == project_id,
|
||||
ResearchProject.user_id == user_id
|
||||
)
|
||||
).first()
|
||||
|
||||
def update_project(
|
||||
self,
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
**updates
|
||||
) -> Optional[ResearchProject]:
|
||||
"""Update a project's state."""
|
||||
project = self.get_project(user_id, project_id)
|
||||
if not project:
|
||||
return None
|
||||
|
||||
# Update fields
|
||||
for key, value in updates.items():
|
||||
if hasattr(project, key):
|
||||
setattr(project, key, value)
|
||||
|
||||
project.updated_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
self.db.refresh(project)
|
||||
return project
|
||||
|
||||
def list_projects(
|
||||
self,
|
||||
user_id: str,
|
||||
status: Optional[str] = None,
|
||||
is_favorite: Optional[bool] = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
) -> List[ResearchProject]:
|
||||
"""List projects for a user."""
|
||||
query = self.db.query(ResearchProject).filter(
|
||||
ResearchProject.user_id == user_id
|
||||
)
|
||||
|
||||
if status:
|
||||
query = query.filter(ResearchProject.status == status)
|
||||
|
||||
if is_favorite is not None:
|
||||
query = query.filter(ResearchProject.is_favorite == is_favorite)
|
||||
|
||||
return query.order_by(desc(ResearchProject.updated_at)).offset(offset).limit(limit).all()
|
||||
|
||||
def delete_project(self, user_id: str, project_id: str) -> bool:
|
||||
"""Delete a project."""
|
||||
project = self.get_project(user_id, project_id)
|
||||
if not project:
|
||||
return False
|
||||
|
||||
self.db.delete(project)
|
||||
self.db.commit()
|
||||
return True
|
||||
@@ -182,4 +182,4 @@ This package consolidates the following previously scattered files:
|
||||
|
||||
- `services.onboarding` - Onboarding and user setup
|
||||
- `models.subscription_models` - Database models
|
||||
- `api.subscription_api` - API endpoints
|
||||
- `api.subscription` - API endpoints (modular structure with routes in `api/subscription/routes/`)
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
"""
|
||||
Log Wrapping Service
|
||||
Intelligently wraps API usage logs when they exceed 5000 records.
|
||||
Intelligently wraps API usage logs when they exceed limits (count or time-based).
|
||||
Aggregates old logs into cumulative records while preserving historical data.
|
||||
|
||||
Features:
|
||||
- Count-based retention: Keeps 4,000 most recent detailed logs
|
||||
- Time-based retention: Aggregates logs older than 90 days
|
||||
- Automatic aggregation: Triggered on log queries
|
||||
- Context preservation: Maintains costs, tokens, counts, success rates
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
@@ -18,13 +24,18 @@ class LogWrappingService:
|
||||
|
||||
MAX_LOGS_PER_USER = 5000
|
||||
AGGREGATION_THRESHOLD_DAYS = 30 # Aggregate logs older than 30 days
|
||||
RETENTION_DAYS = 90 # Time-based retention: aggregate logs older than 90 days
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def check_and_wrap_logs(self, user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Check if user has exceeded log limit and wrap if necessary.
|
||||
Check if user has exceeded log limit (count or time-based) and wrap if necessary.
|
||||
|
||||
Checks both:
|
||||
1. Count-based: If user has more than MAX_LOGS_PER_USER logs
|
||||
2. Time-based: If user has logs older than RETENTION_DAYS
|
||||
|
||||
Returns:
|
||||
Dict with wrapping status and statistics
|
||||
@@ -35,18 +46,42 @@ class LogWrappingService:
|
||||
APIUsageLog.user_id == user_id
|
||||
).scalar() or 0
|
||||
|
||||
if total_count <= self.MAX_LOGS_PER_USER:
|
||||
# Check for logs older than retention period
|
||||
retention_cutoff = datetime.utcnow() - timedelta(days=self.RETENTION_DAYS)
|
||||
old_logs_count = self.db.query(func.count(APIUsageLog.id)).filter(
|
||||
APIUsageLog.user_id == user_id,
|
||||
APIUsageLog.timestamp < retention_cutoff,
|
||||
APIUsageLog.endpoint != '[AGGREGATED]' # Don't re-aggregate already aggregated logs
|
||||
).scalar() or 0
|
||||
|
||||
# Determine if wrapping is needed
|
||||
count_based_trigger = total_count > self.MAX_LOGS_PER_USER
|
||||
time_based_trigger = old_logs_count > 0
|
||||
|
||||
if not count_based_trigger and not time_based_trigger:
|
||||
return {
|
||||
'wrapped': False,
|
||||
'total_logs': total_count,
|
||||
'old_logs': old_logs_count,
|
||||
'max_logs': self.MAX_LOGS_PER_USER,
|
||||
'message': f'Log count ({total_count}) is within limit ({self.MAX_LOGS_PER_USER})'
|
||||
'retention_days': self.RETENTION_DAYS,
|
||||
'message': f'Log count ({total_count}) and age are within limits'
|
||||
}
|
||||
|
||||
# Need to wrap logs - aggregate old logs
|
||||
logger.info(f"[LogWrapping] User {user_id} has {total_count} logs, exceeding limit of {self.MAX_LOGS_PER_USER}. Starting wrap...")
|
||||
# Determine trigger reason
|
||||
trigger_reasons = []
|
||||
if count_based_trigger:
|
||||
trigger_reasons.append(f'count limit ({total_count} > {self.MAX_LOGS_PER_USER})')
|
||||
if time_based_trigger:
|
||||
trigger_reasons.append(f'time-based retention ({old_logs_count} logs older than {self.RETENTION_DAYS} days)')
|
||||
|
||||
wrap_result = self._wrap_old_logs(user_id, total_count)
|
||||
logger.info(
|
||||
f"[LogWrapping] User {user_id} needs log wrapping. "
|
||||
f"Total: {total_count}, Old logs: {old_logs_count}. "
|
||||
f"Triggers: {', '.join(trigger_reasons)}"
|
||||
)
|
||||
|
||||
wrap_result = self._wrap_old_logs(user_id, total_count, time_based=time_based_trigger)
|
||||
|
||||
return {
|
||||
'wrapped': True,
|
||||
@@ -54,6 +89,8 @@ class LogWrappingService:
|
||||
'total_logs_after': wrap_result['logs_remaining'],
|
||||
'aggregated_logs': wrap_result['aggregated_count'],
|
||||
'aggregated_periods': wrap_result['periods'],
|
||||
'trigger_reasons': trigger_reasons,
|
||||
'old_logs_aggregated': wrap_result.get('old_logs_aggregated', 0),
|
||||
'message': f'Wrapped {wrap_result["aggregated_count"]} logs into {len(wrap_result["periods"])} aggregated records'
|
||||
}
|
||||
|
||||
@@ -65,30 +102,76 @@ class LogWrappingService:
|
||||
'message': f'Error wrapping logs: {str(e)}'
|
||||
}
|
||||
|
||||
def _wrap_old_logs(self, user_id: str, total_count: int) -> Dict[str, Any]:
|
||||
def _wrap_old_logs(self, user_id: str, total_count: int, time_based: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Aggregate old logs into cumulative records.
|
||||
|
||||
Strategy:
|
||||
1. Keep most recent 4000 logs (detailed)
|
||||
2. Aggregate logs older than 30 days or oldest logs beyond 4000
|
||||
3. Create aggregated records grouped by provider and billing period
|
||||
4. Delete individual logs that were aggregated
|
||||
1. Keep most recent 4000 logs (detailed) - count-based
|
||||
2. Aggregate logs older than RETENTION_DAYS - time-based
|
||||
3. Aggregate oldest logs beyond 4000 limit - count-based
|
||||
4. Create aggregated records grouped by provider and billing period
|
||||
5. Delete individual logs that were aggregated
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
total_count: Total number of logs for user
|
||||
time_based: If True, prioritize time-based retention over count-based
|
||||
"""
|
||||
try:
|
||||
# Calculate how many logs to keep (4000 detailed, rest aggregated)
|
||||
# Calculate retention cutoff date
|
||||
retention_cutoff = datetime.utcnow() - timedelta(days=self.RETENTION_DAYS)
|
||||
aggregation_cutoff = datetime.utcnow() - timedelta(days=self.AGGREGATION_THRESHOLD_DAYS)
|
||||
|
||||
# Determine which logs to aggregate
|
||||
logs_to_keep = 4000
|
||||
logs_to_aggregate = total_count - logs_to_keep
|
||||
logs_to_aggregate_count = max(0, total_count - logs_to_keep)
|
||||
|
||||
# Get cutoff date (30 days ago)
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=self.AGGREGATION_THRESHOLD_DAYS)
|
||||
if time_based:
|
||||
# Time-based: Aggregate all logs older than retention period
|
||||
# (excluding already aggregated logs)
|
||||
logs_to_process = self.db.query(APIUsageLog).filter(
|
||||
APIUsageLog.user_id == user_id,
|
||||
APIUsageLog.timestamp < retention_cutoff,
|
||||
APIUsageLog.endpoint != '[AGGREGATED]' # Don't re-aggregate
|
||||
).order_by(APIUsageLog.timestamp.asc()).all()
|
||||
|
||||
logger.info(
|
||||
f"[LogWrapping] Time-based aggregation: Found {len(logs_to_process)} logs "
|
||||
f"older than {self.RETENTION_DAYS} days"
|
||||
)
|
||||
else:
|
||||
# Count-based: Aggregate oldest logs beyond the keep limit
|
||||
logs_to_process = self.db.query(APIUsageLog).filter(
|
||||
APIUsageLog.user_id == user_id,
|
||||
APIUsageLog.endpoint != '[AGGREGATED]' # Don't re-aggregate
|
||||
).order_by(APIUsageLog.timestamp.asc()).limit(logs_to_aggregate_count).all()
|
||||
|
||||
logger.info(
|
||||
f"[LogWrapping] Count-based aggregation: Processing {len(logs_to_process)} "
|
||||
f"oldest logs beyond {logs_to_keep} limit"
|
||||
)
|
||||
|
||||
# Get logs to aggregate: oldest logs beyond the keep limit
|
||||
# Order by timestamp ascending to get oldest first
|
||||
# We'll keep the most recent logs_to_keep logs, aggregate the rest
|
||||
logs_to_process = self.db.query(APIUsageLog).filter(
|
||||
APIUsageLog.user_id == user_id
|
||||
).order_by(APIUsageLog.timestamp.asc()).limit(logs_to_aggregate).all()
|
||||
# Also check for time-based logs even if count-based is primary
|
||||
# This ensures we don't keep very old logs just because they're within the count limit
|
||||
if not time_based and logs_to_aggregate_count > 0:
|
||||
# Get logs that are both old AND beyond count limit
|
||||
old_logs_beyond_limit = self.db.query(APIUsageLog).filter(
|
||||
APIUsageLog.user_id == user_id,
|
||||
APIUsageLog.timestamp < retention_cutoff,
|
||||
APIUsageLog.endpoint != '[AGGREGATED]'
|
||||
).order_by(APIUsageLog.timestamp.asc()).all()
|
||||
|
||||
# Merge with count-based logs, prioritizing old logs
|
||||
existing_ids = {log.id for log in logs_to_process}
|
||||
for old_log in old_logs_beyond_limit:
|
||||
if old_log.id not in existing_ids:
|
||||
logs_to_process.append(old_log)
|
||||
|
||||
logger.info(
|
||||
f"[LogWrapping] Combined aggregation: {len(logs_to_process)} logs to process "
|
||||
f"({logs_to_aggregate_count} count-based + {len(old_logs_beyond_limit)} time-based)"
|
||||
)
|
||||
|
||||
if not logs_to_process:
|
||||
return {
|
||||
@@ -218,10 +301,18 @@ class LogWrappingService:
|
||||
f"Remaining logs: {remaining_count}"
|
||||
)
|
||||
|
||||
# Count how many old logs were aggregated (for reporting)
|
||||
# Count logs that were aggregated based on time (not just count)
|
||||
old_logs_aggregated = 0
|
||||
for log in logs_to_process:
|
||||
if log.timestamp and log.timestamp < retention_cutoff:
|
||||
old_logs_aggregated += 1
|
||||
|
||||
return {
|
||||
'aggregated_count': aggregated_count,
|
||||
'logs_remaining': remaining_count,
|
||||
'periods': periods_created
|
||||
'periods': periods_created,
|
||||
'old_logs_aggregated': old_logs_aggregated
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -14,7 +14,7 @@ from collections import defaultdict, deque
|
||||
import asyncio
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, func
|
||||
from sqlalchemy import and_, func, case
|
||||
import re
|
||||
|
||||
from models.api_monitoring import APIRequest, APIEndpointStats, SystemHealth, CachePerformance
|
||||
@@ -369,19 +369,64 @@ async def get_monitoring_stats(minutes: int = 5) -> Dict[str, Any]:
|
||||
db.close()
|
||||
|
||||
async def get_lightweight_stats() -> Dict[str, Any]:
|
||||
"""Get lightweight stats for dashboard header."""
|
||||
"""Get lightweight stats for dashboard header.
|
||||
|
||||
Optimized single-query approach using conditional aggregation for better performance.
|
||||
"""
|
||||
db = None
|
||||
try:
|
||||
db = _get_db_session()
|
||||
# Minimal viable placeholder values
|
||||
now = datetime.utcnow()
|
||||
|
||||
# Get stats from last 5 minutes
|
||||
five_minutes_ago = now - timedelta(minutes=5)
|
||||
|
||||
# Optimized: Single query with conditional aggregation instead of two separate queries
|
||||
# This is much faster as it only scans the table once
|
||||
stats = db.query(
|
||||
func.count(APIRequest.id).label('total_requests'),
|
||||
func.sum(
|
||||
case((APIRequest.status_code >= 400, 1), else_=0)
|
||||
).label('total_errors')
|
||||
).filter(
|
||||
APIRequest.timestamp >= five_minutes_ago
|
||||
).first()
|
||||
|
||||
recent_requests = stats.total_requests or 0 if stats else 0
|
||||
recent_errors = int(stats.total_errors or 0) if stats else 0
|
||||
|
||||
# Calculate error rate
|
||||
error_rate = (recent_errors / recent_requests * 100) if recent_requests > 0 else 0.0
|
||||
|
||||
# Determine status based on error rate
|
||||
if error_rate > 10:
|
||||
status = 'critical'
|
||||
icon = '🔴'
|
||||
elif error_rate > 5:
|
||||
status = 'warning'
|
||||
icon = '🟡'
|
||||
else:
|
||||
status = 'healthy'
|
||||
icon = '🟢'
|
||||
|
||||
return {
|
||||
'status': status,
|
||||
'icon': icon,
|
||||
'recent_requests': recent_requests,
|
||||
'recent_errors': recent_errors,
|
||||
'error_rate': round(error_rate, 2),
|
||||
'timestamp': now.isoformat()
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lightweight stats: {e}", exc_info=True)
|
||||
# Return default healthy state on error
|
||||
return {
|
||||
'status': 'healthy',
|
||||
'icon': '🟢',
|
||||
'recent_requests': 0,
|
||||
'recent_errors': 0,
|
||||
'error_rate': 0.0,
|
||||
'timestamp': now.isoformat()
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}
|
||||
finally:
|
||||
if db is not None:
|
||||
|
||||
@@ -290,6 +290,40 @@ class PricingService:
|
||||
"cost_per_image": 0.04, # $0.04 per image
|
||||
"description": "Stability AI Image Generation"
|
||||
},
|
||||
# WaveSpeed OSS Image Generation Models
|
||||
{
|
||||
"provider": APIProvider.STABILITY, # Using STABILITY provider for image generation
|
||||
"model_name": "qwen-image",
|
||||
"cost_per_image": 0.03, # $0.03 per image (OSS model via WaveSpeed)
|
||||
"cost_per_request": 0.03, # Also support cost_per_request
|
||||
"description": "WaveSpeed Qwen Image (OSS) - Fast generation, cost-effective"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.STABILITY,
|
||||
"model_name": "ideogram-v3-turbo",
|
||||
"cost_per_image": 0.05, # $0.05 per image (OSS model via WaveSpeed)
|
||||
"cost_per_request": 0.05, # Also support cost_per_request
|
||||
"description": "WaveSpeed Ideogram V3 Turbo (OSS) - Photorealistic, text rendering"
|
||||
},
|
||||
# WaveSpeed OSS Image Editing Models
|
||||
{
|
||||
"provider": APIProvider.IMAGE_EDIT,
|
||||
"model_name": "qwen-edit",
|
||||
"cost_per_request": 0.02, # $0.02 per edit (OSS model via WaveSpeed)
|
||||
"description": "WaveSpeed Qwen Image Edit (OSS) - Budget editing, bilingual"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.IMAGE_EDIT,
|
||||
"model_name": "qwen-edit-plus",
|
||||
"cost_per_request": 0.02, # $0.02 per edit (OSS model via WaveSpeed)
|
||||
"description": "WaveSpeed Qwen Image Edit Plus (OSS) - Multi-image editing"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.IMAGE_EDIT,
|
||||
"model_name": "flux-kontext-pro",
|
||||
"cost_per_request": 0.04, # $0.04 per edit (OSS model via WaveSpeed)
|
||||
"description": "WaveSpeed FLUX Kontext Pro (OSS) - Professional editing, typography"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.EXA,
|
||||
"model_name": "exa-search",
|
||||
@@ -305,8 +339,8 @@ class PricingService:
|
||||
{
|
||||
"provider": APIProvider.VIDEO,
|
||||
"model_name": "default",
|
||||
"cost_per_request": 0.10, # $0.10 per video generation (estimated)
|
||||
"description": "AI Video Generation default pricing"
|
||||
"cost_per_request": 0.25, # UPDATED: Default to WAN 2.5 OSS model ($0.25)
|
||||
"description": "AI Video Generation default pricing (OSS: WAN 2.5)"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.VIDEO,
|
||||
@@ -326,6 +360,25 @@ class PricingService:
|
||||
"cost_per_request": 0.30,
|
||||
"description": "WaveSpeed InfiniteTalk (image + audio to talking avatar video)"
|
||||
},
|
||||
# WaveSpeed OSS Video Generation Models
|
||||
{
|
||||
"provider": APIProvider.VIDEO,
|
||||
"model_name": "wan-2.5",
|
||||
"cost_per_request": 0.25, # $0.25 per video (~5 seconds, OSS model via WaveSpeed)
|
||||
"description": "WaveSpeed WAN 2.5 (OSS) - Text-to-Video, Image-to-Video, cost-effective"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.VIDEO,
|
||||
"model_name": "alibaba/wan-2.5",
|
||||
"cost_per_request": 0.25, # $0.25 per video (~5 seconds, OSS model via WaveSpeed)
|
||||
"description": "WaveSpeed WAN 2.5 (OSS) - Alternative path, same model"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.VIDEO,
|
||||
"model_name": "seedance-1.5-pro",
|
||||
"cost_per_request": 0.40, # $0.40 per video (~5 seconds, OSS model via WaveSpeed)
|
||||
"description": "WaveSpeed Seedance 1.5 Pro (OSS) - Longer duration videos (10-30 sec)"
|
||||
},
|
||||
# Audio Generation Pricing (Minimax Speech 02 HD via WaveSpeed)
|
||||
{
|
||||
"provider": APIProvider.AUDIO,
|
||||
@@ -404,7 +457,7 @@ class PricingService:
|
||||
"tier": SubscriptionTier.BASIC,
|
||||
"price_monthly": 29.0,
|
||||
"price_yearly": 290.0,
|
||||
"ai_text_generation_calls_limit": 10, # Unified limit for all LLM providers
|
||||
"ai_text_generation_calls_limit": 50, # INCREASED: Unified limit for all LLM providers (OSS-focused strategy)
|
||||
"gemini_calls_limit": 1000, # Legacy, kept for backwards compatibility (not used for enforcement)
|
||||
"openai_calls_limit": 500,
|
||||
"anthropic_calls_limit": 200,
|
||||
@@ -413,18 +466,18 @@ class PricingService:
|
||||
"serper_calls_limit": 200,
|
||||
"metaphor_calls_limit": 100,
|
||||
"firecrawl_calls_limit": 100,
|
||||
"stability_calls_limit": 5,
|
||||
"stability_calls_limit": 50, # INCREASED: Now includes WaveSpeed OSS models (Qwen Image $0.03)
|
||||
"exa_calls_limit": 500,
|
||||
"video_calls_limit": 20, # 20 videos/month for basic plan
|
||||
"image_edit_calls_limit": 30, # 30 AI image editing calls/month
|
||||
"audio_calls_limit": 50, # 50 AI audio generation calls/month
|
||||
"gemini_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
"openai_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
"anthropic_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
"mistral_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
"monthly_cost_limit": 50.0,
|
||||
"features": ["full_content_generation", "advanced_research", "basic_analytics"],
|
||||
"description": "Great for individuals and small teams"
|
||||
"video_calls_limit": 30, # INCREASED: 30 videos/month (WAN 2.5 OSS $0.25)
|
||||
"image_edit_calls_limit": 50, # INCREASED: 50 AI image editing calls/month (Qwen Edit OSS $0.02)
|
||||
"audio_calls_limit": 100, # INCREASED: 100 AI audio generation calls/month (Minimax Speech OSS)
|
||||
"gemini_tokens_limit": 100000, # INCREASED: 100K tokens per provider (OSS-focused strategy)
|
||||
"openai_tokens_limit": 100000, # INCREASED: 100K tokens per provider
|
||||
"anthropic_tokens_limit": 100000, # INCREASED: 100K tokens per provider
|
||||
"mistral_tokens_limit": 100000, # INCREASED: 100K tokens per provider
|
||||
"monthly_cost_limit": 45.0, # ADJUSTED: $45 cap (aligns with $40-50 hard limit target)
|
||||
"features": ["full_content_generation", "advanced_research", "basic_analytics", "all_tools_access", "oss_models_priority"],
|
||||
"description": "Perfect for individuals and small teams. Access all ALwrity features with generous limits powered by OSS AI models."
|
||||
},
|
||||
{
|
||||
"name": "Pro",
|
||||
|
||||
156
backend/services/subscription/provider_detection.py
Normal file
156
backend/services/subscription/provider_detection.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
Provider Detection Utility
|
||||
Detects the actual provider (WaveSpeed, Google, HuggingFace, etc.) from model names and endpoints.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from models.subscription_models import APIProvider
|
||||
from loguru import logger
|
||||
|
||||
def detect_actual_provider(provider_enum: APIProvider, model_name: Optional[str] = None, endpoint: Optional[str] = None) -> str:
|
||||
"""
|
||||
Detect the actual provider name from provider enum, model name, and endpoint.
|
||||
|
||||
Args:
|
||||
provider_enum: The APIProvider enum value (e.g., APIProvider.VIDEO, APIProvider.GEMINI)
|
||||
model_name: The model name (e.g., "alibaba/wan-2.5/text-to-video", "gemini-2.5-flash")
|
||||
endpoint: The API endpoint (e.g., "/video-generation/wavespeed", "/image-generation/stability")
|
||||
|
||||
Returns:
|
||||
Actual provider name: "wavespeed", "google", "huggingface", "stability", "openai", "anthropic", etc.
|
||||
"""
|
||||
|
||||
# For LLM providers, use the enum value directly
|
||||
if provider_enum in [APIProvider.GEMINI]:
|
||||
return "google"
|
||||
elif provider_enum == APIProvider.OPENAI:
|
||||
return "openai"
|
||||
elif provider_enum == APIProvider.ANTHROPIC:
|
||||
return "anthropic"
|
||||
elif provider_enum == APIProvider.MISTRAL:
|
||||
# MISTRAL enum is used for HuggingFace models
|
||||
return "huggingface"
|
||||
|
||||
# For search APIs, use the enum value
|
||||
elif provider_enum in [APIProvider.TAVILY, APIProvider.SERPER, APIProvider.METAPHOR, APIProvider.FIRECRAWL, APIProvider.EXA]:
|
||||
return provider_enum.value
|
||||
|
||||
# For media generation, detect from model name or endpoint
|
||||
elif provider_enum == APIProvider.VIDEO:
|
||||
# Check model name first
|
||||
if model_name:
|
||||
model_lower = model_name.lower()
|
||||
# WaveSpeed models
|
||||
if any(x in model_lower for x in ["wan-2.5", "seedance", "infinitetalk", "wavespeed", "alibaba"]):
|
||||
return "wavespeed"
|
||||
# HuggingFace models
|
||||
elif any(x in model_lower for x in ["huggingface", "hf", "tencent", "hunyuan"]):
|
||||
return "huggingface"
|
||||
# Google models (future)
|
||||
elif any(x in model_lower for x in ["veo", "gemini"]):
|
||||
return "google"
|
||||
# OpenAI models (future)
|
||||
elif any(x in model_lower for x in ["sora", "openai"]):
|
||||
return "openai"
|
||||
|
||||
# Check endpoint
|
||||
if endpoint:
|
||||
endpoint_lower = endpoint.lower()
|
||||
if "wavespeed" in endpoint_lower:
|
||||
return "wavespeed"
|
||||
elif "huggingface" in endpoint_lower or "hf" in endpoint_lower:
|
||||
return "huggingface"
|
||||
elif "google" in endpoint_lower or "gemini" in endpoint_lower:
|
||||
return "google"
|
||||
elif "openai" in endpoint_lower:
|
||||
return "openai"
|
||||
|
||||
# Default for video: WaveSpeed (most common)
|
||||
return "wavespeed"
|
||||
|
||||
elif provider_enum == APIProvider.AUDIO:
|
||||
# Check model name first
|
||||
if model_name:
|
||||
model_lower = model_name.lower()
|
||||
# WaveSpeed models
|
||||
if any(x in model_lower for x in ["minimax", "speech-02", "wavespeed"]):
|
||||
return "wavespeed"
|
||||
# Google models
|
||||
elif any(x in model_lower for x in ["google", "gemini", "tts"]):
|
||||
return "google"
|
||||
# OpenAI models
|
||||
elif any(x in model_lower for x in ["openai", "tts-1"]):
|
||||
return "openai"
|
||||
# ElevenLabs (future)
|
||||
elif "elevenlabs" in model_lower:
|
||||
return "elevenlabs"
|
||||
|
||||
# Check endpoint
|
||||
if endpoint:
|
||||
endpoint_lower = endpoint.lower()
|
||||
if "wavespeed" in endpoint_lower:
|
||||
return "wavespeed"
|
||||
elif "google" in endpoint_lower:
|
||||
return "google"
|
||||
elif "openai" in endpoint_lower:
|
||||
return "openai"
|
||||
|
||||
# Default for audio: WaveSpeed (most common)
|
||||
return "wavespeed"
|
||||
|
||||
elif provider_enum == APIProvider.STABILITY:
|
||||
# Check model name first
|
||||
if model_name:
|
||||
model_lower = model_name.lower()
|
||||
# WaveSpeed OSS models
|
||||
if any(x in model_lower for x in ["qwen", "ideogram", "flux", "wavespeed"]):
|
||||
return "wavespeed"
|
||||
# Stability AI models
|
||||
elif any(x in model_lower for x in ["stability", "stable-diffusion", "sd-"]):
|
||||
return "stability"
|
||||
# HuggingFace models
|
||||
elif any(x in model_lower for x in ["huggingface", "hf", "runway"]):
|
||||
return "huggingface"
|
||||
|
||||
# Check endpoint
|
||||
if endpoint:
|
||||
endpoint_lower = endpoint.lower()
|
||||
if "wavespeed" in endpoint_lower:
|
||||
return "wavespeed"
|
||||
elif "stability" in endpoint_lower:
|
||||
return "stability"
|
||||
elif "huggingface" in endpoint_lower or "hf" in endpoint_lower:
|
||||
return "huggingface"
|
||||
|
||||
# Default: check if it's actually WaveSpeed based on common OSS models
|
||||
if model_name and any(x in model_name.lower() for x in ["qwen", "ideogram", "flux"]):
|
||||
return "wavespeed"
|
||||
|
||||
# Default for image generation: Stability (legacy)
|
||||
return "stability"
|
||||
|
||||
elif provider_enum == APIProvider.IMAGE_EDIT:
|
||||
# Check model name first
|
||||
if model_name:
|
||||
model_lower = model_name.lower()
|
||||
# WaveSpeed OSS models
|
||||
if any(x in model_lower for x in ["qwen", "flux", "kontext", "wavespeed"]):
|
||||
return "wavespeed"
|
||||
# Stability AI models
|
||||
elif any(x in model_lower for x in ["stability", "stable-diffusion"]):
|
||||
return "stability"
|
||||
|
||||
# Check endpoint
|
||||
if endpoint:
|
||||
endpoint_lower = endpoint.lower()
|
||||
if "wavespeed" in endpoint_lower:
|
||||
return "wavespeed"
|
||||
elif "stability" in endpoint_lower:
|
||||
return "stability"
|
||||
|
||||
# Default for image editing: WaveSpeed (OSS-first strategy)
|
||||
return "wavespeed"
|
||||
|
||||
# Fallback: use enum value
|
||||
logger.warning(f"Could not detect actual provider for {provider_enum.value}, using enum value")
|
||||
return provider_enum.value
|
||||
264
backend/services/subscription/renewal_history_retention.py
Normal file
264
backend/services/subscription/renewal_history_retention.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""
|
||||
Renewal History Retention Service
|
||||
Manages retention policies for subscription renewal history records.
|
||||
|
||||
Retention Policy:
|
||||
- 0-12 months: Full records with usage snapshots
|
||||
- 12-24 months: Full records (compressed/removed usage snapshots)
|
||||
- 24-84 months: Summary records (no usage snapshots, payment data only)
|
||||
- 84+ months: Mark for archive (payment data preserved indefinitely)
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func, desc
|
||||
from loguru import logger
|
||||
import json
|
||||
|
||||
from models.subscription_models import SubscriptionRenewalHistory
|
||||
|
||||
|
||||
class RenewalHistoryRetentionService:
|
||||
"""Service for managing renewal history retention policies."""
|
||||
|
||||
# Retention periods (in days)
|
||||
COMPRESS_SNAPSHOT_DAYS = 365 # 12 months - compress/remove usage snapshots
|
||||
SUMMARY_RECORDS_DAYS = 730 # 24 months - create summary records
|
||||
ARCHIVE_DAYS = 2555 # 84 months (7 years) - mark for archive
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def check_and_apply_retention(self, user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Check and apply retention policies for renewal history.
|
||||
|
||||
Applies retention in stages:
|
||||
1. Compress usage snapshots for records 12-24 months old
|
||||
2. Create summary records for records 24-84 months old
|
||||
3. Mark records older than 84 months for archive
|
||||
|
||||
Returns:
|
||||
Dict with retention status and statistics
|
||||
"""
|
||||
try:
|
||||
now = datetime.utcnow()
|
||||
compress_cutoff = now - timedelta(days=self.COMPRESS_SNAPSHOT_DAYS)
|
||||
summary_cutoff = now - timedelta(days=self.SUMMARY_RECORDS_DAYS)
|
||||
archive_cutoff = now - timedelta(days=self.ARCHIVE_DAYS)
|
||||
|
||||
# Count records in each retention tier
|
||||
total_count = self.db.query(func.count(SubscriptionRenewalHistory.id)).filter(
|
||||
SubscriptionRenewalHistory.user_id == user_id
|
||||
).scalar() or 0
|
||||
|
||||
records_to_compress = self.db.query(SubscriptionRenewalHistory).filter(
|
||||
SubscriptionRenewalHistory.user_id == user_id,
|
||||
SubscriptionRenewalHistory.created_at < compress_cutoff,
|
||||
SubscriptionRenewalHistory.created_at >= summary_cutoff,
|
||||
SubscriptionRenewalHistory.usage_before_renewal.isnot(None) # Has snapshot to compress
|
||||
).all()
|
||||
|
||||
records_to_summarize = self.db.query(SubscriptionRenewalHistory).filter(
|
||||
SubscriptionRenewalHistory.user_id == user_id,
|
||||
SubscriptionRenewalHistory.created_at < summary_cutoff,
|
||||
SubscriptionRenewalHistory.created_at >= archive_cutoff,
|
||||
SubscriptionRenewalHistory.usage_before_renewal.isnot(None) # Has snapshot to remove
|
||||
).all()
|
||||
|
||||
records_to_archive = self.db.query(SubscriptionRenewalHistory).filter(
|
||||
SubscriptionRenewalHistory.user_id == user_id,
|
||||
SubscriptionRenewalHistory.created_at < archive_cutoff
|
||||
).all()
|
||||
|
||||
# Apply retention policies
|
||||
compressed_count = self._compress_usage_snapshots(records_to_compress)
|
||||
summarized_count = self._create_summary_records(records_to_summarize)
|
||||
archived_count = self._mark_for_archive(records_to_archive)
|
||||
|
||||
total_processed = compressed_count + summarized_count + archived_count
|
||||
|
||||
if total_processed == 0:
|
||||
return {
|
||||
'retention_applied': False,
|
||||
'total_records': total_count,
|
||||
'records_to_compress': len(records_to_compress),
|
||||
'records_to_summarize': len(records_to_summarize),
|
||||
'records_to_archive': len(records_to_archive),
|
||||
'message': 'No records require retention processing'
|
||||
}
|
||||
|
||||
self.db.commit()
|
||||
|
||||
logger.info(
|
||||
f"[RenewalRetention] Applied retention for user {user_id}: "
|
||||
f"{compressed_count} compressed, {summarized_count} summarized, "
|
||||
f"{archived_count} archived"
|
||||
)
|
||||
|
||||
return {
|
||||
'retention_applied': True,
|
||||
'total_records': total_count,
|
||||
'compressed_count': compressed_count,
|
||||
'summarized_count': summarized_count,
|
||||
'archived_count': archived_count,
|
||||
'total_processed': total_processed,
|
||||
'message': f'Processed {total_processed} records: {compressed_count} compressed, {summarized_count} summarized, {archived_count} archived'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"[RenewalRetention] Error applying retention for user {user_id}: {e}", exc_info=True)
|
||||
return {
|
||||
'retention_applied': False,
|
||||
'error': str(e),
|
||||
'message': f'Error applying retention: {str(e)}'
|
||||
}
|
||||
|
||||
def _compress_usage_snapshots(self, records: List[SubscriptionRenewalHistory]) -> int:
|
||||
"""
|
||||
Compress usage snapshots for records 12-24 months old.
|
||||
|
||||
Strategy: Replace detailed JSON snapshot with summary statistics only.
|
||||
Keeps only essential metrics: total_calls, total_tokens, total_cost.
|
||||
"""
|
||||
compressed = 0
|
||||
for record in records:
|
||||
if record.usage_before_renewal:
|
||||
try:
|
||||
usage_data = record.usage_before_renewal
|
||||
|
||||
# Handle both dict (SQLAlchemy JSON) and string formats
|
||||
if isinstance(usage_data, str):
|
||||
try:
|
||||
usage_data = json.loads(usage_data)
|
||||
except json.JSONDecodeError:
|
||||
# If it's not valid JSON, remove it
|
||||
record.usage_before_renewal = None
|
||||
compressed += 1
|
||||
continue
|
||||
elif not isinstance(usage_data, dict):
|
||||
# If it's not a dict or string, remove it
|
||||
record.usage_before_renewal = None
|
||||
compressed += 1
|
||||
continue
|
||||
|
||||
# Check if already compressed (has 'compressed_at' key)
|
||||
if isinstance(usage_data, dict) and 'compressed_at' in usage_data:
|
||||
# Already compressed, skip
|
||||
continue
|
||||
|
||||
# Create compressed summary (keep only key metrics)
|
||||
compressed_summary = {
|
||||
'total_calls': usage_data.get('total_calls', 0),
|
||||
'total_tokens': usage_data.get('total_tokens', 0),
|
||||
'total_cost': usage_data.get('total_cost', 0.0),
|
||||
'compressed_at': datetime.utcnow().isoformat(),
|
||||
'note': 'Usage snapshot compressed after 12 months'
|
||||
}
|
||||
|
||||
record.usage_before_renewal = compressed_summary
|
||||
compressed += 1
|
||||
|
||||
except (TypeError, AttributeError, KeyError) as e:
|
||||
logger.warning(f"[RenewalRetention] Failed to compress snapshot for record {record.id}: {e}")
|
||||
# If compression fails, remove snapshot entirely
|
||||
record.usage_before_renewal = None
|
||||
compressed += 1
|
||||
|
||||
return compressed
|
||||
|
||||
def _create_summary_records(self, records: List[SubscriptionRenewalHistory]) -> int:
|
||||
"""
|
||||
Create summary records for records 24-84 months old.
|
||||
|
||||
Strategy: Remove usage snapshots, keep only payment and subscription data.
|
||||
"""
|
||||
summarized = 0
|
||||
for record in records:
|
||||
if record.usage_before_renewal is not None:
|
||||
# Remove usage snapshot, keep payment and subscription data
|
||||
record.usage_before_renewal = None
|
||||
summarized += 1
|
||||
|
||||
return summarized
|
||||
|
||||
def _mark_for_archive(self, records: List[SubscriptionRenewalHistory]) -> int:
|
||||
"""
|
||||
Mark records older than 84 months for archive.
|
||||
|
||||
Strategy: Ensure usage snapshots are removed, payment data is preserved.
|
||||
Note: In future, these could be moved to an archive table.
|
||||
"""
|
||||
archived = 0
|
||||
for record in records:
|
||||
# Ensure usage snapshot is removed (should already be done)
|
||||
if record.usage_before_renewal is not None:
|
||||
record.usage_before_renewal = None
|
||||
archived += 1
|
||||
else:
|
||||
# Already processed, just count
|
||||
archived += 1
|
||||
|
||||
return archived
|
||||
|
||||
def get_retention_stats(self, user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get retention statistics for a user's renewal history.
|
||||
|
||||
Returns breakdown by retention tier.
|
||||
"""
|
||||
try:
|
||||
now = datetime.utcnow()
|
||||
compress_cutoff = now - timedelta(days=self.COMPRESS_SNAPSHOT_DAYS)
|
||||
summary_cutoff = now - timedelta(days=self.SUMMARY_RECORDS_DAYS)
|
||||
archive_cutoff = now - timedelta(days=self.ARCHIVE_DAYS)
|
||||
|
||||
total = self.db.query(func.count(SubscriptionRenewalHistory.id)).filter(
|
||||
SubscriptionRenewalHistory.user_id == user_id
|
||||
).scalar() or 0
|
||||
|
||||
recent = self.db.query(func.count(SubscriptionRenewalHistory.id)).filter(
|
||||
SubscriptionRenewalHistory.user_id == user_id,
|
||||
SubscriptionRenewalHistory.created_at >= compress_cutoff
|
||||
).scalar() or 0
|
||||
|
||||
to_compress = self.db.query(func.count(SubscriptionRenewalHistory.id)).filter(
|
||||
SubscriptionRenewalHistory.user_id == user_id,
|
||||
SubscriptionRenewalHistory.created_at < compress_cutoff,
|
||||
SubscriptionRenewalHistory.created_at >= summary_cutoff,
|
||||
SubscriptionRenewalHistory.usage_before_renewal.isnot(None)
|
||||
).scalar() or 0
|
||||
|
||||
to_summarize = self.db.query(func.count(SubscriptionRenewalHistory.id)).filter(
|
||||
SubscriptionRenewalHistory.user_id == user_id,
|
||||
SubscriptionRenewalHistory.created_at < summary_cutoff,
|
||||
SubscriptionRenewalHistory.created_at >= archive_cutoff,
|
||||
SubscriptionRenewalHistory.usage_before_renewal.isnot(None)
|
||||
).scalar() or 0
|
||||
|
||||
to_archive = self.db.query(func.count(SubscriptionRenewalHistory.id)).filter(
|
||||
SubscriptionRenewalHistory.user_id == user_id,
|
||||
SubscriptionRenewalHistory.created_at < archive_cutoff
|
||||
).scalar() or 0
|
||||
|
||||
return {
|
||||
'total_records': total,
|
||||
'recent_records': recent, # 0-12 months
|
||||
'records_to_compress': to_compress, # 12-24 months
|
||||
'records_to_summarize': to_summarize, # 24-84 months
|
||||
'records_to_archive': to_archive, # 84+ months
|
||||
'retention_policy': {
|
||||
'compress_after_days': self.COMPRESS_SNAPSHOT_DAYS,
|
||||
'summarize_after_days': self.SUMMARY_RECORDS_DAYS,
|
||||
'archive_after_days': self.ARCHIVE_DAYS
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[RenewalRetention] Error getting retention stats for user {user_id}: {e}", exc_info=True)
|
||||
return {
|
||||
'error': str(e),
|
||||
'total_records': 0
|
||||
}
|
||||
@@ -6,6 +6,7 @@ from loguru import logger
|
||||
|
||||
_checked_subscription_plan_columns: bool = False
|
||||
_checked_usage_summaries_columns: bool = False
|
||||
_checked_api_usage_logs_columns: bool = False
|
||||
|
||||
|
||||
def ensure_subscription_plan_columns(db: Session) -> None:
|
||||
@@ -114,9 +115,58 @@ def ensure_usage_summaries_columns(db: Session) -> None:
|
||||
raise
|
||||
|
||||
|
||||
def ensure_api_usage_logs_columns(db: Session) -> None:
|
||||
"""Ensure required columns exist on api_usage_logs for runtime safety.
|
||||
|
||||
This is a defensive guard for environments where migrations have not yet
|
||||
been applied. If columns are missing (e.g., actual_provider_name), we add them
|
||||
with a safe default so ORM queries do not fail.
|
||||
"""
|
||||
global _checked_api_usage_logs_columns
|
||||
if _checked_api_usage_logs_columns:
|
||||
return
|
||||
|
||||
try:
|
||||
# Discover existing columns using PRAGMA
|
||||
result = db.execute(text("PRAGMA table_info(api_usage_logs)"))
|
||||
cols: Set[str] = {row[1] for row in result}
|
||||
|
||||
logger.debug(f"Schema check: Found {len(cols)} columns in api_usage_logs table")
|
||||
|
||||
# Columns we may reference in models but might be missing in older DBs
|
||||
required_columns = {
|
||||
"actual_provider_name": "VARCHAR(50) NULL",
|
||||
}
|
||||
|
||||
for col_name, ddl in required_columns.items():
|
||||
if col_name not in cols:
|
||||
logger.info(f"Adding missing column {col_name} to api_usage_logs table")
|
||||
try:
|
||||
db.execute(text(f"ALTER TABLE api_usage_logs ADD COLUMN {col_name} {ddl}"))
|
||||
db.commit()
|
||||
logger.info(f"Successfully added column {col_name}")
|
||||
except Exception as alter_err:
|
||||
logger.error(f"Failed to add column {col_name}: {alter_err}")
|
||||
db.rollback()
|
||||
# Don't set flag on error - allow retry
|
||||
raise
|
||||
else:
|
||||
logger.debug(f"Column {col_name} already exists")
|
||||
|
||||
# Only set flag if we successfully completed the check
|
||||
_checked_api_usage_logs_columns = True
|
||||
except Exception as e:
|
||||
logger.error(f"Error ensuring api_usage_logs columns: {e}", exc_info=True)
|
||||
db.rollback()
|
||||
# Don't set the flag if there was an error, so we retry next time
|
||||
_checked_api_usage_logs_columns = False
|
||||
raise
|
||||
|
||||
|
||||
def ensure_all_schema_columns(db: Session) -> None:
|
||||
"""Ensure all required columns exist in subscription-related tables."""
|
||||
ensure_subscription_plan_columns(db)
|
||||
ensure_usage_summaries_columns(db)
|
||||
ensure_api_usage_logs_columns(db)
|
||||
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from models.subscription_models import (
|
||||
UserSubscription, UsageStatus
|
||||
)
|
||||
from .pricing_service import PricingService
|
||||
from .provider_detection import detect_actual_provider
|
||||
|
||||
class UsageTrackingService:
|
||||
"""Service for tracking API usage and managing subscription limits."""
|
||||
@@ -67,12 +68,21 @@ class UsageTrackingService:
|
||||
|
||||
# Create usage log entry
|
||||
billing_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Detect actual provider name (WaveSpeed, Google, HuggingFace, etc.)
|
||||
actual_provider_name = detect_actual_provider(
|
||||
provider_enum=provider,
|
||||
model_name=model_used,
|
||||
endpoint=endpoint
|
||||
)
|
||||
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
endpoint=endpoint,
|
||||
method=method,
|
||||
model_used=model_used,
|
||||
actual_provider_name=actual_provider_name, # Track actual provider
|
||||
tokens_input=tokens_input,
|
||||
tokens_output=tokens_output,
|
||||
tokens_total=(tokens_input or 0) + (tokens_output or 0),
|
||||
@@ -404,18 +414,128 @@ class UsageTrackingService:
|
||||
'cost': mistral_cost
|
||||
}
|
||||
|
||||
# Add other providers (Video, Audio, Image, Image Edit) for comprehensive breakdown
|
||||
# Video (WaveSpeed, HuggingFace, etc.)
|
||||
video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_cost = getattr(summary, "video_cost", 0.0) or 0.0
|
||||
if video_calls > 0 and video_cost == 0.0:
|
||||
video_logs = self.db.query(APIUsageLog).filter(
|
||||
APIUsageLog.user_id == user_id,
|
||||
APIUsageLog.provider == APIProvider.VIDEO,
|
||||
APIUsageLog.billing_period == billing_period
|
||||
).all()
|
||||
if video_logs:
|
||||
video_cost = sum(float(log.cost_total or 0.0) for log in video_logs)
|
||||
|
||||
provider_breakdown['video'] = {
|
||||
'calls': video_calls,
|
||||
'tokens': 0,
|
||||
'cost': video_cost
|
||||
}
|
||||
|
||||
# Audio (WaveSpeed, etc.)
|
||||
audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_cost = getattr(summary, "audio_cost", 0.0) or 0.0
|
||||
if audio_calls > 0 and audio_cost == 0.0:
|
||||
audio_logs = self.db.query(APIUsageLog).filter(
|
||||
APIUsageLog.user_id == user_id,
|
||||
APIUsageLog.provider == APIProvider.AUDIO,
|
||||
APIUsageLog.billing_period == billing_period
|
||||
).all()
|
||||
if audio_logs:
|
||||
audio_cost = sum(float(log.cost_total or 0.0) for log in audio_logs)
|
||||
|
||||
provider_breakdown['audio'] = {
|
||||
'calls': audio_calls,
|
||||
'tokens': 0,
|
||||
'cost': audio_cost
|
||||
}
|
||||
|
||||
# Image Generation (Stability/WaveSpeed)
|
||||
stability_calls = getattr(summary, "stability_calls", 0) or 0
|
||||
stability_cost = getattr(summary, "stability_cost", 0.0) or 0.0
|
||||
if stability_calls > 0 and stability_cost == 0.0:
|
||||
stability_logs = self.db.query(APIUsageLog).filter(
|
||||
APIUsageLog.user_id == user_id,
|
||||
APIUsageLog.provider == APIProvider.STABILITY,
|
||||
APIUsageLog.billing_period == billing_period
|
||||
).all()
|
||||
if stability_logs:
|
||||
stability_cost = sum(float(log.cost_total or 0.0) for log in stability_logs)
|
||||
|
||||
provider_breakdown['image'] = {
|
||||
'calls': stability_calls,
|
||||
'tokens': 0,
|
||||
'cost': stability_cost
|
||||
}
|
||||
|
||||
# Image Editing (WaveSpeed)
|
||||
image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_cost = getattr(summary, "image_edit_cost", 0.0) or 0.0
|
||||
if image_edit_calls > 0 and image_edit_cost == 0.0:
|
||||
image_edit_logs = self.db.query(APIUsageLog).filter(
|
||||
APIUsageLog.user_id == user_id,
|
||||
APIUsageLog.provider == APIProvider.IMAGE_EDIT,
|
||||
APIUsageLog.billing_period == billing_period
|
||||
).all()
|
||||
if image_edit_logs:
|
||||
image_edit_cost = sum(float(log.cost_total or 0.0) for log in image_edit_logs)
|
||||
|
||||
provider_breakdown['image_edit'] = {
|
||||
'calls': image_edit_calls,
|
||||
'tokens': 0,
|
||||
'cost': image_edit_cost
|
||||
}
|
||||
|
||||
# Search APIs
|
||||
tavily_calls = getattr(summary, "tavily_calls", 0) or 0
|
||||
tavily_cost = getattr(summary, "tavily_cost", 0.0) or 0.0
|
||||
provider_breakdown['tavily'] = {
|
||||
'calls': tavily_calls,
|
||||
'tokens': 0,
|
||||
'cost': tavily_cost
|
||||
}
|
||||
|
||||
serper_calls = getattr(summary, "serper_calls", 0) or 0
|
||||
serper_cost = getattr(summary, "serper_cost", 0.0) or 0.0
|
||||
provider_breakdown['serper'] = {
|
||||
'calls': serper_calls,
|
||||
'tokens': 0,
|
||||
'cost': serper_cost
|
||||
}
|
||||
|
||||
exa_calls = getattr(summary, "exa_calls", 0) or 0
|
||||
exa_cost = getattr(summary, "exa_cost", 0.0) or 0.0
|
||||
provider_breakdown['exa'] = {
|
||||
'calls': exa_calls,
|
||||
'tokens': 0,
|
||||
'cost': exa_cost
|
||||
}
|
||||
|
||||
# Calculate total cost from provider breakdown if summary total_cost is 0
|
||||
calculated_total_cost = gemini_cost + mistral_cost
|
||||
calculated_total_cost = (
|
||||
gemini_cost + mistral_cost + video_cost + audio_cost +
|
||||
stability_cost + image_edit_cost + tavily_cost + serper_cost + exa_cost
|
||||
)
|
||||
summary_total_cost = summary.total_cost or 0.0
|
||||
# Use calculated cost if summary cost is 0, otherwise use summary cost (it's more accurate)
|
||||
final_total_cost = summary_total_cost if summary_total_cost > 0 else calculated_total_cost
|
||||
|
||||
# If we calculated costs from logs, update the summary for future requests
|
||||
if calculated_total_cost > 0 and summary_total_cost == 0.0:
|
||||
logger.info(f"[UsageStats] Updating summary costs: total_cost={final_total_cost:.6f}, gemini_cost={gemini_cost:.6f}, mistral_cost={mistral_cost:.6f}")
|
||||
logger.info(f"[UsageStats] Updating summary costs: total_cost={final_total_cost:.6f}, gemini_cost={gemini_cost:.6f}, mistral_cost={mistral_cost:.6f}, video_cost={video_cost:.6f}, audio_cost={audio_cost:.6f}, image_cost={stability_cost:.6f}")
|
||||
summary.total_cost = final_total_cost
|
||||
summary.gemini_cost = gemini_cost
|
||||
summary.mistral_cost = mistral_cost
|
||||
# Update other provider costs if they exist
|
||||
if hasattr(summary, 'video_cost'):
|
||||
summary.video_cost = video_cost
|
||||
if hasattr(summary, 'audio_cost'):
|
||||
summary.audio_cost = audio_cost
|
||||
if hasattr(summary, 'stability_cost'):
|
||||
summary.stability_cost = stability_cost
|
||||
if hasattr(summary, 'image_edit_cost'):
|
||||
summary.image_edit_cost = image_edit_cost
|
||||
try:
|
||||
self.db.commit()
|
||||
except Exception as e:
|
||||
|
||||
@@ -1053,11 +1053,11 @@ class VideoStudioService:
|
||||
return base_cost * duration * model_multiplier * resolution_multiplier
|
||||
|
||||
def _get_default_model(self, operation_type: str) -> str:
|
||||
"""Get default model for operation type."""
|
||||
"""Get default model for operation type (OSS-focused defaults)."""
|
||||
defaults = {
|
||||
"text-to-video": "hunyuan-video-1.5",
|
||||
"image-to-video": "alibaba/wan-2.5",
|
||||
"text-to-video": "wan-2.5", # OSS: WAN 2.5 ($0.25) vs HunyuanVideo ($0.10) - better quality/value
|
||||
"image-to-video": "wan-2.5", # OSS: WAN 2.5 (same as text-to-video)
|
||||
"avatar": "wavespeed/mocha",
|
||||
"enhancement": "wavespeed/flashvsr",
|
||||
}
|
||||
return defaults.get(operation_type, "hunyuan-video-1.5")
|
||||
return defaults.get(operation_type, "wan-2.5") # Default to OSS model
|
||||
@@ -72,6 +72,7 @@ class ImageGenerator:
|
||||
model_paths = {
|
||||
"ideogram-v3-turbo": "ideogram-ai/ideogram-v3-turbo",
|
||||
"qwen-image": "wavespeed-ai/qwen-image/text-to-image",
|
||||
"flux-kontext-pro": "wavespeed-ai/flux-kontext-pro/text-to-image",
|
||||
}
|
||||
|
||||
model_path = model_paths.get(model)
|
||||
|
||||
Reference in New Issue
Block a user