AI Researcher and Video Studio implementation complete
This commit is contained in:
@@ -12,7 +12,7 @@ from datetime import datetime
|
||||
from services.database import get_db
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from services.content_asset_service import ContentAssetService
|
||||
from models.content_asset_models import AssetType, AssetSource
|
||||
from models.content_asset_models import AssetType, AssetSource, AssetCollection
|
||||
|
||||
router = APIRouter(prefix="/api/content-assets", tags=["Content Assets"])
|
||||
|
||||
@@ -62,6 +62,11 @@ async def get_assets(
|
||||
search: Optional[str] = Query(None, description="Search query"),
|
||||
tags: Optional[str] = Query(None, description="Comma-separated tags"),
|
||||
favorites_only: bool = Query(False, description="Only favorites"),
|
||||
collection_id: Optional[int] = Query(None, description="Filter by collection ID"),
|
||||
date_from: Optional[str] = Query(None, description="Filter from date (ISO format)"),
|
||||
date_to: Optional[str] = Query(None, description="Filter to date (ISO format)"),
|
||||
sort_by: str = Query("created_at", description="Sort by: created_at, updated_at, cost, file_size, title"),
|
||||
sort_order: str = Query("desc", description="Sort order: asc or desc"),
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: Session = Depends(get_db),
|
||||
@@ -95,6 +100,29 @@ async def get_assets(
|
||||
if tags:
|
||||
tags_list = [tag.strip() for tag in tags.split(",")]
|
||||
|
||||
# Parse date filters
|
||||
date_from_obj = None
|
||||
if date_from:
|
||||
try:
|
||||
date_from_obj = datetime.fromisoformat(date_from.replace('Z', '+00:00'))
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid date_from format. Use ISO format.")
|
||||
|
||||
date_to_obj = None
|
||||
if date_to:
|
||||
try:
|
||||
date_to_obj = datetime.fromisoformat(date_to.replace('Z', '+00:00'))
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid date_to format. Use ISO format.")
|
||||
|
||||
# Validate sort parameters
|
||||
valid_sort_by = ["created_at", "updated_at", "cost", "file_size", "title"]
|
||||
if sort_by not in valid_sort_by:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid sort_by. Must be one of: {', '.join(valid_sort_by)}")
|
||||
|
||||
if sort_order not in ["asc", "desc"]:
|
||||
raise HTTPException(status_code=400, detail="Invalid sort_order. Must be 'asc' or 'desc'")
|
||||
|
||||
assets, total = service.get_user_assets(
|
||||
user_id=user_id,
|
||||
asset_type=asset_type_enum,
|
||||
@@ -102,6 +130,11 @@ async def get_assets(
|
||||
search_query=search,
|
||||
tags=tags_list,
|
||||
favorites_only=favorites_only,
|
||||
collection_id=collection_id,
|
||||
date_from=date_from_obj,
|
||||
date_to=date_to_obj,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
@@ -330,3 +363,303 @@ async def get_statistics(
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error fetching statistics: {str(e)}")
|
||||
|
||||
|
||||
# ==================== Collection Endpoints ====================
|
||||
|
||||
class CollectionResponse(BaseModel):
|
||||
"""Response model for collection data."""
|
||||
id: int
|
||||
user_id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
is_public: bool = False
|
||||
cover_asset_id: Optional[int] = None
|
||||
asset_count: int = 0
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CollectionListResponse(BaseModel):
|
||||
"""Response model for collection list."""
|
||||
collections: List[CollectionResponse]
|
||||
total: int
|
||||
limit: int
|
||||
offset: int
|
||||
|
||||
|
||||
class CollectionCreateRequest(BaseModel):
|
||||
"""Request model for creating a collection."""
|
||||
name: str = Field(..., description="Collection name")
|
||||
description: Optional[str] = Field(None, description="Collection description")
|
||||
is_public: bool = Field(False, description="Whether collection is public")
|
||||
|
||||
|
||||
class CollectionUpdateRequest(BaseModel):
|
||||
"""Request model for updating a collection."""
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
is_public: Optional[bool] = None
|
||||
cover_asset_id: Optional[int] = None
|
||||
|
||||
|
||||
@router.post("/collections", response_model=CollectionResponse)
|
||||
async def create_collection(
|
||||
collection_data: CollectionCreateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Create a new asset collection."""
|
||||
try:
|
||||
user_id = current_user.get("user_id") or current_user.get("id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found")
|
||||
|
||||
service = ContentAssetService(db)
|
||||
collection = service.create_collection(
|
||||
user_id=user_id,
|
||||
name=collection_data.name,
|
||||
description=collection_data.description,
|
||||
is_public=collection_data.is_public,
|
||||
)
|
||||
|
||||
# Get asset count
|
||||
assets, _ = service.get_collection_assets(collection.id, user_id, limit=1, offset=0)
|
||||
asset_count = len(assets)
|
||||
|
||||
response = CollectionResponse.model_validate(collection)
|
||||
response.asset_count = asset_count
|
||||
return response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error creating collection: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/collections", response_model=CollectionListResponse)
|
||||
async def get_collections(
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Get user's collections."""
|
||||
try:
|
||||
user_id = current_user.get("user_id") or current_user.get("id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found")
|
||||
|
||||
service = ContentAssetService(db)
|
||||
collections, total = service.get_user_collections(user_id, limit=limit, offset=offset)
|
||||
|
||||
# Get asset counts for each collection
|
||||
collection_responses = []
|
||||
for collection in collections:
|
||||
assets, _ = service.get_collection_assets(collection.id, user_id, limit=1, offset=0)
|
||||
response = CollectionResponse.model_validate(collection)
|
||||
response.asset_count = len(assets)
|
||||
collection_responses.append(response)
|
||||
|
||||
return CollectionListResponse(
|
||||
collections=collection_responses,
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error fetching collections: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/collections/{collection_id}", response_model=CollectionResponse)
|
||||
async def get_collection(
|
||||
collection_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Get a specific collection."""
|
||||
try:
|
||||
user_id = current_user.get("user_id") or current_user.get("id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found")
|
||||
|
||||
service = ContentAssetService(db)
|
||||
collection = service.get_collection_by_id(collection_id, user_id)
|
||||
|
||||
if not collection:
|
||||
raise HTTPException(status_code=404, detail="Collection not found")
|
||||
|
||||
assets, _ = service.get_collection_assets(collection.id, user_id, limit=1, offset=0)
|
||||
response = CollectionResponse.model_validate(collection)
|
||||
response.asset_count = len(assets)
|
||||
return response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error fetching collection: {str(e)}")
|
||||
|
||||
|
||||
@router.put("/collections/{collection_id}", response_model=CollectionResponse)
|
||||
async def update_collection(
|
||||
collection_id: int,
|
||||
update_data: CollectionUpdateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Update collection metadata."""
|
||||
try:
|
||||
user_id = current_user.get("user_id") or current_user.get("id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found")
|
||||
|
||||
service = ContentAssetService(db)
|
||||
collection = service.update_collection(
|
||||
collection_id=collection_id,
|
||||
user_id=user_id,
|
||||
name=update_data.name,
|
||||
description=update_data.description,
|
||||
is_public=update_data.is_public,
|
||||
cover_asset_id=update_data.cover_asset_id,
|
||||
)
|
||||
|
||||
if not collection:
|
||||
raise HTTPException(status_code=404, detail="Collection not found")
|
||||
|
||||
assets, _ = service.get_collection_assets(collection.id, user_id, limit=1, offset=0)
|
||||
response = CollectionResponse.model_validate(collection)
|
||||
response.asset_count = len(assets)
|
||||
return response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error updating collection: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/collections/{collection_id}", response_model=Dict[str, Any])
|
||||
async def delete_collection(
|
||||
collection_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Delete a collection."""
|
||||
try:
|
||||
user_id = current_user.get("user_id") or current_user.get("id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found")
|
||||
|
||||
service = ContentAssetService(db)
|
||||
success = service.delete_collection(collection_id, user_id)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Collection not found")
|
||||
|
||||
return {"collection_id": collection_id, "deleted": True}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error deleting collection: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/collections/{collection_id}/assets", response_model=AssetListResponse)
|
||||
async def get_collection_assets(
|
||||
collection_id: int,
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Get all assets in a collection."""
|
||||
try:
|
||||
user_id = current_user.get("user_id") or current_user.get("id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found")
|
||||
|
||||
service = ContentAssetService(db)
|
||||
collection = service.get_collection_by_id(collection_id, user_id)
|
||||
|
||||
if not collection:
|
||||
raise HTTPException(status_code=404, detail="Collection not found")
|
||||
|
||||
assets, total = service.get_collection_assets(collection_id, user_id, limit=limit, offset=offset)
|
||||
|
||||
return AssetListResponse(
|
||||
assets=[AssetResponse.model_validate(asset) for asset in assets],
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error fetching collection assets: {str(e)}")
|
||||
|
||||
|
||||
class CollectionAssetsRequest(BaseModel):
|
||||
"""Request model for adding/removing assets from collection."""
|
||||
asset_ids: List[int] = Field(..., description="List of asset IDs")
|
||||
|
||||
|
||||
@router.post("/collections/{collection_id}/assets", response_model=Dict[str, Any])
|
||||
async def add_assets_to_collection(
|
||||
collection_id: int,
|
||||
request: CollectionAssetsRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Add assets to a collection."""
|
||||
try:
|
||||
user_id = current_user.get("user_id") or current_user.get("id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found")
|
||||
|
||||
service = ContentAssetService(db)
|
||||
count = service.add_assets_to_collection(collection_id, user_id, request.asset_ids)
|
||||
|
||||
return {
|
||||
"collection_id": collection_id,
|
||||
"assets_added": count,
|
||||
"asset_ids": request.asset_ids,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error adding assets to collection: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/collections/{collection_id}/assets", response_model=Dict[str, Any])
|
||||
async def remove_assets_from_collection(
|
||||
collection_id: int,
|
||||
request: CollectionAssetsRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Remove assets from a collection."""
|
||||
try:
|
||||
user_id = current_user.get("user_id") or current_user.get("id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found")
|
||||
|
||||
service = ContentAssetService(db)
|
||||
count = service.remove_assets_from_collection(collection_id, user_id, request.asset_ids)
|
||||
|
||||
return {
|
||||
"collection_id": collection_id,
|
||||
"assets_removed": count,
|
||||
"asset_ids": request.asset_ids,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error removing assets from collection: {str(e)}")
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from typing import Optional, List, Dict, Any
|
||||
from loguru import logger
|
||||
import uuid
|
||||
import asyncio
|
||||
from models.research_intent_models import TrendAnalysis
|
||||
|
||||
from services.database import get_db
|
||||
from services.research.core import (
|
||||
@@ -379,7 +380,7 @@ class AnalyzeIntentRequest(BaseModel):
|
||||
|
||||
|
||||
class AnalyzeIntentResponse(BaseModel):
|
||||
"""Response from intent analysis."""
|
||||
"""Response from intent analysis with optimized provider parameters."""
|
||||
success: bool
|
||||
intent: Dict[str, Any]
|
||||
analysis_summary: str
|
||||
@@ -387,7 +388,16 @@ class AnalyzeIntentResponse(BaseModel):
|
||||
suggested_keywords: List[str]
|
||||
suggested_angles: List[str]
|
||||
quick_options: List[Dict[str, Any]]
|
||||
confidence_reason: Optional[str] = None
|
||||
great_example: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
# Unified: Optimized provider parameters based on intent
|
||||
optimized_config: Optional[Dict[str, Any]] = None # Provider settings auto-configured from intent
|
||||
recommended_provider: Optional[str] = None # Best provider for this intent (exa, tavily, google)
|
||||
|
||||
# Google Trends configuration (if trends in deliverables)
|
||||
trends_config: Optional[Dict[str, Any]] = None # Trends keywords and settings with justifications
|
||||
|
||||
|
||||
class IntentDrivenResearchRequest(BaseModel):
|
||||
@@ -406,6 +416,9 @@ class IntentDrivenResearchRequest(BaseModel):
|
||||
include_domains: List[str] = Field(default_factory=list)
|
||||
exclude_domains: List[str] = Field(default_factory=list)
|
||||
|
||||
# Google Trends configuration (from intent analysis)
|
||||
trends_config: Optional[Dict[str, Any]] = None # Trends keywords and settings
|
||||
|
||||
# Skip intent inference (for re-runs with same intent)
|
||||
skip_inference: bool = False
|
||||
|
||||
@@ -445,6 +458,9 @@ class IntentDrivenResearchResponse(BaseModel):
|
||||
# The inferred/confirmed intent
|
||||
intent: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Google Trends data (if trends were analyzed)
|
||||
google_trends_data: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Error handling
|
||||
error_message: Optional[str] = None
|
||||
|
||||
@@ -480,14 +496,14 @@ async def analyze_research_intent(
|
||||
|
||||
if request.use_persona or request.use_competitor_data:
|
||||
from services.research.research_persona_service import ResearchPersonaService
|
||||
from services.onboarding_service import OnboardingService
|
||||
from services.onboarding.database_service import OnboardingDatabaseService
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Get database session
|
||||
db = next(get_db())
|
||||
try:
|
||||
persona_service = ResearchPersonaService(db)
|
||||
onboarding_service = OnboardingService()
|
||||
onboarding_service = OnboardingDatabaseService(db=db)
|
||||
|
||||
if request.use_persona:
|
||||
research_persona = persona_service.get_or_generate(user_id)
|
||||
@@ -497,37 +513,91 @@ async def analyze_research_intent(
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Infer intent
|
||||
intent_service = ResearchIntentInference()
|
||||
response = await intent_service.infer_intent(
|
||||
# Use Unified Research Analyzer (single AI call for intent + queries + params)
|
||||
from services.research.intent.unified_research_analyzer import UnifiedResearchAnalyzer
|
||||
|
||||
analyzer = UnifiedResearchAnalyzer()
|
||||
unified_result = await analyzer.analyze(
|
||||
user_input=request.user_input,
|
||||
keywords=request.keywords,
|
||||
research_persona=research_persona,
|
||||
competitor_data=competitor_data,
|
||||
industry=research_persona.default_industry if research_persona else None,
|
||||
target_audience=research_persona.default_target_audience if research_persona else None,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Generate targeted queries
|
||||
query_generator = IntentQueryGenerator()
|
||||
query_result = await query_generator.generate_queries(
|
||||
intent=response.intent,
|
||||
research_persona=research_persona,
|
||||
)
|
||||
if not unified_result.get("success", False):
|
||||
logger.warning("Unified analysis failed, using fallback")
|
||||
|
||||
# Update response with queries
|
||||
response.suggested_queries = [q.dict() for q in query_result.get("queries", [])]
|
||||
response.suggested_keywords = query_result.get("enhanced_keywords", [])
|
||||
response.suggested_angles = query_result.get("research_angles", [])
|
||||
# Extract results
|
||||
intent = unified_result.get("intent")
|
||||
queries = unified_result.get("queries", [])
|
||||
exa_config = unified_result.get("exa_config", {})
|
||||
tavily_config = unified_result.get("tavily_config", {})
|
||||
trends_config = unified_result.get("trends_config", {}) # NEW: Google Trends config
|
||||
|
||||
# Build optimized config with AI-driven justifications
|
||||
optimized_config = {
|
||||
"provider": unified_result.get("recommended_provider", "exa"),
|
||||
"provider_justification": unified_result.get("provider_justification", ""),
|
||||
# Exa settings with justifications
|
||||
"exa_type": exa_config.get("type", "auto"),
|
||||
"exa_type_justification": exa_config.get("type_justification", ""),
|
||||
"exa_category": exa_config.get("category"),
|
||||
"exa_category_justification": exa_config.get("category_justification", ""),
|
||||
"exa_include_domains": exa_config.get("includeDomains", []),
|
||||
"exa_include_domains_justification": exa_config.get("includeDomains_justification", ""),
|
||||
"exa_num_results": exa_config.get("numResults", 10),
|
||||
"exa_num_results_justification": exa_config.get("numResults_justification", ""),
|
||||
"exa_date_filter": exa_config.get("startPublishedDate"),
|
||||
"exa_date_justification": exa_config.get("date_justification", ""),
|
||||
"exa_highlights": exa_config.get("highlights", True),
|
||||
"exa_highlights_justification": exa_config.get("highlights_justification", ""),
|
||||
"exa_context": exa_config.get("context", True),
|
||||
"exa_context_justification": exa_config.get("context_justification", ""),
|
||||
# Tavily settings with justifications
|
||||
"tavily_topic": tavily_config.get("topic", "general"),
|
||||
"tavily_topic_justification": tavily_config.get("topic_justification", ""),
|
||||
"tavily_search_depth": tavily_config.get("search_depth", "advanced"),
|
||||
"tavily_search_depth_justification": tavily_config.get("search_depth_justification", ""),
|
||||
"tavily_include_answer": tavily_config.get("include_answer", True),
|
||||
"tavily_include_answer_justification": tavily_config.get("include_answer_justification", ""),
|
||||
"tavily_time_range": tavily_config.get("time_range"),
|
||||
"tavily_time_range_justification": tavily_config.get("time_range_justification", ""),
|
||||
"tavily_max_results": tavily_config.get("max_results", 10),
|
||||
"tavily_max_results_justification": tavily_config.get("max_results_justification", ""),
|
||||
"tavily_raw_content": tavily_config.get("include_raw_content", "markdown"),
|
||||
"tavily_raw_content_justification": tavily_config.get("include_raw_content_justification", ""),
|
||||
}
|
||||
|
||||
# Build trends config response (if enabled)
|
||||
trends_config_response = None
|
||||
if trends_config.get("enabled", False):
|
||||
trends_config_response = {
|
||||
"enabled": True,
|
||||
"keywords": trends_config.get("keywords", []),
|
||||
"keywords_justification": trends_config.get("keywords_justification", ""),
|
||||
"timeframe": trends_config.get("timeframe", "today 12-m"),
|
||||
"timeframe_justification": trends_config.get("timeframe_justification", ""),
|
||||
"geo": trends_config.get("geo", "US"),
|
||||
"geo_justification": trends_config.get("geo_justification", ""),
|
||||
"expected_insights": trends_config.get("expected_insights", []),
|
||||
}
|
||||
|
||||
return AnalyzeIntentResponse(
|
||||
success=True,
|
||||
intent=response.intent.dict(),
|
||||
analysis_summary=response.analysis_summary,
|
||||
suggested_queries=response.suggested_queries,
|
||||
suggested_keywords=response.suggested_keywords,
|
||||
suggested_angles=response.suggested_angles,
|
||||
quick_options=response.quick_options,
|
||||
intent=intent.dict() if hasattr(intent, 'dict') else intent,
|
||||
analysis_summary=unified_result.get("analysis_summary", ""),
|
||||
suggested_queries=[q.dict() if hasattr(q, 'dict') else q for q in queries],
|
||||
suggested_keywords=unified_result.get("enhanced_keywords", []),
|
||||
suggested_angles=unified_result.get("research_angles", []),
|
||||
quick_options=[], # Deprecated in unified approach
|
||||
confidence_reason=intent.confidence_reason if hasattr(intent, 'confidence_reason') else "",
|
||||
great_example=intent.great_example if hasattr(intent, 'great_example') else "",
|
||||
optimized_config=optimized_config,
|
||||
recommended_provider=unified_result.get("recommended_provider", "exa"),
|
||||
trends_config=trends_config_response, # NEW: Google Trends configuration
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -540,6 +610,8 @@ async def analyze_research_intent(
|
||||
suggested_keywords=[],
|
||||
suggested_angles=[],
|
||||
quick_options=[],
|
||||
confidence_reason=None,
|
||||
great_example=None,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
@@ -591,6 +663,7 @@ async def execute_intent_driven_research(
|
||||
intent_response = await intent_service.infer_intent(
|
||||
user_input=request.user_input,
|
||||
research_persona=research_persona,
|
||||
user_id=user_id,
|
||||
)
|
||||
intent = intent_response.intent
|
||||
else:
|
||||
@@ -613,6 +686,7 @@ async def execute_intent_driven_research(
|
||||
query_result = await query_generator.generate_queries(
|
||||
intent=intent,
|
||||
research_persona=research_persona,
|
||||
user_id=user_id,
|
||||
)
|
||||
queries = query_result.get("queries", [])
|
||||
|
||||
@@ -648,8 +722,35 @@ async def execute_intent_driven_research(
|
||||
exclude_domains=request.exclude_domains,
|
||||
)
|
||||
|
||||
# Execute research
|
||||
raw_result = await engine.research(context)
|
||||
# Execute research and trends in parallel
|
||||
research_task = asyncio.create_task(engine.research(context))
|
||||
|
||||
# Execute Google Trends analysis in parallel (if enabled)
|
||||
trends_task = None
|
||||
trends_data = None
|
||||
if request.trends_config and request.trends_config.get("enabled"):
|
||||
from services.research.trends.google_trends_service import GoogleTrendsService
|
||||
trends_service = GoogleTrendsService()
|
||||
trends_task = asyncio.create_task(
|
||||
trends_service.analyze_trends(
|
||||
keywords=request.trends_config.get("keywords", []),
|
||||
timeframe=request.trends_config.get("timeframe", "today 12-m"),
|
||||
geo=request.trends_config.get("geo", "US"),
|
||||
user_id=user_id
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for research to complete
|
||||
raw_result = await research_task
|
||||
|
||||
# Wait for trends if it was started
|
||||
if trends_task:
|
||||
try:
|
||||
trends_data = await trends_task
|
||||
logger.info(f"Google Trends data fetched: {len(trends_data.get('interest_over_time', []))} time points")
|
||||
except Exception as e:
|
||||
logger.error(f"Google Trends analysis failed: {e}")
|
||||
trends_data = None
|
||||
|
||||
# Analyze results using intent-aware analyzer
|
||||
analyzer = IntentAwareAnalyzer()
|
||||
@@ -661,8 +762,13 @@ async def execute_intent_driven_research(
|
||||
},
|
||||
intent=intent,
|
||||
research_persona=research_persona,
|
||||
user_id=user_id, # Required for subscription checking
|
||||
)
|
||||
|
||||
# Merge Google Trends data into trends analysis
|
||||
if trends_data and analyzed_result.trends:
|
||||
analyzed_result = _merge_trends_data(analyzed_result, trends_data)
|
||||
|
||||
# Build response
|
||||
return IntentDrivenResearchResponse(
|
||||
success=True,
|
||||
@@ -687,6 +793,7 @@ async def execute_intent_driven_research(
|
||||
gaps_identified=analyzed_result.gaps_identified,
|
||||
follow_up_queries=analyzed_result.follow_up_queries,
|
||||
intent=intent.dict(),
|
||||
google_trends_data=trends_data, # Include Google Trends data in response
|
||||
)
|
||||
|
||||
finally:
|
||||
@@ -737,3 +844,67 @@ def _map_provider_to_preference(provider: str) -> ProviderPreference:
|
||||
}
|
||||
return mapping.get(provider, ProviderPreference.AUTO)
|
||||
|
||||
|
||||
def _merge_trends_data(
|
||||
analyzed_result: Any,
|
||||
trends_data: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Merge Google Trends data into analyzed result trends.
|
||||
|
||||
Enhances AI-extracted trends with Google Trends data.
|
||||
"""
|
||||
from services.research.intent.intent_aware_analyzer import IntentDrivenResearchResult
|
||||
from models.research_intent_models import TrendAnalysis
|
||||
|
||||
if not analyzed_result.trends:
|
||||
return analyzed_result
|
||||
|
||||
# Enhance each trend with Google Trends data
|
||||
enhanced_trends = []
|
||||
for trend in analyzed_result.trends:
|
||||
# Create enhanced trend with Google Trends data
|
||||
trend_dict = trend.dict() if hasattr(trend, 'dict') else trend
|
||||
trend_dict["google_trends_data"] = trends_data
|
||||
|
||||
# Add interest score if available
|
||||
if trends_data.get("interest_over_time"):
|
||||
# Calculate average interest score
|
||||
interest_values = []
|
||||
for point in trends_data["interest_over_time"]:
|
||||
for key, value in point.items():
|
||||
if key not in ["date", "isPartial"] and isinstance(value, (int, float)):
|
||||
interest_values.append(value)
|
||||
if interest_values:
|
||||
trend_dict["interest_score"] = sum(interest_values) / len(interest_values)
|
||||
|
||||
# Add related topics/queries
|
||||
if trends_data.get("related_topics"):
|
||||
top_topics = [t.get("topic_title", "") for t in trends_data["related_topics"].get("top", [])[:5]]
|
||||
rising_topics = [t.get("topic_title", "") for t in trends_data["related_topics"].get("rising", [])[:5]]
|
||||
trend_dict["related_topics"] = {"top": top_topics, "rising": rising_topics}
|
||||
|
||||
if trends_data.get("related_queries"):
|
||||
top_queries = [q.get("query", "") for q in trends_data["related_queries"].get("top", [])[:5]]
|
||||
rising_queries = [q.get("query", "") for q in trends_data["related_queries"].get("rising", [])[:5]]
|
||||
trend_dict["related_queries"] = {"top": top_queries, "rising": rising_queries}
|
||||
|
||||
# Add regional interest
|
||||
if trends_data.get("interest_by_region"):
|
||||
regional_interest = {}
|
||||
for region in trends_data["interest_by_region"][:10]: # Top 10 regions
|
||||
region_name = region.get("geoName", "")
|
||||
if region_name:
|
||||
# Get interest value (first numeric column)
|
||||
for key, value in region.items():
|
||||
if key != "geoName" and isinstance(value, (int, float)):
|
||||
regional_interest[region_name] = value
|
||||
break
|
||||
trend_dict["regional_interest"] = regional_interest
|
||||
|
||||
enhanced_trends.append(TrendAnalysis(**trend_dict))
|
||||
|
||||
# Update analyzed result with enhanced trends
|
||||
analyzed_result.trends = enhanced_trends
|
||||
return analyzed_result
|
||||
|
||||
|
||||
@@ -236,6 +236,56 @@ async def get_persona_defaults(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/research-persona/verify")
|
||||
async def verify_research_persona_exists(
|
||||
current_user: Dict = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Verify if research persona exists in database (for debugging).
|
||||
Returns detailed information about persona status.
|
||||
"""
|
||||
try:
|
||||
user_id = str(current_user.get('id'))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User not authenticated")
|
||||
|
||||
if not db:
|
||||
raise HTTPException(status_code=500, detail="Database not available")
|
||||
|
||||
persona_service = ResearchPersonaService(db_session=db)
|
||||
persona_data = persona_service._get_persona_data_record(user_id)
|
||||
|
||||
if not persona_data:
|
||||
return {
|
||||
"exists": False,
|
||||
"reason": "No persona_data record found",
|
||||
"user_id": user_id
|
||||
}
|
||||
|
||||
has_persona = (
|
||||
persona_data.research_persona is not None
|
||||
and persona_data.research_persona != {}
|
||||
and persona_data.research_persona != ""
|
||||
and isinstance(persona_data.research_persona, dict)
|
||||
and len(persona_data.research_persona) > 0
|
||||
)
|
||||
|
||||
cache_valid = persona_service.is_cache_valid(persona_data) if has_persona else False
|
||||
|
||||
return {
|
||||
"exists": has_persona,
|
||||
"cache_valid": cache_valid,
|
||||
"generated_at": persona_data.research_persona_generated_at.isoformat() if persona_data.research_persona_generated_at else None,
|
||||
"persona_type": type(persona_data.research_persona).__name__ if persona_data.research_persona else None,
|
||||
"persona_keys": list(persona_data.research_persona.keys()) if has_persona and isinstance(persona_data.research_persona, dict) else None,
|
||||
"user_id": user_id
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[ResearchConfig] Error verifying research persona: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to verify research persona: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/research-persona")
|
||||
async def get_research_persona(
|
||||
current_user: Dict = Depends(get_current_user),
|
||||
@@ -261,6 +311,14 @@ async def get_research_persona(
|
||||
raise HTTPException(status_code=500, detail="Database not available")
|
||||
|
||||
persona_service = ResearchPersonaService(db_session=db)
|
||||
|
||||
# First check if persona exists (without generating)
|
||||
existing_persona = persona_service.get_cached_only(user_id)
|
||||
if existing_persona and not force_refresh:
|
||||
logger.info(f"[ResearchConfig] Returning existing research persona for user {user_id} (force_refresh={force_refresh})")
|
||||
return existing_persona.dict()
|
||||
|
||||
# Only generate if persona doesn't exist or force_refresh is True
|
||||
research_persona = persona_service.get_or_generate(user_id, force_refresh=force_refresh)
|
||||
|
||||
if not research_persona:
|
||||
@@ -286,8 +344,9 @@ async def get_research_config(
|
||||
):
|
||||
"""
|
||||
Get complete research configuration including provider availability and persona defaults.
|
||||
|
||||
Requires authentication - user must be logged in.
|
||||
"""
|
||||
user_id = None
|
||||
try:
|
||||
user_id = str(current_user.get('id'))
|
||||
logger.info(f"[ResearchConfig] Starting get_research_config for user {user_id}")
|
||||
@@ -378,19 +437,30 @@ async def get_research_config(
|
||||
|
||||
# Get research persona (optional, may not exist for all users)
|
||||
# CRITICAL: Use get_cached_only() to avoid triggering rate limit checks
|
||||
# Only return persona if it's already cached - don't generate on config load
|
||||
# Returns persona if it exists in database, regardless of cache validity
|
||||
research_persona = None
|
||||
persona_scheduled = False
|
||||
try:
|
||||
logger.debug(f"[ResearchConfig] Getting cached research persona for user {user_id}")
|
||||
logger.info(f"[ResearchConfig] 🔍 Getting research persona for user {user_id}")
|
||||
persona_service = ResearchPersonaService(db_session=db)
|
||||
research_persona = persona_service.get_cached_only(user_id)
|
||||
|
||||
logger.info(
|
||||
f"[ResearchConfig] Research persona check for user {user_id}: "
|
||||
f"persona_exists={research_persona is not None}, "
|
||||
f"onboarding_completed={onboarding_completed}"
|
||||
)
|
||||
if research_persona:
|
||||
# Check cache validity for logging
|
||||
persona_data_record = persona_service._get_persona_data_record(user_id)
|
||||
cache_valid = persona_data_record and persona_service.is_cache_valid(persona_data_record) if persona_data_record else False
|
||||
logger.info(
|
||||
f"[ResearchConfig] ✅ Research persona FOUND for user {user_id}: "
|
||||
f"exists=True, cache_valid={cache_valid}, "
|
||||
f"industry={research_persona.default_industry}, "
|
||||
f"onboarding_completed={onboarding_completed}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[ResearchConfig] ⚠️ Research persona NOT FOUND for user {user_id}: "
|
||||
f"persona_exists=False, "
|
||||
f"onboarding_completed={onboarding_completed}"
|
||||
)
|
||||
|
||||
# If onboarding is completed but persona doesn't exist, schedule generation
|
||||
if onboarding_completed and not research_persona:
|
||||
@@ -412,13 +482,24 @@ async def get_research_config(
|
||||
# get_cached_only() never raises HTTPException, but catch any unexpected errors
|
||||
logger.warning(f"[ResearchConfig] Could not load cached research persona for user {user_id}: {e}", exc_info=True)
|
||||
|
||||
# FastAPI will automatically serialize the ResearchPersona Pydantic model
|
||||
# Convert ResearchPersona to dict for proper serialization
|
||||
# FastAPI should handle Pydantic models, but explicit conversion ensures compatibility
|
||||
research_persona_dict = None
|
||||
if research_persona:
|
||||
try:
|
||||
research_persona_dict = research_persona.dict() if hasattr(research_persona, 'dict') else research_persona
|
||||
logger.debug(f"[ResearchConfig] Converted research persona to dict for user {user_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[ResearchConfig] Failed to convert research persona to dict: {e}")
|
||||
research_persona_dict = None
|
||||
|
||||
# FastAPI will automatically serialize the ResearchPersona dict
|
||||
# If there's a serialization issue, we catch it and log it
|
||||
try:
|
||||
response = ResearchConfigResponse(
|
||||
provider_availability=provider_availability,
|
||||
persona_defaults=persona_defaults,
|
||||
research_persona=research_persona,
|
||||
research_persona=research_persona_dict,
|
||||
onboarding_completed=onboarding_completed,
|
||||
persona_scheduled=persona_scheduled
|
||||
)
|
||||
@@ -434,9 +515,10 @@ async def get_research_config(
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[ResearchConfig] Response for user {user_id}: "
|
||||
f"[ResearchConfig] 📤 Response for user {user_id}: "
|
||||
f"onboarding_completed={onboarding_completed}, "
|
||||
f"persona_exists={research_persona is not None}, "
|
||||
f"persona_dict_exists={research_persona_dict is not None}, "
|
||||
f"persona_scheduled={persona_scheduled}"
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user