diff --git a/backend/api/blog_writer/router.py b/backend/api/blog_writer/router.py index 0d3b60fe..12c6a72e 100644 --- a/backend/api/blog_writer/router.py +++ b/backend/api/blog_writer/router.py @@ -185,10 +185,20 @@ async def get_research_status(task_id: str) -> Dict[str, Any]: # Outline Endpoints @router.post("/outline/start") -async def start_outline_generation(request: BlogOutlineRequest) -> Dict[str, Any]: +async def start_outline_generation( + request: BlogOutlineRequest, + current_user: Dict[str, Any] = Depends(get_current_user) +) -> Dict[str, Any]: """Start an outline generation operation and return a task ID for polling.""" try: - task_id = task_manager.start_outline_task(request) + # Extract Clerk user ID (required) + if not current_user: + raise HTTPException(status_code=401, detail="Authentication required") + user_id = str(current_user.get('id')) + if not user_id: + raise HTTPException(status_code=401, detail="User ID not found in authentication token") + + task_id = task_manager.start_outline_task(request, user_id) return {"task_id": task_id, "status": "started"} except Exception as e: logger.error(f"Failed to start outline generation: {e}") @@ -272,12 +282,22 @@ async def generate_section(request: BlogSectionRequest) -> BlogSectionResponse: @router.post("/content/start") -async def start_content_generation(request: Dict[str, Any]) -> Dict[str, Any]: +async def start_content_generation( + request: Dict[str, Any], + current_user: Dict[str, Any] = Depends(get_current_user) +) -> Dict[str, Any]: """Start full content generation and return a task id for polling. Accepts a payload compatible with MediumBlogGenerateRequest to minimize duplication. """ try: + # Extract Clerk user ID (required) + if not current_user: + raise HTTPException(status_code=401, detail="Authentication required") + user_id = str(current_user.get('id')) + if not user_id: + raise HTTPException(status_code=401, detail="User ID not found in authentication token") + # Map dict to MediumBlogGenerateRequest for reuse from models.blog_models import MediumBlogGenerateRequest, MediumSectionOutline, PersonaInfo sections = [MediumSectionOutline(**s) for s in request.get("sections", [])] @@ -293,7 +313,7 @@ async def start_content_generation(request: Dict[str, Any]) -> Dict[str, Any]: globalTargetWords=request.get("globalTargetWords", 1000), researchKeywords=request.get("researchKeywords") or request.get("keywords"), ) - task_id = task_manager.start_content_generation_task(req) + task_id = task_manager.start_content_generation_task(req, user_id) return {"task_id": task_id, "status": "started"} except Exception as e: logger.error(f"Failed to start content generation: {e}") @@ -307,6 +327,51 @@ async def content_generation_status(task_id: str) -> Dict[str, Any]: status = await task_manager.get_task_status(task_id) if status is None: raise HTTPException(status_code=404, detail="Task not found") + + # If task failed with subscription error, return HTTP error so frontend interceptor can catch it + if status.get('status') == 'failed' and status.get('error_status') in [429, 402]: + error_data = status.get('error_data', {}) or {} + error_status = status.get('error_status', 429) + + if not isinstance(error_data, dict): + logger.warning(f"Content generation task {task_id} error_data not dict: {error_data}") + error_data = {'error': str(error_data)} + + # Determine provider and usage info + stored_error_message = status.get('error', error_data.get('error')) + provider = error_data.get('provider', 'unknown') + usage_info = error_data.get('usage_info') + + if not usage_info: + usage_info = { + 'provider': provider, + 'message': stored_error_message, + 'error_type': error_data.get('error_type', 'unknown') + } + # Include any known fields from error_data + for key in ['current_tokens', 'requested_tokens', 'limit', 'current_calls']: + if key in error_data: + usage_info[key] = error_data[key] + + # Build error message for detail + error_msg = error_data.get('message', stored_error_message or 'Subscription limit exceeded') + + # Log the subscription error with all context + logger.warning(f"Content generation task {task_id} failed with subscription error {error_status}: {error_msg}") + logger.warning(f" Provider: {provider}, Usage Info: {usage_info}") + + # Use JSONResponse to ensure detail is returned as-is, not wrapped in an array + from fastapi.responses import JSONResponse + return JSONResponse( + status_code=error_status, + content={ + 'error': error_data.get('error', stored_error_message or 'Subscription limit exceeded'), + 'message': error_msg, + 'provider': provider, + 'usage_info': usage_info + } + ) + return status except HTTPException: raise @@ -499,14 +564,24 @@ async def get_outline_cache_entries(limit: int = 20): # --------------------------- @router.post("/generate/medium/start") -async def start_medium_generation(request: MediumBlogGenerateRequest): +async def start_medium_generation( + request: MediumBlogGenerateRequest, + current_user: Dict[str, Any] = Depends(get_current_user) +): """Start medium-length blog generation (ā¤1000 words) and return a task id.""" try: + # Extract Clerk user ID (required) + if not current_user: + raise HTTPException(status_code=401, detail="Authentication required") + user_id = str(current_user.get('id')) + if not user_id: + raise HTTPException(status_code=401, detail="User ID not found in authentication token") + # Simple server-side guard if (request.globalTargetWords or 1000) > 1000: raise HTTPException(status_code=400, detail="Global target words exceed 1000; use per-section generation") - task_id = task_manager.start_medium_generation_task(request) + task_id = task_manager.start_medium_generation_task(request, user_id) return {"task_id": task_id, "status": "started"} except HTTPException: raise @@ -522,6 +597,51 @@ async def medium_generation_status(task_id: str): status = await task_manager.get_task_status(task_id) if status is None: raise HTTPException(status_code=404, detail="Task not found") + + # If task failed with subscription error, return HTTP error so frontend interceptor can catch it + if status.get('status') == 'failed' and status.get('error_status') in [429, 402]: + error_data = status.get('error_data', {}) or {} + error_status = status.get('error_status', 429) + + if not isinstance(error_data, dict): + logger.warning(f"Medium generation task {task_id} error_data not dict: {error_data}") + error_data = {'error': str(error_data)} + + # Determine provider and usage info + stored_error_message = status.get('error', error_data.get('error')) + provider = error_data.get('provider', 'unknown') + usage_info = error_data.get('usage_info') + + if not usage_info: + usage_info = { + 'provider': provider, + 'message': stored_error_message, + 'error_type': error_data.get('error_type', 'unknown') + } + # Include any known fields from error_data + for key in ['current_tokens', 'requested_tokens', 'limit', 'current_calls']: + if key in error_data: + usage_info[key] = error_data[key] + + # Build error message for detail + error_msg = error_data.get('message', stored_error_message or 'Subscription limit exceeded') + + # Log the subscription error with all context + logger.warning(f"Medium generation task {task_id} failed with subscription error {error_status}: {error_msg}") + logger.warning(f" Provider: {provider}, Usage Info: {usage_info}") + + # Use JSONResponse to ensure detail is returned as-is, not wrapped in an array + from fastapi.responses import JSONResponse + return JSONResponse( + status_code=error_status, + content={ + 'error': error_data.get('error', stored_error_message or 'Subscription limit exceeded'), + 'message': error_msg, + 'provider': provider, + 'usage_info': usage_info + } + ) + return status except HTTPException: raise diff --git a/backend/api/blog_writer/seo_analysis.py b/backend/api/blog_writer/seo_analysis.py index dee0df5e..84deaf21 100644 --- a/backend/api/blog_writer/seo_analysis.py +++ b/backend/api/blog_writer/seo_analysis.py @@ -5,7 +5,7 @@ Provides API endpoint for analyzing blog content SEO with parallel processing and CopilotKit integration for real-time progress updates. """ -from fastapi import APIRouter, HTTPException, BackgroundTasks +from fastapi import APIRouter, HTTPException, BackgroundTasks, Depends from pydantic import BaseModel from typing import Dict, Any, Optional from loguru import logger @@ -13,6 +13,7 @@ from datetime import datetime from services.blog_writer.seo.blog_content_seo_analyzer import BlogContentSEOAnalyzer from services.blog_writer.core.blog_writer_service import BlogWriterService +from middleware.auth_middleware import get_current_user router = APIRouter(prefix="/api/blog-writer/seo", tags=["Blog SEO Analysis"]) @@ -56,7 +57,10 @@ blog_writer_service = BlogWriterService() @router.post("/analyze", response_model=SEOAnalysisResponse) -async def analyze_blog_seo(request: SEOAnalysisRequest): +async def analyze_blog_seo( + request: SEOAnalysisRequest, + current_user: Dict[str, Any] = Depends(get_current_user) +): """ Analyze blog content for SEO optimization @@ -69,6 +73,7 @@ async def analyze_blog_seo(request: SEOAnalysisRequest): Args: request: SEOAnalysisRequest containing blog content and research data + current_user: Authenticated user from middleware Returns: SEOAnalysisResponse with comprehensive analysis results @@ -76,6 +81,14 @@ async def analyze_blog_seo(request: SEOAnalysisRequest): try: logger.info(f"Starting SEO analysis for blog content") + # Extract Clerk user ID (required) + if not current_user: + raise HTTPException(status_code=401, detail="Authentication required") + + user_id = str(current_user.get('id', '')) + if not user_id: + raise HTTPException(status_code=401, detail="Invalid user ID in authentication token") + # Validate request if not request.blog_content or not request.blog_content.strip(): raise HTTPException(status_code=400, detail="Blog content is required") @@ -91,7 +104,8 @@ async def analyze_blog_seo(request: SEOAnalysisRequest): analysis_results = await seo_analyzer.analyze_blog_content( blog_content=request.blog_content, research_data=request.research_data, - blog_title=request.blog_title + blog_title=request.blog_title, + user_id=user_id ) # Check for errors @@ -131,7 +145,10 @@ async def analyze_blog_seo(request: SEOAnalysisRequest): @router.post("/analyze-with-progress") -async def analyze_blog_seo_with_progress(request: SEOAnalysisRequest): +async def analyze_blog_seo_with_progress( + request: SEOAnalysisRequest, + current_user: Dict[str, Any] = Depends(get_current_user) +): """ Analyze blog content for SEO with real-time progress updates @@ -140,6 +157,7 @@ async def analyze_blog_seo_with_progress(request: SEOAnalysisRequest): Args: request: SEOAnalysisRequest containing blog content and research data + current_user: Authenticated user from middleware Returns: Generator yielding progress updates and final results @@ -147,6 +165,14 @@ async def analyze_blog_seo_with_progress(request: SEOAnalysisRequest): try: logger.info(f"Starting SEO analysis with progress for blog content") + # Extract Clerk user ID (required) + if not current_user: + raise HTTPException(status_code=401, detail="Authentication required") + + user_id = str(current_user.get('id', '')) + if not user_id: + raise HTTPException(status_code=401, detail="Invalid user ID in authentication token") + # Validate request if not request.blog_content or not request.blog_content.strip(): raise HTTPException(status_code=400, detail="Blog content is required") @@ -209,7 +235,9 @@ async def analyze_blog_seo_with_progress(request: SEOAnalysisRequest): # Perform actual analysis analysis_results = await seo_analyzer.analyze_blog_content( blog_content=request.blog_content, - research_data=request.research_data + research_data=request.research_data, + blog_title=request.blog_title, + user_id=user_id ) # Final result diff --git a/backend/api/blog_writer/task_manager.py b/backend/api/blog_writer/task_manager.py index acbc15f0..59952a18 100644 --- a/backend/api/blog_writer/task_manager.py +++ b/backend/api/blog_writer/task_manager.py @@ -88,8 +88,12 @@ class TaskManager: response["error"] = task["error"] if "error_status" in task: response["error_status"] = task["error_status"] + logger.info(f"[TaskManager] get_task_status for {task_id}: Including error_status={task['error_status']} in response") if "error_data" in task: response["error_data"] = task["error_data"] + logger.info(f"[TaskManager] get_task_status for {task_id}: Including error_data with keys: {list(task['error_data'].keys()) if isinstance(task['error_data'], dict) else 'not-dict'}") + else: + logger.warning(f"[TaskManager] get_task_status for {task_id}: Task failed but no error_data found. Task keys: {list(task.keys())}") return response @@ -127,29 +131,33 @@ class TaskManager: asyncio.create_task(self._run_research_task(task_id, request, user_id)) return task_id - def start_outline_task(self, request: BlogOutlineRequest) -> str: + def start_outline_task(self, request: BlogOutlineRequest, user_id: str) -> str: """Start an outline generation operation and return a task ID.""" task_id = self.create_task("outline") # Start the outline generation operation in the background - asyncio.create_task(self._run_outline_generation_task(task_id, request)) + asyncio.create_task(self._run_outline_generation_task(task_id, request, user_id)) return task_id - def start_medium_generation_task(self, request: MediumBlogGenerateRequest) -> str: + def start_medium_generation_task(self, request: MediumBlogGenerateRequest, user_id: str) -> str: """Start a medium (ā¤1000 words) full-blog generation task.""" task_id = self.create_task("medium_generation") - asyncio.create_task(self._run_medium_generation_task(task_id, request)) + asyncio.create_task(self._run_medium_generation_task(task_id, request, user_id)) return task_id - def start_content_generation_task(self, request: MediumBlogGenerateRequest) -> str: + def start_content_generation_task(self, request: MediumBlogGenerateRequest, user_id: str) -> str: """Start content generation (full blog via sections) with provider parity. Internally reuses medium generator pipeline for now but tracked under distinct task_type 'content_generation' and same polling contract. + + Args: + request: Content generation request + user_id: User ID (required for subscription checks and usage tracking) """ task_id = self.create_task("content_generation") - asyncio.create_task(self._run_medium_generation_task(task_id, request)) + asyncio.create_task(self._run_medium_generation_task(task_id, request, user_id)) return task_id async def _run_research_task(self, task_id: str, request: BlogResearchRequest, user_id: str): @@ -205,7 +213,7 @@ class TaskManager: self.task_storage[task_id]["status"] = "failed" self.task_storage[task_id]["error"] = "Research completed with unknown status" - async def _run_outline_generation_task(self, task_id: str, request: BlogOutlineRequest): + async def _run_outline_generation_task(self, task_id: str, request: BlogOutlineRequest, user_id: str): """Background task to run outline generation and update status with progress messages.""" try: # Update status to running @@ -215,21 +223,31 @@ class TaskManager: # Send initial progress message await self.update_progress(task_id, "š§© Starting outline generation...") - # Run the actual outline generation with progress updates - result = await self.service.generate_outline_with_progress(request, task_id) + # Run the actual outline generation with progress updates (pass user_id for subscription checks) + result = await self.service.generate_outline_with_progress(request, task_id, user_id) # Update status to completed await self.update_progress(task_id, f"ā Outline generated successfully! Created {len(result.outline)} sections with {len(result.title_options)} title options.") self.task_storage[task_id]["status"] = "completed" self.task_storage[task_id]["result"] = result.dict() + except HTTPException as http_error: + # Handle HTTPException (e.g., 429 subscription limit) - preserve error details for frontend + error_detail = http_error.detail + error_message = error_detail.get('message', str(error_detail)) if isinstance(error_detail, dict) else str(error_detail) + await self.update_progress(task_id, f"ā {error_message}") + self.task_storage[task_id]["status"] = "failed" + self.task_storage[task_id]["error"] = error_message + # Store HTTP error details for frontend modal + self.task_storage[task_id]["error_status"] = http_error.status_code + self.task_storage[task_id]["error_data"] = error_detail if isinstance(error_detail, dict) else {"error": str(error_detail)} except Exception as e: await self.update_progress(task_id, f"ā Outline generation failed: {str(e)}") # Update status to failed self.task_storage[task_id]["status"] = "failed" self.task_storage[task_id]["error"] = str(e) - async def _run_medium_generation_task(self, task_id: str, request: MediumBlogGenerateRequest): + async def _run_medium_generation_task(self, task_id: str, request: MediumBlogGenerateRequest, user_id: str): """Background task to generate a medium blog using a single structured JSON call.""" try: self.task_storage[task_id]["status"] = "running" @@ -245,6 +263,7 @@ class TaskManager: result: MediumBlogGenerateResult = await self.service.generate_medium_blog_with_progress( request, task_id, + user_id ) if not result or not getattr(result, "sections", None): @@ -263,10 +282,38 @@ class TaskManager: self.task_storage[task_id]["result"] = result.dict() await self.update_progress(task_id, f"ā Generated {len(result.sections)} sections successfully.") - except Exception as e: - await self.update_progress(task_id, f"ā Medium generation failed: {str(e)}") + except HTTPException as http_error: + # Handle HTTPException (e.g., 429 subscription limit) - preserve error details for frontend + logger.info(f"[TaskManager] Caught HTTPException in medium generation task {task_id}: status={http_error.status_code}, detail={http_error.detail}") + error_detail = http_error.detail + error_message = error_detail.get('message', str(error_detail)) if isinstance(error_detail, dict) else str(error_detail) + await self.update_progress(task_id, f"ā {error_message}") self.task_storage[task_id]["status"] = "failed" - self.task_storage[task_id]["error"] = str(e) + self.task_storage[task_id]["error"] = error_message + # Store HTTP error details for frontend modal + self.task_storage[task_id]["error_status"] = http_error.status_code + self.task_storage[task_id]["error_data"] = error_detail if isinstance(error_detail, dict) else {"error": str(error_detail)} + logger.info(f"[TaskManager] Stored error_status={http_error.status_code} and error_data keys: {list(error_detail.keys()) if isinstance(error_detail, dict) else 'not-dict'}") + except Exception as e: + # Check if this is an HTTPException that got wrapped (can happen in async tasks) + # HTTPException has status_code and detail attributes + logger.info(f"[TaskManager] Caught Exception in medium generation task {task_id}: type={type(e).__name__}, has_status_code={hasattr(e, 'status_code')}, has_detail={hasattr(e, 'detail')}") + if hasattr(e, 'status_code') and hasattr(e, 'detail'): + # This is an HTTPException that was caught as generic Exception + logger.info(f"[TaskManager] Detected HTTPException in Exception handler: status={e.status_code}, detail={e.detail}") + error_detail = e.detail + error_message = error_detail.get('message', str(error_detail)) if isinstance(error_detail, dict) else str(error_detail) + await self.update_progress(task_id, f"ā {error_message}") + self.task_storage[task_id]["status"] = "failed" + self.task_storage[task_id]["error"] = error_message + # Store HTTP error details for frontend modal + self.task_storage[task_id]["error_status"] = e.status_code + self.task_storage[task_id]["error_data"] = error_detail if isinstance(error_detail, dict) else {"error": str(error_detail)} + logger.info(f"[TaskManager] Stored error_status={e.status_code} and error_data keys: {list(error_detail.keys()) if isinstance(error_detail, dict) else 'not-dict'}") + else: + await self.update_progress(task_id, f"ā Medium generation failed: {str(e)}") + self.task_storage[task_id]["status"] = "failed" + self.task_storage[task_id]["error"] = str(e) # Global task manager instance diff --git a/backend/api/subscription_api.py b/backend/api/subscription_api.py index 074a77d6..9abc7ceb 100644 --- a/backend/api/subscription_api.py +++ b/backend/api/subscription_api.py @@ -12,6 +12,7 @@ from functools import lru_cache from services.database import get_db from services.subscription import UsageTrackingService, PricingService +from services.subscription.schema_utils import ensure_subscription_plan_columns from middleware.auth_middleware import get_current_user from models.subscription_models import ( APIProvider, SubscriptionPlan, UserSubscription, UsageSummary, @@ -79,6 +80,8 @@ async def get_subscription_plans( """Get all available subscription plans.""" try: + # Ensure required columns exist (handles environments without migrations applied yet) + ensure_subscription_plan_columns(db) plans = db.query(SubscriptionPlan).filter( SubscriptionPlan.is_active == True ).order_by(SubscriptionPlan.price_monthly).all() @@ -137,6 +140,7 @@ async def get_user_subscription( raise HTTPException(status_code=403, detail="Access denied") try: + ensure_subscription_plan_columns(db) subscription = db.query(UserSubscription).filter( UserSubscription.user_id == user_id, UserSubscription.is_active == True @@ -234,6 +238,7 @@ async def get_subscription_status( raise HTTPException(status_code=403, detail="Access denied") try: + ensure_subscription_plan_columns(db) subscription = db.query(UserSubscription).filter( UserSubscription.user_id == user_id, UserSubscription.is_active == True @@ -346,6 +351,7 @@ async def subscribe_to_plan( raise HTTPException(status_code=403, detail="Access denied") try: + ensure_subscription_plan_columns(db) plan_id = subscription_data.get('plan_id') billing_cycle = subscription_data.get('billing_cycle', 'monthly') @@ -427,11 +433,16 @@ async def subscribe_to_plan( logger.info(f" š No usage summary found for period {current_period} (will be created on reset)") # Clear subscription limits cache to force refresh on next check + # IMPORTANT: Do this BEFORE resetting usage to ensure cache is cleared first try: from services.subscription import PricingService # Clear cache for this specific user (class-level cache shared across all instances) cleared_count = PricingService.clear_user_cache(user_id) logger.info(f" šļø Cleared {cleared_count} subscription cache entries for user {user_id}") + + # Also expire all SQLAlchemy objects to force fresh reads + db.expire_all() + logger.info(f" š Expired all SQLAlchemy objects to force fresh reads") except Exception as cache_err: logger.error(f" ā Failed to clear cache after subscribe: {cache_err}") @@ -441,12 +452,22 @@ async def subscribe_to_plan( usage_service = UsageTrackingService(db) reset_result = await usage_service.reset_current_billing_period(user_id) - # Re-query usage summary from DB after reset to get fresh data + # Force commit to ensure reset is persisted + db.commit() + + # Expire all SQLAlchemy objects to force fresh reads + db.expire_all() + + # Re-query usage summary from DB after reset to get fresh data (fresh query) usage_after = db.query(UsageSummary).filter( UsageSummary.user_id == user_id, UsageSummary.billing_period == current_period ).first() + # Refresh the usage object if found to ensure we have latest data + if usage_after: + db.refresh(usage_after) + if reset_result.get('reset'): logger.info(f" ā Usage counters RESET successfully") if usage_after: @@ -635,6 +656,7 @@ async def get_dashboard_data( """Get comprehensive dashboard data for usage monitoring.""" try: + ensure_subscription_plan_columns(db) # Serve from short TTL cache to avoid hammering DB on bursts import time now = time.time() diff --git a/backend/api/wix_routes.py b/backend/api/wix_routes.py index 9e35f25d..5f9ac837 100644 --- a/backend/api/wix_routes.py +++ b/backend/api/wix_routes.py @@ -535,15 +535,33 @@ async def test_publish_real(payload: Dict[str, Any]) -> Dict[str, Any]: if not member_id: raise HTTPException(status_code=400, detail="Unable to resolve member_id from token") + # Extract SEO metadata if provided + seo_metadata = payload.get("seo_metadata") + + # Extract category/tag IDs or names + # Can be either: + # - IDs: List of UUID strings + # - Names: List of name strings (will be looked up/created) + category_ids = payload.get("category_ids") or payload.get("category_names") + tag_ids = payload.get("tag_ids") or payload.get("tag_names") + + # If SEO metadata has categories/tags but they weren't explicitly provided, use them + if seo_metadata: + if not category_ids and seo_metadata.get("blog_categories"): + category_ids = seo_metadata.get("blog_categories") + if not tag_ids and seo_metadata.get("blog_tags"): + tag_ids = seo_metadata.get("blog_tags") + result = wix_service.create_blog_post( access_token=access_token, title=payload.get("title") or "Untitled", content=payload.get("content") or "", cover_image_url=payload.get("cover_image_url"), - category_ids=payload.get("category_ids") or None, - tag_ids=payload.get("tag_ids") or None, + category_ids=category_ids, + tag_ids=tag_ids, publish=bool(payload.get("publish", True)), member_id=member_id, + seo_metadata=seo_metadata, ) return { diff --git a/backend/app.py b/backend/app.py index e4a9cdca..80f673ea 100644 --- a/backend/app.py +++ b/backend/app.py @@ -11,6 +11,7 @@ import asyncio from datetime import datetime from services.subscription import monitoring_middleware + # Import modular utilities from alwrity_utils import HealthChecker, RateLimiter, FrontendServing, RouterManager, OnboardingManager diff --git a/backend/database/migrations/006_add_exa_provider.sql b/backend/database/migrations/006_add_exa_provider.sql new file mode 100644 index 00000000..47fe3e8c --- /dev/null +++ b/backend/database/migrations/006_add_exa_provider.sql @@ -0,0 +1,17 @@ +-- Add EXA to subscription plans +ALTER TABLE subscription_plans +ADD COLUMN exa_calls_limit INT DEFAULT 0; + +-- Add EXA to usage summaries +ALTER TABLE usage_summaries +ADD COLUMN exa_calls INT DEFAULT 0; + +ALTER TABLE usage_summaries +ADD COLUMN exa_cost FLOAT DEFAULT 0.0; + +-- Update default limits for existing plans +UPDATE subscription_plans SET exa_calls_limit = 100 WHERE tier = 'free'; +UPDATE subscription_plans SET exa_calls_limit = 500 WHERE tier = 'basic'; +UPDATE subscription_plans SET exa_calls_limit = 2000 WHERE tier = 'pro'; +UPDATE subscription_plans SET exa_calls_limit = 0 WHERE tier = 'enterprise'; + diff --git a/backend/models/blog_models.py b/backend/models/blog_models.py index 51a01113..01b9ea32 100644 --- a/backend/models/blog_models.py +++ b/backend/models/blog_models.py @@ -1,5 +1,6 @@ from pydantic import BaseModel, Field from typing import List, Optional, Dict, Any +from enum import Enum class PersonaInfo(BaseModel): @@ -50,6 +51,51 @@ class GroundingMetadata(BaseModel): web_search_queries: List[str] = [] +class ResearchMode(str, Enum): + """Research modes for different depth levels.""" + BASIC = "basic" + COMPREHENSIVE = "comprehensive" + TARGETED = "targeted" + + +class SourceType(str, Enum): + """Types of sources to include in research.""" + WEB = "web" + ACADEMIC = "academic" + NEWS = "news" + INDUSTRY = "industry" + EXPERT = "expert" + + +class DateRange(str, Enum): + """Date range filters for research.""" + LAST_WEEK = "last_week" + LAST_MONTH = "last_month" + LAST_3_MONTHS = "last_3_months" + LAST_6_MONTHS = "last_6_months" + LAST_YEAR = "last_year" + ALL_TIME = "all_time" + + +class ResearchProvider(str, Enum): + """Research provider options.""" + GOOGLE = "google" # Gemini native grounding + EXA = "exa" # Exa neural search + + +class ResearchConfig(BaseModel): + """Configuration for research execution.""" + mode: ResearchMode = ResearchMode.BASIC + provider: ResearchProvider = ResearchProvider.GOOGLE + date_range: Optional[DateRange] = None + source_types: List[SourceType] = [] + max_sources: int = 10 + include_statistics: bool = True + include_expert_quotes: bool = True + include_competitors: bool = True + include_trends: bool = True + + class BlogResearchRequest(BaseModel): keywords: List[str] topic: Optional[str] = None @@ -58,6 +104,8 @@ class BlogResearchRequest(BaseModel): tone: Optional[str] = None word_count_target: Optional[int] = 1500 persona: Optional[PersonaInfo] = None + research_mode: Optional[ResearchMode] = ResearchMode.BASIC + config: Optional[ResearchConfig] = None class BlogResearchResponse(BaseModel): diff --git a/backend/models/subscription_models.py b/backend/models/subscription_models.py index 0e770ca3..a6c73945 100644 --- a/backend/models/subscription_models.py +++ b/backend/models/subscription_models.py @@ -34,6 +34,7 @@ class APIProvider(enum.Enum): METAPHOR = "metaphor" FIRECRAWL = "firecrawl" STABILITY = "stability" + EXA = "exa" class BillingCycle(enum.Enum): MONTHLY = "monthly" @@ -66,6 +67,7 @@ class SubscriptionPlan(Base): metaphor_calls_limit = Column(Integer, default=0) firecrawl_calls_limit = Column(Integer, default=0) stability_calls_limit = Column(Integer, default=0) # Image generation + exa_calls_limit = Column(Integer, default=0) # Exa neural search # Token Limits (for LLM providers) gemini_tokens_limit = Column(Integer, default=0) @@ -182,6 +184,7 @@ class UsageSummary(Base): metaphor_calls = Column(Integer, default=0) firecrawl_calls = Column(Integer, default=0) stability_calls = Column(Integer, default=0) + exa_calls = Column(Integer, default=0) # Token Usage gemini_tokens = Column(Integer, default=0) @@ -199,6 +202,7 @@ class UsageSummary(Base): metaphor_cost = Column(Float, default=0.0) firecrawl_cost = Column(Float, default=0.0) stability_cost = Column(Float, default=0.0) + exa_cost = Column(Float, default=0.0) # Totals total_calls = Column(Integer, default=0) diff --git a/backend/requirements.txt b/backend/requirements.txt index edfa663f..76557032 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -38,6 +38,7 @@ aiohttp>=3.9.0 # Data processing pandas>=2.0.0 numpy>=1.24.0 +markdown>=3.5.0 # SEO Analysis dependencies advertools>=0.14.0 diff --git a/backend/scripts/update_basic_plan_limits.py b/backend/scripts/update_basic_plan_limits.py index e88350dc..3a5ec090 100644 --- a/backend/scripts/update_basic_plan_limits.py +++ b/backend/scripts/update_basic_plan_limits.py @@ -3,7 +3,7 @@ Script to update Basic plan subscription limits for testing rate limits and rene Updates: - LLM Calls (all providers): 10 calls (was 500-1000) -- LLM Tokens (all providers): 2000 tokens (was 200k-1M) +- LLM Tokens (all providers): 5000 tokens (increased from 2000 to support research workflow) - Images: 5 images (was 50) This script updates the SubscriptionPlan table, which automatically applies to all users @@ -69,11 +69,11 @@ def update_basic_plan_limits(): basic_plan.anthropic_calls_limit = 200 basic_plan.mistral_calls_limit = 500 - # Update all LLM provider token limits to 2000 - basic_plan.gemini_tokens_limit = 2000 - basic_plan.openai_tokens_limit = 2000 - basic_plan.anthropic_tokens_limit = 2000 - basic_plan.mistral_tokens_limit = 2000 + # Update all LLM provider token limits to 20000 (increased from 5000 for better stability) + basic_plan.gemini_tokens_limit = 20000 + basic_plan.openai_tokens_limit = 20000 + basic_plan.anthropic_tokens_limit = 20000 + basic_plan.mistral_tokens_limit = 20000 # Update image generation limit to 5 basic_plan.stability_calls_limit = 5 @@ -83,7 +83,7 @@ def update_basic_plan_limits(): logger.info("\nš New Basic plan limits:") logger.info(f" LLM Calls (all providers): 10") - logger.info(f" LLM Tokens (all providers): 2000") + logger.info(f" LLM Tokens (all providers): 20000 (increased from 5000)") logger.info(f" Images: 5") # Count and get affected users @@ -118,7 +118,7 @@ def update_basic_plan_limits(): # New limits - use unified AI text generation limit if available new_call_limit = getattr(basic_plan, 'ai_text_generation_calls_limit', None) or basic_plan.gemini_calls_limit - new_token_limit = basic_plan.gemini_tokens_limit # 2000 + new_token_limit = basic_plan.gemini_tokens_limit # 5000 (increased from 2000) new_image_limit = basic_plan.stability_calls_limit # 5 for sub in user_subscriptions: @@ -253,7 +253,7 @@ if __name__ == "__main__": logger.info("="*60) logger.info("This will update Basic plan limits for testing rate limits:") logger.info(" - LLM Calls: 10 (all providers)") - logger.info(" - LLM Tokens: 2000 (all providers)") + logger.info(" - LLM Tokens: 20000 (all providers, increased from 5000)") logger.info(" - Images: 5") logger.info("="*60) diff --git a/backend/services/blog_writer/content/medium_blog_generator.py b/backend/services/blog_writer/content/medium_blog_generator.py index df243c4b..00231c0d 100644 --- a/backend/services/blog_writer/content/medium_blog_generator.py +++ b/backend/services/blog_writer/content/medium_blog_generator.py @@ -8,6 +8,7 @@ import time import json from typing import Dict, Any, List from loguru import logger +from fastapi import HTTPException from models.blog_models import ( MediumBlogGenerateRequest, @@ -25,8 +26,20 @@ class MediumBlogGenerator: def __init__(self): self.cache = persistent_content_cache - async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str) -> MediumBlogGenerateResult: - """Use Gemini structured JSON to generate a medium-length blog in one call.""" + async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str, user_id: str) -> MediumBlogGenerateResult: + """Use Gemini structured JSON to generate a medium-length blog in one call. + + Args: + req: Medium blog generation request + task_id: Task ID for progress updates + user_id: User ID (required for subscription checks and usage tracking) + + Raises: + ValueError: If user_id is not provided + """ + if not user_id: + raise ValueError("user_id is required for medium blog generation (subscription checks and usage tracking)") + import time start = time.time() @@ -156,7 +169,7 @@ class MediumBlogGenerator: - Use language that resonates with {audience} - Maintain consistent voice that reflects this persona's expertise """ - + prompt = ( f"Write blog content for the following sections. Each section should be {req.globalTargetWords or 1000} words total, distributed across all sections.\n\n" f"Blog Title: {req.title}\n\n" @@ -176,11 +189,20 @@ class MediumBlogGenerator: f"Sections to write:\n{json.dumps(payload, ensure_ascii=False, indent=2)}" ) - ai_resp = llm_text_gen( - prompt=prompt, - json_struct=schema, - system_prompt=system, - ) + try: + ai_resp = llm_text_gen( + prompt=prompt, + json_struct=schema, + system_prompt=system, + user_id=user_id + ) + except HTTPException: + # Re-raise HTTPExceptions (e.g., 429 subscription limit) to preserve error details + raise + except Exception as llm_error: + # Wrap other errors + logger.error(f"AI generation failed: {llm_error}") + raise Exception(f"AI generation failed: {str(llm_error)}") # Check for errors in AI response if not ai_resp or ai_resp.get("error"): diff --git a/backend/services/blog_writer/core/blog_writer_service.py b/backend/services/blog_writer/core/blog_writer_service.py index b8324420..89b7c919 100644 --- a/backend/services/blog_writer/core/blog_writer_service.py +++ b/backend/services/blog_writer/core/blog_writer_service.py @@ -105,13 +105,20 @@ class BlogWriterService: return await self.research_service.research_with_progress(request, task_id, user_id) # Outline Methods - async def generate_outline(self, request: BlogOutlineRequest) -> BlogOutlineResponse: - """Generate AI-powered outline from research data.""" - return await self.outline_service.generate_outline(request) + async def generate_outline(self, request: BlogOutlineRequest, user_id: str) -> BlogOutlineResponse: + """Generate AI-powered outline from research data. + + Args: + request: Outline generation request with research data + user_id: User ID (required for subscription checks and usage tracking) + """ + if not user_id: + raise ValueError("user_id is required for outline generation (subscription checks and usage tracking)") + return await self.outline_service.generate_outline(request, user_id) - async def generate_outline_with_progress(self, request: BlogOutlineRequest, task_id: str) -> BlogOutlineResponse: + async def generate_outline_with_progress(self, request: BlogOutlineRequest, task_id: str, user_id: str) -> BlogOutlineResponse: """Generate outline with real-time progress updates.""" - return await self.outline_service.generate_outline_with_progress(request, task_id) + return await self.outline_service.generate_outline_with_progress(request, task_id, user_id) async def refine_outline(self, request: BlogOutlineRefineRequest) -> BlogOutlineResponse: """Refine outline with HITL operations.""" @@ -334,9 +341,17 @@ class BlogWriterService: # TODO: Move to content module return BlogPublishResponse(success=True, platform=request.platform, url="https://example.com/post") - async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str) -> MediumBlogGenerateResult: - """Use Gemini structured JSON to generate a medium-length blog in one call.""" - return await self.medium_blog_generator.generate_medium_blog_with_progress(req, task_id) + async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str, user_id: str) -> MediumBlogGenerateResult: + """Use Gemini structured JSON to generate a medium-length blog in one call. + + Args: + req: Medium blog generation request + task_id: Task ID for progress updates + user_id: User ID (required for subscription checks and usage tracking) + """ + if not user_id: + raise ValueError("user_id is required for medium blog generation (subscription checks and usage tracking)") + return await self.medium_blog_generator.generate_medium_blog_with_progress(req, task_id, user_id) async def analyze_flow_basic(self, request: Dict[str, Any]) -> Dict[str, Any]: """Analyze flow metrics for entire blog using single AI call (cost-effective).""" diff --git a/backend/services/blog_writer/outline/outline_generator.py b/backend/services/blog_writer/outline/outline_generator.py index a99ed5c3..40bfe0ae 100644 --- a/backend/services/blog_writer/outline/outline_generator.py +++ b/backend/services/blog_writer/outline/outline_generator.py @@ -42,10 +42,20 @@ class OutlineGenerator: self.response_processor = ResponseProcessor() self.parallel_processor = ParallelProcessor(self.source_mapper, self.grounding_engine) - async def generate(self, request: BlogOutlineRequest) -> BlogOutlineResponse: + async def generate(self, request: BlogOutlineRequest, user_id: str) -> BlogOutlineResponse: """ - Generate AI-powered outline using research results + Generate AI-powered outline using research results. + + Args: + request: Outline generation request with research data + user_id: User ID (required for subscription checks and usage tracking) + + Raises: + ValueError: If user_id is not provided """ + if not user_id: + raise ValueError("user_id is required for outline generation (subscription checks and usage tracking)") + # Extract research insights research = request.research primary_keywords = research.keyword_analysis.get('primary', []) @@ -68,15 +78,15 @@ class OutlineGenerator: # Define schema with proper property ordering (critical for Gemini API) outline_schema = self.prompt_builder.get_outline_schema() - # Generate outline using structured JSON response with retry logic - outline_data = await self.response_processor.generate_with_retry(outline_prompt, outline_schema) + # Generate outline using structured JSON response with retry logic (user_id required) + outline_data = await self.response_processor.generate_with_retry(outline_prompt, outline_schema, user_id) # Convert to BlogOutlineSection objects outline_sections = self.response_processor.convert_to_sections(outline_data, sources) - # Run parallel processing for speed optimization + # Run parallel processing for speed optimization (user_id required) mapped_sections, grounding_insights = await self.parallel_processor.run_parallel_processing_async( - outline_sections, research + outline_sections, research, user_id ) # Enhance sections with grounding insights @@ -85,9 +95,9 @@ class OutlineGenerator: mapped_sections, research.grounding_metadata, grounding_insights ) - # Optimize outline for better flow, SEO, and engagement + # Optimize outline for better flow, SEO, and engagement (user_id required) logger.info("Optimizing outline for better flow and engagement...") - optimized_sections = await self.outline_optimizer.optimize(grounding_enhanced_sections, "comprehensive optimization") + optimized_sections = await self.outline_optimizer.optimize(grounding_enhanced_sections, "comprehensive optimization", user_id) # Rebalance word counts for optimal distribution target_words = request.word_count or 1500 @@ -118,10 +128,21 @@ class OutlineGenerator: research_coverage=research_coverage ) - async def generate_with_progress(self, request: BlogOutlineRequest, task_id: str) -> BlogOutlineResponse: + async def generate_with_progress(self, request: BlogOutlineRequest, task_id: str, user_id: str) -> BlogOutlineResponse: """ Outline generation method with progress updates for real-time feedback. + + Args: + request: Outline generation request with research data + task_id: Task ID for progress updates + user_id: User ID (required for subscription checks and usage tracking) + + Raises: + ValueError: If user_id is not provided """ + if not user_id: + raise ValueError("user_id is required for outline generation (subscription checks and usage tracking)") + from api.blog_writer.task_manager import task_manager # Extract research insights @@ -150,17 +171,17 @@ class OutlineGenerator: await task_manager.update_progress(task_id, "š Making AI request to generate structured outline...") - # Generate outline using structured JSON response with retry logic - outline_data = await self.response_processor.generate_with_retry(outline_prompt, outline_schema, task_id) + # Generate outline using structured JSON response with retry logic (user_id required for subscription checks) + outline_data = await self.response_processor.generate_with_retry(outline_prompt, outline_schema, user_id, task_id) await task_manager.update_progress(task_id, "š Processing outline structure and validating sections...") # Convert to BlogOutlineSection objects outline_sections = self.response_processor.convert_to_sections(outline_data, sources) - # Run parallel processing for speed optimization + # Run parallel processing for speed optimization (user_id required for subscription checks) mapped_sections, grounding_insights = await self.parallel_processor.run_parallel_processing( - outline_sections, research, task_id + outline_sections, research, user_id, task_id ) # Enhance sections with grounding insights (depends on both previous tasks) @@ -169,9 +190,9 @@ class OutlineGenerator: mapped_sections, research.grounding_metadata, grounding_insights ) - # Optimize outline for better flow, SEO, and engagement + # Optimize outline for better flow, SEO, and engagement (user_id required for subscription checks) await task_manager.update_progress(task_id, "šÆ Optimizing outline for better flow and engagement...") - optimized_sections = await self.outline_optimizer.optimize(grounding_enhanced_sections, "comprehensive optimization") + optimized_sections = await self.outline_optimizer.optimize(grounding_enhanced_sections, "comprehensive optimization", user_id) # Rebalance word counts for optimal distribution await task_manager.update_progress(task_id, "āļø Rebalancing word count distribution...") diff --git a/backend/services/blog_writer/outline/outline_optimizer.py b/backend/services/blog_writer/outline/outline_optimizer.py index e5c4c0fb..8fa36e2a 100644 --- a/backend/services/blog_writer/outline/outline_optimizer.py +++ b/backend/services/blog_writer/outline/outline_optimizer.py @@ -13,8 +13,23 @@ from models.blog_models import BlogOutlineSection class OutlineOptimizer: """Optimizes outlines for better flow, SEO, and engagement.""" - async def optimize(self, outline: List[BlogOutlineSection], focus: str = "general optimization") -> List[BlogOutlineSection]: - """Optimize entire outline for better flow, SEO, and engagement.""" + async def optimize(self, outline: List[BlogOutlineSection], focus: str, user_id: str) -> List[BlogOutlineSection]: + """Optimize entire outline for better flow, SEO, and engagement. + + Args: + outline: List of outline sections to optimize + focus: Optimization focus (e.g., "general optimization") + user_id: User ID (required for subscription checks and usage tracking) + + Returns: + List of optimized outline sections + + Raises: + ValueError: If user_id is not provided + """ + if not user_id: + raise ValueError("user_id is required for outline optimization (subscription checks and usage tracking)") + outline_text = "\n".join([f"{i+1}. {s.heading}" for i, s in enumerate(outline)]) optimization_prompt = f"""Optimize this blog outline for better flow, engagement, and SEO: @@ -67,7 +82,8 @@ Return JSON format: optimized_data = llm_text_gen( prompt=optimization_prompt, json_struct=optimization_schema, - system_prompt=None + system_prompt=None, + user_id=user_id ) # Handle the new schema format with "outline" wrapper diff --git a/backend/services/blog_writer/outline/outline_service.py b/backend/services/blog_writer/outline/outline_service.py index 2dc6d506..1118b3de 100644 --- a/backend/services/blog_writer/outline/outline_service.py +++ b/backend/services/blog_writer/outline/outline_service.py @@ -29,11 +29,21 @@ class OutlineService: self.outline_optimizer = OutlineOptimizer() self.section_enhancer = SectionEnhancer() - async def generate_outline(self, request: BlogOutlineRequest) -> BlogOutlineResponse: + async def generate_outline(self, request: BlogOutlineRequest, user_id: str) -> BlogOutlineResponse: """ - Stage 2: Content Planning with AI-generated outline using research results - Uses Gemini with research data to create comprehensive, SEO-optimized outline + Stage 2: Content Planning with AI-generated outline using research results. + Uses Gemini with research data to create comprehensive, SEO-optimized outline. + + Args: + request: Outline generation request with research data + user_id: User ID (required for subscription checks and usage tracking) + + Raises: + ValueError: If user_id is not provided """ + if not user_id: + raise ValueError("user_id is required for outline generation (subscription checks and usage tracking)") + # Extract cache parameters - use original user keywords for consistent caching keywords = request.research.original_keywords or request.research.keyword_analysis.get('primary', []) industry = getattr(request.persona, 'industry', 'general') if request.persona else 'general' @@ -56,9 +66,9 @@ class OutlineService: logger.info(f"Using cached outline for keywords: {keywords}") return BlogOutlineResponse(**cached_result) - # Generate new outline if not cached + # Generate new outline if not cached (user_id required) logger.info(f"Generating new outline for keywords: {keywords}") - result = await self.outline_generator.generate(request) + result = await self.outline_generator.generate(request, user_id) # Cache the result persistent_outline_cache.cache_outline( @@ -73,7 +83,7 @@ class OutlineService: return result - async def generate_outline_with_progress(self, request: BlogOutlineRequest, task_id: str) -> BlogOutlineResponse: + async def generate_outline_with_progress(self, request: BlogOutlineRequest, task_id: str, user_id: str) -> BlogOutlineResponse: """ Outline generation method with progress updates for real-time feedback. """ @@ -104,7 +114,7 @@ class OutlineService: # Generate new outline if not cached logger.info(f"Generating new outline for keywords: {keywords} (with progress updates)") - result = await self.outline_generator.generate_with_progress(request, task_id) + result = await self.outline_generator.generate_with_progress(request, task_id, user_id) # Cache the result persistent_outline_cache.cache_outline( diff --git a/backend/services/blog_writer/outline/parallel_processor.py b/backend/services/blog_writer/outline/parallel_processor.py index 395f8dda..61e066ed 100644 --- a/backend/services/blog_writer/outline/parallel_processor.py +++ b/backend/services/blog_writer/outline/parallel_processor.py @@ -17,18 +17,25 @@ class ParallelProcessor: self.source_mapper = source_mapper self.grounding_engine = grounding_engine - async def run_parallel_processing(self, outline_sections, research, task_id: str = None) -> Tuple[Any, Any]: + async def run_parallel_processing(self, outline_sections, research, user_id: str, task_id: str = None) -> Tuple[Any, Any]: """ Run source mapping and grounding insights extraction in parallel. Args: outline_sections: List of outline sections to process research: Research data object + user_id: User ID (required for subscription checks and usage tracking) task_id: Optional task ID for progress updates Returns: Tuple of (mapped_sections, grounding_insights) + + Raises: + ValueError: If user_id is not provided """ + if not user_id: + raise ValueError("user_id is required for parallel processing (subscription checks and usage tracking)") + if task_id: from api.blog_writer.task_manager import task_manager await task_manager.update_progress(task_id, "ā” Running parallel processing for maximum speed...") @@ -37,7 +44,7 @@ class ParallelProcessor: # Run these tasks in parallel to save time source_mapping_task = asyncio.create_task( - self._run_source_mapping(outline_sections, research, task_id) + self._run_source_mapping(outline_sections, research, task_id, user_id) ) grounding_insights_task = asyncio.create_task( @@ -52,22 +59,29 @@ class ParallelProcessor: return mapped_sections, grounding_insights - async def run_parallel_processing_async(self, outline_sections, research) -> Tuple[Any, Any]: + async def run_parallel_processing_async(self, outline_sections, research, user_id: str) -> Tuple[Any, Any]: """ Run parallel processing without progress updates (for non-progress methods). Args: outline_sections: List of outline sections to process research: Research data object + user_id: User ID (required for subscription checks and usage tracking) Returns: Tuple of (mapped_sections, grounding_insights) + + Raises: + ValueError: If user_id is not provided """ + if not user_id: + raise ValueError("user_id is required for parallel processing (subscription checks and usage tracking)") + logger.info("Running parallel processing for maximum speed...") # Run these tasks in parallel to save time source_mapping_task = asyncio.create_task( - self._run_source_mapping_async(outline_sections, research) + self._run_source_mapping_async(outline_sections, research, user_id) ) grounding_insights_task = asyncio.create_task( @@ -82,12 +96,12 @@ class ParallelProcessor: return mapped_sections, grounding_insights - async def _run_source_mapping(self, outline_sections, research, task_id): + async def _run_source_mapping(self, outline_sections, research, task_id, user_id: str): """Run source mapping in parallel.""" if task_id: from api.blog_writer.task_manager import task_manager await task_manager.update_progress(task_id, "š Applying intelligent source-to-section mapping...") - return self.source_mapper.map_sources_to_sections(outline_sections, research) + return self.source_mapper.map_sources_to_sections(outline_sections, research, user_id) async def _run_grounding_insights_extraction(self, research, task_id): """Run grounding insights extraction in parallel.""" @@ -96,10 +110,10 @@ class ParallelProcessor: await task_manager.update_progress(task_id, "š§ Extracting grounding metadata insights...") return self.grounding_engine.extract_contextual_insights(research.grounding_metadata) - async def _run_source_mapping_async(self, outline_sections, research): + async def _run_source_mapping_async(self, outline_sections, research, user_id: str): """Run source mapping in parallel (async version without progress updates).""" logger.info("Applying intelligent source-to-section mapping...") - return self.source_mapper.map_sources_to_sections(outline_sections, research) + return self.source_mapper.map_sources_to_sections(outline_sections, research, user_id) async def _run_grounding_insights_extraction_async(self, research): """Run grounding insights extraction in parallel (async version without progress updates).""" diff --git a/backend/services/blog_writer/outline/response_processor.py b/backend/services/blog_writer/outline/response_processor.py index c8d8479c..826d3509 100644 --- a/backend/services/blog_writer/outline/response_processor.py +++ b/backend/services/blog_writer/outline/response_processor.py @@ -18,8 +18,21 @@ class ResponseProcessor: """Initialize the response processor.""" pass - async def generate_with_retry(self, prompt: str, schema: Dict[str, Any], task_id: str = None) -> Dict[str, Any]: - """Generate outline with retry logic for API failures.""" + async def generate_with_retry(self, prompt: str, schema: Dict[str, Any], user_id: str, task_id: str = None) -> Dict[str, Any]: + """Generate outline with retry logic for API failures. + + Args: + prompt: The prompt for outline generation + schema: JSON schema for structured response + user_id: User ID (required for subscription checks and usage tracking) + task_id: Optional task ID for progress updates + + Raises: + ValueError: If user_id is not provided + """ + if not user_id: + raise ValueError("user_id is required for outline generation (subscription checks and usage tracking)") + from services.llm_providers.main_text_generation import llm_text_gen from api.blog_writer.task_manager import task_manager @@ -34,7 +47,8 @@ class ResponseProcessor: outline_data = llm_text_gen( prompt=prompt, json_struct=schema, - system_prompt=None + system_prompt=None, + user_id=user_id ) # Log response for debugging diff --git a/backend/services/blog_writer/outline/section_enhancer.py b/backend/services/blog_writer/outline/section_enhancer.py index 936576bc..8cd47890 100644 --- a/backend/services/blog_writer/outline/section_enhancer.py +++ b/backend/services/blog_writer/outline/section_enhancer.py @@ -12,8 +12,23 @@ from models.blog_models import BlogOutlineSection class SectionEnhancer: """Enhances individual outline sections using AI.""" - async def enhance(self, section: BlogOutlineSection, focus: str = "general improvement") -> BlogOutlineSection: - """Enhance a section using AI with research context.""" + async def enhance(self, section: BlogOutlineSection, focus: str, user_id: str) -> BlogOutlineSection: + """Enhance a section using AI with research context. + + Args: + section: Outline section to enhance + focus: Enhancement focus (e.g., "general improvement") + user_id: User ID (required for subscription checks and usage tracking) + + Returns: + Enhanced outline section + + Raises: + ValueError: If user_id is not provided + """ + if not user_id: + raise ValueError("user_id is required for section enhancement (subscription checks and usage tracking)") + enhancement_prompt = f""" Enhance the following blog section to make it more engaging, comprehensive, and valuable: @@ -61,7 +76,8 @@ class SectionEnhancer: enhanced_data = llm_text_gen( prompt=enhancement_prompt, json_struct=enhancement_schema, - system_prompt=None + system_prompt=None, + user_id=user_id ) if isinstance(enhanced_data, dict) and 'error' not in enhanced_data: diff --git a/backend/services/blog_writer/outline/source_mapper.py b/backend/services/blog_writer/outline/source_mapper.py index d630b22b..8ff766be 100644 --- a/backend/services/blog_writer/outline/source_mapper.py +++ b/backend/services/blog_writer/outline/source_mapper.py @@ -52,7 +52,8 @@ class SourceToSectionMapper: def map_sources_to_sections( self, sections: List[BlogOutlineSection], - research_data: BlogResearchResponse + research_data: BlogResearchResponse, + user_id: str ) -> List[BlogOutlineSection]: """ Map research sources to outline sections using intelligent algorithms. @@ -60,10 +61,17 @@ class SourceToSectionMapper: Args: sections: List of outline sections to map sources to research_data: Research data containing sources and metadata + user_id: User ID (required for subscription checks and usage tracking) Returns: List of outline sections with intelligently mapped sources + + Raises: + ValueError: If user_id is not provided """ + if not user_id: + raise ValueError("user_id is required for source mapping (subscription checks and usage tracking)") + if not sections or not research_data.sources: logger.warning("No sections or sources to map") return sections @@ -73,8 +81,8 @@ class SourceToSectionMapper: # Step 1: Algorithmic mapping mapping_results = self._algorithmic_source_mapping(sections, research_data) - # Step 2: AI validation and improvement (single prompt) - validated_mapping = self._ai_validate_mapping(mapping_results, research_data) + # Step 2: AI validation and improvement (single prompt, user_id required for subscription checks) + validated_mapping = self._ai_validate_mapping(mapping_results, research_data, user_id) # Step 3: Apply validated mapping to sections mapped_sections = self._apply_mapping_to_sections(sections, validated_mapping) @@ -261,7 +269,8 @@ class SourceToSectionMapper: def _ai_validate_mapping( self, mapping_results: Dict[str, List[Tuple[ResearchSource, float]]], - research_data: BlogResearchResponse + research_data: BlogResearchResponse, + user_id: str ) -> Dict[str, List[Tuple[ResearchSource, float]]]: """ Use AI to validate and improve the algorithmic mapping results. @@ -269,18 +278,25 @@ class SourceToSectionMapper: Args: mapping_results: Algorithmic mapping results research_data: Research data for context + user_id: User ID (required for subscription checks and usage tracking) Returns: AI-validated and improved mapping results + + Raises: + ValueError: If user_id is not provided """ + if not user_id: + raise ValueError("user_id is required for AI validation (subscription checks and usage tracking)") + try: logger.info("Starting AI validation of source-to-section mapping...") # Build AI validation prompt validation_prompt = self._build_validation_prompt(mapping_results, research_data) - # Get AI validation response - validation_response = self._get_ai_validation_response(validation_prompt) + # Get AI validation response (user_id required for subscription checks) + validation_response = self._get_ai_validation_response(validation_prompt, user_id) # Parse and apply AI validation results validated_mapping = self._parse_validation_response(validation_response, mapping_results, research_data) @@ -548,23 +564,31 @@ Analyze the mapping and provide your recommendations. return prompt - def _get_ai_validation_response(self, prompt: str) -> str: + def _get_ai_validation_response(self, prompt: str, user_id: str) -> str: """ Get AI validation response using LLM provider. Args: prompt: Validation prompt + user_id: User ID (required for subscription checks and usage tracking) Returns: AI validation response + + Raises: + ValueError: If user_id is not provided """ + if not user_id: + raise ValueError("user_id is required for AI validation response (subscription checks and usage tracking)") + try: from services.llm_providers.main_text_generation import llm_text_gen response = llm_text_gen( prompt=prompt, json_struct=None, - system_prompt=None + system_prompt=None, + user_id=user_id ) return response diff --git a/backend/services/blog_writer/research/__init__.py b/backend/services/blog_writer/research/__init__.py index ecf76b11..87035e12 100644 --- a/backend/services/blog_writer/research/__init__.py +++ b/backend/services/blog_writer/research/__init__.py @@ -13,11 +13,17 @@ from .keyword_analyzer import KeywordAnalyzer from .competitor_analyzer import CompetitorAnalyzer from .content_angle_generator import ContentAngleGenerator from .data_filter import ResearchDataFilter +from .base_provider import ResearchProvider as BaseResearchProvider +from .google_provider import GoogleResearchProvider +from .exa_provider import ExaResearchProvider __all__ = [ 'ResearchService', 'KeywordAnalyzer', 'CompetitorAnalyzer', 'ContentAngleGenerator', - 'ResearchDataFilter' + 'ResearchDataFilter', + 'BaseResearchProvider', + 'GoogleResearchProvider', + 'ExaResearchProvider', ] diff --git a/backend/services/blog_writer/research/base_provider.py b/backend/services/blog_writer/research/base_provider.py new file mode 100644 index 00000000..72aae9b2 --- /dev/null +++ b/backend/services/blog_writer/research/base_provider.py @@ -0,0 +1,37 @@ +""" +Base Research Provider Interface + +Abstract base class for research provider implementations. +Ensures consistency across different research providers (Google, Exa, etc.) +""" + +from abc import ABC, abstractmethod +from typing import Dict, Any + + +class ResearchProvider(ABC): + """Abstract base class for research providers.""" + + @abstractmethod + async def search( + self, + prompt: str, + topic: str, + industry: str, + target_audience: str, + config: Any, # ResearchConfig + user_id: str + ) -> Dict[str, Any]: + """Execute research and return raw results.""" + pass + + @abstractmethod + def get_provider_enum(self): + """Return APIProvider enum for subscription tracking.""" + pass + + @abstractmethod + def estimate_tokens(self) -> int: + """Estimate token usage for pre-flight validation.""" + pass + diff --git a/backend/services/blog_writer/research/exa_provider.py b/backend/services/blog_writer/research/exa_provider.py new file mode 100644 index 00000000..6c20a0d6 --- /dev/null +++ b/backend/services/blog_writer/research/exa_provider.py @@ -0,0 +1,188 @@ +""" +Exa Research Provider + +Neural search implementation using Exa API for high-quality, citation-rich research. +""" + +from exa_py import Exa +import os +from loguru import logger +from models.subscription_models import APIProvider +from .base_provider import ResearchProvider as BaseProvider + + +class ExaResearchProvider(BaseProvider): + """Exa neural search provider.""" + + def __init__(self): + self.api_key = os.getenv("EXA_API_KEY") + if not self.api_key: + raise RuntimeError("EXA_API_KEY not configured") + self.exa = Exa(self.api_key) + logger.info("ā Exa Research Provider initialized") + + async def search(self, prompt, topic, industry, target_audience, config, user_id): + """Execute Exa neural search and return standardized results.""" + # Build Exa query + query = f"{topic} {industry} {target_audience}" + + # Map source types to Exa categories + category = self._map_source_type_to_category(config.source_types) + + logger.info(f"[Exa Research] Executing search: {query}") + + # Execute Exa search + results = self.exa.search_and_contents( + query, + type="auto", + category=category, + num_results=min(config.max_sources, 25), + contents={ + 'text': {'max_characters': 1000}, + 'summary': {'query': f"Key insights about {topic}"}, + 'highlights': { + 'num_sentences': 2, + 'highlights_per_url': 3 + } + } + ) + + # Transform to standardized format + sources = self._transform_sources(results.results) + content = self._aggregate_content(results.results) + search_type = getattr(results, 'resolvedSearchType', 'neural') if hasattr(results, 'resolvedSearchType') else 'neural' + + # Get cost if available + cost = 0.005 # Default Exa cost for 1-25 results + if hasattr(results, 'costDollars'): + if hasattr(results.costDollars, 'total'): + cost = results.costDollars.total + + logger.info(f"[Exa Research] Search completed: {len(sources)} sources, type: {search_type}") + + return { + 'sources': sources, + 'content': content, + 'search_type': search_type, + 'provider': 'exa', + 'search_queries': [query], + 'cost': {'total': cost} + } + + def get_provider_enum(self): + """Return EXA provider enum for subscription tracking.""" + return APIProvider.EXA + + def estimate_tokens(self) -> int: + """Estimate token usage for Exa (not token-based).""" + return 0 # Exa is per-search, not token-based + + def _map_source_type_to_category(self, source_types): + """Map SourceType enum to Exa category parameter.""" + if not source_types: + return None + + category_map = { + 'research paper': 'research paper', + 'news': 'news', + 'web': 'personal site', + 'industry': 'company', + 'expert': 'linkedin profile' + } + + for st in source_types: + if st.value in category_map: + return category_map[st.value] + + return None + + def _transform_sources(self, results): + """Transform Exa results to ResearchSource format.""" + sources = [] + for idx, result in enumerate(results): + source_type = self._determine_source_type(result.url if hasattr(result, 'url') else '') + + sources.append({ + 'title': result.title if hasattr(result, 'title') else '', + 'url': result.url if hasattr(result, 'url') else '', + 'excerpt': self._get_excerpt(result), + 'credibility_score': 0.85, # Exa results are high quality + 'published_at': result.publishedDate if hasattr(result, 'publishedDate') else None, + 'index': idx, + 'source_type': source_type, + 'content': result.text if hasattr(result, 'text') else '', + 'highlights': result.highlights if hasattr(result, 'highlights') else [], + 'summary': result.summary if hasattr(result, 'summary') else '' + }) + + return sources + + def _get_excerpt(self, result): + """Extract excerpt from Exa result.""" + if hasattr(result, 'text') and result.text: + return result.text[:500] + elif hasattr(result, 'summary') and result.summary: + return result.summary + return '' + + def _determine_source_type(self, url): + """Determine source type from URL.""" + if not url: + return 'web' + + url_lower = url.lower() + if 'arxiv.org' in url_lower or 'research' in url_lower: + return 'academic' + elif any(news in url_lower for news in ['cnn.com', 'bbc.com', 'reuters.com', 'theguardian.com']): + return 'news' + elif 'linkedin.com' in url_lower: + return 'expert' + else: + return 'web' + + def _aggregate_content(self, results): + """Aggregate content from Exa results for LLM analysis.""" + content_parts = [] + + for idx, result in enumerate(results): + if hasattr(result, 'summary') and result.summary: + content_parts.append(f"Source {idx + 1}: {result.summary}") + elif hasattr(result, 'text') and result.text: + content_parts.append(f"Source {idx + 1}: {result.text[:1000]}") + + return "\n\n".join(content_parts) + + def track_exa_usage(self, user_id: str, cost: float): + """Track Exa API usage after successful call.""" + from services.database import get_db + from services.subscription import PricingService + from sqlalchemy import text + + db = next(get_db()) + try: + pricing_service = PricingService(db) + current_period = pricing_service.get_current_billing_period(user_id) + + # Update exa_calls and exa_cost via SQL UPDATE + update_query = text(""" + UPDATE usage_summaries + SET exa_calls = COALESCE(exa_calls, 0) + 1, + exa_cost = COALESCE(exa_cost, 0) + :cost, + total_calls = total_calls + 1, + total_cost = total_cost + :cost + WHERE user_id = :user_id AND billing_period = :period + """) + db.execute(update_query, { + 'cost': cost, + 'user_id': user_id, + 'period': current_period + }) + db.commit() + + logger.info(f"[Exa] Tracked usage: user={user_id}, cost=${cost}") + except Exception as e: + logger.error(f"[Exa] Failed to track usage: {e}") + db.rollback() + finally: + db.close() + diff --git a/backend/services/blog_writer/research/google_provider.py b/backend/services/blog_writer/research/google_provider.py new file mode 100644 index 00000000..b0aa06d8 --- /dev/null +++ b/backend/services/blog_writer/research/google_provider.py @@ -0,0 +1,40 @@ +""" +Google Research Provider + +Wrapper for Gemini native Google Search grounding to match base provider interface. +""" + +from services.llm_providers.gemini_grounded_provider import GeminiGroundedProvider +from models.subscription_models import APIProvider +from .base_provider import ResearchProvider as BaseProvider +from loguru import logger + + +class GoogleResearchProvider(BaseProvider): + """Google research provider using Gemini native grounding.""" + + def __init__(self): + self.gemini = GeminiGroundedProvider() + + async def search(self, prompt, topic, industry, target_audience, config, user_id): + """Call Gemini grounding with pre-flight validation.""" + logger.info(f"[Google Research] Executing search for topic: {topic}") + + result = await self.gemini.generate_grounded_content( + prompt=prompt, + content_type="research", + max_tokens=2000, + user_id=user_id, + validate_subsequent_operations=True + ) + + return result + + def get_provider_enum(self): + """Return GEMINI provider enum for subscription tracking.""" + return APIProvider.GEMINI + + def estimate_tokens(self) -> int: + """Estimate token usage for Google grounding.""" + return 1200 # Conservative estimate + diff --git a/backend/services/blog_writer/research/research_service.py b/backend/services/blog_writer/research/research_service.py index f7c0e2d4..533fab13 100644 --- a/backend/services/blog_writer/research/research_service.py +++ b/backend/services/blog_writer/research/research_service.py @@ -16,6 +16,9 @@ from models.blog_models import ( GroundingChunk, GroundingSupport, Citation, + ResearchConfig, + ResearchMode, + ResearchProvider, ) from services.blog_writer.logger_config import blog_writer_logger, log_function_call from fastapi import HTTPException @@ -24,6 +27,7 @@ from .keyword_analyzer import KeywordAnalyzer from .competitor_analyzer import CompetitorAnalyzer from .content_angle_generator import ContentAngleGenerator from .data_filter import ResearchDataFilter +from .research_strategies import get_strategy_for_mode class ResearchService: @@ -44,7 +48,6 @@ class ResearchService: Includes intelligent caching for exact keyword matches. """ try: - from services.llm_providers.gemini_grounded_provider import GeminiGroundedProvider from services.cache.research_cache import research_cache topic = request.topic or ", ".join(request.keywords) @@ -79,62 +82,104 @@ class ResearchService: # Cache miss - proceed with API call logger.info(f"Cache miss - making API call for keywords: {request.keywords}") - blog_writer_logger.log_operation_start("gemini_api_call", api_name="gemini_grounded", operation="research") - gemini = GeminiGroundedProvider() + blog_writer_logger.log_operation_start("research_api_call", api_name="research", operation="research") - # Single comprehensive research prompt - Gemini handles Google Search automatically - research_prompt = f""" - Research the topic "{topic}" in the {industry} industry for {target_audience} audience. Provide a comprehensive analysis including: - - 1. Current trends and insights (2024-2025) - 2. Key statistics and data points with sources - 3. Industry expert opinions and quotes - 4. Recent developments and news - 5. Market analysis and forecasts - 6. Best practices and case studies - 7. Keyword analysis: primary, secondary, and long-tail opportunities - 8. Competitor analysis: top players and content gaps - 9. Content angle suggestions: 5 compelling angles for blog posts - - Focus on factual, up-to-date information from credible sources. - Include specific data points, percentages, and recent developments. - Structure your response with clear sections for each analysis area. - """ + # Determine research mode and get appropriate strategy + research_mode = request.research_mode or ResearchMode.BASIC + config = request.config or ResearchConfig(mode=research_mode, provider=ResearchProvider.GOOGLE) + strategy = get_strategy_for_mode(research_mode) - # Single Gemini call with native Google Search grounding - no fallbacks - # Validation is handled inside generate_grounded_content when validate_subsequent_operations=True - import time - api_start_time = time.time() - gemini_result = await gemini.generate_grounded_content( - prompt=research_prompt, - content_type="research", - max_tokens=2000, - user_id=user_id, - validate_subsequent_operations=True # Validates Google Grounding + 3 LLM calls - ) - api_duration_ms = (time.time() - api_start_time) * 1000 + logger.info(f"Research: mode={research_mode.value}, provider={config.provider.value}") - # Log API call performance - blog_writer_logger.log_api_call( - "gemini_grounded", - "generate_grounded_content", - api_duration_ms, - token_usage=gemini_result.get("token_usage", {}), - content_length=len(gemini_result.get("content", "")) - ) + # Build research prompt based on strategy + research_prompt = strategy.build_research_prompt(topic, industry, target_audience, config) - # Extract sources from grounding metadata - sources = self._extract_sources_from_grounding(gemini_result) + # Route to appropriate provider + if config.provider == ResearchProvider.EXA: + # Exa research workflow + from .exa_provider import ExaResearchProvider + from services.subscription.preflight_validator import validate_exa_research_operations + from services.database import get_db + from services.subscription import PricingService + import os + import time + + # Pre-flight validation + db_val = next(get_db()) + try: + pricing_service = PricingService(db_val) + gpt_provider = os.getenv("GPT_PROVIDER", "google") + validate_exa_research_operations(pricing_service, user_id, gpt_provider) + finally: + db_val.close() + + # Execute Exa search + api_start_time = time.time() + try: + exa_provider = ExaResearchProvider() + raw_result = await exa_provider.search( + research_prompt, topic, industry, target_audience, config, user_id + ) + api_duration_ms = (time.time() - api_start_time) * 1000 + + # Track usage + cost = raw_result.get('cost', {}).get('total', 0.005) if isinstance(raw_result.get('cost'), dict) else 0.005 + exa_provider.track_exa_usage(user_id, cost) + + # Log API call performance + blog_writer_logger.log_api_call( + "exa_search", + "search_and_contents", + api_duration_ms, + token_usage={}, + content_length=len(raw_result.get('content', '')) + ) + + # Extract content for downstream analysis + content = raw_result.get('content', '') + sources = raw_result.get('sources', []) + search_widget = "" # Exa doesn't provide search widgets + search_queries = raw_result.get('search_queries', []) + grounding_metadata = None # Exa doesn't provide grounding metadata + + except RuntimeError as e: + if "EXA_API_KEY not configured" in str(e): + logger.warning("Exa not configured, falling back to Google") + config.provider = ResearchProvider.GOOGLE + # Continue to Google flow below + raw_result = None + else: + raise + + if config.provider != ResearchProvider.EXA: + # Google research (existing flow) or fallback from Exa + from .google_provider import GoogleResearchProvider + import time + + api_start_time = time.time() + google_provider = GoogleResearchProvider() + gemini_result = await google_provider.search( + research_prompt, topic, industry, target_audience, config, user_id + ) + api_duration_ms = (time.time() - api_start_time) * 1000 + + # Log API call performance + blog_writer_logger.log_api_call( + "gemini_grounded", + "generate_grounded_content", + api_duration_ms, + token_usage=gemini_result.get("token_usage", {}), + content_length=len(gemini_result.get("content", "")) + ) + + # Extract sources and content + sources = self._extract_sources_from_grounding(gemini_result) + content = gemini_result.get("content", "") + search_widget = gemini_result.get("search_widget", "") or "" + search_queries = gemini_result.get("search_queries", []) or [] + grounding_metadata = self._extract_grounding_metadata(gemini_result) - # Extract grounding metadata for detailed UI display - grounding_metadata = self._extract_grounding_metadata(gemini_result) - - # Extract search widget and queries for UI display - search_widget = gemini_result.get("search_widget", "") or "" - search_queries = gemini_result.get("search_queries", []) or [] - - # Parse the comprehensive response for different analysis components - content = gemini_result.get("content", "") + # Continue with common analysis (same for both providers) keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords, user_id=user_id) competitor_analysis = self.competitor_analyzer.analyze(content, user_id=user_id) suggested_angles = self.content_angle_generator.generate(content, topic, industry, user_id=user_id) @@ -261,7 +306,6 @@ class ResearchService: Research method with progress updates for real-time feedback. """ try: - from services.llm_providers.gemini_grounded_provider import GeminiGroundedProvider from services.cache.research_cache import research_cache from services.cache.persistent_research_cache import persistent_research_cache from api.blog_writer.task_manager import task_manager @@ -293,66 +337,100 @@ class ResearchService: logger.info(f"Returning cached research result for keywords: {request.keywords}") return BlogResearchResponse(**cached_result) - # User ID validation (validation logic is now in Google Grounding provider) + # User ID validation if not user_id: await task_manager.update_progress(task_id, "ā Error: User ID is required for research operation") raise ValueError("user_id is required for research operation. Please provide Clerk user ID.") - # Cache miss - proceed with API call - await task_manager.update_progress(task_id, "š Cache miss - connecting to Google Search grounding...") - logger.info(f"Cache miss - making API call for keywords: {request.keywords}") - gemini = GeminiGroundedProvider() - - # Single comprehensive research prompt - Gemini handles Google Search automatically - research_prompt = f""" - Research the topic "{topic}" in the {industry} industry for {target_audience} audience. Provide a comprehensive analysis including: - - 1. Current trends and insights (2024-2025) - 2. Key statistics and data points with sources - 3. Industry expert opinions and quotes - 4. Recent developments and news - 5. Market analysis and forecasts - 6. Best practices and case studies - 7. Keyword analysis: primary, secondary, and long-tail opportunities - 8. Competitor analysis: top players and content gaps - 9. Content angle suggestions: 5 compelling angles for blog posts - - Focus on factual, up-to-date information from credible sources. - Include specific data points, percentages, and recent developments. - Structure your response with clear sections for each analysis area. - """ + # Determine research mode and get appropriate strategy + research_mode = request.research_mode or ResearchMode.BASIC + config = request.config or ResearchConfig(mode=research_mode, provider=ResearchProvider.GOOGLE) + strategy = get_strategy_for_mode(research_mode) - await task_manager.update_progress(task_id, "š¤ Making AI request to Gemini with Google Search grounding...") - # Single Gemini call with native Google Search grounding - no fallbacks - # Validation is handled inside generate_grounded_content when validate_subsequent_operations=True - try: - gemini_result = await gemini.generate_grounded_content( - prompt=research_prompt, - content_type="research", - max_tokens=2000, - user_id=user_id, - validate_subsequent_operations=True # Validates Google Grounding + 3 LLM calls - ) - except HTTPException as http_error: - # Re-raise HTTPException so it can be properly handled by task manager - logger.error(f"Subscription limit exceeded for research: {http_error.detail}") - await task_manager.update_progress(task_id, f"ā Subscription limit exceeded: {http_error.detail.get('message', str(http_error.detail)) if isinstance(http_error.detail, dict) else str(http_error.detail)}") - raise # Re-raise HTTPException to preserve status code and error details + logger.info(f"Research: mode={research_mode.value}, provider={config.provider.value}") - await task_manager.update_progress(task_id, "š Processing research results and extracting insights...") - # Extract sources from grounding metadata - sources = self._extract_sources_from_grounding(gemini_result) + # Build research prompt based on strategy + research_prompt = strategy.build_research_prompt(topic, industry, target_audience, config) - # Extract grounding metadata for detailed UI display - grounding_metadata = self._extract_grounding_metadata(gemini_result) - - # Extract search widget and queries for UI display - search_widget = gemini_result.get("search_widget", "") or "" - search_queries = gemini_result.get("search_queries", []) or [] + # Route to appropriate provider + if config.provider == ResearchProvider.EXA: + # Exa research workflow + from .exa_provider import ExaResearchProvider + from services.subscription.preflight_validator import validate_exa_research_operations + from services.database import get_db + from services.subscription import PricingService + import os + + await task_manager.update_progress(task_id, "š Connecting to Exa neural search...") + + # Pre-flight validation + db_val = next(get_db()) + try: + pricing_service = PricingService(db_val) + gpt_provider = os.getenv("GPT_PROVIDER", "google") + validate_exa_research_operations(pricing_service, user_id, gpt_provider) + except HTTPException as http_error: + logger.error(f"Subscription limit exceeded for Exa research: {http_error.detail}") + await task_manager.update_progress(task_id, f"ā Subscription limit exceeded: {http_error.detail.get('message', str(http_error.detail)) if isinstance(http_error.detail, dict) else str(http_error.detail)}") + raise + finally: + db_val.close() + + # Execute Exa search + await task_manager.update_progress(task_id, "š¤ Executing Exa neural search...") + try: + exa_provider = ExaResearchProvider() + raw_result = await exa_provider.search( + research_prompt, topic, industry, target_audience, config, user_id + ) + + # Track usage + cost = raw_result.get('cost', {}).get('total', 0.005) if isinstance(raw_result.get('cost'), dict) else 0.005 + exa_provider.track_exa_usage(user_id, cost) + + # Extract content for downstream analysis + content = raw_result.get('content', '') + sources = raw_result.get('sources', []) + search_widget = "" # Exa doesn't provide search widgets + search_queries = raw_result.get('search_queries', []) + grounding_metadata = None # Exa doesn't provide grounding metadata + + except RuntimeError as e: + if "EXA_API_KEY not configured" in str(e): + logger.warning("Exa not configured, falling back to Google") + await task_manager.update_progress(task_id, "ā ļø Exa not configured, falling back to Google Search") + config.provider = ResearchProvider.GOOGLE + # Continue to Google flow below + else: + raise + + if config.provider != ResearchProvider.EXA: + # Google research (existing flow) + from .google_provider import GoogleResearchProvider + + await task_manager.update_progress(task_id, "š Connecting to Google Search grounding...") + google_provider = GoogleResearchProvider() + + await task_manager.update_progress(task_id, "š¤ Making AI request to Gemini with Google Search grounding...") + try: + gemini_result = await google_provider.search( + research_prompt, topic, industry, target_audience, config, user_id + ) + except HTTPException as http_error: + logger.error(f"Subscription limit exceeded for Google research: {http_error.detail}") + await task_manager.update_progress(task_id, f"ā Subscription limit exceeded: {http_error.detail.get('message', str(http_error.detail)) if isinstance(http_error.detail, dict) else str(http_error.detail)}") + raise + + await task_manager.update_progress(task_id, "š Processing research results and extracting insights...") + # Extract sources and content + sources = self._extract_sources_from_grounding(gemini_result) + content = gemini_result.get("content", "") + search_widget = gemini_result.get("search_widget", "") or "" + search_queries = gemini_result.get("search_queries", []) or [] + grounding_metadata = self._extract_grounding_metadata(gemini_result) + # Continue with common analysis (same for both providers) await task_manager.update_progress(task_id, "š Analyzing keywords and content angles...") - # Parse the comprehensive response for different analysis components - content = gemini_result.get("content", "") keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords, user_id=user_id) competitor_analysis = self.competitor_analyzer.analyze(content, user_id=user_id) suggested_angles = self.content_angle_generator.generate(content, topic, industry, user_id=user_id) diff --git a/backend/services/blog_writer/research/research_strategies.py b/backend/services/blog_writer/research/research_strategies.py new file mode 100644 index 00000000..5de20e34 --- /dev/null +++ b/backend/services/blog_writer/research/research_strategies.py @@ -0,0 +1,234 @@ +""" +Research Strategy Pattern Implementation + +Different strategies for executing research based on depth and focus. +""" + +from abc import ABC, abstractmethod +from typing import Dict, Any +from loguru import logger + +from models.blog_models import BlogResearchRequest, ResearchMode, ResearchConfig +from .keyword_analyzer import KeywordAnalyzer +from .competitor_analyzer import CompetitorAnalyzer +from .content_angle_generator import ContentAngleGenerator + + +class ResearchStrategy(ABC): + """Base class for research strategies.""" + + def __init__(self): + self.keyword_analyzer = KeywordAnalyzer() + self.competitor_analyzer = CompetitorAnalyzer() + self.content_angle_generator = ContentAngleGenerator() + + @abstractmethod + def build_research_prompt( + self, + topic: str, + industry: str, + target_audience: str, + config: ResearchConfig + ) -> str: + """Build the research prompt for the strategy.""" + pass + + @abstractmethod + def get_mode(self) -> ResearchMode: + """Return the research mode this strategy handles.""" + pass + + +class BasicResearchStrategy(ResearchStrategy): + """Basic research strategy - keyword focused, minimal analysis.""" + + def get_mode(self) -> ResearchMode: + return ResearchMode.BASIC + + def build_research_prompt( + self, + topic: str, + industry: str, + target_audience: str, + config: ResearchConfig + ) -> str: + """Build basic research prompt focused on keywords and quick insights.""" + prompt = f"""You are a professional blog content strategist researching for a {industry} blog targeting {target_audience}. + +Research Topic: "{topic}" + +Provide analysis in this EXACT format: + +## CURRENT TRENDS (2024-2025) +- [Trend 1 with specific data and source URL] +- [Trend 2 with specific data and source URL] +- [Trend 3 with specific data and source URL] + +## KEY STATISTICS +- [Statistic 1: specific number/percentage with source URL] +- [Statistic 2: specific number/percentage with source URL] +- [Statistic 3: specific number/percentage with source URL] +- [Statistic 4: specific number/percentage with source URL] +- [Statistic 5: specific number/percentage with source URL] + +## PRIMARY KEYWORDS +1. "{topic}" (main keyword) +2. [Variation 1] +3. [Variation 2] + +## SECONDARY KEYWORDS +[5 related keywords for blog content] + +## CONTENT ANGLES (Top 5) +1. [Angle 1: specific unique approach] +2. [Angle 2: specific unique approach] +3. [Angle 3: specific unique approach] +4. [Angle 4: specific unique approach] +5. [Angle 5: specific unique approach] + +REQUIREMENTS: +- Cite EVERY claim with authoritative source URLs +- Use 2024-2025 data when available +- Include specific numbers, dates, examples +- Focus on actionable blog insights for {target_audience}""" + return prompt.strip() + + +class ComprehensiveResearchStrategy(ResearchStrategy): + """Comprehensive research strategy - full analysis with all components.""" + + def get_mode(self) -> ResearchMode: + return ResearchMode.COMPREHENSIVE + + def build_research_prompt( + self, + topic: str, + industry: str, + target_audience: str, + config: ResearchConfig + ) -> str: + """Build comprehensive research prompt with all analysis components.""" + date_filter = f"\nDate Focus: {config.date_range.value.replace('_', ' ')}" if config.date_range else "" + source_filter = f"\nPriority Sources: {', '.join([s.value for s in config.source_types])}" if config.source_types else "" + + prompt = f"""You are a senior blog content strategist conducting comprehensive research for a {industry} blog targeting {target_audience}. + +Research Topic: "{topic}"{date_filter}{source_filter} + +Provide COMPLETE analysis in this EXACT format: + +## TRENDS AND INSIGHTS (2024-2025) +[5-7 trends with specific data, numbers, and source URLs] + +## KEY STATISTICS +[7-10 statistics with exact numbers, percentages, dates, and source URLs] + +## EXPERT OPINIONS +[4-5 expert quotes with full attribution and source URLs] + +## RECENT DEVELOPMENTS +[5-7 recent news/developments with dates and source URLs] + +## MARKET ANALYSIS +[3-5 market insights with data points and source URLs] + +## BEST PRACTICES & CASE STUDIES +[3-5 examples with specific outcomes/metrics and source URLs] + +## KEYWORD ANALYSIS +Primary Keywords: [3 main variations] +Secondary Keywords: [7-10 related keywords] +Long-Tail Opportunities: [5-7 specific search phrases] + +## COMPETITOR ANALYSIS +Top Competitors: [5 competitors with brief descriptions] +Content Gaps: [5 topics competitors are missing] +Competitive Advantages: [5 unique angles we can own] + +## CONTENT ANGLES (Exactly 5) +1. [Unique angle with reasoning and target benefit] +2. [Unique angle with reasoning and target benefit] +3. [Unique angle with reasoning and target benefit] +4. [Unique angle with reasoning and target benefit] +5. [Unique angle with reasoning and target benefit] + +VERIFICATION REQUIREMENTS: +- Minimum 2 authoritative sources per major claim +- Prioritize: Industry publications > Research papers > News > Blogs +- 2024-2025 data strongly preferred +- All numbers must include context (timeframe, sample size, methodology) +- Every recommendation must be actionable for {target_audience}""" + return prompt.strip() + + +class TargetedResearchStrategy(ResearchStrategy): + """Targeted research strategy - focused on specific aspects.""" + + def get_mode(self) -> ResearchMode: + return ResearchMode.TARGETED + + def build_research_prompt( + self, + topic: str, + industry: str, + target_audience: str, + config: ResearchConfig + ) -> str: + """Build targeted research prompt based on config preferences.""" + sections = [] + + if config.include_trends: + sections.append("""## CURRENT TRENDS +[3-5 trends with data and source URLs]""") + + if config.include_statistics: + sections.append("""## KEY STATISTICS +[5-7 statistics with numbers and source URLs]""") + + if config.include_expert_quotes: + sections.append("""## EXPERT OPINIONS +[3-4 expert quotes with attribution and source URLs]""") + + if config.include_competitors: + sections.append("""## COMPETITOR ANALYSIS +Top Competitors: [3-5] +Content Gaps: [3-5]""") + + # Always include keywords and angles + sections.append("""## KEYWORD ANALYSIS +Primary: [2-3 variations] +Secondary: [5-7 keywords] +Long-Tail: [3-5 phrases]""") + + sections.append("""## CONTENT ANGLES (3-5) +[Unique blog angles with reasoning]""") + + sections_str = "\n\n".join(sections) + + prompt = f"""You are a blog content strategist conducting targeted research for a {industry} blog targeting {target_audience}. + +Research Topic: "{topic}" + +Provide focused analysis in this EXACT format: + +{sections_str} + +REQUIREMENTS: +- Cite all claims with authoritative source URLs +- Include specific numbers, dates, examples +- Focus on actionable insights for {target_audience} +- Use 2024-2025 data when available""" + return prompt.strip() + + +def get_strategy_for_mode(mode: ResearchMode) -> ResearchStrategy: + """Factory function to get the appropriate strategy for a mode.""" + strategy_map = { + ResearchMode.BASIC: BasicResearchStrategy, + ResearchMode.COMPREHENSIVE: ComprehensiveResearchStrategy, + ResearchMode.TARGETED: TargetedResearchStrategy, + } + + strategy_class = strategy_map.get(mode, BasicResearchStrategy) + return strategy_class() + diff --git a/backend/services/integrations/wix/__init__.py b/backend/services/integrations/wix/__init__.py index 132171a6..67c5a433 100644 --- a/backend/services/integrations/wix/__init__.py +++ b/backend/services/integrations/wix/__init__.py @@ -2,4 +2,14 @@ Wix integration modular services package. """ +from services.integrations.wix.seo import build_seo_data +from services.integrations.wix.ricos_converter import markdown_to_html, convert_via_wix_api +from services.integrations.wix.blog_publisher import create_blog_post + +__all__ = [ + 'build_seo_data', + 'markdown_to_html', + 'convert_via_wix_api', + 'create_blog_post', +] diff --git a/backend/services/integrations/wix/blog.py b/backend/services/integrations/wix/blog.py index 305158ce..6476e05a 100644 --- a/backend/services/integrations/wix/blog.py +++ b/backend/services/integrations/wix/blog.py @@ -20,6 +20,40 @@ class WixBlogService: return h def create_draft_post(self, access_token: str, payload: Dict[str, Any], extra_headers: Optional[Dict[str, str]] = None) -> Dict[str, Any]: + # Log the exact payload being sent for debugging + import json + logger.warning(f"š¤ Sending to Wix Blog API:") + logger.warning(f" Endpoint: {self.base_url}/blog/v3/draft-posts") + logger.warning(f" Payload top-level keys: {list(payload.keys())}") + if 'draftPost' in payload: + dp = payload['draftPost'] + logger.warning(f" draftPost keys: {list(dp.keys())}") + if 'richContent' in dp: + rc = dp['richContent'] + logger.warning(f" richContent keys: {list(rc.keys()) if isinstance(rc, dict) else 'N/A'}") + if isinstance(rc, dict) and 'nodes' in rc: + nodes = rc['nodes'] + logger.warning(f" richContent.nodes count: {len(nodes) if isinstance(nodes, list) else 'N/A'}") + # Inspect first LIST_ITEM node if any + for i, node in enumerate(nodes[:10]): + if isinstance(node, dict) and node.get('type') == 'LIST_ITEM': + logger.warning(f" Found LIST_ITEM at index {i}:") + logger.warning(f" Keys: {list(node.keys())}") + logger.warning(f" Has listItemData: {'listItemData' in node}") + if 'listItemData' in node: + logger.warning(f" listItemData type: {type(node['listItemData'])}, value: {node['listItemData']}") + if 'nodes' in node: + nested = node['nodes'] + logger.warning(f" Nested nodes count: {len(nested) if isinstance(nested, list) else 'N/A'}") + for j, n_node in enumerate(nested[:3]): + if isinstance(n_node, dict): + logger.warning(f" Nested node {j}: type={n_node.get('type')}, keys={list(n_node.keys())}") + if n_node.get('type') == 'PARAGRAPH' and 'paragraphData' in n_node: + logger.warning(f" paragraphData type: {type(n_node['paragraphData'])}, value: {n_node['paragraphData']}") + break # Only inspect first LIST_ITEM + + logger.warning(f" Full Payload JSON (first 8000 chars):\n{json.dumps(payload, indent=2, ensure_ascii=False)[:8000]}...") + response = requests.post(f"{self.base_url}/blog/v3/draft-posts", headers=self.headers(access_token, extra_headers), json=payload) response.raise_for_status() return response.json() diff --git a/backend/services/integrations/wix/blog_publisher.py b/backend/services/integrations/wix/blog_publisher.py new file mode 100644 index 00000000..39c2d916 --- /dev/null +++ b/backend/services/integrations/wix/blog_publisher.py @@ -0,0 +1,716 @@ +""" +Blog Post Publisher for Wix + +Handles blog post creation, validation, and publishing to Wix. +""" + +import json +import uuid +import requests +import jwt +from typing import Dict, Any, Optional, List +from loguru import logger +from services.integrations.wix.blog import WixBlogService +from services.integrations.wix.content import convert_content_to_ricos +from services.integrations.wix.ricos_converter import convert_via_wix_api +from services.integrations.wix.seo import build_seo_data + + +def validate_ricos_content(ricos_content: Dict[str, Any]) -> Dict[str, Any]: + """ + Validate and normalize Ricos document structure. + + Args: + ricos_content: Ricos document dict + + Returns: + Validated and normalized Ricos document + """ + # Validate Ricos document structure before using + if not ricos_content or not isinstance(ricos_content, dict): + logger.error("Invalid Ricos content - not a dict") + raise ValueError("Failed to convert content to valid Ricos format") + + if 'type' not in ricos_content: + ricos_content['type'] = 'DOCUMENT' + logger.debug("Added missing richContent type 'DOCUMENT'") + if ricos_content.get('type') != 'DOCUMENT': + logger.warning(f"richContent type expected 'DOCUMENT', got {ricos_content.get('type')}, correcting") + ricos_content['type'] = 'DOCUMENT' + + if 'id' not in ricos_content or not isinstance(ricos_content.get('id'), str): + ricos_content['id'] = str(uuid.uuid4()) + logger.debug("Added missing richContent id") + + if 'nodes' not in ricos_content: + logger.warning("Ricos document missing 'nodes' field, adding empty nodes array") + ricos_content['nodes'] = [] + + logger.debug(f"Ricos document structure: nodes={len(ricos_content.get('nodes', []))}") + + # Validate richContent is a proper object with nodes array + # Per Wix API: richContent must be a RichContent object with nodes array + if not isinstance(ricos_content, dict): + raise ValueError(f"richContent must be a dict object, got {type(ricos_content)}") + + # Ensure nodes array exists and is valid + if 'nodes' not in ricos_content: + logger.warning("richContent missing 'nodes', adding empty array") + ricos_content['nodes'] = [] + + if not isinstance(ricos_content['nodes'], list): + raise ValueError(f"richContent.nodes must be a list, got {type(ricos_content['nodes'])}") + + # Recursive function to validate and fix nodes at any depth + def validate_node_recursive(node: Dict[str, Any], path: str = "root") -> None: + """ + Recursively validate a node and all its nested children, ensuring: + 1. All required data fields exist for each node type + 2. All 'nodes' arrays are proper lists + 3. No None values in critical fields + """ + if not isinstance(node, dict): + logger.error(f"{path}: Node is not a dict: {type(node)}") + return + + # Ensure type and id exist + if 'type' not in node: + logger.error(f"{path}: Missing 'type' field - REQUIRED") + node['type'] = 'PARAGRAPH' # Default fallback + if 'id' not in node: + node['id'] = str(uuid.uuid4()) + logger.debug(f"{path}: Added missing 'id'") + + node_type = node.get('type') + + # CRITICAL: Per Wix API schema, data fields like paragraphData, bulletedListData, etc. + # are OPTIONAL and should be OMITTED entirely when empty, not included as {} + # Only validate fields that have required properties + + # Special handling: Remove listItemData if it exists (not in Wix API schema) + if node_type == 'LIST_ITEM' and 'listItemData' in node: + logger.debug(f"{path}: Removing incorrect listItemData field from LIST_ITEM") + del node['listItemData'] + + # Only validate HEADING nodes - they require headingData with level property + if node_type == 'HEADING': + if 'headingData' not in node or not isinstance(node.get('headingData'), dict): + logger.warning(f"{path} (HEADING): Missing headingData, adding default level 1") + node['headingData'] = {'level': 1} + elif 'level' not in node['headingData']: + logger.warning(f"{path} (HEADING): Missing level in headingData, adding default") + node['headingData']['level'] = 1 + + # TEXT nodes must have textData + if node_type == 'TEXT': + if 'textData' not in node or not isinstance(node.get('textData'), dict): + logger.error(f"{path} (TEXT): Missing/invalid textData - node will be problematic") + node['textData'] = {'text': '', 'decorations': []} + + # LINK and IMAGE nodes must have their data fields + if node_type == 'LINK' and ('linkData' not in node or not isinstance(node.get('linkData'), dict)): + logger.error(f"{path} (LINK): Missing/invalid linkData - node will be problematic") + if node_type == 'IMAGE' and ('imageData' not in node or not isinstance(node.get('imageData'), dict)): + logger.error(f"{path} (IMAGE): Missing/invalid imageData - node will be problematic") + + # Remove None values from any data fields that exist (Wix API rejects None) + for data_field in ['headingData', 'paragraphData', 'blockquoteData', 'bulletedListData', + 'orderedListData', 'textData', 'linkData', 'imageData']: + if data_field in node and isinstance(node[data_field], dict): + data_value = node[data_field] + keys_to_remove = [k for k, v in data_value.items() if v is None] + if keys_to_remove: + logger.debug(f"{path} ({node_type}): Removing None values from {data_field}: {keys_to_remove}") + for key in keys_to_remove: + del data_value[key] + + # Ensure 'nodes' field exists for container nodes + container_types = ['HEADING', 'PARAGRAPH', 'BLOCKQUOTE', 'LIST_ITEM', 'LINK', + 'BULLETED_LIST', 'ORDERED_LIST'] + if node_type in container_types: + if 'nodes' not in node: + logger.warning(f"{path} ({node_type}): Missing 'nodes' field, adding empty array") + node['nodes'] = [] + elif not isinstance(node['nodes'], list): + logger.error(f"{path} ({node_type}): Invalid 'nodes' field (not a list), fixing") + node['nodes'] = [] + + # Recursively validate all nested nodes + for nested_idx, nested_node in enumerate(node['nodes']): + nested_path = f"{path}.nodes[{nested_idx}]" + validate_node_recursive(nested_node, nested_path) + + # Validate all top-level nodes recursively + for idx, node in enumerate(ricos_content['nodes']): + validate_node_recursive(node, f"nodes[{idx}]") + + # Ensure documentStyle exists and is a dict (required by Wix API when provided) + if 'metadata' not in ricos_content or not isinstance(ricos_content.get('metadata'), dict): + ricos_content['metadata'] = {'version': 1, 'id': str(uuid.uuid4())} + logger.debug("Added default metadata to richContent") + else: + ricos_content['metadata'].setdefault('version', 1) + ricos_content['metadata'].setdefault('id', str(uuid.uuid4())) + + if 'documentStyle' not in ricos_content or not isinstance(ricos_content.get('documentStyle'), dict): + ricos_content['documentStyle'] = { + 'paragraph': { + 'decorations': [], + 'nodeStyle': {}, + 'lineHeight': '1.5' + } + } + logger.debug("Added default documentStyle to richContent") + + logger.debug(f"ā Validated richContent: {len(ricos_content['nodes'])} nodes, has_metadata={bool(ricos_content.get('metadata'))}, has_documentStyle={bool(ricos_content.get('documentStyle'))}") + + return ricos_content + + +def validate_payload_no_none(obj, path=""): + """Recursively validate that no None values exist in the payload""" + if obj is None: + raise ValueError(f"Found None value at path: {path}") + if isinstance(obj, dict): + for key, value in obj.items(): + validate_payload_no_none(value, f"{path}.{key}" if path else key) + elif isinstance(obj, list): + for idx, item in enumerate(obj): + validate_payload_no_none(item, f"{path}[{idx}]" if path else f"[{idx}]") + + +def create_blog_post( + blog_service: WixBlogService, + access_token: str, + title: str, + content: str, + member_id: str, + cover_image_url: str = None, + category_ids: List[str] = None, + tag_ids: List[str] = None, + publish: bool = True, + seo_metadata: Dict[str, Any] = None, + import_image_func = None, + lookup_categories_func = None, + lookup_tags_func = None, + base_url: str = 'https://www.wixapis.com' +) -> Dict[str, Any]: + """ + Create and optionally publish a blog post on Wix + + Args: + blog_service: WixBlogService instance + access_token: Valid access token + title: Blog post title + content: Blog post content (markdown) + member_id: Required for third-party apps - the member ID of the post author + cover_image_url: Optional cover image URL + category_ids: Optional list of category IDs or names + tag_ids: Optional list of tag IDs or names + publish: Whether to publish immediately or save as draft + seo_metadata: Optional SEO metadata dict + import_image_func: Function to import images (optional) + lookup_categories_func: Function to lookup/create categories (optional) + lookup_tags_func: Function to lookup/create tags (optional) + base_url: Wix API base URL + + Returns: + Created blog post information + """ + if not member_id: + raise ValueError("memberId is required for third-party apps creating blog posts") + + headers = { + 'Authorization': f'Bearer {access_token}', + 'Content-Type': 'application/json' + } + + # Build valid Ricos rich content + # Ensure content is not empty + if not content or not content.strip(): + content = "This is a post from ALwrity." + logger.warning("ā ļø Content was empty, using default text") + + # Try Wix API first (more reliable), fall back to custom parser + ricos_content = None + try: + logger.warning("š Attempting to convert markdown to Ricos via Wix API...") + ricos_content = convert_via_wix_api(content, access_token, base_url) + logger.warning(f"ā Wix API conversion successful. Ricos document has {len(ricos_content.get('nodes', []))} nodes") + except Exception as e: + logger.warning(f"ā ļø Wix Ricos API conversion failed: {e}. Falling back to custom parser...") + # Fall back to custom parser + ricos_content = convert_content_to_ricos(content, None) + logger.warning(f"ā Custom parser conversion complete. Ricos document has {len(ricos_content.get('nodes', []))} nodes") + + # Validate Ricos content + ricos_content = validate_ricos_content(ricos_content) + + # Minimal payload per Wix docs: title, memberId, and richContent + # CRITICAL: Only include fields that have valid values (no None, no empty strings for required fields) + blog_data = { + 'draftPost': { + 'title': str(title).strip() if title else "Untitled", + 'memberId': str(member_id).strip(), # Required for third-party apps (validated above) + 'richContent': ricos_content, # Must be a valid Ricos document object + }, + 'publish': bool(publish), + 'fieldsets': ['URL'] # Simplified fieldsets + } + + # Add excerpt only if content exists and is not empty (avoid None or empty strings) + excerpt = (content or '').strip()[:200] if content else None + if excerpt and len(excerpt) > 0: + blog_data['draftPost']['excerpt'] = str(excerpt) + + # Add cover image if provided + if cover_image_url and import_image_func: + try: + media_id = import_image_func(access_token, cover_image_url, f'Cover: {title}') + # Ensure media_id is a string and not None + if media_id and isinstance(media_id, str): + blog_data['draftPost']['media'] = { + 'wixMedia': { + 'image': {'id': str(media_id).strip()} + }, + 'displayed': True, + 'custom': True + } + else: + logger.warning(f"Invalid media_id type or value: {type(media_id)}, skipping media") + except Exception as e: + logger.warning(f"Failed to import cover image: {e}") + + # Handle categories - can be either IDs (list of strings) or names (for lookup) + category_ids_to_use = None + if category_ids: + # Check if these are IDs (UUIDs) or names + if isinstance(category_ids, list) and len(category_ids) > 0: + # Assume IDs if first item looks like UUID (has hyphens and is long) + first_item = str(category_ids[0]) + if '-' in first_item and len(first_item) > 30: + category_ids_to_use = category_ids + elif lookup_categories_func: + # These are names, need to lookup/create + extra_headers = {} + if 'wix-site-id' in headers: + extra_headers['wix-site-id'] = headers['wix-site-id'] + category_ids_to_use = lookup_categories_func( + access_token, category_ids, extra_headers if extra_headers else None + ) + + # Handle tags - can be either IDs (list of strings) or names (for lookup) + tag_ids_to_use = None + if tag_ids: + # Check if these are IDs (UUIDs) or names + if isinstance(tag_ids, list) and len(tag_ids) > 0: + # Assume IDs if first item looks like UUID (has hyphens and is long) + first_item = str(tag_ids[0]) + if '-' in first_item and len(first_item) > 30: + tag_ids_to_use = tag_ids + elif lookup_tags_func: + # These are names, need to lookup/create + extra_headers = {} + if 'wix-site-id' in headers: + extra_headers['wix-site-id'] = headers['wix-site-id'] + tag_ids_to_use = lookup_tags_func( + access_token, tag_ids, extra_headers if extra_headers else None + ) + + # Add categories if we have IDs (must be non-empty list of strings) + # CRITICAL: Wix API rejects empty arrays or arrays with None/empty strings + if category_ids_to_use and isinstance(category_ids_to_use, list) and len(category_ids_to_use) > 0: + # Filter out None, empty strings, and ensure all are valid UUID strings + valid_category_ids = [str(cid).strip() for cid in category_ids_to_use if cid and str(cid).strip()] + if valid_category_ids: + blog_data['draftPost']['categoryIds'] = valid_category_ids + logger.debug(f"Added {len(valid_category_ids)} category IDs") + else: + logger.warning("All category IDs were invalid, not including categoryIds in payload") + + # Add tags if we have IDs (must be non-empty list of strings) + # CRITICAL: Wix API rejects empty arrays or arrays with None/empty strings + if tag_ids_to_use and isinstance(tag_ids_to_use, list) and len(tag_ids_to_use) > 0: + # Filter out None, empty strings, and ensure all are valid UUID strings + valid_tag_ids = [str(tid).strip() for tid in tag_ids_to_use if tid and str(tid).strip()] + if valid_tag_ids: + blog_data['draftPost']['tagIds'] = valid_tag_ids + logger.debug(f"Added {len(valid_tag_ids)} tag IDs") + else: + logger.warning("All tag IDs were invalid, not including tagIds in payload") + + # Build SEO data from metadata if provided + # TESTING: Skip SEO data temporarily to confirm richContent fix + test_skip_seo = True + if test_skip_seo: + logger.warning("š§Ŗ TESTING: Skipping SEO data to isolate richContent vs seoData issue") + seo_data = None + elif seo_metadata: + logger.warning(f"š Building SEO data from metadata. Keys: {list(seo_metadata.keys())}") + seo_data = build_seo_data(seo_metadata, title) + if seo_data: + # Log detailed SEO structure + logger.warning(f"š SEO data built: {len(seo_data.get('tags', []))} tags, {len(seo_data.get('settings', {}).get('keywords', []))} keywords") + + # Log each SEO tag for debugging (key ones only to avoid too much output) + if seo_data.get('tags'): + for idx, tag in enumerate(seo_data['tags'][:3]): # First 3 tags only + tag_type = tag.get('type') + if tag_type == 'title': + logger.warning(f" SEO tag {idx+1}: type={tag_type}, children={str(tag.get('children', ''))[:50]}...") + else: + props = tag.get('props', {}) + content_preview = str(props.get('content', props.get('href', props.get('name', ''))))[:50] + logger.warning(f" SEO tag {idx+1}: type={tag_type}, props={list(props.keys())}, content={content_preview}...") + if len(seo_data['tags']) > 3: + logger.warning(f" ... and {len(seo_data['tags']) - 3} more SEO tags") + + blog_data['draftPost']['seoData'] = seo_data + logger.warning(f"ā Added seoData to blog post with {len(seo_data.get('tags', []))} tags") + else: + logger.warning("ā ļø SEO data was empty after building - check build_seo_data function") + + # Add SEO slug if provided (separate field from seoData) + if seo_metadata and seo_metadata.get('url_slug'): + blog_data['draftPost']['seoSlug'] = str(seo_metadata.get('url_slug')).strip() + logger.warning(f"ā Added SEO slug: {blog_data['draftPost']['seoSlug']}") + + if test_skip_seo: + logger.warning("ā ļø SEO data skipped for testing - will add back once richContent is confirmed working") + elif not seo_metadata: + logger.warning("ā ļø No SEO metadata provided to create_blog_post") + + # Log the payload structure for debugging (without sensitive data) + logger.warning(f"š Creating blog post with title: '{title}'") + logger.warning(f"š Draft post fields: {list(blog_data['draftPost'].keys())}") + + # Detailed SEO logging + if 'seoData' in blog_data['draftPost']: + seo_data_debug = blog_data['draftPost']['seoData'] + logger.warning(f"š SEO data in payload: {len(seo_data_debug.get('tags', []))} tags, {len(seo_data_debug.get('settings', {}).get('keywords', []))} keywords") + + # Log sample SEO tags (first 2 only to avoid too much output) + if seo_data_debug.get('tags'): + logger.warning("š SEO Tags sample:") + for i, tag in enumerate(seo_data_debug['tags'][:2]): # First 2 tags + logger.warning(f" Tag {i+1}: type={tag.get('type')}, custom={tag.get('custom')}, disabled={tag.get('disabled')}") + if len(seo_data_debug['tags']) > 2: + logger.warning(f" ... and {len(seo_data_debug['tags']) - 2} more tags") + + if seo_data_debug.get('settings', {}).get('keywords'): + keywords_list = [k.get('term') for k in seo_data_debug['settings']['keywords'][:3]] + logger.warning(f"š Keywords: {keywords_list}") + + # Log FULL seoData structure for debugging + import json + try: + seo_json = json.dumps(seo_data_debug, indent=2, ensure_ascii=False) + logger.warning(f"š FULL seoData JSON:\n{seo_json[:2000]}...") # First 2000 chars + except Exception as e: + logger.error(f"Failed to serialize seoData: {e}") + else: + logger.warning("ā ļø No seoData in draft post payload!") + + try: + # Add wix-site-id header if we can extract it from token + extra_headers = {} + try: + token_str = str(access_token) + if token_str and token_str.startswith('OauthNG.JWS.'): + jwt_part = token_str[12:] + payload = jwt.decode(jwt_part, options={"verify_signature": False, "verify_aud": False}) + data_payload = payload.get('data', {}) + if isinstance(data_payload, str): + try: + data_payload = json.loads(data_payload) + except: + pass + instance_data = data_payload.get('instance', {}) + meta_site_id = instance_data.get('metaSiteId') + if isinstance(meta_site_id, str) and meta_site_id: + extra_headers['wix-site-id'] = meta_site_id + headers['wix-site-id'] = meta_site_id + except Exception as e: + logger.debug(f"Could not extract site ID from token: {e}") + + # Make the API call + logger.warning(f"š Calling Wix API: POST /blog/v3/draft-posts") + logger.warning(f"š¦ Payload: title='{blog_data['draftPost'].get('title')}', has_seoData={'seoData' in blog_data['draftPost']}, has_richContent={'richContent' in blog_data['draftPost']}") + + # Validate payload structure before sending + draft_post = blog_data.get('draftPost', {}) + if not isinstance(draft_post, dict): + raise ValueError("draftPost must be a dict object") + + # Validate richContent structure + if 'richContent' in draft_post: + rc = draft_post['richContent'] + if not isinstance(rc, dict): + raise ValueError(f"richContent must be a dict, got {type(rc)}") + if 'nodes' not in rc: + raise ValueError("richContent missing 'nodes' field") + if not isinstance(rc['nodes'], list): + raise ValueError(f"richContent.nodes must be a list, got {type(rc['nodes'])}") + logger.debug(f"ā richContent validation passed: {len(rc.get('nodes', []))} nodes") + + # Validate seoData structure if present + if 'seoData' in draft_post: + seo = draft_post['seoData'] + if not isinstance(seo, dict): + raise ValueError(f"seoData must be a dict, got {type(seo)}") + if 'tags' in seo and not isinstance(seo['tags'], list): + raise ValueError(f"seoData.tags must be a list, got {type(seo.get('tags'))}") + if 'settings' in seo and not isinstance(seo['settings'], dict): + raise ValueError(f"seoData.settings must be a dict, got {type(seo.get('settings'))}") + logger.debug(f"ā seoData validation passed: {len(seo.get('tags', []))} tags") + + # Final validation: Ensure no None values in any nested objects + # Wix API rejects None values and expects proper types + try: + validate_payload_no_none(blog_data, "blog_data") + logger.debug("ā Payload validation passed: No None values found") + except ValueError as e: + logger.error(f"ā Payload validation failed: {e}") + raise + + # Log full payload structure for debugging (sanitized) + logger.warning(f"š¦ Full payload structure validation:") + logger.warning(f" - draftPost type: {type(draft_post)}") + logger.warning(f" - draftPost keys: {list(draft_post.keys())}") + logger.warning(f" - richContent type: {type(draft_post.get('richContent'))}") + if 'richContent' in draft_post: + rc = draft_post['richContent'] + logger.warning(f" - richContent keys: {list(rc.keys()) if isinstance(rc, dict) else 'N/A'}") + logger.warning(f" - richContent.nodes type: {type(rc.get('nodes'))}, count: {len(rc.get('nodes', []))}") + logger.warning(f" - richContent.metadata type: {type(rc.get('metadata'))}") + logger.warning(f" - richContent.documentStyle type: {type(rc.get('documentStyle'))}") + logger.warning(f" - seoData type: {type(draft_post.get('seoData'))}") + if 'seoData' in draft_post: + seo = draft_post['seoData'] + logger.warning(f" - seoData keys: {list(seo.keys()) if isinstance(seo, dict) else 'N/A'}") + logger.warning(f" - seoData.tags type: {type(seo.get('tags'))}, count: {len(seo.get('tags', []))}") + logger.warning(f" - seoData.settings type: {type(seo.get('settings'))}") + if 'categoryIds' in draft_post: + logger.warning(f" - categoryIds type: {type(draft_post.get('categoryIds'))}, count: {len(draft_post.get('categoryIds', []))}") + if 'tagIds' in draft_post: + logger.warning(f" - tagIds type: {type(draft_post.get('tagIds'))}, count: {len(draft_post.get('tagIds', []))}") + + # Log a sample of the payload JSON to see exact structure (first 2000 chars) + try: + import json + payload_json = json.dumps(blog_data, indent=2, ensure_ascii=False) + logger.warning(f"š Payload JSON preview (first 3000 chars):\n{payload_json[:3000]}...") + + # Also log a deep structure inspection of richContent.nodes (first few nodes) + if 'richContent' in blog_data['draftPost']: + nodes = blog_data['draftPost']['richContent'].get('nodes', []) + if nodes: + logger.warning(f"š Inspecting first 5 richContent.nodes:") + for i, node in enumerate(nodes[:5]): + logger.warning(f" Node {i+1}: type={node.get('type')}, keys={list(node.keys())}") + # Check for any None values in node + for key, value in node.items(): + if value is None: + logger.error(f" ā ļø Node {i+1}.{key} is None!") + elif isinstance(value, dict): + for k, v in value.items(): + if v is None: + logger.error(f" ā ļø Node {i+1}.{key}.{k} is None!") + # Deep check: if it's a list-type node, inspect list items + if node.get('type') in ['BULLETED_LIST', 'ORDERED_LIST']: + list_items = node.get('nodes', []) + if list_items: + logger.warning(f" List has {len(list_items)} items, checking first LIST_ITEM:") + first_item = list_items[0] + logger.warning(f" LIST_ITEM keys: {list(first_item.keys())}") + # Verify listItemData is NOT present (correct per Wix API spec) + if 'listItemData' in first_item: + logger.error(f" ā LIST_ITEM incorrectly has listItemData!") + else: + logger.debug(f" ā LIST_ITEM correctly has no listItemData") + # Check nested PARAGRAPH nodes + nested_nodes = first_item.get('nodes', []) + if nested_nodes: + logger.warning(f" LIST_ITEM has {len(nested_nodes)} nested nodes") + for n_idx, n_node in enumerate(nested_nodes[:2]): + logger.warning(f" Nested node {n_idx+1}: type={n_node.get('type')}, keys={list(n_node.keys())}") + except Exception as e: + logger.warning(f"Could not serialize payload for logging: {e}") + + # Note: All node validation is done by validate_ricos_content() which runs earlier + # The recursive validation ensures all required data fields are present at any depth + + # Final deep validation: Serialize and deserialize to catch any JSON-serialization issues + # This will raise an error if there are any objects that can't be serialized + try: + import json + test_json = json.dumps(blog_data, ensure_ascii=False) + test_parsed = json.loads(test_json) + logger.debug("ā Payload JSON serialization test passed") + except (TypeError, ValueError) as e: + logger.error(f"ā Payload JSON serialization failed: {e}") + raise ValueError(f"Payload contains non-serializable data: {e}") + + # Final check: Ensure documentStyle and metadata are valid objects (not None, not empty strings) + rc = blog_data['draftPost']['richContent'] + if 'documentStyle' in rc: + doc_style = rc['documentStyle'] + if doc_style is None or doc_style == "": + logger.warning("ā ļø documentStyle is None or empty string, removing it") + del rc['documentStyle'] + elif not isinstance(doc_style, dict): + logger.warning(f"ā ļø documentStyle is not a dict ({type(doc_style)}), removing it") + del rc['documentStyle'] + + if 'metadata' in rc: + metadata = rc['metadata'] + if metadata is None or metadata == "": + logger.warning("ā ļø metadata is None or empty string, removing it") + del rc['metadata'] + elif not isinstance(metadata, dict): + logger.warning(f"ā ļø metadata is not a dict ({type(metadata)}), removing it") + del rc['metadata'] + + # Check for any None values in critical nested structures + def check_none_in_dict(d, path=""): + """Recursively check for None values that shouldn't be there""" + issues = [] + if isinstance(d, dict): + for key, value in d.items(): + current_path = f"{path}.{key}" if path else key + if value is None: + # Some fields can legitimately be None, but most shouldn't + if key not in ['decorations', 'nodeStyle', 'props']: + issues.append(current_path) + elif isinstance(value, dict): + issues.extend(check_none_in_dict(value, current_path)) + elif isinstance(value, list): + for i, item in enumerate(value): + if item is None: + issues.append(f"{current_path}[{i}]") + elif isinstance(item, dict): + issues.extend(check_none_in_dict(item, f"{current_path}[{i}]")) + return issues + + none_issues = check_none_in_dict(blog_data['draftPost']['richContent']) + if none_issues: + logger.error(f"ā Found None values in richContent at: {none_issues[:10]}") # Limit to first 10 + # Remove None values from critical paths + for issue_path in none_issues[:5]: # Fix first 5 + parts = issue_path.split('.') + try: + obj = blog_data['draftPost']['richContent'] + for part in parts[:-1]: + if '[' in part: + key, idx = part.split('[') + idx = int(idx.rstrip(']')) + obj = obj[key][idx] + else: + obj = obj[part] + final_key = parts[-1] + if '[' in final_key: + key, idx = final_key.split('[') + idx = int(idx.rstrip(']')) + obj[key][idx] = {} + else: + obj[final_key] = {} + logger.warning(f"Fixed None value at {issue_path}") + except: + pass + + # Log the final payload structure one more time before sending + logger.warning(f"š¤ Final payload ready - draftPost keys: {list(blog_data['draftPost'].keys())}") + logger.warning(f"š¤ RichContent nodes count: {len(blog_data['draftPost']['richContent'].get('nodes', []))}") + logger.warning(f"š¤ RichContent has metadata: {bool(blog_data['draftPost']['richContent'].get('metadata'))}") + logger.warning(f"š¤ RichContent has documentStyle: {bool(blog_data['draftPost']['richContent'].get('documentStyle'))}") + + # Try sending WITHOUT SEO data first to isolate the issue + test_without_seo = False # Disabled - listItemData issue fixed + if test_without_seo and 'seoData' in blog_data['draftPost']: + logger.warning("š§Ŗ TESTING WITHOUT SEO DATA to isolate issue...") + # Clone the payload without SEO data + test_payload_no_seo = { + 'draftPost': { + 'title': blog_data['draftPost']['title'], + 'memberId': blog_data['draftPost']['memberId'], + 'richContent': blog_data['draftPost']['richContent'], + 'excerpt': blog_data['draftPost'].get('excerpt', '') + }, + 'publish': False, + 'fieldsets': ['URL'] + } + try: + logger.warning("š§Ŗ Attempting without SEO data...") + test_result = blog_service.create_draft_post(access_token, test_payload_no_seo, extra_headers or None) + logger.warning(f"ā WITHOUT SEO DATA SUCCEEDED! Post ID: {test_result.get('draftPost', {}).get('id')}") + logger.error("ā ļøā ļøā ļø ISSUE IS WITH SEO DATA STRUCTURE!") + # If this succeeds, don't send the full payload, just return this result + return test_result + except Exception as e: + logger.warning(f"ā WITHOUT SEO DATA ALSO FAILED: {e}") + logger.warning("ā ļø Issue is NOT with SEO data, continuing with full payload...") + + # Try sending with minimal structure first to isolate the issue + # Create a test payload with just required fields + minimal_test = False # Set to True to test with minimal payload + if minimal_test: + logger.warning("š§Ŗ TESTING WITH MINIMAL PAYLOAD (title + memberId + simple richContent)") + test_payload = { + 'draftPost': { + 'title': blog_data['draftPost']['title'], + 'memberId': blog_data['draftPost']['memberId'], + 'richContent': { + 'nodes': [ + { + 'id': str(uuid.uuid4()), + 'type': 'PARAGRAPH', + 'nodes': [ + { + 'id': str(uuid.uuid4()), + 'type': 'TEXT', + 'textData': { + 'text': 'Test paragraph', + 'decorations': [] + } + } + ], + 'paragraphData': {} + } + ], + 'metadata': {'version': 1, 'id': str(uuid.uuid4())}, + 'documentStyle': {} + } + }, + 'publish': False, + 'fieldsets': ['URL'] + } + logger.warning("š§Ŗ Attempting minimal payload first...") + try: + test_result = blog_service.create_draft_post(access_token, test_payload, extra_headers or None) + logger.warning(f"ā MINIMAL PAYLOAD SUCCEEDED! Post ID: {test_result.get('draftPost', {}).get('id')}") + logger.warning("ā ļø Issue is with complex content, not basic structure") + except Exception as e: + logger.error(f"ā MINIMAL PAYLOAD ALSO FAILED: {e}") + logger.error("ā ļø Issue is with basic structure or permissions") + + result = blog_service.create_draft_post(access_token, blog_data, extra_headers or None) + + # Log response + draft_post = result.get('draftPost', {}) + logger.warning(f"ā Blog post created successfully! Post ID: {draft_post.get('id', 'N/A')}") + + # Check if SEO data was preserved in response + if 'seoData' in draft_post: + seo_response = draft_post['seoData'] + logger.warning(f"ā SEO data confirmed in response: {len(seo_response.get('tags', []))} tags, {len(seo_response.get('settings', {}).get('keywords', []))} keywords") + else: + logger.warning("ā ļø No seoData in response - it may have been filtered out by Wix API") + logger.warning(f"š Response fields: {list(draft_post.keys())}") + + return result + except requests.RequestException as e: + logger.error(f"Failed to create blog post: {e}") + if hasattr(e, 'response') and e.response is not None: + logger.error(f"Response body: {e.response.text}") + raise + diff --git a/backend/services/integrations/wix/content.py b/backend/services/integrations/wix/content.py index 216b19df..0a31aec4 100644 --- a/backend/services/integrations/wix/content.py +++ b/backend/services/integrations/wix/content.py @@ -1,58 +1,460 @@ +import re +import uuid from typing import Any, Dict, List +def parse_markdown_inline(text: str) -> List[Dict[str, Any]]: + """ + Parse inline markdown formatting (bold, italic, links) into Ricos text nodes. + Returns a list of text nodes with decorations. + Handles: **bold**, *italic*, [links](url), `code`, and combinations. + """ + if not text: + return [{ + 'id': str(uuid.uuid4()), + 'type': 'TEXT', + 'textData': {'text': '', 'decorations': []} + }] + + nodes = [] + + # Process text character by character to handle nested/adjacent formatting + # This is more robust than regex for complex cases + i = 0 + current_text = '' + current_decorations = [] + + while i < len(text): + # Check for bold **text** (must come before single * check) + if i < len(text) - 1 and text[i:i+2] == '**': + # Save any accumulated text + if current_text: + nodes.append({ + 'id': str(uuid.uuid4()), + 'type': 'TEXT', + 'textData': { + 'text': current_text, + 'decorations': current_decorations.copy() + } + }) + current_text = '' + + # Find closing ** + end_bold = text.find('**', i + 2) + if end_bold != -1: + bold_text = text[i + 2:end_bold] + # Recursively parse the bold text for nested formatting + bold_nodes = parse_markdown_inline(bold_text) + # Add BOLD decoration to all text nodes within + for node in bold_nodes: + if node['type'] == 'TEXT': + node_decorations = node['textData'].get('decorations', []).copy() + if 'BOLD' not in node_decorations: + node_decorations.append('BOLD') + node['textData']['decorations'] = node_decorations + nodes.append(node) + i = end_bold + 2 + continue + + # Check for link [text](url) + elif text[i] == '[': + # Save any accumulated text + if current_text: + nodes.append({ + 'id': str(uuid.uuid4()), + 'type': 'TEXT', + 'textData': { + 'text': current_text, + 'decorations': current_decorations.copy() + } + }) + current_text = '' + current_decorations = [] + + # Find matching ] + link_end = text.find(']', i) + if link_end != -1 and link_end < len(text) - 1 and text[link_end + 1] == '(': + link_text = text[i + 1:link_end] + url_start = link_end + 2 + url_end = text.find(')', url_start) + if url_end != -1: + url = text[url_start:url_end] + # Create link node + link_node_id = str(uuid.uuid4()) + text_node_id = str(uuid.uuid4()) + link_text_nodes = parse_markdown_inline(link_text) + # Wrap link text in LINK node + nodes.append({ + 'id': link_node_id, + 'type': 'LINK', + 'nodes': link_text_nodes if link_text_nodes else [{ + 'id': text_node_id, + 'type': 'TEXT', + 'textData': {'text': link_text, 'decorations': []} + }], + 'linkData': { + 'link': { + 'url': url, + 'target': '_blank' + } + } + }) + i = url_end + 1 + continue + + # Check for code `text` + elif text[i] == '`': + # Save any accumulated text + if current_text: + nodes.append({ + 'id': str(uuid.uuid4()), + 'type': 'TEXT', + 'textData': { + 'text': current_text, + 'decorations': current_decorations.copy() + } + }) + current_text = '' + current_decorations = [] + + # Find closing ` + code_end = text.find('`', i + 1) + if code_end != -1: + code_text = text[i + 1:code_end] + nodes.append({ + 'id': str(uuid.uuid4()), + 'type': 'TEXT', + 'textData': { + 'text': code_text, + 'decorations': ['CODE'] + } + }) + i = code_end + 1 + continue + + # Check for italic *text* (only if not part of **) + elif text[i] == '*' and (i == 0 or text[i-1] != '*') and (i == len(text) - 1 or text[i+1] != '*'): + # Save any accumulated text + if current_text: + nodes.append({ + 'id': str(uuid.uuid4()), + 'type': 'TEXT', + 'textData': { + 'text': current_text, + 'decorations': current_decorations.copy() + } + }) + current_text = '' + current_decorations = [] + + # Find closing * (but not **) + italic_end = text.find('*', i + 1) + if italic_end != -1: + # Make sure it's not part of ** + if italic_end == len(text) - 1 or text[italic_end + 1] != '*': + italic_text = text[i + 1:italic_end] + italic_nodes = parse_markdown_inline(italic_text) + # Add ITALIC decoration + for node in italic_nodes: + if node['type'] == 'TEXT': + node_decorations = node['textData'].get('decorations', []).copy() + if 'ITALIC' not in node_decorations: + node_decorations.append('ITALIC') + node['textData']['decorations'] = node_decorations + nodes.append(node) + i = italic_end + 1 + continue + + # Regular character + current_text += text[i] + i += 1 + + # Add any remaining text + if current_text: + nodes.append({ + 'id': str(uuid.uuid4()), + 'type': 'TEXT', + 'textData': { + 'text': current_text, + 'decorations': current_decorations.copy() + } + }) + + # If no nodes created, return single plain text node + if not nodes: + nodes.append({ + 'id': str(uuid.uuid4()), + 'type': 'TEXT', + 'textData': { + 'text': text, + 'decorations': [] + } + }) + + return nodes + + def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str, Any]: """ - Convert simple markdown-like text into minimal valid Ricos JSON. + Convert markdown content into valid Ricos JSON format. + Supports headings, paragraphs, lists, bold, italic, links, and images. """ - paragraphs = content.split('\n\n') + if not content: + content = "This is a post from ALwrity." + nodes = [] - - import uuid - - for paragraph in paragraphs: - text = paragraph.strip() - if not text: + lines = content.split('\n') + + i = 0 + while i < len(lines): + line = lines[i].strip() + + if not line: + i += 1 continue + node_id = str(uuid.uuid4()) - text_node_id = str(uuid.uuid4()) - - if text.startswith('#'): - level = len(text) - len(text.lstrip('#')) - heading_text = text.lstrip('# ').strip() + + # Check for headings + if line.startswith('#'): + level = len(line) - len(line.lstrip('#')) + heading_text = line.lstrip('# ').strip() + text_nodes = parse_markdown_inline(heading_text) nodes.append({ 'id': node_id, 'type': 'HEADING', - 'nodes': [{ - 'id': text_node_id, - 'type': 'TEXT', - 'textData': { - 'text': heading_text, - 'decorations': [] - } - }], - 'headingData': { 'level': min(level, 6) } + 'nodes': text_nodes, + 'headingData': {'level': min(level, 6)} }) - else: - nodes.append({ - 'id': node_id, + i += 1 + + # Check for blockquotes + elif line.startswith('>'): + quote_text = line.lstrip('> ').strip() + # Continue reading consecutive blockquote lines + quote_lines = [quote_text] + i += 1 + while i < len(lines) and lines[i].strip().startswith('>'): + quote_lines.append(lines[i].strip().lstrip('> ').strip()) + i += 1 + quote_content = ' '.join(quote_lines) + text_nodes = parse_markdown_inline(quote_content) + # CRITICAL: TEXT nodes must be wrapped in PARAGRAPH nodes within BLOCKQUOTE + paragraph_node = { + 'id': str(uuid.uuid4()), 'type': 'PARAGRAPH', - 'nodes': [{ - 'id': text_node_id, - 'type': 'TEXT', - 'textData': { - 'text': text, - 'decorations': [] - } - }], + 'nodes': text_nodes, 'paragraphData': {} - }) - + } + blockquote_node = { + 'id': node_id, + 'type': 'BLOCKQUOTE', + 'nodes': [paragraph_node], + 'blockquoteData': {} + } + nodes.append(blockquote_node) + + # Check for unordered lists (handle both '- ' and '* ' markers) + elif (line.startswith('- ') or line.startswith('* ') or + (line.startswith('-') and len(line) > 1 and line[1] != '-') or + (line.startswith('*') and len(line) > 1 and line[1] != '*')): + list_items = [] + list_marker = '- ' if line.startswith('-') else '* ' + # Process list items + while i < len(lines): + current_line = lines[i].strip() + # Check if this is a list item + is_list_item = (current_line.startswith('- ') or current_line.startswith('* ') or + (current_line.startswith('-') and len(current_line) > 1 and current_line[1] != '-') or + (current_line.startswith('*') and len(current_line) > 1 and current_line[1] != '*')) + + if not is_list_item: + break + + # Extract item text (handle both '- ' and '-item' formats) + if current_line.startswith('- ') or current_line.startswith('* '): + item_text = current_line[2:].strip() + elif current_line.startswith('-'): + item_text = current_line[1:].strip() + elif current_line.startswith('*'): + item_text = current_line[1:].strip() + else: + item_text = current_line + + list_items.append(item_text) + i += 1 + + # Check for nested items (indented with 2+ spaces) + while i < len(lines): + next_line = lines[i] + # Must be indented and be a list marker + if next_line.startswith(' ') and (next_line.strip().startswith('- ') or + next_line.strip().startswith('* ') or + (next_line.strip().startswith('-') and len(next_line.strip()) > 1) or + (next_line.strip().startswith('*') and len(next_line.strip()) > 1)): + nested_text = next_line.strip() + if nested_text.startswith('- ') or nested_text.startswith('* '): + nested_text = nested_text[2:].strip() + elif nested_text.startswith('-'): + nested_text = nested_text[1:].strip() + elif nested_text.startswith('*'): + nested_text = nested_text[1:].strip() + list_items.append(nested_text) + i += 1 + else: + break + + # Build list items with proper formatting + # CRITICAL: TEXT nodes must be wrapped in PARAGRAPH nodes within LIST_ITEM + # NOTE: LIST_ITEM nodes do NOT have a data field per Wix API schema + # Wix API: omit empty data objects, don't include them as {} + list_node_items = [] + for item_text in list_items: + item_node_id = str(uuid.uuid4()) + text_nodes = parse_markdown_inline(item_text) + paragraph_node = { + 'id': str(uuid.uuid4()), + 'type': 'PARAGRAPH', + 'nodes': text_nodes, + 'paragraphData': {} + } + list_item_node = { + 'id': item_node_id, + 'type': 'LIST_ITEM', + 'nodes': [paragraph_node] + } + list_node_items.append(list_item_node) + + bulleted_list_node = { + 'id': node_id, + 'type': 'BULLETED_LIST', + 'nodes': list_node_items, + 'bulletedListData': {} + } + nodes.append(bulleted_list_node) + + # Check for ordered lists + elif re.match(r'^\d+\.\s+', line): + list_items = [] + while i < len(lines) and re.match(r'^\d+\.\s+', lines[i].strip()): + item_text = re.sub(r'^\d+\.\s+', '', lines[i].strip()) + list_items.append(item_text) + i += 1 + # Check for nested items + while i < len(lines) and lines[i].strip().startswith(' ') and re.match(r'^\s+\d+\.\s+', lines[i].strip()): + nested_text = re.sub(r'^\s+\d+\.\s+', '', lines[i].strip()) + list_items.append(nested_text) + i += 1 + + # CRITICAL: TEXT nodes must be wrapped in PARAGRAPH nodes within LIST_ITEM + # NOTE: LIST_ITEM nodes do NOT have a data field per Wix API schema + # Wix API: omit empty data objects, don't include them as {} + list_node_items = [] + for item_text in list_items: + item_node_id = str(uuid.uuid4()) + text_nodes = parse_markdown_inline(item_text) + paragraph_node = { + 'id': str(uuid.uuid4()), + 'type': 'PARAGRAPH', + 'nodes': text_nodes, + 'paragraphData': {} + } + list_item_node = { + 'id': item_node_id, + 'type': 'LIST_ITEM', + 'nodes': [paragraph_node] + } + list_node_items.append(list_item_node) + + ordered_list_node = { + 'id': node_id, + 'type': 'ORDERED_LIST', + 'nodes': list_node_items, + 'orderedListData': {} + } + nodes.append(ordered_list_node) + + # Check for images + elif line.startswith('!['): + img_match = re.match(r'!\[([^\]]*)\]\(([^)]+)\)', line) + if img_match: + alt_text = img_match.group(1) + img_url = img_match.group(2) + nodes.append({ + 'id': node_id, + 'type': 'IMAGE', + 'nodes': [], + 'imageData': { + 'image': { + 'src': {'url': img_url}, + 'altText': alt_text + }, + 'containerData': { + 'alignment': 'CENTER', + 'width': {'size': 'CONTENT'} + } + } + }) + i += 1 + + # Regular paragraph + else: + # Collect consecutive non-empty lines as paragraph content + para_lines = [line] + i += 1 + while i < len(lines): + next_line = lines[i].strip() + if not next_line: + break + # Stop if next line is a special markdown element + if (next_line.startswith('#') or + next_line.startswith('- ') or + next_line.startswith('* ') or + next_line.startswith('>') or + next_line.startswith('![') or + re.match(r'^\d+\.\s+', next_line)): + break + para_lines.append(next_line) + i += 1 + + para_text = ' '.join(para_lines) + text_nodes = parse_markdown_inline(para_text) + + # Only add paragraph if there are text nodes + if text_nodes: + paragraph_node = { + 'id': node_id, + 'type': 'PARAGRAPH', + 'nodes': text_nodes, + 'paragraphData': {} + } + nodes.append(paragraph_node) + + # Ensure at least one node exists + # Wix API: omit empty data objects, don't include them as {} + if not nodes: + fallback_paragraph = { + 'id': str(uuid.uuid4()), + 'type': 'PARAGRAPH', + 'nodes': [{ + 'id': str(uuid.uuid4()), + 'type': 'TEXT', + 'textData': { + 'text': content[:500] if content else "This is a post from ALwrity.", + 'decorations': [] + } + }], + 'paragraphData': {} + } + nodes.append(fallback_paragraph) + return { + 'type': 'DOCUMENT', + 'id': str(uuid.uuid4()), 'nodes': nodes, - 'metadata': { 'version': 1, 'id': str(uuid.uuid4()) }, + 'metadata': {'version': 1, 'id': str(uuid.uuid4())}, 'documentStyle': { - 'paragraph': { 'decorations': [], 'nodeStyle': {}, 'lineHeight': '1.5' } + 'paragraph': {'decorations': [], 'nodeStyle': {}, 'lineHeight': '1.5'} } } diff --git a/backend/services/integrations/wix/media.py b/backend/services/integrations/wix/media.py index afd1c2ef..eab4ba8b 100644 --- a/backend/services/integrations/wix/media.py +++ b/backend/services/integrations/wix/media.py @@ -7,6 +7,12 @@ class WixMediaService: self.base_url = base_url def import_image(self, access_token: str, image_url: str, display_name: str) -> Dict[str, Any]: + """ + Import external image to Wix Media Manager. + + Official endpoint: https://www.wixapis.com/site-media/v1/files/import + Reference: https://dev.wix.com/docs/rest/assets/media/media-manager/files/import-file + """ headers = { 'Authorization': f'Bearer {access_token}', 'Content-Type': 'application/json', @@ -16,7 +22,9 @@ class WixMediaService: 'mediaType': 'IMAGE', 'displayName': display_name, } - response = requests.post(f"{self.base_url}/media/v1/files/import", headers=headers, json=payload) + # Correct endpoint per Wix API documentation + endpoint = f"{self.base_url}/site-media/v1/files/import" + response = requests.post(endpoint, headers=headers, json=payload) response.raise_for_status() return response.json() diff --git a/backend/services/integrations/wix/ricos_converter.py b/backend/services/integrations/wix/ricos_converter.py new file mode 100644 index 00000000..faba70c7 --- /dev/null +++ b/backend/services/integrations/wix/ricos_converter.py @@ -0,0 +1,277 @@ +""" +Ricos Document Converter for Wix + +Converts markdown content to Wix Ricos JSON format using either: +1. Wix's official Ricos Documents API (preferred) +2. Custom markdown parser (fallback) +""" + +import json +import requests +import jwt +from typing import Dict, Any, Optional +from loguru import logger + + +def markdown_to_html(markdown_content: str) -> str: + """ + Convert markdown content to HTML. + Uses a simple markdown parser for basic conversion. + + Args: + markdown_content: Markdown content to convert + + Returns: + HTML string + """ + try: + # Try using markdown library if available + import markdown + html = markdown.markdown(markdown_content, extensions=['fenced_code', 'tables']) + return html + except ImportError: + # Fallback: Simple regex-based conversion for basic markdown + logger.warning("markdown library not available, using basic markdown-to-HTML conversion") + import re + + if not markdown_content or not markdown_content.strip(): + return "
This is a post from ALwrity.
" + + lines = markdown_content.split('\n') + result = [] + in_list = False + list_type = None # 'ul' or 'ol' + in_code_block = False + code_block_content = [] + + i = 0 + while i < len(lines): + line = lines[i].strip() + + # Handle code blocks first + if line.startswith('```'): + if not in_code_block: + in_code_block = True + code_block_content = [] + i += 1 + continue + else: + in_code_block = False + result.append(f'{"\n".join(code_block_content)}')
+ code_block_content = []
+ i += 1
+ continue
+
+ if in_code_block:
+ code_block_content.append(lines[i])
+ i += 1
+ continue
+
+ # Close any open lists
+ if in_list and not (line.startswith('- ') or line.startswith('* ') or re.match(r'^\d+\.\s+', line)):
+ result.append(f'{list_type}>')
+ in_list = False
+ list_type = None
+
+ if not line:
+ i += 1
+ continue
+
+ # Headers
+ if line.startswith('###'):
+ result.append(f'') + # Regular paragraphs + else: + para_text = line + # Process inline formatting + para_text = re.sub(r'\*\*(.*?)\*\*', r'\1', para_text) + para_text = re.sub(r'\*(.*?)\*', r'\1', para_text) + para_text = re.sub(r'\[([^\]]+)\]\(([^\)]+)\)', r'\1', para_text) + para_text = re.sub(r'`([^`]+)`', r'{quote_text}
\1', para_text)
+ result.append(f'{para_text}
') + + i += 1 + + # Close any open lists + if in_list: + result.append(f'{list_type}>') + + # Ensure we have at least one paragraph + if not result: + result.append('This is a post from ALwrity.
') + + html = '\n'.join(result) + + logger.debug(f"Converted {len(markdown_content)} chars markdown to {len(html)} chars HTML") + return html + + +def convert_via_wix_api(markdown_content: str, access_token: str, base_url: str = 'https://www.wixapis.com') -> Dict[str, Any]: + """ + Convert markdown to Ricos using Wix's official Ricos Documents API. + Uses HTML format for better reliability (per Wix documentation, HTML is fully supported). + + Reference: https://dev.wix.com/docs/api-reference/assets/rich-content/ricos-documents/convert-to-ricos-document + + Args: + markdown_content: Markdown content to convert (will be converted to HTML) + access_token: Wix access token + base_url: Wix API base URL (default: https://www.wixapis.com) + + Returns: + Ricos JSON document + """ + # Validate content is not empty + markdown_stripped = markdown_content.strip() if markdown_content else "" + if not markdown_stripped: + logger.error("Markdown content is empty or whitespace-only") + raise ValueError("Content cannot be empty for Wix Ricos API conversion") + + logger.debug(f"Converting markdown to HTML: input_length={len(markdown_stripped)} chars") + + # Convert markdown to HTML for better reliability with Wix API + # HTML format is more structured and less prone to parsing errors + html_content = markdown_to_html(markdown_stripped) + + # Validate HTML content is not empty - CRITICAL for Wix API + html_stripped = html_content.strip() if html_content else "" + if not html_stripped or len(html_stripped) == 0: + logger.error(f"HTML conversion produced empty content! Markdown length: {len(markdown_stripped)}") + logger.error(f"Markdown sample: {markdown_stripped[:500]}...") + logger.error(f"HTML result: '{html_content}' (type: {type(html_content)})") + # Fallback: use a minimal valid HTML if conversion failed + html_content = "Content from ALwrity blog writer.
" + logger.warning("Using fallback HTML due to empty conversion result") + else: + html_content = html_stripped + + logger.debug(f"ā Converted markdown to HTML: {len(html_content)} chars, preview: {html_content[:200]}...") + + headers = { + 'Authorization': f'Bearer {access_token}', + 'Content-Type': 'application/json' + } + + # Add wix-site-id if available from token + try: + token_str = str(access_token) + if token_str and token_str.startswith('OauthNG.JWS.'): + jwt_part = token_str[12:] + payload = jwt.decode(jwt_part, options={"verify_signature": False, "verify_aud": False}) + data_payload = payload.get('data', {}) + if isinstance(data_payload, str): + try: + data_payload = json.loads(data_payload) + except: + pass + instance_data = data_payload.get('instance', {}) + meta_site_id = instance_data.get('metaSiteId') + if isinstance(meta_site_id, str) and meta_site_id: + headers['wix-site-id'] = meta_site_id + except Exception as e: + logger.debug(f"Could not extract site ID from token: {e}") + + # Call Wix Ricos Documents API: Convert to Ricos Document + # Official endpoint: https://www.wixapis.com/ricos/v1/ricos-document/convert/to-ricos + # Reference: https://dev.wix.com/docs/rest/assets/rich-content/ricos-documents/convert-to-ricos-document + endpoint = f"{base_url}/ricos/v1/ricos-document/convert/to-ricos" + + # Ensure HTML content is not empty or just whitespace + html_stripped = html_content.strip() if html_content else "" + if not html_stripped or len(html_stripped) == 0: + logger.error(f"HTML content is empty after conversion. Markdown length: {len(markdown_content)}") + logger.error(f"Markdown preview (first 500 chars): {markdown_content[:500] if markdown_content else 'N/A'}") + raise ValueError(f"HTML content cannot be empty. Original markdown had {len(markdown_content)} characters.") + + # Payload structure per Wix API: html/markdown/plainText field at root, optional plugins + payload = { + 'html': html_stripped, # Direct field, not nested in options + 'plugins': [] # Optional: empty array uses default plugins + } + + logger.warning(f"š¤ Sending to Wix Ricos API: html_length={len(payload['html'])}, plugins_count={len(payload['plugins'])}") + logger.debug(f"HTML preview (first 300 chars): {html_stripped[:300]}...") + + try: + # Log the exact payload being sent (for debugging) + logger.warning(f"š¤ Wix Ricos API Request:") + logger.warning(f" Endpoint: {endpoint}") + logger.warning(f" Payload keys: {list(payload.keys())}") + logger.warning(f" HTML length: {len(payload.get('html', ''))}") + logger.warning(f" Plugins: {payload.get('plugins', [])}") + logger.debug(f" Full payload (first 500 chars of HTML): {str(payload)[:500]}") + + response = requests.post( + endpoint, + headers=headers, + json=payload, + timeout=30 + ) + response.raise_for_status() + result = response.json() + + # Extract the ricos document from response + # Response structure: { "document": { "nodes": [...], "metadata": {...}, "documentStyle": {...} } } + ricos_document = result.get('document') + if not ricos_document: + # Fallback: try other possible response fields + ricos_document = result.get('ricosDocument') or result.get('ricos') or result + + if not ricos_document: + logger.error(f"Unexpected response structure from Wix API: {list(result.keys())}") + logger.error(f"Response: {result}") + raise ValueError("Wix API did not return a valid Ricos document") + + logger.warning(f"ā Successfully converted HTML to Ricos via Wix API: {len(ricos_document.get('nodes', []))} nodes") + return ricos_document + + except requests.RequestException as e: + logger.error(f"ā Wix Ricos API conversion failed: {e}") + if hasattr(e, 'response') and e.response is not None: + logger.error(f" Response status: {e.response.status_code}") + logger.error(f" Response headers: {dict(e.response.headers)}") + try: + error_body = e.response.json() + logger.error(f" Response JSON: {error_body}") + except: + logger.error(f" Response text: {e.response.text}") + logger.error(f" Request payload was: {json.dumps(payload, indent=2)[:1000]}...") # First 1000 chars + raise + diff --git a/backend/services/integrations/wix/seo.py b/backend/services/integrations/wix/seo.py new file mode 100644 index 00000000..febf48c4 --- /dev/null +++ b/backend/services/integrations/wix/seo.py @@ -0,0 +1,300 @@ +""" +SEO Data Builder for Wix Blog Posts + +Builds Wix-compatible seoData objects from ALwrity SEO metadata. +""" + +from typing import Dict, Any, Optional +from loguru import logger + + +def build_seo_data(seo_metadata: Dict[str, Any], default_title: str = None) -> Optional[Dict[str, Any]]: + """ + Build Wix seoData object from our SEO metadata format. + + Args: + seo_metadata: SEO metadata dict with fields like: + - seo_title: SEO optimized title + - meta_description: Meta description + - focus_keyword: Main keyword + - blog_tags: List of tag strings (for keywords) + - open_graph: Open Graph data dict + - canonical_url: Canonical URL + default_title: Fallback title if seo_title not provided + + Returns: + Wix seoData object with settings.keywords and tags array, or None if empty + """ + seo_data = { + 'settings': { + 'keywords': [] + }, + 'tags': [] + } + + # Build keywords array + keywords_list = [] + + # Add main keyword (focus_keyword) if provided + focus_keyword = seo_metadata.get('focus_keyword') + if focus_keyword: + keywords_list.append({ + 'term': str(focus_keyword), + 'isMain': True + }) + + # Add additional keywords from blog_tags or other sources + blog_tags = seo_metadata.get('blog_tags', []) + if isinstance(blog_tags, list): + for tag in blog_tags: + tag_str = str(tag).strip() + if tag_str and tag_str != focus_keyword: # Don't duplicate main keyword + keywords_list.append({ + 'term': tag_str, + 'isMain': False + }) + + # Add social hashtags as keywords if available + social_hashtags = seo_metadata.get('social_hashtags', []) + if isinstance(social_hashtags, list): + for hashtag in social_hashtags: + # Remove # if present + hashtag_str = str(hashtag).strip().lstrip('#') + if hashtag_str and hashtag_str != focus_keyword: + keywords_list.append({ + 'term': hashtag_str, + 'isMain': False + }) + + seo_data['settings']['keywords'] = keywords_list + + # Validate keywords list is not empty (or ensure at least one keyword exists) + if not seo_data['settings']['keywords']: + logger.warning("No keywords found in SEO metadata, adding empty keywords array") + + # Build tags array (meta tags, Open Graph, etc.) + tags_list = [] + + # Meta description + meta_description = seo_metadata.get('meta_description') + if meta_description: + tags_list.append({ + 'type': 'meta', + 'props': { + 'name': 'description', + 'content': str(meta_description) + }, + 'custom': True, + 'disabled': False + }) + + # SEO title - 'title' type uses 'children' field, not 'props.content' + seo_title = seo_metadata.get('seo_title') or default_title + if seo_title: + tags_list.append({ + 'type': 'title', + 'children': str(seo_title), # Title tags use 'children', not 'props.content' + 'custom': True, + 'disabled': False + }) + + # Open Graph tags + open_graph = seo_metadata.get('open_graph', {}) + if isinstance(open_graph, dict): + # OG Title + og_title = open_graph.get('title') or seo_title + if og_title: + tags_list.append({ + 'type': 'meta', + 'props': { + 'property': 'og:title', + 'content': str(og_title) + }, + 'custom': True, + 'disabled': False + }) + + # OG Description + og_description = open_graph.get('description') or meta_description + if og_description: + tags_list.append({ + 'type': 'meta', + 'props': { + 'property': 'og:description', + 'content': str(og_description) + }, + 'custom': True, + 'disabled': False + }) + + # OG Image + og_image = open_graph.get('image') + if og_image: + # Skip base64 images for OG tags (Wix needs URLs) + if isinstance(og_image, str) and (og_image.startswith('http://') or og_image.startswith('https://')): + tags_list.append({ + 'type': 'meta', + 'props': { + 'property': 'og:image', + 'content': og_image + }, + 'custom': True, + 'disabled': False + }) + + # OG Type + tags_list.append({ + 'type': 'meta', + 'props': { + 'property': 'og:type', + 'content': 'article' + }, + 'custom': True, + 'disabled': False + }) + + # OG URL (canonical or provided URL) + og_url = open_graph.get('url') or seo_metadata.get('canonical_url') + if og_url: + tags_list.append({ + 'type': 'meta', + 'props': { + 'property': 'og:url', + 'content': str(og_url) + }, + 'custom': True, + 'disabled': False + }) + + # Twitter Card tags + twitter_card = seo_metadata.get('twitter_card', {}) + if isinstance(twitter_card, dict): + twitter_title = twitter_card.get('title') or seo_title + if twitter_title: + tags_list.append({ + 'type': 'meta', + 'props': { + 'name': 'twitter:title', + 'content': str(twitter_title) + }, + 'custom': True, + 'disabled': False + }) + + twitter_description = twitter_card.get('description') or meta_description + if twitter_description: + tags_list.append({ + 'type': 'meta', + 'props': { + 'name': 'twitter:description', + 'content': str(twitter_description) + }, + 'custom': True, + 'disabled': False + }) + + twitter_image = twitter_card.get('image') + if twitter_image and isinstance(twitter_image, str) and (twitter_image.startswith('http://') or twitter_image.startswith('https://')): + tags_list.append({ + 'type': 'meta', + 'props': { + 'name': 'twitter:image', + 'content': twitter_image + }, + 'custom': True, + 'disabled': False + }) + + twitter_card_type = twitter_card.get('card', 'summary_large_image') + tags_list.append({ + 'type': 'meta', + 'props': { + 'name': 'twitter:card', + 'content': str(twitter_card_type) + }, + 'custom': True, + 'disabled': False + }) + + # Canonical URL as link tag + canonical_url = seo_metadata.get('canonical_url') + if canonical_url: + tags_list.append({ + 'type': 'link', + 'props': { + 'rel': 'canonical', + 'href': str(canonical_url) + }, + 'custom': True, + 'disabled': False + }) + + # Validate all tags have required fields before adding + validated_tags = [] + for tag in tags_list: + if not isinstance(tag, dict): + logger.warning(f"Skipping invalid tag (not a dict): {type(tag)}") + continue + # Ensure required fields exist + if 'type' not in tag: + logger.warning("Skipping tag missing 'type' field") + continue + # Ensure 'custom' and 'disabled' fields exist + if 'custom' not in tag: + tag['custom'] = True + if 'disabled' not in tag: + tag['disabled'] = False + # Validate tag structure based on type + tag_type = tag.get('type') + if tag_type == 'title': + if 'children' not in tag or not tag['children']: + logger.warning("Skipping title tag with missing/invalid 'children' field") + continue + elif tag_type == 'meta': + if 'props' not in tag or not isinstance(tag['props'], dict): + logger.warning("Skipping meta tag with missing/invalid 'props' field") + continue + if 'name' not in tag['props'] and 'property' not in tag['props']: + logger.warning("Skipping meta tag with missing 'name' or 'property' in props") + continue + # Ensure 'content' exists and is not empty + if 'content' not in tag['props'] or not str(tag['props'].get('content', '')).strip(): + logger.warning(f"Skipping meta tag with missing/empty 'content': {tag.get('props', {})}") + continue + elif tag_type == 'link': + if 'props' not in tag or not isinstance(tag['props'], dict): + logger.warning("Skipping link tag with missing/invalid 'props' field") + continue + # Ensure 'href' exists and is not empty for link tags + if 'href' not in tag['props'] or not str(tag['props'].get('href', '')).strip(): + logger.warning(f"Skipping link tag with missing/empty 'href': {tag.get('props', {})}") + continue + validated_tags.append(tag) + + seo_data['tags'] = validated_tags + + # Final validation: ensure seoData structure is complete + if not isinstance(seo_data['settings'], dict): + logger.error("seoData.settings is not a dict, creating default") + seo_data['settings'] = {'keywords': []} + if not isinstance(seo_data['settings'].get('keywords'), list): + logger.error("seoData.settings.keywords is not a list, creating empty list") + seo_data['settings']['keywords'] = [] + if not isinstance(seo_data['tags'], list): + logger.error("seoData.tags is not a list, creating empty list") + seo_data['tags'] = [] + + # CRITICAL: Per Wix API patterns, omit empty structures instead of including them as {} + # If keywords is empty, omit settings entirely + if not seo_data['settings'].get('keywords'): + logger.debug("No keywords found, omitting settings from seoData") + seo_data.pop('settings', None) + + logger.debug(f"Built SEO data: {len(validated_tags)} tags, {len(keywords_list)} keywords") + + # Only return seoData if we have at least keywords or tags + if keywords_list or validated_tags: + return seo_data + + return None + diff --git a/backend/services/llm_providers/main_text_generation.py b/backend/services/llm_providers/main_text_generation.py index e7fccae6..01333265 100644 --- a/backend/services/llm_providers/main_text_generation.py +++ b/backend/services/llm_providers/main_text_generation.py @@ -9,6 +9,7 @@ import json from typing import Optional, Dict, Any from datetime import datetime from loguru import logger +from fastapi import HTTPException from ..onboarding.api_key_manager import APIKeyManager from .gemini_provider import gemini_text_response, gemini_structured_json_response @@ -129,11 +130,16 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: pricing_service = PricingService(db) # Estimate tokens from prompt (input tokens) - # Note: We estimate output tokens conservatively (assume response is similar length to prompt) - # This prevents underestimating total token usage + # CRITICAL: Use worst-case scenario (input + max_tokens) for validation to prevent abuse + # This ensures we block requests that would exceed limits even if response is longer than expected input_tokens = int(len(prompt.split()) * 1.3) - # Conservative estimate: assume output tokens ā input tokens * 1.0 (can be up to max_tokens) - estimated_output_tokens = min(input_tokens, max_tokens) if max_tokens else int(input_tokens * 0.8) + # Worst-case estimate: assume maximum possible output tokens (max_tokens if specified) + # This prevents abuse where actual response tokens exceed the estimate + if max_tokens: + estimated_output_tokens = max_tokens # Use maximum allowed output tokens + else: + # If max_tokens not specified, use conservative estimate (input * 1.5) + estimated_output_tokens = int(input_tokens * 1.5) estimated_total_tokens = input_tokens + estimated_output_tokens # Check limits using sync method from pricing service (strict enforcement) @@ -146,7 +152,14 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: if not can_proceed: logger.warning(f"[llm_text_gen] Subscription limit exceeded for user {user_id}: {message}") - raise RuntimeError(f"Subscription limit exceeded: {message}") + # Raise HTTPException(429) with usage info so frontend can display subscription modal + error_detail = { + 'error': message, + 'message': message, + 'provider': actual_provider_name or provider_enum.value, + 'usage_info': usage_info if usage_info else {} + } + raise HTTPException(status_code=429, detail=error_detail) # Get current usage for limit checking only current_period = pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m") @@ -159,6 +172,9 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: finally: db.close() + except HTTPException: + # Re-raise HTTPExceptions (e.g., 429 subscription limit) - preserve error details + raise except RuntimeError: # Re-raise subscription limit errors raise @@ -244,7 +260,8 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: db_track = next(get_db()) try: # Estimate tokens from prompt and response - tokens_input = estimated_tokens # Already calculated above + # Recalculate input tokens from prompt (consistent with pre-flight estimation) + tokens_input = int(len(prompt.split()) * 1.3) tokens_output = int(len(str(response_text).split()) * 1.3) # Estimate output tokens tokens_total = tokens_input + tokens_output @@ -259,45 +276,186 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: logger.debug(f"[llm_text_gen] Looking for usage summary: user_id={user_id}, period={current_period}") + # Get limits once for safety check (to prevent exceeding limits even if actual usage > estimate) + provider_name = provider_enum.value + limits = pricing.get_user_limits(user_id) + token_limit = 0 + if limits and limits.get('limits'): + token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0 + + # CRITICAL: Use raw SQL to read current values directly from DB, bypassing SQLAlchemy cache + # This ensures we always get the absolute latest committed values, even across different sessions + from sqlalchemy import text + current_calls_before = 0 + current_tokens_before = 0 + record_count = 0 # Initialize to ensure it's always defined + + # CRITICAL: First check if record exists using COUNT query + try: + check_query = text("SELECT COUNT(*) FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period") + record_count = db_track.execute(check_query, {'user_id': user_id, 'period': current_period}).scalar() + logger.debug(f"[llm_text_gen] š DEBUG: Record count check - found {record_count} record(s) for user={user_id}, period={current_period}") + except Exception as count_error: + logger.error(f"[llm_text_gen] ā COUNT query failed: {count_error}", exc_info=True) + record_count = 0 + + if record_count and record_count > 0: + # Record exists - read current values with raw SQL + try: + # Validate provider_name to prevent SQL injection (whitelist approach) + valid_providers = ['gemini', 'openai', 'anthropic', 'mistral'] + if provider_name not in valid_providers: + raise ValueError(f"Invalid provider_name for SQL query: {provider_name}") + + # Read current values directly from database using raw SQL + # CRITICAL: This bypasses SQLAlchemy's session cache and gets absolute latest values + sql_query = text(f""" + SELECT {provider_name}_calls, {provider_name}_tokens + FROM usage_summaries + WHERE user_id = :user_id AND billing_period = :period + LIMIT 1 + """) + logger.debug(f"[llm_text_gen] š Executing raw SQL for EXISTING record: SELECT {provider_name}_calls, {provider_name}_tokens WHERE user_id={user_id}, period={current_period}") + result = db_track.execute(sql_query, {'user_id': user_id, 'period': current_period}).first() + if result: + raw_calls = result[0] if result[0] is not None else 0 + raw_tokens = result[1] if result[1] is not None else 0 + current_calls_before = raw_calls + current_tokens_before = raw_tokens + logger.debug(f"[llm_text_gen] ā Raw SQL SUCCESS: Found EXISTING record - calls={current_calls_before}, tokens={current_tokens_before} (provider={provider_name}, column={provider_name}_calls/{provider_name}_tokens)") + logger.debug(f"[llm_text_gen] š Raw SQL returned row: {result}, extracted calls={raw_calls}, tokens={raw_tokens}") + else: + logger.error(f"[llm_text_gen] ā CRITICAL BUG: Record EXISTS (count={record_count}) but SELECT query returned None! Query: {sql_query}") + # Fallback: Use ORM to get values + summary_fallback = db_track.query(UsageSummary).filter( + UsageSummary.user_id == user_id, + UsageSummary.billing_period == current_period + ).first() + if summary_fallback: + db_track.refresh(summary_fallback) + current_calls_before = getattr(summary_fallback, f"{provider_name}_calls", 0) or 0 + current_tokens_before = getattr(summary_fallback, f"{provider_name}_tokens", 0) or 0 + logger.warning(f"[llm_text_gen] ā ļø Using ORM fallback: calls={current_calls_before}, tokens={current_tokens_before}") + except Exception as sql_error: + logger.error(f"[llm_text_gen] ā Raw SQL query failed: {sql_error}", exc_info=True) + # Fallback: Use ORM to get values + summary_fallback = db_track.query(UsageSummary).filter( + UsageSummary.user_id == user_id, + UsageSummary.billing_period == current_period + ).first() + if summary_fallback: + db_track.refresh(summary_fallback) + current_calls_before = getattr(summary_fallback, f"{provider_name}_calls", 0) or 0 + current_tokens_before = getattr(summary_fallback, f"{provider_name}_tokens", 0) or 0 + else: + logger.debug(f"[llm_text_gen] ā¹ļø No record exists yet (will create new) - user={user_id}, period={current_period}") + + # Get or create usage summary object (needed for ORM update) summary = db_track.query(UsageSummary).filter( UsageSummary.user_id == user_id, UsageSummary.billing_period == current_period ).first() if not summary: - logger.info(f"[llm_text_gen] Creating new usage summary for user {user_id}, period {current_period}") + logger.debug(f"[llm_text_gen] Creating NEW usage summary for user {user_id}, period {current_period}") summary = UsageSummary( user_id=user_id, billing_period=current_period ) db_track.add(summary) db_track.flush() # Ensure summary is persisted before updating + # New record - values are already 0, no need to set + logger.debug(f"[llm_text_gen] ā New summary created - starting from 0") + else: + # CRITICAL: Update the ORM object with values from raw SQL query + # This ensures the ORM object reflects the actual database state before we increment + logger.debug(f"[llm_text_gen] š Existing summary found - syncing with raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}") + setattr(summary, f"{provider_name}_calls", current_calls_before) + if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]: + setattr(summary, f"{provider_name}_tokens", current_tokens_before) + logger.debug(f"[llm_text_gen] ā Synchronized ORM object: {provider_name}_calls={current_calls_before}, {provider_name}_tokens={current_tokens_before}") - # Get "before" state for unified log - provider_name = provider_enum.value - current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0 + logger.debug(f"[llm_text_gen] Current {provider_name}_calls from DB (raw SQL): {current_calls_before}") # Update provider-specific counters (sync operation) new_calls = current_calls_before + 1 - setattr(summary, f"{provider_name}_calls", new_calls) - logger.debug(f"[llm_text_gen] Updated {provider_name}_calls: {current_calls_before} -> {new_calls}") - # Update token usage for LLM providers + # CRITICAL: Use direct SQL UPDATE instead of ORM setattr for dynamic attributes + # SQLAlchemy doesn't detect changes when using setattr() on dynamic attributes + # Using raw SQL UPDATE ensures the change is persisted + from sqlalchemy import text + update_calls_query = text(f""" + UPDATE usage_summaries + SET {provider_name}_calls = :new_calls + WHERE user_id = :user_id AND billing_period = :period + """) + db_track.execute(update_calls_query, { + 'new_calls': new_calls, + 'user_id': user_id, + 'period': current_period + }) + logger.debug(f"[llm_text_gen] Updated {provider_name}_calls via SQL: {current_calls_before} -> {new_calls}") + + # Update token usage for LLM providers with safety check + # CRITICAL: Use current_tokens_before from raw SQL query (NOT from ORM object) + # The ORM object may have stale values, but raw SQL always has the latest committed values if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]: - current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0 - new_tokens = current_tokens_before + tokens_total - setattr(summary, f"{provider_name}_tokens", new_tokens) - logger.debug(f"[llm_text_gen] Updated {provider_name}_tokens: {current_tokens_before} -> {new_tokens}") + logger.debug(f"[llm_text_gen] Current {provider_name}_tokens from DB (raw SQL): {current_tokens_before}") + + # SAFETY CHECK: Prevent exceeding token limit even if actual usage exceeds estimate + # This prevents abuse where actual response tokens exceed pre-flight validation estimate + projected_new_tokens = current_tokens_before + tokens_total + + # If limit is set (> 0) and would be exceeded, cap at limit + if token_limit > 0 and projected_new_tokens > token_limit: + logger.warning( + f"[llm_text_gen] ā ļø ACTUAL token usage ({tokens_total}) exceeded estimate. " + f"Would exceed limit: {projected_new_tokens} > {token_limit}. " + f"Capping tracked tokens at limit to prevent abuse." + ) + # Cap at limit to prevent abuse + new_tokens = token_limit + # Adjust tokens_total for accurate total tracking + tokens_total = token_limit - current_tokens_before + if tokens_total < 0: + tokens_total = 0 + else: + new_tokens = projected_new_tokens + + # CRITICAL: Use direct SQL UPDATE instead of ORM setattr for dynamic attributes + update_tokens_query = text(f""" + UPDATE usage_summaries + SET {provider_name}_tokens = :new_tokens + WHERE user_id = :user_id AND billing_period = :period + """) + db_track.execute(update_tokens_query, { + 'new_tokens': new_tokens, + 'user_id': user_id, + 'period': current_period + }) + logger.debug(f"[llm_text_gen] Updated {provider_name}_tokens via SQL: {current_tokens_before} -> {new_tokens}") else: current_tokens_before = 0 new_tokens = 0 - # Update totals + # Update totals using SQL UPDATE old_total_calls = summary.total_calls or 0 old_total_tokens = summary.total_tokens or 0 - summary.total_calls = old_total_calls + 1 - summary.total_tokens = old_total_tokens + tokens_total - logger.debug(f"[llm_text_gen] Updated totals: calls {old_total_calls} -> {summary.total_calls}, tokens {old_total_tokens} -> {summary.total_tokens}") + new_total_calls = old_total_calls + 1 + new_total_tokens = old_total_tokens + tokens_total + + update_totals_query = text(""" + UPDATE usage_summaries + SET total_calls = :total_calls, total_tokens = :total_tokens + WHERE user_id = :user_id AND billing_period = :period + """) + db_track.execute(update_totals_query, { + 'total_calls': new_total_calls, + 'total_tokens': new_total_tokens, + 'user_id': user_id, + 'period': current_period + }) + logger.debug(f"[llm_text_gen] Updated totals via SQL: calls {old_total_calls} -> {new_total_calls}, tokens {old_total_tokens} -> {new_total_tokens}") # Get plan details for unified log limits = pricing.get_user_limits(user_id) @@ -310,12 +468,52 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: current_images_before = getattr(summary, "stability_calls", 0) or 0 image_limit = limits['limits'].get("stability_calls", 0) if limits else 0 - db_track.commit() - logger.info(f"[llm_text_gen] ā Successfully tracked usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens") + # CRITICAL DEBUG: Print diagnostic info BEFORE commit (always visible, flushed immediately) + import sys + debug_msg = f"[DEBUG] BEFORE COMMIT - Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}, Period: {current_period}, New calls will be: {new_calls}, New tokens will be: {new_tokens}" + print(debug_msg, flush=True) + sys.stdout.flush() + logger.debug(f"[llm_text_gen] {debug_msg}") + + # CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions + db_track.flush() # Flush to ensure changes are in DB (not just in transaction) + db_track.commit() # Commit transaction to make changes visible to other sessions + logger.debug(f"[llm_text_gen] ā Successfully tracked usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens (COMMITTED to DB)") + logger.debug(f"[llm_text_gen] Database state after commit: {provider_name}_calls={new_calls}, {provider_name}_tokens={new_tokens} (should be visible to next session)") + + # CRITICAL: Verify commit worked by reading back from DB immediately after commit + try: + verify_query = text(f"SELECT {provider_name}_calls, {provider_name}_tokens FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1") + verify_result = db_track.execute(verify_query, {'user_id': user_id, 'period': current_period}).first() + if verify_result: + verified_calls = verify_result[0] if verify_result[0] is not None else 0 + verified_tokens = verify_result[1] if verify_result[1] is not None else 0 + logger.debug(f"[llm_text_gen] ā VERIFICATION AFTER COMMIT: Read back calls={verified_calls}, tokens={verified_tokens} (expected: calls={new_calls}, tokens={new_tokens})") + if verified_calls != new_calls or verified_tokens != new_tokens: + logger.error(f"[llm_text_gen] ā CRITICAL: COMMIT VERIFICATION FAILED! Expected calls={new_calls}, tokens={new_tokens}, but DB has calls={verified_calls}, tokens={verified_tokens}") + # Force another commit attempt + db_track.commit() + verify_result2 = db_track.execute(verify_query, {'user_id': user_id, 'period': current_period}).first() + if verify_result2: + verified_calls2 = verify_result2[0] if verify_result2[0] is not None else 0 + verified_tokens2 = verify_result2[1] if verify_result2[1] is not None else 0 + logger.debug(f"[llm_text_gen] š After second commit attempt: calls={verified_calls2}, tokens={verified_tokens2}") + else: + logger.debug(f"[llm_text_gen] ā COMMIT VERIFICATION PASSED: Values match expected values") + else: + logger.error(f"[llm_text_gen] ā CRITICAL: COMMIT VERIFICATION FAILED! Record not found after commit!") + except Exception as verify_error: + logger.error(f"[llm_text_gen] ā Error verifying commit: {verify_error}", exc_info=True) # UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message # Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral") # Include image stats in the log + # DEBUG: Log the actual values being used + logger.debug(f"[llm_text_gen] š FINAL VALUES FOR LOG: calls_before={current_calls_before}, calls_after={new_calls}, tokens_before={current_tokens_before}, tokens_after={new_tokens}, provider={provider_name}, enum={provider_enum}") + + # CRITICAL DEBUG: Print diagnostic info to stdout (always visible) + print(f"[DEBUG] Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}") + print(f""" [SUBSCRIPTION] LLM Text Generation āā User: {user_id} @@ -407,7 +605,8 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: db_track = next(get_db()) try: # Estimate tokens from prompt and response - tokens_input = estimated_tokens + # Recalculate input tokens from prompt (consistent with pre-flight estimation) + tokens_input = int(len(prompt.split()) * 1.3) tokens_output = int(len(str(response_text).split()) * 1.3) tokens_total = tokens_input + tokens_output @@ -418,6 +617,49 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: pricing = PricingService(db_track) current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m") + # Get limits once for safety check (to prevent exceeding limits even if actual usage > estimate) + provider_name = provider_enum.value + limits = pricing.get_user_limits(user_id) + token_limit = 0 + if limits and limits.get('limits'): + token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0 + + # CRITICAL: Use raw SQL to read current values directly from DB, bypassing SQLAlchemy cache + from sqlalchemy import text + current_calls_before = 0 + current_tokens_before = 0 + + try: + # Validate provider_name to prevent SQL injection + valid_providers = ['gemini', 'openai', 'anthropic', 'mistral'] + if provider_name not in valid_providers: + raise ValueError(f"Invalid provider_name for SQL query: {provider_name}") + + # Read current values directly from database using raw SQL + sql_query = text(f""" + SELECT {provider_name}_calls, {provider_name}_tokens + FROM usage_summaries + WHERE user_id = :user_id AND billing_period = :period + LIMIT 1 + """) + result = db_track.execute(sql_query, {'user_id': user_id, 'period': current_period}).first() + if result: + current_calls_before = result[0] if result[0] is not None else 0 + current_tokens_before = result[1] if result[1] is not None else 0 + logger.debug(f"[llm_text_gen] Raw SQL read current values (fallback): calls={current_calls_before}, tokens={current_tokens_before}") + except Exception as sql_error: + logger.warning(f"[llm_text_gen] Raw SQL query failed (fallback), falling back to ORM: {sql_error}") + # Fallback to ORM query if raw SQL fails + summary = db_track.query(UsageSummary).filter( + UsageSummary.user_id == user_id, + UsageSummary.billing_period == current_period + ).first() + if summary: + db_track.refresh(summary) + current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0 + current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0 + + # Get or create usage summary object (needed for ORM update) summary = db_track.query(UsageSummary).filter( UsageSummary.user_id == user_id, UsageSummary.billing_period == current_period @@ -430,41 +672,68 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: ) db_track.add(summary) db_track.flush() # Ensure summary is persisted before updating + else: + # CRITICAL: Update the ORM object with values from raw SQL query + # This ensures the ORM object reflects the actual database state before we increment + setattr(summary, f"{provider_name}_calls", current_calls_before) + if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]: + setattr(summary, f"{provider_name}_tokens", current_tokens_before) + logger.debug(f"[llm_text_gen] Synchronized summary object with raw SQL values (fallback): calls={current_calls_before}, tokens={current_tokens_before}") - # Get "before" state for unified log - provider_name = provider_enum.value - current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0 + # Get "before" state for unified log (from raw SQL query) + logger.debug(f"[llm_text_gen] Current {provider_name}_calls from DB (fallback, raw SQL): {current_calls_before}") # Update provider-specific counters (sync operation) new_calls = current_calls_before + 1 setattr(summary, f"{provider_name}_calls", new_calls) - # Update token usage for LLM providers + # Update token usage for LLM providers with safety check + # Use current_tokens_before from raw SQL query (most reliable) if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]: - current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0 - new_tokens = current_tokens_before + tokens_total + logger.debug(f"[llm_text_gen] Current {provider_name}_tokens from DB (fallback, raw SQL): {current_tokens_before}") + + # SAFETY CHECK: Prevent exceeding token limit even if actual usage exceeds estimate + # This prevents abuse where actual response tokens exceed pre-flight validation estimate + projected_new_tokens = current_tokens_before + tokens_total + + # If limit is set (> 0) and would be exceeded, cap at limit + if token_limit > 0 and projected_new_tokens > token_limit: + logger.warning( + f"[llm_text_gen] ā ļø ACTUAL token usage ({tokens_total}) exceeded estimate in fallback provider. " + f"Would exceed limit: {projected_new_tokens} > {token_limit}. " + f"Capping tracked tokens at limit to prevent abuse." + ) + # Cap at limit to prevent abuse + new_tokens = token_limit + # Adjust tokens_total for accurate total tracking + tokens_total = token_limit - current_tokens_before + if tokens_total < 0: + tokens_total = 0 + else: + new_tokens = projected_new_tokens + setattr(summary, f"{provider_name}_tokens", new_tokens) else: current_tokens_before = 0 new_tokens = 0 - # Update totals + # Update totals (using potentially capped tokens_total from safety check) summary.total_calls = (summary.total_calls or 0) + 1 summary.total_tokens = (summary.total_tokens or 0) + tokens_total - # Get plan details for unified log - limits = pricing.get_user_limits(user_id) + # Get plan details for unified log (limits already retrieved above) plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown' tier = limits.get('tier', 'unknown') if limits else 'unknown' call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0 - token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) if limits else 0 # Get image stats for unified log current_images_before = getattr(summary, "stability_calls", 0) or 0 image_limit = limits['limits'].get("stability_calls", 0) if limits else 0 - db_track.commit() - logger.info(f"[llm_text_gen] ā Successfully tracked fallback usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens") + # CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions + db_track.flush() # Flush to ensure changes are in DB (not just in transaction) + db_track.commit() # Commit transaction to make changes visible to other sessions + logger.info(f"[llm_text_gen] ā Successfully tracked fallback usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens (committed)") # UNIFIED SUBSCRIPTION LOG for fallback # Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral") diff --git a/backend/services/subscription/limit_validation.py b/backend/services/subscription/limit_validation.py new file mode 100644 index 00000000..da4c678d --- /dev/null +++ b/backend/services/subscription/limit_validation.py @@ -0,0 +1,725 @@ +""" +Limit Validation Module +Handles subscription limit checking and validation logic. +Extracted from pricing_service.py for better modularity. +""" + +from typing import Dict, Any, Optional, List, Tuple, TYPE_CHECKING +from datetime import datetime, timedelta +from sqlalchemy import text +from loguru import logger + +from models.subscription_models import ( + UserSubscription, UsageSummary, SubscriptionPlan, + APIProvider, SubscriptionTier +) + +if TYPE_CHECKING: + from .pricing_service import PricingService + + +class LimitValidator: + """Validates subscription limits for API usage.""" + + def __init__(self, pricing_service: 'PricingService'): + """ + Initialize limit validator with reference to PricingService. + + Args: + pricing_service: Instance of PricingService to access helper methods and cache + """ + self.pricing_service = pricing_service + self.db = pricing_service.db + + def check_usage_limits(self, user_id: str, provider: APIProvider, + tokens_requested: int = 0, actual_provider_name: Optional[str] = None) -> Tuple[bool, str, Dict[str, Any]]: + """Check if user can make an API call within their limits. + + Args: + user_id: User ID + provider: APIProvider enum (may be MISTRAL for HuggingFace) + tokens_requested: Estimated tokens for the request + actual_provider_name: Optional actual provider name (e.g., "huggingface" when provider is MISTRAL) + + Returns: + (can_proceed, error_message, usage_info) + """ + try: + # Use actual_provider_name if provided, otherwise use enum value + # This fixes cases where HuggingFace maps to MISTRAL enum but should show as "huggingface" in errors + display_provider_name = actual_provider_name or provider.value + + logger.debug(f"[Subscription Check] Starting limit check for user {user_id}, provider {display_provider_name}, tokens {tokens_requested}") + + # Short TTL cache to reduce DB reads under sustained traffic + cache_key = f"{user_id}:{provider.value}" + now = datetime.utcnow() + cached = self.pricing_service._limits_cache.get(cache_key) + if cached and cached.get('expires_at') and cached['expires_at'] > now: + logger.debug(f"[Subscription Check] Using cached result for {user_id}:{provider.value}") + return tuple(cached['result']) # type: ignore + + # Get user subscription first to check expiration + subscription = self.db.query(UserSubscription).filter( + UserSubscription.user_id == user_id, + UserSubscription.is_active == True + ).first() + + if subscription: + logger.debug(f"[Subscription Check] Found subscription for user {user_id}: plan_id={subscription.plan_id}, period_end={subscription.current_period_end}") + else: + logger.debug(f"[Subscription Check] No active subscription found for user {user_id}") + + # Check subscription expiration (STRICT: deny if expired) + if subscription: + if subscription.current_period_end < now: + logger.warning(f"[Subscription Check] Subscription expired for user {user_id}: period_end={subscription.current_period_end}, now={now}") + # Subscription expired - check if auto_renew is enabled + if not getattr(subscription, 'auto_renew', False): + # Expired and no auto-renew - deny access + logger.warning(f"[Subscription Check] Subscription expired for user {user_id}, auto_renew=False, denying access") + result = (False, "Subscription expired. Please renew your subscription to continue using the service.", { + 'expired': True, + 'period_end': subscription.current_period_end.isoformat() + }) + self.pricing_service._limits_cache[cache_key] = { + 'result': result, + 'expires_at': now + timedelta(seconds=30) + } + return result + else: + # Try to auto-renew + if not self.pricing_service._ensure_subscription_current(subscription): + # Auto-renew failed - deny access + result = (False, "Subscription expired and auto-renewal failed. Please renew manually.", { + 'expired': True, + 'auto_renew_failed': True + }) + self.pricing_service._limits_cache[cache_key] = { + 'result': result, + 'expires_at': now + timedelta(seconds=30) + } + return result + + # Get user limits with error handling (STRICT: fail on errors) + # CRITICAL: Expire SQLAlchemy objects to ensure we get fresh plan data after renewal + try: + # Force expire subscription and plan objects to avoid stale cache + if subscription and subscription.plan_id: + plan_obj = self.db.query(SubscriptionPlan).filter(SubscriptionPlan.id == subscription.plan_id).first() + if plan_obj: + self.db.expire(plan_obj) + logger.debug(f"[Subscription Check] Expired plan object to ensure fresh limits after renewal") + + limits = self.pricing_service.get_user_limits(user_id) + if limits: + logger.debug(f"[Subscription Check] Retrieved limits for user {user_id}: plan={limits.get('plan_name')}, tier={limits.get('tier')}") + # Log token limits for debugging + token_limits = limits.get('limits', {}) + logger.debug(f"[Subscription Check] Token limits: gemini={token_limits.get('gemini_tokens')}, mistral={token_limits.get('mistral_tokens')}, openai={token_limits.get('openai_tokens')}, anthropic={token_limits.get('anthropic_tokens')}") + else: + logger.debug(f"[Subscription Check] No limits found for user {user_id}, checking free tier") + except Exception as e: + logger.error(f"[Subscription Check] Error getting user limits for {user_id}: {e}", exc_info=True) + # STRICT: Fail closed - deny request if we can't check limits + return False, f"Failed to retrieve subscription limits: {str(e)}", {} + + if not limits: + # No subscription found - check for free tier + free_plan = self.db.query(SubscriptionPlan).filter( + SubscriptionPlan.tier == SubscriptionTier.FREE, + SubscriptionPlan.is_active == True + ).first() + if free_plan: + logger.info(f"[Subscription Check] Assigning free tier to user {user_id}") + limits = self.pricing_service._plan_to_limits_dict(free_plan) + else: + # No subscription and no free tier - deny access + logger.warning(f"[Subscription Check] No subscription or free tier found for user {user_id}, denying access") + return False, "No subscription plan found. Please subscribe to a plan.", {} + + # Get current usage for this billing period with error handling + # CRITICAL: Use fresh queries to avoid SQLAlchemy cache after renewal + try: + current_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m") + + # Expire all objects to force fresh read from DB (critical after renewal) + self.db.expire_all() + + # Use raw SQL query first to bypass ORM cache, fallback to ORM if SQL fails + usage = None + try: + from sqlalchemy import text + sql_query = text("SELECT * FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1") + result = self.db.execute(sql_query, {'user_id': user_id, 'period': current_period}).first() + if result: + # Map result to UsageSummary object + usage = self.db.query(UsageSummary).filter( + UsageSummary.user_id == user_id, + UsageSummary.billing_period == current_period + ).first() + if usage: + self.db.refresh(usage) # Ensure fresh data + except Exception as sql_error: + logger.debug(f"[Subscription Check] Raw SQL query failed, using ORM: {sql_error}") + # Fallback to ORM query + usage = self.db.query(UsageSummary).filter( + UsageSummary.user_id == user_id, + UsageSummary.billing_period == current_period + ).first() + if usage: + self.db.refresh(usage) # Ensure fresh data + + if not usage: + # First usage this period, create summary + try: + usage = UsageSummary( + user_id=user_id, + billing_period=current_period + ) + self.db.add(usage) + self.db.commit() + except Exception as create_error: + logger.error(f"Error creating usage summary: {create_error}") + self.db.rollback() + # STRICT: Fail closed on DB error + return False, f"Failed to create usage summary: {str(create_error)}", {} + except Exception as e: + logger.error(f"Error getting usage summary for {user_id}: {e}") + self.db.rollback() + # STRICT: Fail closed on DB error + return False, f"Failed to retrieve usage summary: {str(e)}", {} + + # Check call limits with error handling + # NOTE: call_limit = 0 means UNLIMITED (Enterprise plans) + try: + # Use display_provider_name for error messages, but provider.value for DB queries + provider_name = provider.value # For DB field names (e.g., "mistral_calls", "mistral_tokens") + + # For LLM text generation providers, check against unified total_calls limit + llm_providers = ['gemini', 'openai', 'anthropic', 'mistral'] + is_llm_provider = provider_name in llm_providers + + if is_llm_provider: + # Use unified AI text generation limit (total_calls across all LLM providers) + ai_text_gen_limit = limits['limits'].get('ai_text_generation_calls', 0) or 0 + + # If unified limit not set, fall back to provider-specific limit for backwards compatibility + if ai_text_gen_limit == 0: + ai_text_gen_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0 + + # Calculate total LLM provider calls (sum of gemini + openai + anthropic + mistral) + current_total_llm_calls = ( + (usage.gemini_calls or 0) + + (usage.openai_calls or 0) + + (usage.anthropic_calls or 0) + + (usage.mistral_calls or 0) + ) + + # Only enforce limit if limit > 0 (0 means unlimited for Enterprise) + if ai_text_gen_limit > 0 and current_total_llm_calls >= ai_text_gen_limit: + logger.error(f"[Subscription Check] AI text generation call limit exceeded for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit} (provider: {display_provider_name})") + result = (False, f"AI text generation call limit reached. Used {current_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls this billing period.", { + 'current_calls': current_total_llm_calls, + 'limit': ai_text_gen_limit, + 'usage_percentage': (current_total_llm_calls / ai_text_gen_limit) * 100 if ai_text_gen_limit > 0 else 0, + 'provider': display_provider_name, # Use display name for consistency + 'usage_info': { + 'provider': display_provider_name, # Use display name for user-facing info + 'current_calls': current_total_llm_calls, + 'limit': ai_text_gen_limit, + 'type': 'ai_text_generation', + 'breakdown': { + 'gemini': usage.gemini_calls or 0, + 'openai': usage.openai_calls or 0, + 'anthropic': usage.anthropic_calls or 0, + 'mistral': usage.mistral_calls or 0 # DB field name (not display name) + } + } + }) + self.pricing_service._limits_cache[cache_key] = { + 'result': result, + 'expires_at': now + timedelta(seconds=30) + } + return result + else: + logger.debug(f"[Subscription Check] AI text generation limit check passed for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit if ai_text_gen_limit > 0 else 'unlimited'} (provider: {display_provider_name})") + else: + # For non-LLM providers, check provider-specific limit + current_calls = getattr(usage, f"{provider_name}_calls", 0) or 0 + call_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0 + + # Only enforce limit if limit > 0 (0 means unlimited for Enterprise) + if call_limit > 0 and current_calls >= call_limit: + logger.error(f"[Subscription Check] Call limit exceeded for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit}") + result = (False, f"API call limit reached for {display_provider_name}. Used {current_calls} of {call_limit} calls this billing period.", { + 'current_calls': current_calls, + 'limit': call_limit, + 'usage_percentage': 100.0, + 'provider': display_provider_name # Use display name for consistency + }) + self.pricing_service._limits_cache[cache_key] = { + 'result': result, + 'expires_at': now + timedelta(seconds=30) + } + return result + else: + logger.debug(f"[Subscription Check] Call limit check passed for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit if call_limit > 0 else 'unlimited'}") + except Exception as e: + logger.error(f"Error checking call limits: {e}") + # Continue to next check + + # Check token limits for LLM providers with error handling + # NOTE: token_limit = 0 means UNLIMITED (Enterprise plans) + try: + if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]: + current_tokens = getattr(usage, f"{provider_name}_tokens", 0) or 0 + token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0 + + # Only enforce limit if limit > 0 (0 means unlimited for Enterprise) + if token_limit > 0 and (current_tokens + tokens_requested) > token_limit: + result = (False, f"Token limit would be exceeded for {display_provider_name}. Current: {current_tokens}, Requested: {tokens_requested}, Limit: {token_limit}", { + 'current_tokens': current_tokens, + 'requested_tokens': tokens_requested, + 'limit': token_limit, + 'usage_percentage': ((current_tokens + tokens_requested) / token_limit) * 100, + 'provider': display_provider_name, # Use display name in error details + 'usage_info': { + 'provider': display_provider_name, + 'current_tokens': current_tokens, + 'requested_tokens': tokens_requested, + 'limit': token_limit, + 'type': 'tokens' + } + }) + self.pricing_service._limits_cache[cache_key] = { + 'result': result, + 'expires_at': now + timedelta(seconds=30) + } + return result + except Exception as e: + logger.error(f"Error checking token limits: {e}") + # Continue to next check + + # Check cost limits with error handling + # NOTE: cost_limit = 0 means UNLIMITED (Enterprise plans) + try: + cost_limit = limits['limits'].get('monthly_cost', 0) or 0 + # Only enforce limit if limit > 0 (0 means unlimited for Enterprise) + if cost_limit > 0 and usage.total_cost >= cost_limit: + result = (False, f"Monthly cost limit reached. Current cost: ${usage.total_cost:.2f}, Limit: ${cost_limit:.2f}", { + 'current_cost': usage.total_cost, + 'limit': cost_limit, + 'usage_percentage': 100.0 + }) + self.pricing_service._limits_cache[cache_key] = { + 'result': result, + 'expires_at': now + timedelta(seconds=30) + } + return result + except Exception as e: + logger.error(f"Error checking cost limits: {e}") + # Continue to success case + + # Calculate usage percentages for warnings + try: + # Determine which call variables to use based on provider type + if is_llm_provider: + # Use unified LLM call tracking + current_call_count = current_total_llm_calls + call_limit_value = ai_text_gen_limit + else: + # Use provider-specific call tracking + current_call_count = current_calls + call_limit_value = call_limit + + call_usage_pct = (current_call_count / max(call_limit_value, 1)) * 100 if call_limit_value > 0 else 0 + cost_usage_pct = (usage.total_cost / max(cost_limit, 1)) * 100 if cost_limit > 0 else 0 + result = (True, "Within limits", { + 'current_calls': current_call_count, + 'call_limit': call_limit_value, + 'call_usage_percentage': call_usage_pct, + 'current_cost': usage.total_cost, + 'cost_limit': cost_limit, + 'cost_usage_percentage': cost_usage_pct + }) + self.pricing_service._limits_cache[cache_key] = { + 'result': result, + 'expires_at': now + timedelta(seconds=30) + } + return result + except Exception as e: + logger.error(f"Error calculating usage percentages: {e}") + # Return basic success + return True, "Within limits", {} + + except Exception as e: + logger.error(f"Unexpected error in check_usage_limits for {user_id}: {e}") + # STRICT: Fail closed - deny requests if subscription system fails + return False, f"Subscription check error: {str(e)}", {} + + def check_comprehensive_limits( + self, + user_id: str, + operations: List[Dict[str, Any]] + ) -> Tuple[bool, Optional[str], Optional[Dict[str, Any]]]: + """ + Comprehensive pre-flight validation that checks ALL limits before making ANY API calls. + + This prevents wasteful API calls by validating that ALL subsequent operations will succeed + before making the first external API call. + + Args: + user_id: User ID + operations: List of operations to validate, each with: + - 'provider': APIProvider enum + - 'tokens_requested': int (estimated tokens for LLM calls, 0 for non-LLM) + - 'actual_provider_name': Optional[str] (e.g., "huggingface" when provider is MISTRAL) + - 'operation_type': str (e.g., "google_grounding", "llm_call", "image_generation") + + Returns: + (can_proceed, error_message, error_details) + If can_proceed is False, error_message explains which limit would be exceeded + """ + try: + logger.info(f"[Pre-flight Check] š Starting comprehensive validation for user {user_id}") + logger.info(f"[Pre-flight Check] š Validating {len(operations)} operation(s) before making any API calls") + + # Get current usage and limits once + current_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m") + + logger.info(f"[Pre-flight Check] š Billing Period: {current_period} (for user {user_id})") + + # Explicitly expire any cached objects and refresh from DB to ensure fresh data + self.db.expire_all() + + usage = self.db.query(UsageSummary).filter( + UsageSummary.user_id == user_id, + UsageSummary.billing_period == current_period + ).first() + + # CRITICAL: Explicitly refresh from database to get latest values (clears SQLAlchemy cache) + if usage: + self.db.refresh(usage) + + # Log what we actually read from database + if usage: + logger.info(f"[Pre-flight Check] š Usage Summary from DB (Period: {current_period}):") + logger.info(f" āā Gemini: {usage.gemini_tokens or 0} tokens / {usage.gemini_calls or 0} calls") + logger.info(f" āā Mistral/HF: {usage.mistral_tokens or 0} tokens / {usage.mistral_calls or 0} calls") + logger.info(f" āā Total Tokens: {usage.total_tokens or 0}") + logger.info(f" āā Usage Status: {usage.usage_status.value if usage.usage_status else 'N/A'}") + else: + logger.info(f"[Pre-flight Check] š No usage summary found for period {current_period} (will create new)") + + if not usage: + # First usage this period, create summary + try: + usage = UsageSummary( + user_id=user_id, + billing_period=current_period + ) + self.db.add(usage) + self.db.commit() + except Exception as create_error: + logger.error(f"Error creating usage summary: {create_error}") + self.db.rollback() + return False, f"Failed to create usage summary: {str(create_error)}", {} + + # Get user limits + limits_dict = self.pricing_service.get_user_limits(user_id) + if not limits_dict: + # No subscription found - check for free tier + free_plan = self.db.query(SubscriptionPlan).filter( + SubscriptionPlan.tier == SubscriptionTier.FREE, + SubscriptionPlan.is_active == True + ).first() + if free_plan: + limits_dict = self.pricing_service._plan_to_limits_dict(free_plan) + else: + return False, "No subscription plan found. Please subscribe to a plan.", {} + + limits = limits_dict.get('limits', {}) + + # Track cumulative usage across all operations + total_llm_calls = ( + (usage.gemini_calls or 0) + + (usage.openai_calls or 0) + + (usage.anthropic_calls or 0) + + (usage.mistral_calls or 0) + ) + total_llm_tokens = {} + total_images = usage.stability_calls or 0 + + # Log current usage summary + logger.info(f"[Pre-flight Check] š Current Usage Summary:") + logger.info(f" āā Total LLM Calls: {total_llm_calls}") + logger.info(f" āā Gemini Tokens: {usage.gemini_tokens or 0}, Mistral/HF Tokens: {usage.mistral_tokens or 0}") + logger.info(f" āā Image Calls: {total_images}") + + # Validate each operation + for op_idx, operation in enumerate(operations): + provider = operation.get('provider') + provider_name = provider.value if hasattr(provider, 'value') else str(provider) + tokens_requested = operation.get('tokens_requested', 0) + actual_provider_name = operation.get('actual_provider_name') + operation_type = operation.get('operation_type', 'unknown') + + display_provider_name = actual_provider_name or provider_name + + logger.error(f"[Pre-flight Check] ā Operation {op_idx + 1}/{len(operations)}: {operation_type}") + logger.error(f" āā Provider: {display_provider_name} (enum: {provider_name})") + logger.error(f" āā Operation Index: {op_idx}") + logger.error(f" āā Estimated Tokens Requested: {tokens_requested}") + + # Check if this is an LLM provider + llm_providers = ['gemini', 'openai', 'anthropic', 'mistral'] + is_llm_provider = provider_name in llm_providers + + # Check unified AI text generation limit for LLM providers + if is_llm_provider: + ai_text_gen_limit = limits.get('ai_text_generation_calls', 0) or 0 + if ai_text_gen_limit == 0: + # Fallback to provider-specific limit + ai_text_gen_limit = limits.get(f"{provider_name}_calls", 0) or 0 + + # Count this operation as an LLM call + projected_total_llm_calls = total_llm_calls + 1 + + if ai_text_gen_limit > 0 and projected_total_llm_calls > ai_text_gen_limit: + error_info = { + 'current_calls': total_llm_calls, + 'limit': ai_text_gen_limit, + 'provider': display_provider_name, + 'operation_type': operation_type, + 'operation_index': op_idx + } + return False, f"AI text generation call limit would be exceeded. Would use {projected_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls.", { + 'error_type': 'call_limit', + 'usage_info': error_info + } + + # Check token limits for this provider + # CRITICAL: Always query fresh from DB for each operation to avoid SQLAlchemy cache issues + # This ensures we get the latest values after subscription renewal, even for cumulative tracking + provider_tokens_key = f"{provider_name}_tokens" + + # Try to get fresh value from DB with comprehensive error handling + base_current_tokens = 0 + query_succeeded = False + + try: + # Validate column name is safe (only allow known provider token columns) + valid_token_columns = ['gemini_tokens', 'openai_tokens', 'anthropic_tokens', 'mistral_tokens'] + + if provider_tokens_key not in valid_token_columns: + logger.error(f" āā Invalid provider tokens key: {provider_tokens_key}") + query_succeeded = True # Treat as success with 0 value + else: + # Method 1: Try raw SQL query to completely bypass ORM cache + try: + logger.debug(f" āā Attempting raw SQL query for {provider_tokens_key}") + sql_query = text(f""" + SELECT {provider_tokens_key} + FROM usage_summaries + WHERE user_id = :user_id + AND billing_period = :period + LIMIT 1 + """) + + logger.debug(f" āā SQL: SELECT {provider_tokens_key} FROM usage_summaries WHERE user_id={user_id} AND billing_period={current_period}") + + result = self.db.execute(sql_query, { + 'user_id': user_id, + 'period': current_period + }).first() + + if result: + base_current_tokens = result[0] if result[0] is not None else 0 + logger.error(f"[Pre-flight Check] ā Raw SQL query returned result: {result[0]} -> {base_current_tokens}") + else: + base_current_tokens = 0 + logger.error(f"[Pre-flight Check] ā ļø Raw SQL query returned None (no rows found)") + + query_succeeded = True + logger.error(f"[Pre-flight Check] ā Raw SQL query succeeded for {provider_tokens_key}: {base_current_tokens}") + + except Exception as sql_error: + logger.error(f" āā Raw SQL query failed for {provider_tokens_key}: {type(sql_error).__name__}: {sql_error}", exc_info=True) + query_succeeded = False # Will try ORM fallback + + # Method 2: Fallback to fresh ORM query if raw SQL fails + if not query_succeeded: + try: + # Expire all cached objects and do fresh query + self.db.expire_all() + fresh_usage = self.db.query(UsageSummary).filter( + UsageSummary.user_id == user_id, + UsageSummary.billing_period == current_period + ).first() + + if fresh_usage: + # Explicitly refresh to get latest from DB + self.db.refresh(fresh_usage) + base_current_tokens = getattr(fresh_usage, provider_tokens_key, 0) or 0 + else: + base_current_tokens = 0 + + query_succeeded = True + logger.info(f"[Pre-flight Check] ā ORM fallback query succeeded for {provider_tokens_key}: {base_current_tokens}") + + except Exception as orm_error: + logger.error(f" āā ORM query also failed: {orm_error}", exc_info=True) + query_succeeded = False + + except Exception as e: + logger.error(f" āā Unexpected error getting tokens from DB for {provider_tokens_key}: {e}", exc_info=True) + base_current_tokens = 0 # Fail safe - assume 0 if we can't query + + if not query_succeeded: + logger.warning(f" āā Both query methods failed, using 0 as fallback") + + # CRITICAL LOG: Always log what we got from DB - this helps debug renewal issues + # Use ERROR level to ensure it shows even if INFO is filtered + logger.error(f"[Pre-flight Check] š Fresh DB Query for {display_provider_name}:") + logger.error(f" āā Column: {provider_tokens_key}") + logger.error(f" āā Billing Period: {current_period}") + logger.error(f" āā User ID: {user_id}") + logger.error(f" āā Method: {'Raw SQL' if query_succeeded and base_current_tokens >= 0 else 'ORM' if query_succeeded else 'Failed - using 0'}") + logger.error(f" āā Value from DB: {base_current_tokens}") + + # Add any projected tokens from previous operations in this validation run + # Note: total_llm_tokens tracks ONLY projected tokens from this run, not base DB value + projected_from_previous = total_llm_tokens.get(provider_tokens_key, 0) + + # Current tokens = base from DB + projected from previous operations in this run + current_provider_tokens = base_current_tokens + projected_from_previous + + # Use ERROR level to ensure visibility + logger.error(f"[Pre-flight Check] š Token Calculation for {display_provider_name}:") + logger.error(f" āā Base from DB (fresh query): {base_current_tokens}") + logger.error(f" āā Projected from previous ops in this run: {projected_from_previous}") + logger.error(f" āā Total current tokens (base + projected): {current_provider_tokens}") + + # Also check the initial usage object to see if it's being used incorrectly + if usage and hasattr(usage, provider_tokens_key): + initial_usage_value = getattr(usage, provider_tokens_key, 0) or 0 + logger.error(f" ā ļø Initial usage object value: {initial_usage_value} (this should NOT be used for fresh query)") + + token_limit = limits.get(provider_tokens_key, 0) or 0 + + if token_limit > 0 and tokens_requested > 0: + projected_tokens = current_provider_tokens + tokens_requested + logger.info(f" āā Token Check: {current_provider_tokens} (current) + {tokens_requested} (requested) = {projected_tokens} (total) / {token_limit} (limit)") + + if projected_tokens > token_limit: + usage_percentage = (projected_tokens / token_limit) * 100 if token_limit > 0 else 0 + error_info = { + 'current_tokens': current_provider_tokens, + 'base_tokens_from_db': base_current_tokens, + 'projected_from_previous_ops': projected_from_previous, + 'requested_tokens': tokens_requested, + 'limit': token_limit, + 'provider': display_provider_name, + 'operation_type': operation_type, + 'operation_index': op_idx + } + # Make error message clearer: show actual DB usage vs projected + if projected_from_previous > 0: + error_msg = ( + f"Token limit exceeded for {display_provider_name} " + f"({operation_type}). " + f"Base usage: {base_current_tokens}/{token_limit}, " + f"After previous operations in this workflow: {current_provider_tokens}/{token_limit}, " + f"This operation would add: {tokens_requested}, " + f"Total would be: {projected_tokens} (exceeds by {projected_tokens - token_limit} tokens)" + ) + else: + error_msg = ( + f"Token limit exceeded for {display_provider_name} " + f"({operation_type}). " + f"Current: {current_provider_tokens}/{token_limit}, " + f"Requested: {tokens_requested}, " + f"Would exceed by: {projected_tokens - token_limit} tokens " + f"({usage_percentage:.1f}% of limit)" + ) + logger.error(f"[Pre-flight Check] ā BLOCKED: {error_msg}") + return False, error_msg, { + 'error_type': 'token_limit', + 'usage_info': error_info + } + else: + logger.info(f" āā ā Token limit check passed: {projected_tokens} <= {token_limit}") + + # Update cumulative counts for next operation + total_llm_calls = projected_total_llm_calls + # Update cumulative projected tokens from this validation run + # This represents only projected tokens from previous operations in this run + # Base DB value is always queried fresh, so we only track the projection delta + old_projected = total_llm_tokens.get(provider_tokens_key, 0) + if tokens_requested > 0: + # Add this operation's tokens to cumulative projected tokens + total_llm_tokens[provider_tokens_key] = projected_from_previous + tokens_requested + logger.error(f"[Pre-flight Check] š Updated cumulative projected tokens for {display_provider_name}:") + logger.error(f" āā Previous projected: {projected_from_previous}") + logger.error(f" āā This operation requested: {tokens_requested}") + logger.error(f" āā New cumulative projected: {total_llm_tokens[provider_tokens_key]}") + logger.error(f" āā Old value in dict was: {old_projected}") + else: + # No tokens requested, keep existing projected tokens (or 0 if first operation) + total_llm_tokens[provider_tokens_key] = projected_from_previous + logger.error(f"[Pre-flight Check] š No tokens requested, keeping projected at: {projected_from_previous}") + + # Check image generation limits + elif provider == APIProvider.STABILITY: + image_limit = limits.get('stability_calls', 0) or 0 + projected_images = total_images + 1 + + if image_limit > 0 and projected_images > image_limit: + error_info = { + 'current_images': total_images, + 'limit': image_limit, + 'provider': 'stability', + 'operation_type': operation_type, + 'operation_index': op_idx + } + return False, f"Image generation limit would be exceeded. Would use {projected_images} of {image_limit} images this billing period.", { + 'error_type': 'image_limit', + 'usage_info': error_info + } + + total_images = projected_images + + # Check other provider-specific limits + else: + provider_calls_key = f"{provider_name}_calls" + current_provider_calls = getattr(usage, provider_calls_key, 0) or 0 + call_limit = limits.get(provider_calls_key, 0) or 0 + + if call_limit > 0: + projected_calls = current_provider_calls + 1 + if projected_calls > call_limit: + error_info = { + 'current_calls': current_provider_calls, + 'limit': call_limit, + 'provider': display_provider_name, + 'operation_type': operation_type, + 'operation_index': op_idx + } + return False, f"API call limit would be exceeded for {display_provider_name}. Would use {projected_calls} of {call_limit} calls this billing period.", { + 'error_type': 'call_limit', + 'usage_info': error_info + } + + # All checks passed + logger.info(f"[Pre-flight Check] ā All {len(operations)} operation(s) validated successfully") + logger.info(f"[Pre-flight Check] ā User {user_id} is cleared to proceed with API calls") + return True, None, None + + except Exception as e: + error_type = type(e).__name__ + error_message = str(e) + logger.error(f"[Pre-flight Check] ā Error during comprehensive limit check: {error_type}: {error_message}", exc_info=True) + logger.error(f"[Pre-flight Check] ā User: {user_id}, Operations count: {len(operations) if operations else 0}") + return False, f"Failed to validate limits: {error_type}: {error_message}", {} + diff --git a/backend/services/subscription/preflight_validator.py b/backend/services/subscription/preflight_validator.py index 0526cc3c..96f24f7d 100644 --- a/backend/services/subscription/preflight_validator.py +++ b/backend/services/subscription/preflight_validator.py @@ -44,15 +44,17 @@ def validate_research_operations( llm_provider_name = "gemini" # Estimate tokens for each operation in research workflow - # Google Grounding call: ~2000 tokens (input + output) + # Google Grounding call: ~1200 tokens (input: ~500 tokens, output: ~700 tokens for research results) # Keyword analyzer: ~1000 tokens (input: 3000 chars research, output: structured JSON) # Competitor analyzer: ~1000 tokens (input: 3000 chars research, output: structured JSON) # Content angle generator: ~1000 tokens (input: 3000 chars research, output: list of angles) + # Note: These are conservative estimates. Actual usage may be lower, but we use these for pre-flight validation + # to prevent wasteful API calls if the workflow would exceed limits. operations_to_validate = [ { 'provider': APIProvider.GEMINI, # Google Grounding uses Gemini - 'tokens_requested': 2000, + 'tokens_requested': 1200, # Reduced from 2000 to more realistic estimate 'actual_provider_name': 'gemini', 'operation_type': 'google_grounding' }, @@ -126,6 +128,120 @@ def validate_research_operations( ) +def validate_exa_research_operations( + pricing_service: PricingService, + user_id: str, + gpt_provider: str = "google" +) -> None: + """ + Validate all operations for an Exa research workflow before making ANY API calls. + + This prevents wasteful external API calls (like Exa search) if subsequent + LLM calls would fail due to token or call limits. + + Args: + pricing_service: PricingService instance + user_id: User ID for subscription checking + gpt_provider: GPT provider from env var (defaults to "google") + + Returns: + None + If validation fails, raises HTTPException with 429 status + """ + try: + # Determine actual provider for LLM calls based on GPT_PROVIDER env var + gpt_provider_lower = gpt_provider.lower() + if gpt_provider_lower == "huggingface": + llm_provider_enum = APIProvider.MISTRAL # Maps to HuggingFace + llm_provider_name = "huggingface" + else: + llm_provider_enum = APIProvider.GEMINI + llm_provider_name = "gemini" + + # Estimate tokens for each operation in Exa research workflow + # Exa Search call: 1 Exa API call (not token-based) + # Keyword analyzer: ~1000 tokens (input: research results, output: structured JSON) + # Competitor analyzer: ~1000 tokens (input: research results, output: structured JSON) + # Content angle generator: ~1000 tokens (input: research results, output: list of angles) + # Note: These are conservative estimates for pre-flight validation + + operations_to_validate = [ + { + 'provider': APIProvider.EXA, # Exa API call + 'tokens_requested': 0, + 'actual_provider_name': 'exa', + 'operation_type': 'exa_neural_search' + }, + { + 'provider': llm_provider_enum, + 'tokens_requested': 1000, + 'actual_provider_name': llm_provider_name, + 'operation_type': 'keyword_analysis' + }, + { + 'provider': llm_provider_enum, + 'tokens_requested': 1000, + 'actual_provider_name': llm_provider_name, + 'operation_type': 'competitor_analysis' + }, + { + 'provider': llm_provider_enum, + 'tokens_requested': 1000, + 'actual_provider_name': llm_provider_name, + 'operation_type': 'content_angle_generation' + } + ] + + logger.info(f"[Pre-flight Validator] š Starting Exa Research Workflow Validation") + logger.info(f" āā User: {user_id}") + logger.info(f" āā LLM Provider: {llm_provider_name} (GPT_PROVIDER={gpt_provider})") + logger.info(f" āā Operations to validate: {len(operations_to_validate)}") + + can_proceed, message, error_details = pricing_service.check_comprehensive_limits( + user_id=user_id, + operations=operations_to_validate + ) + + if not can_proceed: + usage_info = error_details.get('usage_info', {}) if error_details else {} + provider = usage_info.get('provider', llm_provider_name) if usage_info else llm_provider_name + operation_type = usage_info.get('operation_type', 'unknown') + + logger.error(f"[Pre-flight Validator] ā EXA RESEARCH WORKFLOW BLOCKED") + logger.error(f" āā User: {user_id}") + logger.error(f" āā Blocked at: {operation_type}") + logger.error(f" āā Provider: {provider}") + logger.error(f" āā Reason: {message}") + + # Raise HTTPException immediately - frontend gets immediate response, no API calls made + raise HTTPException( + status_code=429, + detail={ + 'error': message, + 'message': message, + 'provider': provider, + 'usage_info': usage_info if usage_info else error_details + } + ) + + logger.info(f"[Pre-flight Validator] ā EXA RESEARCH WORKFLOW APPROVED") + logger.info(f" āā User: {user_id}") + logger.info(f" āā All {len(operations_to_validate)} operations validated - proceeding with API calls") + # Validation passed - no return needed (function raises HTTPException if validation fails) + + except HTTPException: + raise + except Exception as e: + logger.error(f"[Pre-flight Validator] Error validating Exa research operations: {e}", exc_info=True) + raise HTTPException( + status_code=500, + detail={ + 'error': f"Failed to validate operations: {str(e)}", + 'message': f"Failed to validate operations: {str(e)}" + } + ) + + def validate_image_generation_operations( pricing_service: PricingService, user_id: str diff --git a/backend/services/subscription/pricing_service.py b/backend/services/subscription/pricing_service.py index 629759d5..a72912c5 100644 --- a/backend/services/subscription/pricing_service.py +++ b/backend/services/subscription/pricing_service.py @@ -258,6 +258,12 @@ class PricingService: "model_name": "stable-diffusion", "cost_per_image": 0.04, # $0.04 per image "description": "Stability AI Image Generation" + }, + { + "provider": APIProvider.EXA, + "model_name": "exa-search", + "cost_per_request": 0.005, # $0.005 per search (1-25 results) + "description": "Exa Neural Search API" } ] @@ -296,6 +302,7 @@ class PricingService: "metaphor_calls_limit": 10, "firecrawl_calls_limit": 10, "stability_calls_limit": 5, + "exa_calls_limit": 100, "gemini_tokens_limit": 100000, "monthly_cost_limit": 0.0, "features": ["basic_content_generation", "limited_research"], @@ -316,10 +323,11 @@ class PricingService: "metaphor_calls_limit": 100, "firecrawl_calls_limit": 100, "stability_calls_limit": 5, - "gemini_tokens_limit": 2000, - "openai_tokens_limit": 2000, - "anthropic_tokens_limit": 2000, - "mistral_tokens_limit": 2000, + "exa_calls_limit": 500, + "gemini_tokens_limit": 20000, # Increased from 5000 for better stability + "openai_tokens_limit": 20000, # Increased from 5000 for better stability + "anthropic_tokens_limit": 20000, # Increased from 5000 for better stability + "mistral_tokens_limit": 20000, # Increased from 5000 for better stability "monthly_cost_limit": 50.0, "features": ["full_content_generation", "advanced_research", "basic_analytics"], "description": "Great for individuals and small teams" @@ -338,6 +346,7 @@ class PricingService: "metaphor_calls_limit": 500, "firecrawl_calls_limit": 500, "stability_calls_limit": 200, + "exa_calls_limit": 2000, "gemini_tokens_limit": 5000000, "openai_tokens_limit": 2500000, "anthropic_tokens_limit": 1000000, @@ -360,6 +369,7 @@ class PricingService: "metaphor_calls_limit": 0, "firecrawl_calls_limit": 0, "stability_calls_limit": 0, + "exa_calls_limit": 0, # Unlimited "gemini_tokens_limit": 0, "openai_tokens_limit": 0, "anthropic_tokens_limit": 0, @@ -423,11 +433,14 @@ class PricingService: def get_user_limits(self, user_id: str) -> Optional[Dict[str, Any]]: """Get usage limits for a user based on their subscription.""" + # CRITICAL: Expire all objects first to ensure fresh data after renewal + self.db.expire_all() + subscription = self.db.query(UserSubscription).filter( UserSubscription.user_id == user_id, UserSubscription.is_active == True ).first() - + if not subscription: # Return free tier limits free_plan = self.db.query(SubscriptionPlan).filter( @@ -439,7 +452,23 @@ class PricingService: # Ensure current period before returning limits self._ensure_subscription_current(subscription) - return self._plan_to_limits_dict(subscription.plan) + + # CRITICAL: Refresh subscription to get latest plan_id, then refresh plan relationship + self.db.refresh(subscription) + + # Re-query plan directly to ensure fresh data (bypass relationship cache) + plan = self.db.query(SubscriptionPlan).filter( + SubscriptionPlan.id == subscription.plan_id + ).first() + + if not plan: + logger.error(f"Plan not found for subscription plan_id={subscription.plan_id}") + return None + + # Refresh plan to ensure fresh limits + self.db.refresh(plan) + + return self._plan_to_limits_dict(plan) def _ensure_ai_text_gen_column_detection(self) -> None: """Detect at runtime whether ai_text_generation_calls_limit column exists and cache the result.""" @@ -508,290 +537,20 @@ class PricingService: tokens_requested: int = 0, actual_provider_name: Optional[str] = None) -> Tuple[bool, str, Dict[str, Any]]: """Check if user can make an API call within their limits. + Delegates to LimitValidator for actual validation logic. + Args: user_id: User ID provider: APIProvider enum (may be MISTRAL for HuggingFace) tokens_requested: Estimated tokens for the request actual_provider_name: Optional actual provider name (e.g., "huggingface" when provider is MISTRAL) - """ - try: - # Use actual_provider_name if provided, otherwise use enum value - # This fixes cases where HuggingFace maps to MISTRAL enum but should show as "huggingface" in errors - display_provider_name = actual_provider_name or provider.value - - logger.debug(f"[Subscription Check] Starting limit check for user {user_id}, provider {display_provider_name}, tokens {tokens_requested}") - - # Short TTL cache to reduce DB reads under sustained traffic - cache_key = f"{user_id}:{provider.value}" - now = datetime.utcnow() - cached = self._limits_cache.get(cache_key) - if cached and cached.get('expires_at') and cached['expires_at'] > now: - logger.debug(f"[Subscription Check] Using cached result for {user_id}:{provider.value}") - return tuple(cached['result']) # type: ignore - - # Get user subscription first to check expiration - subscription = self.db.query(UserSubscription).filter( - UserSubscription.user_id == user_id, - UserSubscription.is_active == True - ).first() - - if subscription: - logger.debug(f"[Subscription Check] Found subscription for user {user_id}: plan_id={subscription.plan_id}, period_end={subscription.current_period_end}") - else: - logger.debug(f"[Subscription Check] No active subscription found for user {user_id}") - - # Check subscription expiration (STRICT: deny if expired) - if subscription: - if subscription.current_period_end < now: - logger.warning(f"[Subscription Check] Subscription expired for user {user_id}: period_end={subscription.current_period_end}, now={now}") - # Subscription expired - check if auto_renew is enabled - if not getattr(subscription, 'auto_renew', False): - # Expired and no auto-renew - deny access - logger.warning(f"[Subscription Check] Subscription expired for user {user_id}, auto_renew=False, denying access") - result = (False, "Subscription expired. Please renew your subscription to continue using the service.", { - 'expired': True, - 'period_end': subscription.current_period_end.isoformat() - }) - self._limits_cache[cache_key] = { - 'result': result, - 'expires_at': now + timedelta(seconds=30) - } - return result - else: - # Try to auto-renew - if not self._ensure_subscription_current(subscription): - # Auto-renew failed - deny access - result = (False, "Subscription expired and auto-renewal failed. Please renew manually.", { - 'expired': True, - 'auto_renew_failed': True - }) - self._limits_cache[cache_key] = { - 'result': result, - 'expires_at': now + timedelta(seconds=30) - } - return result - - # Get user limits with error handling (STRICT: fail on errors) - try: - limits = self.get_user_limits(user_id) - if limits: - logger.debug(f"[Subscription Check] Retrieved limits for user {user_id}: plan={limits.get('plan_name')}, tier={limits.get('tier')}") - else: - logger.debug(f"[Subscription Check] No limits found for user {user_id}, checking free tier") - except Exception as e: - logger.error(f"[Subscription Check] Error getting user limits for {user_id}: {e}", exc_info=True) - # STRICT: Fail closed - deny request if we can't check limits - return False, f"Failed to retrieve subscription limits: {str(e)}", {} - - if not limits: - # No subscription found - check for free tier - free_plan = self.db.query(SubscriptionPlan).filter( - SubscriptionPlan.tier == SubscriptionTier.FREE, - SubscriptionPlan.is_active == True - ).first() - if free_plan: - logger.info(f"[Subscription Check] Assigning free tier to user {user_id}") - limits = self._plan_to_limits_dict(free_plan) - else: - # No subscription and no free tier - deny access - logger.warning(f"[Subscription Check] No subscription or free tier found for user {user_id}, denying access") - return False, "No subscription plan found. Please subscribe to a plan.", {} - - # Get current usage for this billing period with error handling - try: - current_period = self.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m") - usage = self.db.query(UsageSummary).filter( - UsageSummary.user_id == user_id, - UsageSummary.billing_period == current_period - ).first() - - if not usage: - # First usage this period, create summary - try: - usage = UsageSummary( - user_id=user_id, - billing_period=current_period - ) - self.db.add(usage) - self.db.commit() - except Exception as create_error: - logger.error(f"Error creating usage summary: {create_error}") - self.db.rollback() - # STRICT: Fail closed on DB error - return False, f"Failed to create usage summary: {str(create_error)}", {} - except Exception as e: - logger.error(f"Error getting usage summary for {user_id}: {e}") - self.db.rollback() - # STRICT: Fail closed on DB error - return False, f"Failed to retrieve usage summary: {str(e)}", {} - - # Check call limits with error handling - # NOTE: call_limit = 0 means UNLIMITED (Enterprise plans) - try: - # Use display_provider_name for error messages, but provider.value for DB queries - provider_name = provider.value # For DB field names (e.g., "mistral_calls", "mistral_tokens") - - # For LLM text generation providers, check against unified total_calls limit - llm_providers = ['gemini', 'openai', 'anthropic', 'mistral'] - is_llm_provider = provider_name in llm_providers - - if is_llm_provider: - # Use unified AI text generation limit (total_calls across all LLM providers) - ai_text_gen_limit = limits['limits'].get('ai_text_generation_calls', 0) or 0 - - # If unified limit not set, fall back to provider-specific limit for backwards compatibility - if ai_text_gen_limit == 0: - ai_text_gen_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0 - - # Calculate total LLM provider calls (sum of gemini + openai + anthropic + mistral) - current_total_llm_calls = ( - (usage.gemini_calls or 0) + - (usage.openai_calls or 0) + - (usage.anthropic_calls or 0) + - (usage.mistral_calls or 0) - ) - - # Only enforce limit if limit > 0 (0 means unlimited for Enterprise) - if ai_text_gen_limit > 0 and current_total_llm_calls >= ai_text_gen_limit: - logger.error(f"[Subscription Check] AI text generation call limit exceeded for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit} (provider: {display_provider_name})") - result = (False, f"AI text generation call limit reached. Used {current_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls this billing period.", { - 'current_calls': current_total_llm_calls, - 'limit': ai_text_gen_limit, - 'usage_percentage': (current_total_llm_calls / ai_text_gen_limit) * 100 if ai_text_gen_limit > 0 else 0, - 'provider': display_provider_name, # Use display name for consistency - 'usage_info': { - 'provider': display_provider_name, # Use display name for user-facing info - 'current_calls': current_total_llm_calls, - 'limit': ai_text_gen_limit, - 'type': 'ai_text_generation', - 'breakdown': { - 'gemini': usage.gemini_calls or 0, - 'openai': usage.openai_calls or 0, - 'anthropic': usage.anthropic_calls or 0, - 'mistral': usage.mistral_calls or 0 # DB field name (not display name) - } - } - }) - self._limits_cache[cache_key] = { - 'result': result, - 'expires_at': now + timedelta(seconds=30) - } - return result - else: - logger.debug(f"[Subscription Check] AI text generation limit check passed for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit if ai_text_gen_limit > 0 else 'unlimited'} (provider: {display_provider_name})") - else: - # For non-LLM providers, check provider-specific limit - current_calls = getattr(usage, f"{provider_name}_calls", 0) or 0 - call_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0 - - # Only enforce limit if limit > 0 (0 means unlimited for Enterprise) - if call_limit > 0 and current_calls >= call_limit: - logger.error(f"[Subscription Check] Call limit exceeded for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit}") - result = (False, f"API call limit reached for {display_provider_name}. Used {current_calls} of {call_limit} calls this billing period.", { - 'current_calls': current_calls, - 'limit': call_limit, - 'usage_percentage': 100.0, - 'provider': display_provider_name # Use display name for consistency - }) - self._limits_cache[cache_key] = { - 'result': result, - 'expires_at': now + timedelta(seconds=30) - } - return result - else: - logger.debug(f"[Subscription Check] Call limit check passed for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit if call_limit > 0 else 'unlimited'}") - except Exception as e: - logger.error(f"Error checking call limits: {e}") - # Continue to next check - - # Check token limits for LLM providers with error handling - # NOTE: token_limit = 0 means UNLIMITED (Enterprise plans) - try: - if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]: - current_tokens = getattr(usage, f"{provider_name}_tokens", 0) or 0 - token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0 - - # Only enforce limit if limit > 0 (0 means unlimited for Enterprise) - if token_limit > 0 and (current_tokens + tokens_requested) > token_limit: - result = (False, f"Token limit would be exceeded for {display_provider_name}. Current: {current_tokens}, Requested: {tokens_requested}, Limit: {token_limit}", { - 'current_tokens': current_tokens, - 'requested_tokens': tokens_requested, - 'limit': token_limit, - 'usage_percentage': ((current_tokens + tokens_requested) / token_limit) * 100, - 'provider': display_provider_name, # Use display name in error details - 'usage_info': { - 'provider': display_provider_name, - 'current_tokens': current_tokens, - 'requested_tokens': tokens_requested, - 'limit': token_limit, - 'type': 'tokens' - } - }) - self._limits_cache[cache_key] = { - 'result': result, - 'expires_at': now + timedelta(seconds=30) - } - return result - except Exception as e: - logger.error(f"Error checking token limits: {e}") - # Continue to next check - - # Check cost limits with error handling - # NOTE: cost_limit = 0 means UNLIMITED (Enterprise plans) - try: - cost_limit = limits['limits'].get('monthly_cost', 0) or 0 - # Only enforce limit if limit > 0 (0 means unlimited for Enterprise) - if cost_limit > 0 and usage.total_cost >= cost_limit: - result = (False, f"Monthly cost limit reached. Current cost: ${usage.total_cost:.2f}, Limit: ${cost_limit:.2f}", { - 'current_cost': usage.total_cost, - 'limit': cost_limit, - 'usage_percentage': 100.0 - }) - self._limits_cache[cache_key] = { - 'result': result, - 'expires_at': now + timedelta(seconds=30) - } - return result - except Exception as e: - logger.error(f"Error checking cost limits: {e}") - # Continue to success case - - # Calculate usage percentages for warnings - try: - # Determine which call variables to use based on provider type - if is_llm_provider: - # Use unified LLM call tracking - current_call_count = current_total_llm_calls - call_limit_value = ai_text_gen_limit - else: - # Use provider-specific call tracking - current_call_count = current_calls - call_limit_value = call_limit - - call_usage_pct = (current_call_count / max(call_limit_value, 1)) * 100 if call_limit_value > 0 else 0 - cost_usage_pct = (usage.total_cost / max(cost_limit, 1)) * 100 if cost_limit > 0 else 0 - result = (True, "Within limits", { - 'current_calls': current_call_count, - 'call_limit': call_limit_value, - 'call_usage_percentage': call_usage_pct, - 'current_cost': usage.total_cost, - 'cost_limit': cost_limit, - 'cost_usage_percentage': cost_usage_pct - }) - self._limits_cache[cache_key] = { - 'result': result, - 'expires_at': now + timedelta(seconds=30) - } - return result - except Exception as e: - logger.error(f"Error calculating usage percentages: {e}") - # Return basic success - return True, "Within limits", {} - except Exception as e: - logger.error(f"Unexpected error in check_usage_limits for {user_id}: {e}") - # STRICT: Fail closed - deny requests if subscription system fails - return False, f"Subscription check error: {str(e)}", {} + Returns: + (can_proceed, error_message, usage_info) + """ + from .limit_validation import LimitValidator + validator = LimitValidator(self) + return validator.check_usage_limits(user_id, provider, tokens_requested, actual_provider_name) def estimate_tokens(self, text: str, provider: APIProvider) -> int: """Estimate token count for text based on provider.""" @@ -827,6 +586,16 @@ class PricingService: if not pricing: return None + # Return pricing info as dict + return { + 'provider': pricing.provider.value, + 'model_name': pricing.model_name, + 'cost_per_input_token': pricing.cost_per_input_token, + 'cost_per_output_token': pricing.cost_per_output_token, + 'cost_per_request': pricing.cost_per_request, + 'description': pricing.description + } + def check_comprehensive_limits( self, user_id: str, @@ -835,6 +604,7 @@ class PricingService: """ Comprehensive pre-flight validation that checks ALL limits before making ANY API calls. + Delegates to LimitValidator for actual validation logic. This prevents wasteful API calls by validating that ALL subsequent operations will succeed before making the first external API call. @@ -850,202 +620,9 @@ class PricingService: (can_proceed, error_message, error_details) If can_proceed is False, error_message explains which limit would be exceeded """ - try: - logger.info(f"[Pre-flight Check] š Starting comprehensive validation for user {user_id}") - logger.info(f"[Pre-flight Check] š Validating {len(operations)} operation(s) before making any API calls") - - # Get current usage and limits once - current_period = self.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m") - usage = self.db.query(UsageSummary).filter( - UsageSummary.user_id == user_id, - UsageSummary.billing_period == current_period - ).first() - - if not usage: - # First usage this period, create summary - try: - usage = UsageSummary( - user_id=user_id, - billing_period=current_period - ) - self.db.add(usage) - self.db.commit() - except Exception as create_error: - logger.error(f"Error creating usage summary: {create_error}") - self.db.rollback() - return False, f"Failed to create usage summary: {str(create_error)}", {} - - # Get user limits - limits_dict = self.get_user_limits(user_id) - if not limits_dict: - # No subscription found - check for free tier - free_plan = self.db.query(SubscriptionPlan).filter( - SubscriptionPlan.tier == SubscriptionTier.FREE, - SubscriptionPlan.is_active == True - ).first() - if free_plan: - limits_dict = self._plan_to_limits_dict(free_plan) - else: - return False, "No subscription plan found. Please subscribe to a plan.", {} - - limits = limits_dict.get('limits', {}) - - # Track cumulative usage across all operations - total_llm_calls = ( - (usage.gemini_calls or 0) + - (usage.openai_calls or 0) + - (usage.anthropic_calls or 0) + - (usage.mistral_calls or 0) - ) - total_llm_tokens = {} - total_images = usage.stability_calls or 0 - - # Log current usage summary - logger.info(f"[Pre-flight Check] š Current Usage Summary:") - logger.info(f" āā Total LLM Calls: {total_llm_calls}") - logger.info(f" āā Gemini Tokens: {usage.gemini_tokens or 0}, Mistral/HF Tokens: {usage.mistral_tokens or 0}") - logger.info(f" āā Image Calls: {total_images}") - - # Validate each operation - for op_idx, operation in enumerate(operations): - provider = operation.get('provider') - provider_name = provider.value if hasattr(provider, 'value') else str(provider) - tokens_requested = operation.get('tokens_requested', 0) - actual_provider_name = operation.get('actual_provider_name') - operation_type = operation.get('operation_type', 'unknown') - - display_provider_name = actual_provider_name or provider_name - - logger.info(f"[Pre-flight Check] ā Operation {op_idx + 1}/{len(operations)}: {operation_type}") - logger.info(f" āā Provider: {display_provider_name} (enum: {provider_name})") - logger.info(f" āā Estimated Tokens: {tokens_requested}") - - # Check if this is an LLM provider - llm_providers = ['gemini', 'openai', 'anthropic', 'mistral'] - is_llm_provider = provider_name in llm_providers - - # Check unified AI text generation limit for LLM providers - if is_llm_provider: - ai_text_gen_limit = limits.get('ai_text_generation_calls', 0) or 0 - if ai_text_gen_limit == 0: - # Fallback to provider-specific limit - ai_text_gen_limit = limits.get(f"{provider_name}_calls", 0) or 0 - - # Count this operation as an LLM call - projected_total_llm_calls = total_llm_calls + 1 - - if ai_text_gen_limit > 0 and projected_total_llm_calls > ai_text_gen_limit: - error_info = { - 'current_calls': total_llm_calls, - 'limit': ai_text_gen_limit, - 'provider': display_provider_name, - 'operation_type': operation_type, - 'operation_index': op_idx - } - return False, f"AI text generation call limit would be exceeded. Would use {projected_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls.", { - 'error_type': 'call_limit', - 'usage_info': error_info - } - - # Check token limits for this provider - # Use cumulative projected tokens from previous operations, or current from DB if first operation - provider_tokens_key = f"{provider_name}_tokens" - if provider_tokens_key in total_llm_tokens: - # Use cumulative projected tokens from previous operations - current_provider_tokens = total_llm_tokens[provider_tokens_key] - logger.info(f" āā Using cumulative projected tokens: {current_provider_tokens}") - else: - # First operation for this provider - get current from database - current_provider_tokens = getattr(usage, provider_tokens_key, 0) or 0 - total_llm_tokens[provider_tokens_key] = current_provider_tokens - logger.info(f" āā Current tokens from DB: {current_provider_tokens}") - - token_limit = limits.get(provider_tokens_key, 0) or 0 - - if token_limit > 0 and tokens_requested > 0: - projected_tokens = current_provider_tokens + tokens_requested - logger.info(f" āā Token Check: {current_provider_tokens} (current) + {tokens_requested} (requested) = {projected_tokens} (total) / {token_limit} (limit)") - - if projected_tokens > token_limit: - usage_percentage = (projected_tokens / token_limit) * 100 if token_limit > 0 else 0 - error_info = { - 'current_tokens': current_provider_tokens, - 'requested_tokens': tokens_requested, - 'limit': token_limit, - 'provider': display_provider_name, - 'operation_type': operation_type, - 'operation_index': op_idx - } - error_msg = ( - f"Token limit exceeded for {display_provider_name} " - f"({operation_type}). " - f"Current: {current_provider_tokens}/{token_limit}, " - f"Requested: {tokens_requested}, " - f"Would exceed by: {projected_tokens - token_limit} tokens " - f"({usage_percentage:.1f}% of limit)" - ) - logger.error(f"[Pre-flight Check] ā BLOCKED: {error_msg}") - return False, error_msg, { - 'error_type': 'token_limit', - 'usage_info': error_info - } - else: - logger.info(f" āā ā Token limit check passed: {projected_tokens} <= {token_limit}") - - # Update cumulative counts for next operation - total_llm_calls = projected_total_llm_calls - total_llm_tokens[provider_tokens_key] += tokens_requested - logger.info(f" āā Updated cumulative tokens for {display_provider_name}: {total_llm_tokens[provider_tokens_key]}") - - # Check image generation limits - elif provider == APIProvider.STABILITY: - image_limit = limits.get('stability_calls', 0) or 0 - projected_images = total_images + 1 - - if image_limit > 0 and projected_images > image_limit: - error_info = { - 'current_images': total_images, - 'limit': image_limit, - 'provider': 'stability', - 'operation_type': operation_type, - 'operation_index': op_idx - } - return False, f"Image generation limit would be exceeded. Would use {projected_images} of {image_limit} images this billing period.", { - 'error_type': 'image_limit', - 'usage_info': error_info - } - - total_images = projected_images - - # Check other provider-specific limits - else: - provider_calls_key = f"{provider_name}_calls" - current_provider_calls = getattr(usage, provider_calls_key, 0) or 0 - call_limit = limits.get(provider_calls_key, 0) or 0 - - if call_limit > 0: - projected_calls = current_provider_calls + 1 - if projected_calls > call_limit: - error_info = { - 'current_calls': current_provider_calls, - 'limit': call_limit, - 'provider': display_provider_name, - 'operation_type': operation_type, - 'operation_index': op_idx - } - return False, f"API call limit would be exceeded for {display_provider_name}. Would use {projected_calls} of {call_limit} calls this billing period.", { - 'error_type': 'call_limit', - 'usage_info': error_info - } - - # All checks passed - logger.info(f"[Pre-flight Check] ā All {len(operations)} operation(s) validated successfully") - logger.info(f"[Pre-flight Check] ā User {user_id} is cleared to proceed with API calls") - return True, None, None - - except Exception as e: - logger.error(f"[Pre-flight Check] Error during comprehensive limit check: {e}", exc_info=True) - return False, f"Failed to validate limits: {str(e)}", {} + from .limit_validation import LimitValidator + validator = LimitValidator(self) + return validator.check_comprehensive_limits(user_id, operations) def get_pricing_for_provider_model(self, provider: APIProvider, model_name: str) -> Optional[Dict[str, Any]]: """Get pricing configuration for a specific provider and model.""" diff --git a/backend/services/subscription/schema_utils.py b/backend/services/subscription/schema_utils.py new file mode 100644 index 00000000..ac5bdcbc --- /dev/null +++ b/backend/services/subscription/schema_utils.py @@ -0,0 +1,39 @@ +from typing import Set +from sqlalchemy.orm import Session + + +_checked_subscription_plan_columns: bool = False + + +def ensure_subscription_plan_columns(db: Session) -> None: + """Ensure required columns exist on subscription_plans for runtime safety. + + This is a defensive guard for environments where migrations have not yet + been applied. If columns are missing (e.g., exa_calls_limit), we add them + with a safe default so ORM queries do not fail. + """ + global _checked_subscription_plan_columns + if _checked_subscription_plan_columns: + return + + try: + # Discover existing columns + result = db.execute("PRAGMA table_info(subscription_plans)") + cols: Set[str] = {row[1] for row in result} + + # Columns we may reference in models but might be missing in older DBs + required_columns = { + "exa_calls_limit": "INTEGER DEFAULT 0", + } + + for col_name, ddl in required_columns.items(): + if col_name not in cols: + db.execute(f"ALTER TABLE subscription_plans ADD COLUMN {col_name} {ddl}") + db.commit() + except Exception: + # Do not block app if pragma/alter fails; let normal errors surface + db.rollback() + finally: + _checked_subscription_plan_columns = True + + diff --git a/backend/services/wix_service.py b/backend/services/wix_service.py index 237d1238..29761455 100644 --- a/backend/services/wix_service.py +++ b/backend/services/wix_service.py @@ -7,6 +7,7 @@ Handles authentication, permission checking, and blog publishing to Wix websites import os import json import requests +import uuid from typing import Dict, Any, Optional, List from loguru import logger from datetime import datetime, timedelta @@ -19,6 +20,9 @@ from services.integrations.wix.media import WixMediaService from services.integrations.wix.utils import extract_meta_from_token, normalize_token_string, extract_member_id_from_access_token as utils_extract_member from services.integrations.wix.content import convert_content_to_ricos as ricos_builder from services.integrations.wix.auth import WixAuthService +from services.integrations.wix.seo import build_seo_data +from services.integrations.wix.ricos_converter import markdown_to_html, convert_via_wix_api +from services.integrations.wix.blog_publisher import create_blog_post as publish_blog_post class WixService: """Service for interacting with Wix APIs""" @@ -237,13 +241,35 @@ class WixService: logger.error(f"Failed to import image to Wix: {e}") raise - def convert_content_to_ricos(self, content: str, images: List[str] = None) -> Dict[str, Any]: + def convert_content_to_ricos(self, content: str, images: List[str] = None, + use_wix_api: bool = False, access_token: str = None) -> Dict[str, Any]: + """ + Convert markdown content to Ricos JSON format. + + Args: + content: Markdown content to convert + images: Optional list of image URLs + use_wix_api: If True, use Wix's official Ricos Documents API (requires access_token) + access_token: Wix access token (required if use_wix_api=True) + + Returns: + Ricos JSON document + """ + if use_wix_api and access_token: + try: + return convert_via_wix_api(content, access_token, self.base_url) + except Exception as e: + logger.warning(f"Failed to convert via Wix API, falling back to custom parser: {e}") + # Fall back to custom parser + + # Use custom parser (current implementation) return ricos_builder(content, images) + def create_blog_post(self, access_token: str, title: str, content: str, cover_image_url: str = None, category_ids: List[str] = None, tag_ids: List[str] = None, publish: bool = True, - member_id: str = None) -> Dict[str, Any]: + member_id: str = None, seo_metadata: Dict[str, Any] = None) -> Dict[str, Any]: """ Create and optionally publish a blog post on Wix @@ -256,101 +282,33 @@ class WixService: tag_ids: Optional list of tag IDs publish: Whether to publish immediately or save as draft member_id: Required for third-party apps - the member ID of the post author + seo_metadata: Optional SEO metadata dict with fields like: + - seo_title: SEO optimized title + - meta_description: Meta description + - focus_keyword: Main keyword + - blog_tags: List of tag strings (for keywords) + - open_graph: Open Graph data + - canonical_url: Canonical URL Returns: Created blog post information """ - if not member_id: - raise ValueError("memberId is required for third-party apps creating blog posts") - - headers = { - 'Authorization': f'Bearer {access_token}', - 'Content-Type': 'application/json' - } - - # Build valid Ricos rich content (minimum: one paragraph with text) - ricos_content = self.convert_content_to_ricos(content or "This is a post from ALwrity.", None) - - # Minimal payload per Wix docs: title, memberId, and richContent - blog_data = { - 'draftPost': { - 'title': title, - 'memberId': member_id, # Required for third-party apps - 'richContent': ricos_content, - 'excerpt': (content or '').strip()[:200] - }, - 'publish': publish, - 'fieldsets': ['URL'] # Simplified fieldsets - } - - # Add cover image if provided - if cover_image_url: - try: - media_id = self.import_image_to_wix(access_token, cover_image_url, f'Cover: {title}') - blog_data['draftPost']['media'] = { - 'wixMedia': { - 'image': {'id': media_id} - }, - 'displayed': True, - 'custom': True - } - except Exception as e: - logger.warning(f"Failed to import cover image: {e}") - - # Add categories if provided - if category_ids: - blog_data['draftPost']['categoryIds'] = category_ids - - # Add tags if provided - if tag_ids: - blog_data['draftPost']['tagIds'] = tag_ids - - try: - # Check what permissions we have in the token - logger.info("DEBUG: Checking token permissions...") - try: - import jwt - # Extract token string manually since _normalize_access_token doesn't exist - token_str = str(access_token) - if token_str and token_str.startswith('OauthNG.JWS.'): - jwt_part = token_str[12:] - payload = jwt.decode(jwt_part, options={"verify_signature": False, "verify_aud": False}) - logger.info(f"DEBUG: Full token payload: {payload}") - - # Check for permissions in various possible locations - data_payload = payload.get('data', {}) - if isinstance(data_payload, str): - try: - data_payload = json.loads(data_payload) - except: - pass - - instance_data = data_payload.get('instance', {}) - permissions = instance_data.get('permissions', '') - scopes = instance_data.get('scopes', []) - meta_site_id = instance_data.get('metaSiteId') - if isinstance(meta_site_id, str) and meta_site_id: - headers['wix-site-id'] = meta_site_id - logger.info(f"DEBUG: Added wix-site-id header: {meta_site_id}") - logger.info(f"DEBUG: Token permissions: {permissions}") - logger.info(f"DEBUG: Token scopes: {scopes}") - else: - logger.info("DEBUG: Could not decode token for permission check") - except Exception as perm_e: - logger.warning(f"DEBUG: Failed to check permissions: {perm_e}") - - logger.info(f"DEBUG: Sending simplified blog data: {json.dumps(blog_data, indent=2)}") - extra_headers = {} - if 'wix-site-id' in headers: - extra_headers['wix-site-id'] = headers['wix-site-id'] - result = self.blog_service.create_draft_post(access_token, blog_data, extra_headers or None) - logger.info(f"DEBUG: Create draft result: {result}") - return result - except requests.RequestException as e: - logger.error(f"Failed to create blog post: {e}") - if hasattr(e, 'response') and e.response is not None: - logger.error(f"Response body: {e.response.text}") - raise + return publish_blog_post( + blog_service=self.blog_service, + access_token=access_token, + title=title, + content=content, + member_id=member_id, + cover_image_url=cover_image_url, + category_ids=category_ids, + tag_ids=tag_ids, + publish=publish, + seo_metadata=seo_metadata, + import_image_func=self.import_image_to_wix, + lookup_categories_func=self.lookup_or_create_categories, + lookup_tags_func=self.lookup_or_create_tags, + base_url=self.base_url + ) def get_blog_categories(self, access_token: str) -> List[Dict[str, Any]]: """ @@ -383,6 +341,138 @@ class WixService: except requests.RequestException as e: logger.error(f"Failed to get blog tags: {e}") raise + + def lookup_or_create_categories(self, access_token: str, category_names: List[str], + extra_headers: Optional[Dict[str, str]] = None) -> List[str]: + """ + Lookup existing categories by name or create new ones, return their IDs. + + Args: + access_token: Valid access token + category_names: List of category name strings + extra_headers: Optional extra headers (e.g., wix-site-id) + + Returns: + List of category UUIDs + """ + if not category_names: + return [] + + try: + # Get existing categories + existing_categories = self.blog_service.list_categories(access_token, extra_headers) + + # Create name -> ID mapping (case-insensitive) + category_map = {} + for cat in existing_categories: + cat_label = cat.get('label', '').strip() + cat_id = cat.get('id') + if cat_label and cat_id: + category_map[cat_label.lower()] = cat_id + + category_ids = [] + for category_name in category_names: + category_name_clean = str(category_name).strip() + if not category_name_clean: + continue + + # Lookup existing category (case-insensitive) + category_id = category_map.get(category_name_clean.lower()) + + if not category_id: + # Create new category + try: + logger.info(f"Creating new category: {category_name_clean}") + result = self.blog_service.create_category( + access_token, + label=category_name_clean, + extra_headers=extra_headers + ) + new_category = result.get('category', {}) + category_id = new_category.get('id') + if category_id: + category_ids.append(category_id) + # Update map to avoid duplicate creates + category_map[category_name_clean.lower()] = category_id + logger.info(f"Created category '{category_name_clean}' with ID: {category_id}") + except Exception as create_error: + logger.warning(f"Failed to create category '{category_name_clean}': {create_error}") + # Continue with other categories + else: + category_ids.append(category_id) + logger.info(f"Found existing category '{category_name_clean}' with ID: {category_id}") + + return category_ids + + except requests.RequestException as e: + logger.error(f"Failed to lookup/create categories: {e}") + return [] + + def lookup_or_create_tags(self, access_token: str, tag_names: List[str], + extra_headers: Optional[Dict[str, str]] = None) -> List[str]: + """ + Lookup existing tags by name or create new ones, return their IDs. + + Args: + access_token: Valid access token + tag_names: List of tag name strings + extra_headers: Optional extra headers (e.g., wix-site-id) + + Returns: + List of tag UUIDs + """ + if not tag_names: + return [] + + try: + # Get existing tags + existing_tags = self.blog_service.list_tags(access_token, extra_headers) + + # Create name -> ID mapping (case-insensitive) + tag_map = {} + for tag in existing_tags: + tag_label = tag.get('label', '').strip() + tag_id = tag.get('id') + if tag_label and tag_id: + tag_map[tag_label.lower()] = tag_id + + tag_ids = [] + for tag_name in tag_names: + tag_name_clean = str(tag_name).strip() + if not tag_name_clean: + continue + + # Lookup existing tag (case-insensitive) + tag_id = tag_map.get(tag_name_clean.lower()) + + if not tag_id: + # Create new tag + try: + logger.info(f"Creating new tag: {tag_name_clean}") + result = self.blog_service.create_tag( + access_token, + label=tag_name_clean, + extra_headers=extra_headers + ) + new_tag = result.get('tag', {}) + tag_id = new_tag.get('id') + if tag_id: + tag_ids.append(tag_id) + # Update map to avoid duplicate creates + tag_map[tag_name_clean.lower()] = tag_id + logger.info(f"Created tag '{tag_name_clean}' with ID: {tag_id}") + except Exception as create_error: + logger.warning(f"Failed to create tag '{tag_name_clean}': {create_error}") + # Continue with other tags + else: + tag_ids.append(tag_id) + logger.info(f"Found existing tag '{tag_name_clean}' with ID: {tag_id}") + + return tag_ids + + except requests.RequestException as e: + logger.error(f"Failed to lookup/create tags: {e}") + return [] def publish_draft_post(self, access_token: str, draft_post_id: str) -> Dict[str, Any]: """ diff --git a/docs/RESEARCH_COMPONENT_INTEGRATION.md b/docs/RESEARCH_COMPONENT_INTEGRATION.md new file mode 100644 index 00000000..8467fd2d --- /dev/null +++ b/docs/RESEARCH_COMPONENT_INTEGRATION.md @@ -0,0 +1,335 @@ +# Research Component Integration Guide + +## Overview + +The modular Research component has been implemented as a standalone, testable wizard that can be integrated into the blog writer or used independently. This document outlines the architecture, usage, and integration steps. + +## Architecture + +### Backend Strategy Pattern + +The research service now supports multiple research modes through a strategy pattern: + +```python +# Research modes +- Basic: Quick keyword-focused analysis +- Comprehensive: Full analysis with all components +- Targeted: Customizable components based on config + +# Strategy implementation +backend/services/blog_writer/research/research_strategies.py +- ResearchStrategy (base class) +- BasicResearchStrategy +- ComprehensiveResearchStrategy +- TargetedResearchStrategy +``` + +### Frontend Component Structure + +``` +frontend/src/components/Research/ +āāā index.tsx # Main exports +āāā ResearchWizard.tsx # Main wizard container +āāā steps/ +ā āāā StepKeyword.tsx # Step 1: Keyword input +ā āāā StepOptions.tsx # Step 2: Mode selection +ā āāā StepProgress.tsx # Step 3: Progress display +ā āāā StepResults.tsx # Step 4: Results display +āāā hooks/ +ā āāā useResearchWizard.ts # Wizard state management +ā āāā useResearchExecution.ts # API calls and polling +āāā types/ +ā āāā research.types.ts # TypeScript interfaces +āāā utils/ + āāā researchUtils.ts # Utility functions +``` + +## Test Page + +A dedicated test page is available at `/research-test` for testing the research wizard independently. + +**Features:** +- Quick preset keywords for testing +- Debug panel with JSON export +- Performance metrics display +- Cache state visualization + +## Usage + +### Standalone Usage + +```typescript +import { ResearchWizard } from '../components/Research'; + +Use the copilot to begin researching your blog topic.
-Use the copilot to begin researching your blog topic.
+Review and confirm your outline before generating content.
-Review and confirm your outline before generating content.
++ Generate full content for all sections in your confirmed outline. +
+ + : null} + sx={{ + minWidth: 200, + py: 1.5, + px: 4, + }} + > + {isGenerating ? 'Generating Content...' : 'š Generate Content'} + + + {error && ( ++ {error} +
+ )} ++ Generate an AI-powered outline based on your research. +
+ + : null} + sx={{ + minWidth: 200, + py: 1.5, + px: 4, + }} + > + {isGenerating ? 'Generating Outline...' : 'š§© Generate Outline'} + + + {error && ( ++ {error} +
+ )} ++ What keywords and information would you like to use for your research? Please also specify the desired length of the blog post. +
+ ++ Step {wizard.state.currentStep} of {wizard.maxSteps} +
++ Enter your keywords, industry, and target audience to start research. +
+ + {/* Keywords Input */} ++ Separate multiple keywords with commas +
++ Select the type of research that best fits your needs. +
+ ++ {card.description} +
++ Gathering insights from Google Search grounding +
+ + {/* Status Display */} +{error}
++ {isExecuting ? 'Analyzing sources and generating insights...' : 'Finalizing results...'} +
+ > + )} +No results available
++ Test the modular research wizard component +
+
+ {JSON.stringify(results, null, 2)}
+
+