From de4328175dc91d6f92e31be75a1e44b15dc89f19 Mon Sep 17 00:00:00 2001 From: ajaysi Date: Sat, 1 Nov 2025 18:01:14 +0530 Subject: [PATCH] Subscription dashboard improvements, AI text generation limit, and other fixes. --- backend/api/blog_writer/router.py | 115 ++- backend/api/blog_writer/task_manager.py | 28 +- .../api/content_planning/monitoring_routes.py | 155 +++- backend/api/images.py | 109 ++- backend/api/subscription_api.py | 96 ++- backend/api/wix_routes.py | 33 + backend/app.py | 9 + .../add_user_id_to_task_execution_logs.sql | 20 + backend/models/linkedin_models.py | 12 +- backend/models/monitoring_models.py | 3 +- backend/models/subscription_models.py | 18 +- backend/requirements.txt | 3 + .../add_ai_text_generation_limit_column.py | 146 ++++ backend/scripts/cap_basic_plan_usage.py | 210 ++++++ backend/scripts/reset_basic_plan_usage.py | 168 +++++ backend/scripts/update_basic_plan_limits.py | 279 ++++++++ backend/services/active_strategy_service.py | 52 ++ .../blog_writer/core/blog_writer_service.py | 23 +- backend/services/blog_writer/logger_config.py | 9 +- .../research/competitor_analyzer.py | 5 +- .../research/content_angle_generator.py | 5 +- .../blog_writer/research/keyword_analyzer.py | 5 +- .../blog_writer/research/research_service.py | 59 +- .../seo/blog_content_seo_analyzer.py | 15 +- .../seo/blog_seo_metadata_generator.py | 26 +- .../seo/blog_seo_recommendation_applier.py | 6 +- .../llm_providers/gemini_grounded_provider.py | 43 +- .../llm_providers/main_image_generation.py | 33 +- .../llm_providers/main_text_generation.py | 297 +++++++- backend/services/monitoring_data_service.py | 85 +++ backend/services/scheduler/__init__.py | 59 ++ backend/services/scheduler/core/__init__.py | 4 + .../scheduler/core/exception_handler.py | 395 +++++++++++ .../scheduler/core/executor_interface.py | 75 ++ backend/services/scheduler/core/scheduler.py | 628 ++++++++++++++++ .../services/scheduler/core/task_registry.py | 59 ++ .../services/scheduler/executors/__init__.py | 4 + .../executors/monitoring_task_executor.py | 266 +++++++ backend/services/scheduler/utils/__init__.py | 4 + .../scheduler/utils/frequency_calculator.py | 33 + .../services/scheduler/utils/task_loader.py | 60 ++ .../subscription/preflight_validator.py | 189 +++++ .../services/subscription/pricing_service.py | 668 +++++++++++++++--- .../subscription/usage_tracking_service.py | 47 +- frontend/src/App.tsx | 105 ++- frontend/src/api/client.ts | 84 ++- .../src/components/BlogWriter/BlogWriter.tsx | 4 + .../BlogWriterUtils/PhaseContent.tsx | 10 +- .../BlogWriterUtils/WixConnectModal.tsx | 168 +++++ .../BlogWriter/EnhancedOutlineEditor.tsx | 11 +- .../src/components/BlogWriter/Publisher.tsx | 276 +++++++- .../BlogWriter/SuggestionsGenerator.tsx | 36 +- .../BlogWriter/WYSIWYG/BlogEditor.tsx | 34 +- .../BlogWriter/WYSIWYG/BlogSection.tsx | 29 +- .../components/SystemStatusIndicator.tsx | 49 +- .../components/ImageGen/useImageGeneration.ts | 16 +- .../src/components/Pricing/PricingPage.tsx | 154 +++- .../components/SubscriptionExpiredModal.tsx | 186 ++++- .../src/components/shared/DashboardHeader.tsx | 18 +- frontend/src/contexts/SubscriptionContext.tsx | 207 ++++-- frontend/src/hooks/useBlogWriterState.ts | 5 + frontend/src/hooks/usePolling.ts | 70 ++ frontend/src/services/blogWriterApi.ts | 35 +- frontend/src/utils/wixTokenUtils.ts | 198 ++++++ 64 files changed, 5809 insertions(+), 444 deletions(-) create mode 100644 backend/database/migrations/add_user_id_to_task_execution_logs.sql create mode 100644 backend/scripts/add_ai_text_generation_limit_column.py create mode 100644 backend/scripts/cap_basic_plan_usage.py create mode 100644 backend/scripts/reset_basic_plan_usage.py create mode 100644 backend/scripts/update_basic_plan_limits.py create mode 100644 backend/services/scheduler/__init__.py create mode 100644 backend/services/scheduler/core/__init__.py create mode 100644 backend/services/scheduler/core/exception_handler.py create mode 100644 backend/services/scheduler/core/executor_interface.py create mode 100644 backend/services/scheduler/core/scheduler.py create mode 100644 backend/services/scheduler/core/task_registry.py create mode 100644 backend/services/scheduler/executors/__init__.py create mode 100644 backend/services/scheduler/executors/monitoring_task_executor.py create mode 100644 backend/services/scheduler/utils/__init__.py create mode 100644 backend/services/scheduler/utils/frequency_calculator.py create mode 100644 backend/services/scheduler/utils/task_loader.py create mode 100644 backend/services/subscription/preflight_validator.py create mode 100644 frontend/src/components/BlogWriter/BlogWriterUtils/WixConnectModal.tsx create mode 100644 frontend/src/utils/wixTokenUtils.ts diff --git a/backend/api/blog_writer/router.py b/backend/api/blog_writer/router.py index b6093538..0d3b60fe 100644 --- a/backend/api/blog_writer/router.py +++ b/backend/api/blog_writer/router.py @@ -5,10 +5,11 @@ Main router for blog writing operations including research, outline generation, content creation, SEO analysis, and publishing. """ -from fastapi import APIRouter, HTTPException -from typing import Any, Dict, List +from fastapi import APIRouter, HTTPException, Depends +from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field from loguru import logger +from middleware.auth_middleware import get_current_user from models.blog_models import ( BlogResearchRequest, @@ -64,10 +65,21 @@ class SEOApplyRecommendationsRequest(BaseModel): @router.post("/seo/apply-recommendations") -async def apply_seo_recommendations(request: SEOApplyRecommendationsRequest) -> Dict[str, Any]: +async def apply_seo_recommendations( + request: SEOApplyRecommendationsRequest, + current_user: Dict[str, Any] = Depends(get_current_user) +) -> Dict[str, Any]: """Apply actionable SEO recommendations and return updated content.""" try: - result = await recommendation_applier.apply_recommendations(request.dict()) + # Extract Clerk user ID (required) + if not current_user: + raise HTTPException(status_code=401, detail="Authentication required") + + user_id = str(current_user.get('id', '')) + if not user_id: + raise HTTPException(status_code=401, detail="Invalid user ID in authentication token") + + result = await recommendation_applier.apply_recommendations(request.dict(), user_id=user_id) if not result.get("success"): raise HTTPException(status_code=500, detail=result.get("error", "Failed to apply recommendations")) return result @@ -87,13 +99,24 @@ async def health() -> Dict[str, Any]: # Research Endpoints @router.post("/research/start") -async def start_research(request: BlogResearchRequest) -> Dict[str, Any]: +async def start_research( + request: BlogResearchRequest, + current_user: Dict[str, Any] = Depends(get_current_user) +) -> Dict[str, Any]: """Start a research operation and return a task ID for polling.""" try: - # TODO: Get user_id from authentication context - user_id = "anonymous" # This should come from auth middleware + # Extract Clerk user ID (required) + if not current_user: + raise HTTPException(status_code=401, detail="Authentication required") + + user_id = str(current_user.get('id', '')) + if not user_id: + raise HTTPException(status_code=401, detail="Invalid user ID in authentication token") + task_id = await task_manager.start_research_task(request, user_id) return {"task_id": task_id, "status": "started"} + except HTTPException: + raise except Exception as e: logger.error(f"Failed to start research: {e}") raise HTTPException(status_code=500, detail=str(e)) @@ -107,6 +130,50 @@ async def get_research_status(task_id: str) -> Dict[str, Any]: if status is None: raise HTTPException(status_code=404, detail="Task not found") + # If task failed with subscription error, return HTTP error so frontend interceptor can catch it + if status.get('status') == 'failed' and status.get('error_status') in [429, 402]: + error_data = status.get('error_data', {}) or {} + error_status = status.get('error_status', 429) + + if not isinstance(error_data, dict): + logger.warning(f"Research task {task_id} error_data not dict: {error_data}") + error_data = {'error': str(error_data)} + + # Determine provider and usage info + stored_error_message = status.get('error', error_data.get('error')) + provider = error_data.get('provider', 'unknown') + usage_info = error_data.get('usage_info') + + if not usage_info: + usage_info = { + 'provider': provider, + 'message': stored_error_message, + 'error_type': error_data.get('error_type', 'unknown') + } + # Include any known fields from error_data + for key in ['current_tokens', 'requested_tokens', 'limit', 'current_calls']: + if key in error_data: + usage_info[key] = error_data[key] + + # Build error message for detail + error_msg = error_data.get('message', stored_error_message or 'Subscription limit exceeded') + + # Log the subscription error with all context + logger.warning(f"Research task {task_id} failed with subscription error {error_status}: {error_msg}") + logger.warning(f" Provider: {provider}, Usage Info: {usage_info}") + + # Use JSONResponse to ensure detail is returned as-is, not wrapped in an array + from fastapi.responses import JSONResponse + return JSONResponse( + status_code=error_status, + content={ + 'error': error_data.get('error', stored_error_message or 'Subscription limit exceeded'), + 'message': error_msg, + 'provider': provider, + 'usage_info': usage_info + } + ) + logger.info(f"Research status request for {task_id}: {status['status']} with {len(status.get('progress_messages', []))} progress messages") return status except HTTPException: @@ -310,20 +377,46 @@ async def hallucination_check(request: HallucinationCheckRequest) -> Hallucinati # SEO Endpoints @router.post("/seo/analyze", response_model=BlogSEOAnalyzeResponse) -async def seo_analyze(request: BlogSEOAnalyzeRequest) -> BlogSEOAnalyzeResponse: +async def seo_analyze( + request: BlogSEOAnalyzeRequest, + current_user: Dict[str, Any] = Depends(get_current_user) +) -> BlogSEOAnalyzeResponse: """Analyze content for SEO optimization opportunities.""" try: - return await service.seo_analyze(request) + # Extract Clerk user ID (required) + if not current_user: + raise HTTPException(status_code=401, detail="Authentication required") + + user_id = str(current_user.get('id', '')) + if not user_id: + raise HTTPException(status_code=401, detail="Invalid user ID in authentication token") + + return await service.seo_analyze(request, user_id=user_id) + except HTTPException: + raise except Exception as e: logger.error(f"Failed to perform SEO analysis: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/seo/metadata", response_model=BlogSEOMetadataResponse) -async def seo_metadata(request: BlogSEOMetadataRequest) -> BlogSEOMetadataResponse: +async def seo_metadata( + request: BlogSEOMetadataRequest, + current_user: Dict[str, Any] = Depends(get_current_user) +) -> BlogSEOMetadataResponse: """Generate SEO metadata for the blog post.""" try: - return await service.seo_metadata(request) + # Extract Clerk user ID (required) + if not current_user: + raise HTTPException(status_code=401, detail="Authentication required") + + user_id = str(current_user.get('id', '')) + if not user_id: + raise HTTPException(status_code=401, detail="Invalid user ID in authentication token") + + return await service.seo_metadata(request, user_id=user_id) + except HTTPException: + raise except Exception as e: logger.error(f"Failed to generate SEO metadata: {e}") raise HTTPException(status_code=500, detail=str(e)) diff --git a/backend/api/blog_writer/task_manager.py b/backend/api/blog_writer/task_manager.py index 9c393367..acbc15f0 100644 --- a/backend/api/blog_writer/task_manager.py +++ b/backend/api/blog_writer/task_manager.py @@ -10,6 +10,7 @@ import asyncio import uuid from datetime import datetime from typing import Any, Dict, List +from fastapi import HTTPException from loguru import logger from models.blog_models import ( @@ -85,6 +86,10 @@ class TaskManager: response["result"] = task["result"] elif task["status"] == "failed": response["error"] = task["error"] + if "error_status" in task: + response["error_status"] = task["error_status"] + if "error_data" in task: + response["error_data"] = task["error_data"] return response @@ -109,14 +114,17 @@ class TaskManager: logger.info(f"Progress update for task {task_id}: {message}") - async def start_research_task(self, request: BlogResearchRequest, user_id: str = "anonymous") -> str: + async def start_research_task(self, request: BlogResearchRequest, user_id: str) -> str: """Start a research operation and return a task ID.""" if self.use_database: return await self.db_manager.start_research_task(request, user_id) else: task_id = self.create_task("research") + # Store user_id in task for subscription checks + if task_id in self.task_storage: + self.task_storage[task_id]["user_id"] = user_id # Start the research operation in the background - asyncio.create_task(self._run_research_task(task_id, request)) + asyncio.create_task(self._run_research_task(task_id, request, user_id)) return task_id def start_outline_task(self, request: BlogOutlineRequest) -> str: @@ -144,7 +152,7 @@ class TaskManager: asyncio.create_task(self._run_medium_generation_task(task_id, request)) return task_id - async def _run_research_task(self, task_id: str, request: BlogResearchRequest): + async def _run_research_task(self, task_id: str, request: BlogResearchRequest, user_id: str): """Background task to run research and update status with progress messages.""" try: # Update status to running @@ -157,8 +165,8 @@ class TaskManager: # Check cache first await self.update_progress(task_id, "📋 Checking cache for existing research...") - # Run the actual research with progress updates - result = await self.service.research_with_progress(request, task_id) + # Run the actual research with progress updates (pass user_id for subscription checks) + result = await self.service.research_with_progress(request, task_id, user_id) # Check if research failed gracefully if not result.success: @@ -171,6 +179,16 @@ class TaskManager: self.task_storage[task_id]["status"] = "completed" self.task_storage[task_id]["result"] = result.dict() + except HTTPException as http_error: + # Handle HTTPException (e.g., 429 subscription limit) - preserve error details for frontend + error_detail = http_error.detail + error_message = error_detail.get('message', str(error_detail)) if isinstance(error_detail, dict) else str(error_detail) + await self.update_progress(task_id, f"❌ {error_message}") + self.task_storage[task_id]["status"] = "failed" + self.task_storage[task_id]["error"] = error_message + # Store HTTP error details for frontend modal + self.task_storage[task_id]["error_status"] = http_error.status_code + self.task_storage[task_id]["error_data"] = error_detail if isinstance(error_detail, dict) else {"error": str(error_detail)} except Exception as e: await self.update_progress(task_id, f"❌ Research failed with error: {str(e)}") # Update status to failed diff --git a/backend/api/content_planning/monitoring_routes.py b/backend/api/content_planning/monitoring_routes.py index 9495f9bd..956bda27 100644 --- a/backend/api/content_planning/monitoring_routes.py +++ b/backend/api/content_planning/monitoring_routes.py @@ -1,5 +1,5 @@ from fastapi import APIRouter, HTTPException, Depends, Query, Body -from typing import Dict, Any +from typing import Dict, Any, Optional import logging from datetime import datetime, timedelta from sqlalchemy.orm import Session @@ -64,6 +64,15 @@ async def activate_strategy_with_monitoring( if not monitoring_success: logger.warning(f"Failed to save monitoring data for strategy {strategy_id}") + # Trigger scheduler interval adjustment (scheduler will check more frequently now) + try: + from services.scheduler import get_scheduler + scheduler = get_scheduler() + await scheduler.trigger_interval_adjustment() + logger.info(f"Triggered scheduler interval adjustment after strategy {strategy_id} activation") + except Exception as e: + logger.warning(f"Could not trigger scheduler interval adjustment: {e}") + logger.info(f"Successfully activated strategy {strategy_id} with monitoring") return { "success": True, @@ -396,6 +405,150 @@ async def get_monitoring_tasks( logger.error(f"Error retrieving monitoring tasks: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") +@router.get("/user/{user_id}/monitoring-tasks") +async def get_user_monitoring_tasks( + user_id: int, + db: Session = Depends(get_db), + status: Optional[str] = Query(None, description="Filter by task status"), + limit: int = Query(50, description="Maximum number of tasks to return"), + offset: int = Query(0, description="Number of tasks to skip") +): + """ + Get all monitoring tasks for a specific user with their execution status. + + Uses the scheduler's task loader to get tasks filtered by user_id for proper user isolation. + """ + try: + logger.info(f"Getting monitoring tasks for user {user_id}") + + # Use scheduler task loader for user-specific tasks + from services.scheduler.utils.task_loader import load_due_monitoring_tasks + + # Load all tasks for user (not just due tasks - we want all user tasks) + # Join with strategy to filter by user + tasks_query = db.query(MonitoringTask).join( + EnhancedContentStrategy, + MonitoringTask.strategy_id == EnhancedContentStrategy.id + ).filter( + EnhancedContentStrategy.user_id == user_id + ) + + # Apply status filter if provided + if status: + tasks_query = tasks_query.filter(MonitoringTask.status == status) + + # Get tasks with pagination + tasks = tasks_query.order_by(desc(MonitoringTask.created_at)).offset(offset).limit(limit).all() + + tasks_data = [] + for task in tasks: + # Get latest execution log + latest_log = db.query(TaskExecutionLog).filter( + TaskExecutionLog.task_id == task.id + ).order_by(desc(TaskExecutionLog.execution_date)).first() + + # Get strategy info + strategy = db.query(EnhancedContentStrategy).filter( + EnhancedContentStrategy.id == task.strategy_id + ).first() + + task_data = { + "id": task.id, + "strategy_id": task.strategy_id, + "strategy_name": strategy.name if strategy else None, + "title": task.task_title, + "description": task.task_description, + "assignee": task.assignee, + "frequency": task.frequency, + "metric": task.metric, + "measurementMethod": task.measurement_method, + "successCriteria": task.success_criteria, + "alertThreshold": task.alert_threshold, + "status": task.status, + "lastExecuted": latest_log.execution_date.isoformat() if latest_log else None, + "nextExecution": task.next_execution.isoformat() if task.next_execution else None, + "executionCount": db.query(TaskExecutionLog).filter( + TaskExecutionLog.task_id == task.id + ).count(), + "created_at": task.created_at.isoformat() if task.created_at else None + } + tasks_data.append(task_data) + + # Get total count for pagination + total_count = db.query(MonitoringTask).join( + EnhancedContentStrategy, + MonitoringTask.strategy_id == EnhancedContentStrategy.id + ).filter( + EnhancedContentStrategy.user_id == user_id + ) + if status: + total_count = total_count.filter(MonitoringTask.status == status) + total_count = total_count.count() + + return { + "success": True, + "data": tasks_data, + "pagination": { + "total": total_count, + "limit": limit, + "offset": offset, + "has_more": (offset + len(tasks_data)) < total_count + }, + "message": f"Retrieved {len(tasks_data)} monitoring tasks for user {user_id}" + } + + except Exception as e: + logger.error(f"Error retrieving user monitoring tasks: {str(e)}") + raise HTTPException(status_code=500, detail=f"Failed to retrieve monitoring tasks: {str(e)}") + +@router.get("/user/{user_id}/execution-logs") +async def get_user_execution_logs( + user_id: int, + db: Session = Depends(get_db), + status: Optional[str] = Query(None, description="Filter by execution status"), + limit: int = Query(50, description="Maximum number of logs to return"), + offset: int = Query(0, description="Number of logs to skip") +): + """ + Get execution logs for a specific user. + + Provides user isolation by filtering execution logs by user_id. + """ + try: + logger.info(f"Getting execution logs for user {user_id}") + + monitoring_service = MonitoringDataService(db) + logs_data = monitoring_service.get_user_execution_logs( + user_id=user_id, + limit=limit, + offset=offset, + status_filter=status + ) + + # Get total count for pagination + count_query = db.query(TaskExecutionLog).filter( + TaskExecutionLog.user_id == user_id + ) + if status: + count_query = count_query.filter(TaskExecutionLog.status == status) + total_count = count_query.count() + + return { + "success": True, + "data": logs_data, + "pagination": { + "total": total_count, + "limit": limit, + "offset": offset, + "has_more": (offset + len(logs_data)) < total_count + }, + "message": f"Retrieved {len(logs_data)} execution logs for user {user_id}" + } + + except Exception as e: + logger.error(f"Error retrieving execution logs for user {user_id}: {str(e)}") + raise HTTPException(status_code=500, detail=f"Failed to retrieve execution logs: {str(e)}") + @router.get("/{strategy_id}/data-freshness") async def get_data_freshness( strategy_id: int, diff --git a/backend/api/images.py b/backend/api/images.py index 36c9be01..d6095ec3 100644 --- a/backend/api/images.py +++ b/backend/api/images.py @@ -3,13 +3,18 @@ from __future__ import annotations import base64 import os from typing import Optional, Dict, Any +from datetime import datetime -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, Depends from pydantic import BaseModel, Field from services.llm_providers.main_image_generation import generate_image from services.llm_providers.main_text_generation import llm_text_gen from utils.logger_utils import get_service_logger +from middleware.auth_middleware import get_current_user +from services.database import get_db +from services.subscription import UsageTrackingService, PricingService +from models.subscription_models import APIProvider, UsageSummary router = APIRouter(prefix="/api/images", tags=["images"]) @@ -39,9 +44,23 @@ class ImageGenerateResponse(BaseModel): @router.post("/generate", response_model=ImageGenerateResponse) -def generate(req: ImageGenerateRequest) -> ImageGenerateResponse: +def generate( + req: ImageGenerateRequest, + current_user: Dict[str, Any] = Depends(get_current_user) +) -> ImageGenerateResponse: + """Generate image with subscription checking.""" try: + # Extract Clerk user ID (required) + if not current_user: + raise HTTPException(status_code=401, detail="Authentication required") + + user_id = str(current_user.get('id', '')) + if not user_id: + raise HTTPException(status_code=401, detail="Invalid user ID in authentication token") + + # Validation is now handled inside generate_image function last_error: Optional[Exception] = None + result = None for attempt in range(2): # simple single retry try: result = generate_image( @@ -56,8 +75,79 @@ def generate(req: ImageGenerateRequest) -> ImageGenerateResponse: "steps": req.steps, "seed": req.seed, }, + user_id=user_id, # Pass user_id for validation inside generate_image ) image_b64 = base64.b64encode(result.image_bytes).decode("utf-8") + + # TRACK USAGE after successful image generation + if result: + logger.info(f"[images.generate] ✅ Image generation successful, tracking usage for user {user_id}") + try: + db_track = next(get_db()) + try: + # Get or create usage summary + pricing = PricingService(db_track) + current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m") + + logger.debug(f"[images.generate] Looking for usage summary: user_id={user_id}, period={current_period}") + + summary = db_track.query(UsageSummary).filter( + UsageSummary.user_id == user_id, + UsageSummary.billing_period == current_period + ).first() + + if not summary: + logger.info(f"[images.generate] Creating new usage summary for user {user_id}, period {current_period}") + summary = UsageSummary( + user_id=user_id, + billing_period=current_period + ) + db_track.add(summary) + db_track.flush() # Ensure summary is persisted before updating + + # Get "before" state for unified log + current_calls_before = getattr(summary, "stability_calls", 0) or 0 + + # Update provider-specific counters (stability for image generation) + # Note: All image generation goes through STABILITY provider enum regardless of actual provider + new_calls = current_calls_before + 1 + setattr(summary, "stability_calls", new_calls) + logger.debug(f"[images.generate] Updated stability_calls: {current_calls_before} -> {new_calls}") + + # Update totals + old_total_calls = summary.total_calls or 0 + summary.total_calls = old_total_calls + 1 + logger.debug(f"[images.generate] Updated totals: calls {old_total_calls} -> {summary.total_calls}") + + # Get plan details for unified log + limits = pricing.get_user_limits(user_id) + plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown' + tier = limits.get('tier', 'unknown') if limits else 'unknown' + call_limit = limits['limits'].get("stability_calls", 0) if limits else 0 + + db_track.commit() + logger.info(f"[images.generate] ✅ Successfully tracked usage: user {user_id} -> stability -> {new_calls} calls") + + # UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message + print(f""" +[SUBSCRIPTION] Image Generation +├─ User: {user_id} +├─ Plan: {plan_name} ({tier}) +├─ Provider: stability +├─ Actual Provider: {result.provider} +├─ Model: {result.model or 'default'} +├─ Calls: {current_calls_before} → {new_calls} / {call_limit if call_limit > 0 else '∞'} +└─ Status: ✅ Allowed & Tracked +""") + except Exception as track_error: + logger.error(f"[images.generate] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True) + db_track.rollback() + finally: + db_track.close() + except Exception as usage_error: + # Non-blocking: log error but don't fail the request + logger.error(f"[images.generate] ❌ Failed to track usage: {usage_error}", exc_info=True) + return ImageGenerateResponse( image_base64=image_b64, width=result.width, @@ -106,7 +196,10 @@ class ImagePromptSuggestResponse(BaseModel): @router.post("/suggest-prompts", response_model=ImagePromptSuggestResponse) -def suggest_prompts(req: ImagePromptSuggestRequest) -> ImagePromptSuggestResponse: +def suggest_prompts( + req: ImagePromptSuggestRequest, + current_user: Dict[str, Any] = Depends(get_current_user) +) -> ImagePromptSuggestResponse: try: provider = (req.provider or ("gemini" if (os.getenv("GPT_PROVIDER") or "").lower().startswith("gemini") else "huggingface")).lower() section = req.section or {} @@ -203,7 +296,15 @@ def suggest_prompts(req: ImagePromptSuggestRequest) -> ImagePromptSuggestRespons If including on-image text, return it in overlay_text (short: <= 8 words). """ - raw = llm_text_gen(prompt=prompt, system_prompt=system, json_struct=schema) + # Get user_id for llm_text_gen subscription check (required) + if not current_user: + raise HTTPException(status_code=401, detail="Authentication required") + + user_id_for_llm = str(current_user.get('id', '')) + if not user_id_for_llm: + raise HTTPException(status_code=401, detail="Invalid user ID in authentication token") + + raw = llm_text_gen(prompt=prompt, system_prompt=system, json_struct=schema, user_id=user_id_for_llm) data = raw if isinstance(raw, dict) else {} suggestions = data.get("suggestions") or [] # basic fallback if provider returns string diff --git a/backend/api/subscription_api.py b/backend/api/subscription_api.py index 2cd2b704..074a77d6 100644 --- a/backend/api/subscription_api.py +++ b/backend/api/subscription_api.py @@ -94,6 +94,7 @@ async def get_subscription_plans( "description": plan.description, "features": plan.features or [], "limits": { + "ai_text_generation_calls": getattr(plan, 'ai_text_generation_calls_limit', None) or 0, "gemini_calls": plan.gemini_calls_limit, "openai_calls": plan.openai_calls_limit, "anthropic_calls": plan.anthropic_calls_limit, @@ -162,6 +163,7 @@ async def get_user_subscription( }, "status": "free", "limits": { + "ai_text_generation_calls": getattr(free_plan, 'ai_text_generation_calls_limit', None) or 0, "gemini_calls": free_plan.gemini_calls_limit, "openai_calls": free_plan.openai_calls_limit, "anthropic_calls": free_plan.anthropic_calls_limit, @@ -200,6 +202,7 @@ async def get_user_subscription( "is_free": False }, "limits": { + "ai_text_generation_calls": getattr(subscription.plan, 'ai_text_generation_calls_limit', None) or 0, "gemini_calls": subscription.plan.gemini_calls_limit, "openai_calls": subscription.plan.openai_calls_limit, "anthropic_calls": subscription.plan.anthropic_calls_limit, @@ -252,6 +255,7 @@ async def get_subscription_status( "tier": "free", "can_use_api": True, "limits": { + "ai_text_generation_calls": getattr(free_plan, 'ai_text_generation_calls_limit', None) or 0, "gemini_calls": free_plan.gemini_calls_limit, "openai_calls": free_plan.openai_calls_limit, "anthropic_calls": free_plan.anthropic_calls_limit, @@ -309,6 +313,7 @@ async def get_subscription_status( "tier": subscription.plan.tier.value, "can_use_api": True, "limits": { + "ai_text_generation_calls": getattr(subscription.plan, 'ai_text_generation_calls_limit', None) or 0, "gemini_calls": subscription.plan.gemini_calls_limit, "openai_calls": subscription.plan.openai_calls_limit, "anthropic_calls": subscription.plan.anthropic_calls_limit, @@ -331,9 +336,14 @@ async def get_subscription_status( async def subscribe_to_plan( user_id: str, subscription_data: dict, - db: Session = Depends(get_db) + db: Session = Depends(get_db), + current_user: Dict[str, Any] = Depends(get_current_user) ) -> Dict[str, Any]: - """Create or update a user's subscription.""" + """Create or update a user's subscription (renewal).""" + + # Verify user can only subscribe/renew their own subscription + if current_user.get('id') != user_id: + raise HTTPException(status_code=403, detail="Access denied") try: plan_id = subscription_data.get('plan_id') @@ -388,12 +398,75 @@ async def subscribe_to_plan( db.commit() + # Get current usage BEFORE reset for logging + current_period = datetime.utcnow().strftime("%Y-%m") + usage_before = db.query(UsageSummary).filter( + UsageSummary.user_id == user_id, + UsageSummary.billing_period == current_period + ).first() + + # Log renewal request details + logger.info("=" * 80) + logger.info(f"[SUBSCRIPTION RENEWAL] 🔄 Processing renewal request") + logger.info(f" ├─ User: {user_id}") + logger.info(f" ├─ Plan: {plan.name} (ID: {plan_id}, Tier: {plan.tier.value})") + logger.info(f" ├─ Billing Cycle: {billing_cycle}") + logger.info(f" ├─ Period Start: {now.strftime('%Y-%m-%d %H:%M:%S')}") + logger.info(f" └─ Period End: {subscription.current_period_end.strftime('%Y-%m-%d %H:%M:%S')}") + + if usage_before: + logger.info(f" 📊 Current Usage BEFORE Reset (Period: {current_period}):") + logger.info(f" ├─ Gemini: {usage_before.gemini_tokens or 0} tokens / {usage_before.gemini_calls or 0} calls") + logger.info(f" ├─ Mistral/HF: {usage_before.mistral_tokens or 0} tokens / {usage_before.mistral_calls or 0} calls") + logger.info(f" ├─ OpenAI: {usage_before.openai_tokens or 0} tokens / {usage_before.openai_calls or 0} calls") + logger.info(f" ├─ Stability (Images): {usage_before.stability_calls or 0} calls") + logger.info(f" ├─ Total Tokens: {usage_before.total_tokens or 0}") + logger.info(f" ├─ Total Calls: {usage_before.total_calls or 0}") + logger.info(f" └─ Usage Status: {usage_before.usage_status.value}") + else: + logger.info(f" 📊 No usage summary found for period {current_period} (will be created on reset)") + + # Clear subscription limits cache to force refresh on next check + try: + from services.subscription import PricingService + # Clear cache for this specific user (class-level cache shared across all instances) + cleared_count = PricingService.clear_user_cache(user_id) + logger.info(f" 🗑️ Cleared {cleared_count} subscription cache entries for user {user_id}") + except Exception as cache_err: + logger.error(f" ❌ Failed to clear cache after subscribe: {cache_err}") + # Reset usage status for current billing period so new plan takes effect immediately + reset_result = None try: usage_service = UsageTrackingService(db) - await usage_service.reset_current_billing_period(user_id) + reset_result = await usage_service.reset_current_billing_period(user_id) + + # Re-query usage summary from DB after reset to get fresh data + usage_after = db.query(UsageSummary).filter( + UsageSummary.user_id == user_id, + UsageSummary.billing_period == current_period + ).first() + + if reset_result.get('reset'): + logger.info(f" ✅ Usage counters RESET successfully") + if usage_after: + logger.info(f" 📊 New Usage AFTER Reset:") + logger.info(f" ├─ Gemini: {usage_after.gemini_tokens or 0} tokens / {usage_after.gemini_calls or 0} calls") + logger.info(f" ├─ Mistral/HF: {usage_after.mistral_tokens or 0} tokens / {usage_after.mistral_calls or 0} calls") + logger.info(f" ├─ OpenAI: {usage_after.openai_tokens or 0} tokens / {usage_after.openai_calls or 0} calls") + logger.info(f" ├─ Stability (Images): {usage_after.stability_calls or 0} calls") + logger.info(f" ├─ Total Tokens: {usage_after.total_tokens or 0}") + logger.info(f" ├─ Total Calls: {usage_after.total_calls or 0}") + logger.info(f" └─ Usage Status: {usage_after.usage_status.value}") + else: + logger.warning(f" ⚠️ Usage summary not found after reset - may need to be created on next API call") + else: + logger.warning(f" ⚠️ Reset returned: {reset_result.get('reason', 'unknown')}") except Exception as reset_err: - logger.error(f"Failed to reset usage after subscribe: {reset_err}") + logger.error(f" ❌ Failed to reset usage after subscribe: {reset_err}", exc_info=True) + + logger.info(f" ✅ Renewal completed: User {user_id} → {plan.name} ({billing_cycle})") + logger.info("=" * 80) return { "success": True, @@ -404,7 +477,20 @@ async def subscribe_to_plan( "billing_cycle": billing_cycle, "current_period_start": subscription.current_period_start.isoformat(), "current_period_end": subscription.current_period_end.isoformat(), - "status": subscription.status.value + "status": subscription.status.value, + "limits": { + "ai_text_generation_calls": getattr(plan, 'ai_text_generation_calls_limit', None) or 0, + "gemini_calls": plan.gemini_calls_limit, + "openai_calls": plan.openai_calls_limit, + "anthropic_calls": plan.anthropic_calls_limit, + "mistral_calls": plan.mistral_calls_limit, + "tavily_calls": plan.tavily_calls_limit, + "serper_calls": plan.serper_calls_limit, + "metaphor_calls": plan.metaphor_calls_limit, + "firecrawl_calls": plan.firecrawl_calls_limit, + "stability_calls": plan.stability_calls_limit, + "monthly_cost": plan.monthly_cost_limit + } } } diff --git a/backend/api/wix_routes.py b/backend/api/wix_routes.py index 364cfd53..9e35f25d 100644 --- a/backend/api/wix_routes.py +++ b/backend/api/wix_routes.py @@ -477,6 +477,39 @@ async def test_publish_to_wix(request: WixPublishRequest) -> Dict[str, Any]: raise HTTPException(status_code=500, detail=str(e)) +@router.post("/refresh-token") +async def refresh_wix_token(request: Dict[str, Any]) -> Dict[str, Any]: + """ + Refresh Wix access token using refresh token + + Args: + request: Dict containing refresh_token + + Returns: + New token information with access_token, refresh_token, expires_in + """ + try: + refresh_token = request.get("refresh_token") + if not refresh_token: + raise HTTPException(status_code=400, detail="Missing refresh_token") + + # Refresh the token + new_tokens = wix_service.refresh_access_token(refresh_token) + + return { + "success": True, + "access_token": new_tokens.get("access_token"), + "refresh_token": new_tokens.get("refresh_token"), + "expires_in": new_tokens.get("expires_in"), + "token_type": new_tokens.get("token_type", "Bearer") + } + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to refresh Wix token: {e}") + raise HTTPException(status_code=500, detail=f"Failed to refresh token: {str(e)}") + + @router.post("/test/publish/real") async def test_publish_real(payload: Dict[str, Any]) -> Dict[str, Any]: """ diff --git a/backend/app.py b/backend/app.py index a3294a9a..e4a9cdca 100644 --- a/backend/app.py +++ b/backend/app.py @@ -298,6 +298,11 @@ async def startup_event(): try: # Initialize database init_database() + + # Start task scheduler + from services.scheduler import get_scheduler + await get_scheduler().start() + logger.info("ALwrity backend started successfully") except Exception as e: logger.error(f"Error during startup: {e}") @@ -307,6 +312,10 @@ async def startup_event(): async def shutdown_event(): """Cleanup on shutdown.""" try: + # Stop task scheduler + from services.scheduler import get_scheduler + await get_scheduler().stop() + # Close database connections close_database() logger.info("ALwrity backend shutdown successfully") diff --git a/backend/database/migrations/add_user_id_to_task_execution_logs.sql b/backend/database/migrations/add_user_id_to_task_execution_logs.sql new file mode 100644 index 00000000..3f2645ee --- /dev/null +++ b/backend/database/migrations/add_user_id_to_task_execution_logs.sql @@ -0,0 +1,20 @@ +-- Migration: Add user_id column to task_execution_logs for user isolation +-- Date: 2025-01-XX +-- Purpose: Enable user isolation tracking in scheduler task execution logs + +-- Add user_id column (nullable for backward compatibility with existing records) +ALTER TABLE task_execution_logs +ADD COLUMN user_id INTEGER NULL; + +-- Create index for efficient user filtering and queries +CREATE INDEX IF NOT EXISTS idx_task_execution_logs_user_id +ON task_execution_logs(user_id); + +-- Create composite index for common query patterns (user_id + status + execution_date) +CREATE INDEX IF NOT EXISTS idx_task_execution_logs_user_status_date +ON task_execution_logs(user_id, status, execution_date); + +-- Note: Backfilling existing records would require joining with monitoring_tasks +-- and enhanced_content_strategies tables. This can be done in a separate migration +-- or during a maintenance window. For now, existing records will have user_id = NULL. + diff --git a/backend/models/linkedin_models.py b/backend/models/linkedin_models.py index bd6ce5bc..544ce037 100644 --- a/backend/models/linkedin_models.py +++ b/backend/models/linkedin_models.py @@ -65,7 +65,7 @@ class LinkedInPostRequest(BaseModel): persona_override: Optional[Dict[str, Any]] = Field(default=None, description="Session-only persona overrides to apply without saving") class Config: - schema_extra = { + json_schema_extra = { "example": { "topic": "AI in healthcare transformation", "industry": "Healthcare", @@ -102,7 +102,7 @@ class LinkedInArticleRequest(BaseModel): persona_override: Optional[Dict[str, Any]] = Field(default=None, description="Session-only persona overrides to apply without saving") class Config: - schema_extra = { + json_schema_extra = { "example": { "topic": "Digital transformation in manufacturing", "industry": "Manufacturing", @@ -135,7 +135,7 @@ class LinkedInCarouselRequest(BaseModel): include_citations: bool = Field(default=True, description="Whether to include inline citations") class Config: - schema_extra = { + json_schema_extra = { "example": { "topic": "Future of remote work", "industry": "Technology", @@ -167,7 +167,7 @@ class LinkedInVideoScriptRequest(BaseModel): include_citations: bool = Field(default=True, description="Whether to include inline citations") class Config: - schema_extra = { + json_schema_extra = { "example": { "topic": "Cybersecurity best practices", "industry": "Technology", @@ -197,7 +197,7 @@ class LinkedInCommentResponseRequest(BaseModel): grounding_level: GroundingLevel = Field(default=GroundingLevel.BASIC, description="Level of content grounding") class Config: - schema_extra = { + json_schema_extra = { "example": { "original_comment": "Great insights on AI implementation!", "post_context": "Post about AI transformation in healthcare", @@ -353,7 +353,7 @@ class LinkedInPostResponse(BaseModel): grounding_status: Optional[Dict[str, Any]] = Field(None, description="Grounding operation status") class Config: - schema_extra = { + json_schema_extra = { "example": { "success": True, "data": { diff --git a/backend/models/monitoring_models.py b/backend/models/monitoring_models.py index 884ca7e3..7f992f08 100644 --- a/backend/models/monitoring_models.py +++ b/backend/models/monitoring_models.py @@ -48,8 +48,9 @@ class TaskExecutionLog(Base): id = Column(Integer, primary_key=True, index=True) task_id = Column(Integer, ForeignKey("monitoring_tasks.id"), nullable=False) + user_id = Column(Integer, nullable=True) # User ID for user isolation (nullable for backward compatibility) execution_date = Column(DateTime, default=datetime.utcnow) - status = Column(String(50), nullable=False) # 'success', 'failed', 'skipped' + status = Column(String(50), nullable=False) # 'success', 'failed', 'skipped', 'running' result_data = Column(JSON, nullable=True) error_message = Column(Text, nullable=True) execution_time_ms = Column(Integer, nullable=True) diff --git a/backend/models/subscription_models.py b/backend/models/subscription_models.py index db184e67..0e770ca3 100644 --- a/backend/models/subscription_models.py +++ b/backend/models/subscription_models.py @@ -50,16 +50,22 @@ class SubscriptionPlan(Base): price_monthly = Column(Float, nullable=False, default=0.0) price_yearly = Column(Float, nullable=False, default=0.0) - # API Call Limits - gemini_calls_limit = Column(Integer, default=0) # 0 = unlimited - openai_calls_limit = Column(Integer, default=0) - anthropic_calls_limit = Column(Integer, default=0) - mistral_calls_limit = Column(Integer, default=0) + # Unified AI Text Generation Call Limit (applies to all LLM providers: gemini, openai, anthropic, mistral) + # Note: This column may not exist in older databases - use getattr() when accessing + ai_text_generation_calls_limit = Column(Integer, default=0, nullable=True) # 0 = unlimited, None if column doesn't exist + + # Legacy per-provider limits (kept for backwards compatibility and analytics) + gemini_calls_limit = Column(Integer, default=0) # 0 = unlimited (deprecated, use ai_text_generation_calls_limit) + openai_calls_limit = Column(Integer, default=0) # (deprecated, use ai_text_generation_calls_limit) + anthropic_calls_limit = Column(Integer, default=0) # (deprecated, use ai_text_generation_calls_limit) + mistral_calls_limit = Column(Integer, default=0) # (deprecated, use ai_text_generation_calls_limit) + + # Other API Call Limits (non-LLM) tavily_calls_limit = Column(Integer, default=0) serper_calls_limit = Column(Integer, default=0) metaphor_calls_limit = Column(Integer, default=0) firecrawl_calls_limit = Column(Integer, default=0) - stability_calls_limit = Column(Integer, default=0) + stability_calls_limit = Column(Integer, default=0) # Image generation # Token Limits (for LLM providers) gemini_tokens_limit = Column(Integer, default=0) diff --git a/backend/requirements.txt b/backend/requirements.txt index bc3a9844..edfa663f 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -63,6 +63,9 @@ pytest-asyncio>=0.21.0 pydantic>=2.5.2,<3.0.0 typing-extensions>=4.8.0 +# Task scheduling +apscheduler>=3.10.0 + # Optional dependencies (for enhanced features) redis>=5.0.0 schedule>=1.2.0 \ No newline at end of file diff --git a/backend/scripts/add_ai_text_generation_limit_column.py b/backend/scripts/add_ai_text_generation_limit_column.py new file mode 100644 index 00000000..2255038a --- /dev/null +++ b/backend/scripts/add_ai_text_generation_limit_column.py @@ -0,0 +1,146 @@ +""" +Migration Script: Add ai_text_generation_calls_limit column to subscription_plans table. + +This adds the unified AI text generation limit column that applies to all LLM providers +(gemini, openai, anthropic, mistral) instead of per-provider limits. +""" + +import sys +import os +from pathlib import Path +from datetime import datetime, timezone + +# Add the backend directory to Python path +backend_dir = Path(__file__).parent.parent +sys.path.insert(0, str(backend_dir)) + +from sqlalchemy import create_engine, text, inspect +from sqlalchemy.orm import sessionmaker +from loguru import logger + +from models.subscription_models import SubscriptionPlan, SubscriptionTier +from services.database import DATABASE_URL + +def add_ai_text_generation_limit_column(): + """Add ai_text_generation_calls_limit column to subscription_plans table.""" + + try: + engine = create_engine(DATABASE_URL, echo=False) + SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + db = SessionLocal() + + try: + # Check if column already exists + inspector = inspect(engine) + columns = [col['name'] for col in inspector.get_columns('subscription_plans')] + + if 'ai_text_generation_calls_limit' in columns: + logger.info("✅ Column 'ai_text_generation_calls_limit' already exists. Skipping migration.") + return True + + logger.info("📋 Adding 'ai_text_generation_calls_limit' column to subscription_plans table...") + + # Add the column (SQLite compatible) + alter_query = text(""" + ALTER TABLE subscription_plans + ADD COLUMN ai_text_generation_calls_limit INTEGER DEFAULT 0 + """) + + db.execute(alter_query) + db.commit() + + logger.info("✅ Column added successfully!") + + # Update existing plans with unified limits based on their current limits + logger.info("\n🔄 Updating existing subscription plans with unified limits...") + + plans = db.query(SubscriptionPlan).all() + updated_count = 0 + + for plan in plans: + # Use the first non-zero LLM provider limit as the unified limit + # Or use gemini_calls_limit as default + unified_limit = ( + plan.ai_text_generation_calls_limit or + plan.gemini_calls_limit or + plan.openai_calls_limit or + plan.anthropic_calls_limit or + plan.mistral_calls_limit or + 0 + ) + + # For Basic plan, ensure it's set to 10 (from our recent update) + if plan.tier == SubscriptionTier.BASIC: + unified_limit = 10 + + if plan.ai_text_generation_calls_limit != unified_limit: + plan.ai_text_generation_calls_limit = unified_limit + plan.updated_at = datetime.now(timezone.utc) + updated_count += 1 + + logger.info(f" ✅ Updated {plan.name} ({plan.tier.value}): ai_text_generation_calls_limit = {unified_limit}") + else: + logger.info(f" ℹ️ {plan.name} ({plan.tier.value}): already set to {unified_limit}") + + if updated_count > 0: + db.commit() + logger.info(f"\n✅ Updated {updated_count} subscription plan(s)") + else: + logger.info("\nℹ️ No plans needed updating") + + # Display summary + logger.info("\n" + "="*60) + logger.info("MIGRATION SUMMARY") + logger.info("="*60) + + all_plans = db.query(SubscriptionPlan).all() + for plan in all_plans: + logger.info(f"\n{plan.name} ({plan.tier.value}):") + logger.info(f" Unified AI Text Gen Limit: {plan.ai_text_generation_calls_limit if plan.ai_text_generation_calls_limit else 'Not set'}") + logger.info(f" Legacy Limits: gemini={plan.gemini_calls_limit}, mistral={plan.mistral_calls_limit}") + + logger.info("\n" + "="*60) + logger.info("✅ Migration completed successfully!") + + return True + + except Exception as e: + db.rollback() + logger.error(f"❌ Error during migration: {e}") + import traceback + logger.error(traceback.format_exc()) + raise + + finally: + db.close() + + except Exception as e: + logger.error(f"❌ Failed to connect to database: {e}") + import traceback + logger.error(traceback.format_exc()) + return False + +if __name__ == "__main__": + logger.info("🚀 Starting ai_text_generation_calls_limit column migration...") + logger.info("="*60) + logger.info("This will add the unified AI text generation limit column") + logger.info("and update existing plans with appropriate values.") + logger.info("="*60) + + try: + success = add_ai_text_generation_limit_column() + + if success: + logger.info("\n✅ Script completed successfully!") + sys.exit(0) + else: + logger.error("\n❌ Script failed!") + sys.exit(1) + + except KeyboardInterrupt: + logger.info("\n⚠️ Script cancelled by user") + sys.exit(1) + except Exception as e: + logger.error(f"\n❌ Unexpected error: {e}") + sys.exit(1) + diff --git a/backend/scripts/cap_basic_plan_usage.py b/backend/scripts/cap_basic_plan_usage.py new file mode 100644 index 00000000..1baa5769 --- /dev/null +++ b/backend/scripts/cap_basic_plan_usage.py @@ -0,0 +1,210 @@ +""" +Standalone script to cap usage counters at new Basic plan limits. + +This preserves historical usage data but caps it at the new limits so users +can continue making new calls within their limits. +""" + +import sys +import os +from pathlib import Path +from datetime import datetime, timezone + +# Add the backend directory to Python path +backend_dir = Path(__file__).parent.parent +sys.path.insert(0, str(backend_dir)) + +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from loguru import logger + +from models.subscription_models import SubscriptionPlan, SubscriptionTier, UserSubscription, UsageSummary, UsageStatus +from services.database import DATABASE_URL +from services.subscription import PricingService + +def cap_basic_plan_usage(): + """Cap usage counters at new Basic plan limits.""" + + try: + engine = create_engine(DATABASE_URL, echo=False) + SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + db = SessionLocal() + + try: + # Find Basic plan + basic_plan = db.query(SubscriptionPlan).filter( + SubscriptionPlan.tier == SubscriptionTier.BASIC + ).first() + + if not basic_plan: + logger.error("❌ Basic plan not found in database!") + return False + + # New limits + new_call_limit = basic_plan.gemini_calls_limit # Should be 10 + new_token_limit = basic_plan.gemini_tokens_limit # Should be 2000 + new_image_limit = basic_plan.stability_calls_limit # Should be 5 + + logger.info(f"📋 Basic Plan Limits:") + logger.info(f" Calls: {new_call_limit}") + logger.info(f" Tokens: {new_token_limit}") + logger.info(f" Images: {new_image_limit}") + + # Get all Basic plan users + user_subscriptions = db.query(UserSubscription).filter( + UserSubscription.plan_id == basic_plan.id, + UserSubscription.is_active == True + ).all() + + logger.info(f"\n👥 Found {len(user_subscriptions)} Basic plan user(s)") + + pricing_service = PricingService(db) + capped_count = 0 + + for sub in user_subscriptions: + try: + # Get current billing period for this user + current_period = pricing_service.get_current_billing_period(sub.user_id) or datetime.now(timezone.utc).strftime("%Y-%m") + + # Find usage summary for current period + usage_summary = db.query(UsageSummary).filter( + UsageSummary.user_id == sub.user_id, + UsageSummary.billing_period == current_period + ).first() + + if usage_summary: + # Store old values for logging + old_gemini = usage_summary.gemini_calls or 0 + old_mistral = usage_summary.mistral_calls or 0 + old_openai = usage_summary.openai_calls or 0 + old_anthropic = usage_summary.anthropic_calls or 0 + old_tokens = max( + usage_summary.gemini_tokens or 0, + usage_summary.openai_tokens or 0, + usage_summary.anthropic_tokens or 0, + usage_summary.mistral_tokens or 0 + ) + old_images = usage_summary.stability_calls or 0 + + # Check if capping is needed + needs_cap = ( + old_gemini > new_call_limit or + old_mistral > new_call_limit or + old_openai > new_call_limit or + old_anthropic > new_call_limit or + old_images > new_image_limit or + old_tokens > new_token_limit + ) + + if needs_cap: + # Cap LLM provider counters at new limits + usage_summary.gemini_calls = min(old_gemini, new_call_limit) + usage_summary.mistral_calls = min(old_mistral, new_call_limit) + usage_summary.openai_calls = min(old_openai, new_call_limit) + usage_summary.anthropic_calls = min(old_anthropic, new_call_limit) + + # Cap token counters at new limits + usage_summary.gemini_tokens = min(usage_summary.gemini_tokens or 0, new_token_limit) + usage_summary.openai_tokens = min(usage_summary.openai_tokens or 0, new_token_limit) + usage_summary.anthropic_tokens = min(usage_summary.anthropic_tokens or 0, new_token_limit) + usage_summary.mistral_tokens = min(usage_summary.mistral_tokens or 0, new_token_limit) + + # Cap image counter at new limit + usage_summary.stability_calls = min(old_images, new_image_limit) + + # Recalculate totals based on capped values + total_capped_calls = ( + usage_summary.gemini_calls + + usage_summary.mistral_calls + + usage_summary.openai_calls + + usage_summary.anthropic_calls + + usage_summary.stability_calls + ) + total_capped_tokens = ( + usage_summary.gemini_tokens + + usage_summary.mistral_tokens + + usage_summary.openai_tokens + + usage_summary.anthropic_tokens + ) + + usage_summary.total_calls = total_capped_calls + usage_summary.total_tokens = total_capped_tokens + + # Reset status to active to allow new calls + usage_summary.usage_status = UsageStatus.ACTIVE + usage_summary.updated_at = datetime.now(timezone.utc) + + db.commit() + capped_count += 1 + + logger.info(f"\n✅ Capped usage for user {sub.user_id} (period {current_period}):") + logger.info(f" Gemini Calls: {old_gemini} → {usage_summary.gemini_calls} (limit: {new_call_limit})") + logger.info(f" Mistral Calls: {old_mistral} → {usage_summary.mistral_calls} (limit: {new_call_limit})") + logger.info(f" OpenAI Calls: {old_openai} → {usage_summary.openai_calls} (limit: {new_call_limit})") + logger.info(f" Anthropic Calls: {old_anthropic} → {usage_summary.anthropic_calls} (limit: {new_call_limit})") + logger.info(f" Tokens: {old_tokens} → {max(usage_summary.gemini_tokens, usage_summary.mistral_tokens)} (limit: {new_token_limit})") + logger.info(f" Images: {old_images} → {usage_summary.stability_calls} (limit: {new_image_limit})") + else: + logger.info(f" ℹ️ User {sub.user_id} usage is within limits - no capping needed") + else: + logger.info(f" ℹ️ No usage summary found for user {sub.user_id} (period {current_period})") + + except Exception as cap_error: + logger.error(f" ❌ Error capping usage for user {sub.user_id}: {cap_error}") + import traceback + logger.error(traceback.format_exc()) + db.rollback() + + if capped_count > 0: + logger.info(f"\n✅ Successfully capped usage for {capped_count} user(s)") + logger.info(" Historical usage preserved, but capped at new limits") + logger.info(" Users can now make new calls within their limits") + else: + logger.info("\nℹ️ No usage counters needed capping") + + logger.info("\n" + "="*60) + logger.info("CAPPING COMPLETE") + logger.info("="*60) + + return True + + except Exception as e: + db.rollback() + logger.error(f"❌ Error capping usage: {e}") + import traceback + logger.error(traceback.format_exc()) + raise + + finally: + db.close() + + except Exception as e: + logger.error(f"❌ Failed to connect to database: {e}") + import traceback + logger.error(traceback.format_exc()) + return False + +if __name__ == "__main__": + logger.info("🚀 Starting Basic plan usage capping...") + logger.info("="*60) + logger.info("This will cap usage counters at new Basic plan limits") + logger.info("while preserving historical usage data.") + logger.info("="*60) + + try: + success = cap_basic_plan_usage() + + if success: + logger.info("\n✅ Script completed successfully!") + sys.exit(0) + else: + logger.error("\n❌ Script failed!") + sys.exit(1) + + except KeyboardInterrupt: + logger.info("\n⚠️ Script cancelled by user") + sys.exit(1) + except Exception as e: + logger.error(f"\n❌ Unexpected error: {e}") + sys.exit(1) + diff --git a/backend/scripts/reset_basic_plan_usage.py b/backend/scripts/reset_basic_plan_usage.py new file mode 100644 index 00000000..3ec0bc73 --- /dev/null +++ b/backend/scripts/reset_basic_plan_usage.py @@ -0,0 +1,168 @@ +""" +Quick script to reset usage counters for Basic plan users. + +This fixes the issue where plan limits were updated but old usage data remained. +Resets all usage counters (calls, tokens, images) to 0 for the current billing period. +""" + +import sys +import os +from pathlib import Path +from datetime import datetime, timezone + +# Add the backend directory to Python path +backend_dir = Path(__file__).parent.parent +sys.path.insert(0, str(backend_dir)) + +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from loguru import logger + +from models.subscription_models import SubscriptionPlan, SubscriptionTier, UserSubscription, UsageSummary, UsageStatus +from services.database import DATABASE_URL +from services.subscription import PricingService + +def reset_basic_plan_usage(): + """Reset usage counters for all Basic plan users.""" + + try: + engine = create_engine(DATABASE_URL, echo=False) + SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + db = SessionLocal() + + try: + # Find Basic plan + basic_plan = db.query(SubscriptionPlan).filter( + SubscriptionPlan.tier == SubscriptionTier.BASIC + ).first() + + if not basic_plan: + logger.error("❌ Basic plan not found in database!") + return False + + # Get all Basic plan users + user_subscriptions = db.query(UserSubscription).filter( + UserSubscription.plan_id == basic_plan.id, + UserSubscription.is_active == True + ).all() + + logger.info(f"Found {len(user_subscriptions)} Basic plan user(s)") + + pricing_service = PricingService(db) + reset_count = 0 + + for sub in user_subscriptions: + try: + # Get current billing period for this user + current_period = pricing_service.get_current_billing_period(sub.user_id) or datetime.now(timezone.utc).strftime("%Y-%m") + + # Find usage summary for current period + usage_summary = db.query(UsageSummary).filter( + UsageSummary.user_id == sub.user_id, + UsageSummary.billing_period == current_period + ).first() + + if usage_summary: + # Store old values for logging + old_gemini = usage_summary.gemini_calls or 0 + old_mistral = usage_summary.mistral_calls or 0 + old_tokens = (usage_summary.mistral_tokens or 0) + (usage_summary.gemini_tokens or 0) + old_images = usage_summary.stability_calls or 0 + old_total_calls = usage_summary.total_calls or 0 + old_total_tokens = usage_summary.total_tokens or 0 + + # Reset all LLM provider counters + usage_summary.gemini_calls = 0 + usage_summary.openai_calls = 0 + usage_summary.anthropic_calls = 0 + usage_summary.mistral_calls = 0 + + # Reset all token counters + usage_summary.gemini_tokens = 0 + usage_summary.openai_tokens = 0 + usage_summary.anthropic_tokens = 0 + usage_summary.mistral_tokens = 0 + + # Reset image counter + usage_summary.stability_calls = 0 + + # Reset totals + usage_summary.total_calls = 0 + usage_summary.total_tokens = 0 + usage_summary.total_cost = 0.0 + + # Reset status to active + usage_summary.usage_status = UsageStatus.ACTIVE + usage_summary.updated_at = datetime.now(timezone.utc) + + db.commit() + reset_count += 1 + + logger.info(f"\n✅ Reset usage for user {sub.user_id} (period {current_period}):") + logger.info(f" Calls: {old_gemini + old_mistral} (gemini: {old_gemini}, mistral: {old_mistral}) → 0") + logger.info(f" Tokens: {old_tokens} → 0") + logger.info(f" Images: {old_images} → 0") + logger.info(f" Total Calls: {old_total_calls} → 0") + logger.info(f" Total Tokens: {old_total_tokens} → 0") + else: + logger.info(f" ℹ️ No usage summary found for user {sub.user_id} (period {current_period}) - nothing to reset") + + except Exception as reset_error: + logger.error(f" ❌ Error resetting usage for user {sub.user_id}: {reset_error}") + import traceback + logger.error(traceback.format_exc()) + db.rollback() + + if reset_count > 0: + logger.info(f"\n✅ Successfully reset usage counters for {reset_count} user(s)") + else: + logger.info("\nℹ️ No usage counters to reset") + + logger.info("\n" + "="*60) + logger.info("RESET COMPLETE") + logger.info("="*60) + logger.info("\n💡 Usage counters have been reset. Users can now use their new limits.") + logger.info(" Next API call will start counting from 0.") + + return True + + except Exception as e: + db.rollback() + logger.error(f"❌ Error resetting usage: {e}") + import traceback + logger.error(traceback.format_exc()) + raise + + finally: + db.close() + + except Exception as e: + logger.error(f"❌ Failed to connect to database: {e}") + import traceback + logger.error(traceback.format_exc()) + return False + +if __name__ == "__main__": + logger.info("🚀 Starting Basic plan usage counter reset...") + logger.info("="*60) + logger.info("This will reset all usage counters (calls, tokens, images) to 0") + logger.info("for all Basic plan users in their current billing period.") + logger.info("="*60) + + try: + success = reset_basic_plan_usage() + + if success: + logger.info("\n✅ Script completed successfully!") + sys.exit(0) + else: + logger.error("\n❌ Script failed!") + sys.exit(1) + + except KeyboardInterrupt: + logger.info("\n⚠️ Script cancelled by user") + sys.exit(1) + except Exception as e: + logger.error(f"\n❌ Unexpected error: {e}") + sys.exit(1) + diff --git a/backend/scripts/update_basic_plan_limits.py b/backend/scripts/update_basic_plan_limits.py new file mode 100644 index 00000000..e88350dc --- /dev/null +++ b/backend/scripts/update_basic_plan_limits.py @@ -0,0 +1,279 @@ +""" +Script to update Basic plan subscription limits for testing rate limits and renewal flows. + +Updates: +- LLM Calls (all providers): 10 calls (was 500-1000) +- LLM Tokens (all providers): 2000 tokens (was 200k-1M) +- Images: 5 images (was 50) + +This script updates the SubscriptionPlan table, which automatically applies to all users +who have a Basic plan subscription via the plan_id foreign key. +""" + +import sys +import os +from pathlib import Path +from datetime import datetime, timezone + +# Add the backend directory to Python path +backend_dir = Path(__file__).parent.parent +sys.path.insert(0, str(backend_dir)) + +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from loguru import logger + +from models.subscription_models import SubscriptionPlan, SubscriptionTier, UserSubscription, UsageStatus +from services.database import DATABASE_URL + +def update_basic_plan_limits(): + """Update Basic plan limits for testing rate limits and renewal.""" + + try: + engine = create_engine(DATABASE_URL, echo=False) + SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + db = SessionLocal() + + try: + # Find Basic plan + basic_plan = db.query(SubscriptionPlan).filter( + SubscriptionPlan.tier == SubscriptionTier.BASIC + ).first() + + if not basic_plan: + logger.error("❌ Basic plan not found in database!") + return False + + # Store old values for logging + old_limits = { + 'gemini_calls': basic_plan.gemini_calls_limit, + 'mistral_calls': basic_plan.mistral_calls_limit, + 'gemini_tokens': basic_plan.gemini_tokens_limit, + 'mistral_tokens': basic_plan.mistral_tokens_limit, + 'stability_calls': basic_plan.stability_calls_limit, + } + + logger.info(f"📋 Current Basic plan limits:") + logger.info(f" Gemini Calls: {old_limits['gemini_calls']}") + logger.info(f" Mistral Calls: {old_limits['mistral_calls']}") + logger.info(f" Gemini Tokens: {old_limits['gemini_tokens']}") + logger.info(f" Mistral Tokens: {old_limits['mistral_tokens']}") + logger.info(f" Images (Stability): {old_limits['stability_calls']}") + + # Update unified AI text generation limit to 10 + basic_plan.ai_text_generation_calls_limit = 10 + + # Legacy per-provider limits (kept for backwards compatibility, but not used for enforcement) + basic_plan.gemini_calls_limit = 1000 + basic_plan.openai_calls_limit = 500 + basic_plan.anthropic_calls_limit = 200 + basic_plan.mistral_calls_limit = 500 + + # Update all LLM provider token limits to 2000 + basic_plan.gemini_tokens_limit = 2000 + basic_plan.openai_tokens_limit = 2000 + basic_plan.anthropic_tokens_limit = 2000 + basic_plan.mistral_tokens_limit = 2000 + + # Update image generation limit to 5 + basic_plan.stability_calls_limit = 5 + + # Update timestamp + basic_plan.updated_at = datetime.now(timezone.utc) + + logger.info("\n📝 New Basic plan limits:") + logger.info(f" LLM Calls (all providers): 10") + logger.info(f" LLM Tokens (all providers): 2000") + logger.info(f" Images: 5") + + # Count and get affected users + user_subscriptions = db.query(UserSubscription).filter( + UserSubscription.plan_id == basic_plan.id, + UserSubscription.is_active == True + ).all() + + affected_users = len(user_subscriptions) + + logger.info(f"\n👥 Users affected: {affected_users}") + + if affected_users > 0: + logger.info("\n📋 Affected user IDs:") + for sub in user_subscriptions: + logger.info(f" - {sub.user_id}") + else: + logger.info(" (No active Basic plan subscriptions found)") + + # Commit plan limit changes first + db.commit() + logger.info("\n✅ Basic plan limits updated successfully!") + + # Cap usage at new limits for all affected users (preserve historical data, but cap enforcement) + logger.info("\n🔄 Capping usage counters at new limits for Basic plan users...") + logger.info(" (Historical usage preserved, but capped to allow new calls within limits)") + from models.subscription_models import UsageSummary + from services.subscription import PricingService + + pricing_service = PricingService(db) + capped_count = 0 + + # New limits - use unified AI text generation limit if available + new_call_limit = getattr(basic_plan, 'ai_text_generation_calls_limit', None) or basic_plan.gemini_calls_limit + new_token_limit = basic_plan.gemini_tokens_limit # 2000 + new_image_limit = basic_plan.stability_calls_limit # 5 + + for sub in user_subscriptions: + try: + # Get current billing period for this user + current_period = pricing_service.get_current_billing_period(sub.user_id) or datetime.now(timezone.utc).strftime("%Y-%m") + + # Find usage summary for current period + usage_summary = db.query(UsageSummary).filter( + UsageSummary.user_id == sub.user_id, + UsageSummary.billing_period == current_period + ).first() + + if usage_summary: + # Store old values for logging + old_gemini = usage_summary.gemini_calls or 0 + old_mistral = usage_summary.mistral_calls or 0 + old_openai = usage_summary.openai_calls or 0 + old_anthropic = usage_summary.anthropic_calls or 0 + old_tokens = max( + usage_summary.gemini_tokens or 0, + usage_summary.openai_tokens or 0, + usage_summary.anthropic_tokens or 0, + usage_summary.mistral_tokens or 0 + ) + old_images = usage_summary.stability_calls or 0 + + # Cap LLM provider counters at new limits (don't reset, just cap) + # This allows historical data to remain but prevents blocking from old usage + usage_summary.gemini_calls = min(old_gemini, new_call_limit) + usage_summary.mistral_calls = min(old_mistral, new_call_limit) + usage_summary.openai_calls = min(old_openai, new_call_limit) + usage_summary.anthropic_calls = min(old_anthropic, new_call_limit) + + # Cap token counters at new limits + usage_summary.gemini_tokens = min(usage_summary.gemini_tokens or 0, new_token_limit) + usage_summary.openai_tokens = min(usage_summary.openai_tokens or 0, new_token_limit) + usage_summary.anthropic_tokens = min(usage_summary.anthropic_tokens or 0, new_token_limit) + usage_summary.mistral_tokens = min(usage_summary.mistral_tokens or 0, new_token_limit) + + # Cap image counter at new limit + usage_summary.stability_calls = min(old_images, new_image_limit) + + # Update totals based on capped values (approximate) + # Recalculate total_calls and total_tokens based on capped provider values + total_capped_calls = ( + usage_summary.gemini_calls + + usage_summary.mistral_calls + + usage_summary.openai_calls + + usage_summary.anthropic_calls + + usage_summary.stability_calls + ) + total_capped_tokens = ( + usage_summary.gemini_tokens + + usage_summary.mistral_tokens + + usage_summary.openai_tokens + + usage_summary.anthropic_tokens + ) + + usage_summary.total_calls = total_capped_calls + usage_summary.total_tokens = total_capped_tokens + + # Reset status to active to allow new calls + usage_summary.usage_status = UsageStatus.ACTIVE + usage_summary.updated_at = datetime.now(timezone.utc) + + db.commit() + capped_count += 1 + + logger.info(f" ✅ Capped usage for user {sub.user_id}:") + logger.info(f" Gemini Calls: {old_gemini} → {usage_summary.gemini_calls} (limit: {new_call_limit})") + logger.info(f" Mistral Calls: {old_mistral} → {usage_summary.mistral_calls} (limit: {new_call_limit})") + logger.info(f" Tokens: {old_tokens} → {max(usage_summary.gemini_tokens, usage_summary.mistral_tokens)} (limit: {new_token_limit})") + logger.info(f" Images: {old_images} → {usage_summary.stability_calls} (limit: {new_image_limit})") + else: + logger.info(f" ℹ️ No usage summary found for user {sub.user_id} (period {current_period})") + + except Exception as cap_error: + logger.error(f" ❌ Error capping usage for user {sub.user_id}: {cap_error}") + import traceback + logger.error(traceback.format_exc()) + db.rollback() + + if capped_count > 0: + logger.info(f"\n✅ Capped usage counters for {capped_count} user(s)") + logger.info(" Historical usage preserved, but capped at new limits to allow new calls") + else: + logger.info("\nℹ️ No usage counters to cap") + + # Note about cache clearing + logger.info("\n🔄 Cache Information:") + logger.info(" The subscription limits cache is per-instance and will refresh on next request.") + logger.info(" No manual cache clearing needed - limits will be read from database on next check.") + + # Display final summary + logger.info("\n" + "="*60) + logger.info("BASIC PLAN UPDATE SUMMARY") + logger.info("="*60) + logger.info(f"\nPlan: {basic_plan.name} ({basic_plan.tier.value})") + logger.info(f"Price: ${basic_plan.price_monthly}/mo, ${basic_plan.price_yearly}/yr") + logger.info(f"\nUpdated Limits:") + logger.info(f" LLM Calls (gemini/openai/anthropic/mistral): {basic_plan.gemini_calls_limit}") + logger.info(f" LLM Tokens (gemini/openai/anthropic/mistral): {basic_plan.gemini_tokens_limit}") + logger.info(f" Images (stability): {basic_plan.stability_calls_limit}") + logger.info(f"\nUsers Affected: {affected_users}") + logger.info("\n" + "="*60) + logger.info("\n💡 Note: These limits apply immediately to all Basic plan users.") + logger.info(" Historical usage has been preserved but capped at new limits.") + logger.info(" Users can continue making new calls up to the new limits.") + logger.info(" Users will hit rate limits faster for testing purposes.") + + return True + + except Exception as e: + db.rollback() + logger.error(f"❌ Error updating Basic plan: {e}") + import traceback + logger.error(traceback.format_exc()) + raise + + finally: + db.close() + + except Exception as e: + logger.error(f"❌ Failed to connect to database: {e}") + import traceback + logger.error(traceback.format_exc()) + return False + +if __name__ == "__main__": + logger.info("🚀 Starting Basic plan limits update...") + logger.info("="*60) + logger.info("This will update Basic plan limits for testing rate limits:") + logger.info(" - LLM Calls: 10 (all providers)") + logger.info(" - LLM Tokens: 2000 (all providers)") + logger.info(" - Images: 5") + logger.info("="*60) + + # Ask for confirmation in non-interactive mode, proceed directly + # In interactive mode, you can add: input("\nPress Enter to continue or Ctrl+C to cancel...") + + try: + success = update_basic_plan_limits() + + if success: + logger.info("\n✅ Script completed successfully!") + sys.exit(0) + else: + logger.error("\n❌ Script failed!") + sys.exit(1) + + except KeyboardInterrupt: + logger.info("\n⚠️ Script cancelled by user") + sys.exit(1) + except Exception as e: + logger.error(f"\n❌ Unexpected error: {e}") + sys.exit(1) + diff --git a/backend/services/active_strategy_service.py b/backend/services/active_strategy_service.py index 7b25c38b..84c92c6c 100644 --- a/backend/services/active_strategy_service.py +++ b/backend/services/active_strategy_service.py @@ -295,3 +295,55 @@ class ActiveStrategyService: 'cached_users': list(self._memory_cache.keys()), 'last_updates': {k: v.isoformat() for k, v in self._last_cache_update.items()} } + + def count_active_strategies_with_tasks(self) -> int: + """ + Count how many active strategies have monitoring tasks. + + This is used for intelligent scheduling - if there are no active strategies + with tasks, the scheduler can check less frequently. + + Returns: + Number of active strategies that have at least one active monitoring task + """ + try: + if not self.db_session: + logger.warning("Database session not available") + return 0 + + from sqlalchemy import func, and_ + from models.monitoring_models import MonitoringTask + + # Count distinct strategies that: + # 1. Have activation status = 'active' + # 2. Have at least one active monitoring task + count = self.db_session.query( + func.count(func.distinct(EnhancedContentStrategy.id)) + ).join( + StrategyActivationStatus, + EnhancedContentStrategy.id == StrategyActivationStatus.strategy_id + ).join( + MonitoringTask, + EnhancedContentStrategy.id == MonitoringTask.strategy_id + ).filter( + and_( + StrategyActivationStatus.status == 'active', + MonitoringTask.status == 'active' + ) + ).scalar() + + return count or 0 + + except Exception as e: + logger.error(f"Error counting active strategies with tasks: {e}") + # On error, assume there are active strategies (safer to check more frequently) + return 1 + + def has_active_strategies_with_tasks(self) -> bool: + """ + Check if there are any active strategies with monitoring tasks. + + Returns: + True if there are active strategies with tasks, False otherwise + """ + return self.count_active_strategies_with_tasks() > 0 \ No newline at end of file diff --git a/backend/services/blog_writer/core/blog_writer_service.py b/backend/services/blog_writer/core/blog_writer_service.py index e5aef900..b8324420 100644 --- a/backend/services/blog_writer/core/blog_writer_service.py +++ b/backend/services/blog_writer/core/blog_writer_service.py @@ -96,13 +96,13 @@ class BlogWriterService: self.blog_rewriter = BlogRewriter(self.task_manager) # Research Methods - async def research(self, request: BlogResearchRequest) -> BlogResearchResponse: + async def research(self, request: BlogResearchRequest, user_id: str) -> BlogResearchResponse: """Conduct comprehensive research using Google Search grounding.""" - return await self.research_service.research(request) + return await self.research_service.research(request, user_id) - async def research_with_progress(self, request: BlogResearchRequest, task_id: str) -> BlogResearchResponse: + async def research_with_progress(self, request: BlogResearchRequest, task_id: str, user_id: str) -> BlogResearchResponse: """Conduct research with real-time progress updates.""" - return await self.research_service.research_with_progress(request, task_id) + return await self.research_service.research_with_progress(request, task_id, user_id) # Outline Methods async def generate_outline(self, request: BlogOutlineRequest) -> BlogOutlineResponse: @@ -204,11 +204,14 @@ class BlogWriterService: except Exception as e: return {"success": False, "error": str(e)} - async def seo_analyze(self, request: BlogSEOAnalyzeRequest) -> BlogSEOAnalyzeResponse: + async def seo_analyze(self, request: BlogSEOAnalyzeRequest, user_id: str = None) -> BlogSEOAnalyzeResponse: """Analyze content for SEO optimization using comprehensive blog-specific analyzer.""" try: from services.blog_writer.seo.blog_content_seo_analyzer import BlogContentSEOAnalyzer + if not user_id: + raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.") + content = request.content or "" target_keywords = request.keywords or [] @@ -231,7 +234,7 @@ class BlogWriterService: # Use our comprehensive SEO analyzer analyzer = BlogContentSEOAnalyzer() - analysis_results = await analyzer.analyze_blog_content(content, research_data) + analysis_results = await analyzer.analyze_blog_content(content, research_data, user_id=user_id) # Convert results to response format recommendations = analysis_results.get('actionable_recommendations', []) @@ -267,11 +270,14 @@ class BlogWriterService: recommendations=[f"SEO analysis failed: {str(e)}"] ) - async def seo_metadata(self, request: BlogSEOMetadataRequest) -> BlogSEOMetadataResponse: + async def seo_metadata(self, request: BlogSEOMetadataRequest, user_id: str = None) -> BlogSEOMetadataResponse: """Generate comprehensive SEO metadata for content.""" try: from services.blog_writer.seo.blog_seo_metadata_generator import BlogSEOMetadataGenerator + if not user_id: + raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.") + # Initialize metadata generator metadata_generator = BlogSEOMetadataGenerator() @@ -285,7 +291,8 @@ class BlogWriterService: blog_title=request.title or "Untitled Blog Post", research_data=request.research_data or {}, outline=outline, - seo_analysis=seo_analysis + seo_analysis=seo_analysis, + user_id=user_id ) # Convert to BlogSEOMetadataResponse format diff --git a/backend/services/blog_writer/logger_config.py b/backend/services/blog_writer/logger_config.py index f681697b..a295ce33 100644 --- a/backend/services/blog_writer/logger_config.py +++ b/backend/services/blog_writer/logger_config.py @@ -163,13 +163,18 @@ class BlogWriterLogger: context: Optional[Dict[str, Any]] = None ): """Log error with full context.""" + # Safely format error message to avoid KeyError on format strings in error messages + error_str = str(error) + # Replace any curly braces that might be in the error message to avoid format string issues + safe_error_str = error_str.replace('{', '{{').replace('}', '}}') + logger.error( - f"Error in {operation}: {str(error)}", + f"Error in {operation}: {safe_error_str}", extra={ "event_type": "error", "operation": operation, "error_type": type(error).__name__, - "error_message": str(error), + "error_message": error_str, # Keep original in extra, but use safe version in format string "context": context or {} }, exc_info=True diff --git a/backend/services/blog_writer/research/competitor_analyzer.py b/backend/services/blog_writer/research/competitor_analyzer.py index 0e085b0c..2146cbca 100644 --- a/backend/services/blog_writer/research/competitor_analyzer.py +++ b/backend/services/blog_writer/research/competitor_analyzer.py @@ -11,7 +11,7 @@ from loguru import logger class CompetitorAnalyzer: """Analyzes competitors and market intelligence from research content.""" - def analyze(self, content: str) -> Dict[str, Any]: + def analyze(self, content: str, user_id: str = None) -> Dict[str, Any]: """Parse comprehensive competitor analysis from the research content using AI.""" competitor_prompt = f""" Analyze the following research content and extract competitor insights: @@ -57,7 +57,8 @@ class CompetitorAnalyzer: competitor_analysis = llm_text_gen( prompt=competitor_prompt, - json_struct=competitor_schema + json_struct=competitor_schema, + user_id=user_id ) if isinstance(competitor_analysis, dict) and 'error' not in competitor_analysis: diff --git a/backend/services/blog_writer/research/content_angle_generator.py b/backend/services/blog_writer/research/content_angle_generator.py index 69f32555..bb25405f 100644 --- a/backend/services/blog_writer/research/content_angle_generator.py +++ b/backend/services/blog_writer/research/content_angle_generator.py @@ -11,7 +11,7 @@ from loguru import logger class ContentAngleGenerator: """Generates strategic content angles from research content.""" - def generate(self, content: str, topic: str, industry: str) -> List[str]: + def generate(self, content: str, topic: str, industry: str, user_id: str = None) -> List[str]: """Parse strategic content angles from the research content using AI.""" angles_prompt = f""" Analyze the following research content and create strategic content angles for: {topic} in {industry} @@ -65,7 +65,8 @@ class ContentAngleGenerator: angles_result = llm_text_gen( prompt=angles_prompt, - json_struct=angles_schema + json_struct=angles_schema, + user_id=user_id ) if isinstance(angles_result, dict) and 'content_angles' in angles_result: diff --git a/backend/services/blog_writer/research/keyword_analyzer.py b/backend/services/blog_writer/research/keyword_analyzer.py index 598d6fcc..6e29d592 100644 --- a/backend/services/blog_writer/research/keyword_analyzer.py +++ b/backend/services/blog_writer/research/keyword_analyzer.py @@ -11,7 +11,7 @@ from loguru import logger class KeywordAnalyzer: """Analyzes keywords from research content using AI-powered extraction.""" - def analyze(self, content: str, original_keywords: List[str]) -> Dict[str, Any]: + def analyze(self, content: str, original_keywords: List[str], user_id: str = None) -> Dict[str, Any]: """Parse comprehensive keyword analysis from the research content using AI.""" # Use AI to extract and analyze keywords from the rich research content keyword_prompt = f""" @@ -64,7 +64,8 @@ class KeywordAnalyzer: keyword_analysis = llm_text_gen( prompt=keyword_prompt, - json_struct=keyword_schema + json_struct=keyword_schema, + user_id=user_id ) if isinstance(keyword_analysis, dict) and 'error' not in keyword_analysis: diff --git a/backend/services/blog_writer/research/research_service.py b/backend/services/blog_writer/research/research_service.py index 4eac9977..f7c0e2d4 100644 --- a/backend/services/blog_writer/research/research_service.py +++ b/backend/services/blog_writer/research/research_service.py @@ -4,7 +4,8 @@ Research Service - Core research functionality for AI Blog Writer. Handles Google Search grounding, caching, and research orchestration. """ -from typing import Dict, Any, List +from typing import Dict, Any, List, Optional +from datetime import datetime from loguru import logger from models.blog_models import ( @@ -17,6 +18,7 @@ from models.blog_models import ( Citation, ) from services.blog_writer.logger_config import blog_writer_logger, log_function_call +from fastapi import HTTPException from .keyword_analyzer import KeywordAnalyzer from .competitor_analyzer import CompetitorAnalyzer @@ -34,7 +36,7 @@ class ResearchService: self.data_filter = ResearchDataFilter() @log_function_call("research_operation") - async def research(self, request: BlogResearchRequest) -> BlogResearchResponse: + async def research(self, request: BlogResearchRequest, user_id: str) -> BlogResearchResponse: """ Stage 1: Research & Strategy (AI Orchestration) Uses ONLY Gemini's native Google Search grounding - ONE API call for everything. @@ -71,6 +73,10 @@ class ResearchService: blog_writer_logger.log_operation_end("research", 0, success=True, cache_hit=True) return BlogResearchResponse(**cached_result) + # User ID validation (validation logic is now in Google Grounding provider) + if not user_id: + raise ValueError("user_id is required for research operation. Please provide Clerk user ID.") + # Cache miss - proceed with API call logger.info(f"Cache miss - making API call for keywords: {request.keywords}") blog_writer_logger.log_operation_start("gemini_api_call", api_name="gemini_grounded", operation="research") @@ -96,12 +102,15 @@ class ResearchService: """ # Single Gemini call with native Google Search grounding - no fallbacks + # Validation is handled inside generate_grounded_content when validate_subsequent_operations=True import time api_start_time = time.time() gemini_result = await gemini.generate_grounded_content( prompt=research_prompt, content_type="research", - max_tokens=2000 + max_tokens=2000, + user_id=user_id, + validate_subsequent_operations=True # Validates Google Grounding + 3 LLM calls ) api_duration_ms = (time.time() - api_start_time) * 1000 @@ -126,9 +135,9 @@ class ResearchService: # Parse the comprehensive response for different analysis components content = gemini_result.get("content", "") - keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords) - competitor_analysis = self.competitor_analyzer.analyze(content) - suggested_angles = self.content_angle_generator.generate(content, topic, industry) + keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords, user_id=user_id) + competitor_analysis = self.competitor_analyzer.analyze(content, user_id=user_id) + suggested_angles = self.content_angle_generator.generate(content, topic, industry, user_id=user_id) logger.info(f"Research completed successfully with {len(sources)} sources and {len(search_queries)} search queries") @@ -179,6 +188,9 @@ class ResearchService: return filtered_response + except HTTPException: + # Re-raise HTTPException (subscription errors) - let task manager handle it + raise except Exception as e: error_message = str(e) logger.error(f"Research failed: {error_message}") @@ -244,7 +256,7 @@ class ResearchService: ) @log_function_call("research_with_progress") - async def research_with_progress(self, request: BlogResearchRequest, task_id: str) -> BlogResearchResponse: + async def research_with_progress(self, request: BlogResearchRequest, task_id: str, user_id: str) -> BlogResearchResponse: """ Research method with progress updates for real-time feedback. """ @@ -281,6 +293,11 @@ class ResearchService: logger.info(f"Returning cached research result for keywords: {request.keywords}") return BlogResearchResponse(**cached_result) + # User ID validation (validation logic is now in Google Grounding provider) + if not user_id: + await task_manager.update_progress(task_id, "❌ Error: User ID is required for research operation") + raise ValueError("user_id is required for research operation. Please provide Clerk user ID.") + # Cache miss - proceed with API call await task_manager.update_progress(task_id, "🌐 Cache miss - connecting to Google Search grounding...") logger.info(f"Cache miss - making API call for keywords: {request.keywords}") @@ -307,11 +324,20 @@ class ResearchService: await task_manager.update_progress(task_id, "🤖 Making AI request to Gemini with Google Search grounding...") # Single Gemini call with native Google Search grounding - no fallbacks - gemini_result = await gemini.generate_grounded_content( - prompt=research_prompt, - content_type="research", - max_tokens=2000 - ) + # Validation is handled inside generate_grounded_content when validate_subsequent_operations=True + try: + gemini_result = await gemini.generate_grounded_content( + prompt=research_prompt, + content_type="research", + max_tokens=2000, + user_id=user_id, + validate_subsequent_operations=True # Validates Google Grounding + 3 LLM calls + ) + except HTTPException as http_error: + # Re-raise HTTPException so it can be properly handled by task manager + logger.error(f"Subscription limit exceeded for research: {http_error.detail}") + await task_manager.update_progress(task_id, f"❌ Subscription limit exceeded: {http_error.detail.get('message', str(http_error.detail)) if isinstance(http_error.detail, dict) else str(http_error.detail)}") + raise # Re-raise HTTPException to preserve status code and error details await task_manager.update_progress(task_id, "📊 Processing research results and extracting insights...") # Extract sources from grounding metadata @@ -327,9 +353,9 @@ class ResearchService: await task_manager.update_progress(task_id, "🔍 Analyzing keywords and content angles...") # Parse the comprehensive response for different analysis components content = gemini_result.get("content", "") - keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords) - competitor_analysis = self.competitor_analyzer.analyze(content) - suggested_angles = self.content_angle_generator.generate(content, topic, industry) + keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords, user_id=user_id) + competitor_analysis = self.competitor_analyzer.analyze(content, user_id=user_id) + suggested_angles = self.content_angle_generator.generate(content, topic, industry, user_id=user_id) await task_manager.update_progress(task_id, "💾 Caching results for future use...") logger.info(f"Research completed successfully with {len(sources)} sources and {len(search_queries)} search queries") @@ -373,6 +399,9 @@ class ResearchService: return filtered_response + except HTTPException: + # Re-raise HTTPException (subscription errors) - let task manager handle it + raise except Exception as e: error_message = str(e) logger.error(f"Research failed: {error_message}") diff --git a/backend/services/blog_writer/seo/blog_content_seo_analyzer.py b/backend/services/blog_writer/seo/blog_content_seo_analyzer.py index 02611759..12eb44dd 100644 --- a/backend/services/blog_writer/seo/blog_content_seo_analyzer.py +++ b/backend/services/blog_writer/seo/blog_content_seo_analyzer.py @@ -34,17 +34,21 @@ class BlogContentSEOAnalyzer: logger.info("BlogContentSEOAnalyzer initialized") - async def analyze_blog_content(self, blog_content: str, research_data: Dict[str, Any], blog_title: Optional[str] = None) -> Dict[str, Any]: + async def analyze_blog_content(self, blog_content: str, research_data: Dict[str, Any], blog_title: Optional[str] = None, user_id: str = None) -> Dict[str, Any]: """ Main analysis method with parallel processing Args: blog_content: The blog content to analyze research_data: Research data containing keywords and other insights + blog_title: Optional blog title + user_id: Clerk user ID for subscription checking (required) Returns: Comprehensive SEO analysis results """ + if not user_id: + raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.") try: logger.info("Starting blog content SEO analysis") @@ -58,7 +62,7 @@ class BlogContentSEOAnalyzer: # Phase 2: Single AI analysis for structured insights logger.info("Running AI analysis") - ai_insights = await self._run_ai_analysis(blog_content, keywords_data, non_ai_results) + ai_insights = await self._run_ai_analysis(blog_content, keywords_data, non_ai_results, user_id=user_id) # Phase 3: Compile and format results logger.info("Compiling results") @@ -599,8 +603,10 @@ class BlogContentSEOAnalyzer: return recommendations - async def _run_ai_analysis(self, blog_content: str, keywords_data: Dict[str, Any], non_ai_results: Dict[str, Any]) -> Dict[str, Any]: + async def _run_ai_analysis(self, blog_content: str, keywords_data: Dict[str, Any], non_ai_results: Dict[str, Any], user_id: str = None) -> Dict[str, Any]: """Run single AI analysis for structured insights (provider-agnostic)""" + if not user_id: + raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.") try: # Prepare context for AI analysis context = { @@ -658,7 +664,8 @@ class BlogContentSEOAnalyzer: ai_response = llm_text_gen( prompt=prompt, json_struct=schema, - system_prompt=None + system_prompt=None, + user_id=user_id # Pass user_id for subscription checking ) return ai_response diff --git a/backend/services/blog_writer/seo/blog_seo_metadata_generator.py b/backend/services/blog_writer/seo/blog_seo_metadata_generator.py index 0a7ac744..e431d93a 100644 --- a/backend/services/blog_writer/seo/blog_seo_metadata_generator.py +++ b/backend/services/blog_writer/seo/blog_seo_metadata_generator.py @@ -28,7 +28,8 @@ class BlogSEOMetadataGenerator: blog_title: str, research_data: Dict[str, Any], outline: Optional[List[Dict[str, Any]]] = None, - seo_analysis: Optional[Dict[str, Any]] = None + seo_analysis: Optional[Dict[str, Any]] = None, + user_id: str = None ) -> Dict[str, Any]: """ Generate comprehensive SEO metadata using maximum 2 AI calls @@ -39,10 +40,13 @@ class BlogSEOMetadataGenerator: research_data: Research data containing keywords and insights outline: Outline structure with sections and headings seo_analysis: SEO analysis results from previous phase + user_id: Clerk user ID for subscription checking (required) Returns: Comprehensive metadata including all SEO elements """ + if not user_id: + raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.") try: logger.info("Starting comprehensive SEO metadata generation") @@ -53,13 +57,13 @@ class BlogSEOMetadataGenerator: # Call 1: Generate core SEO metadata (parallel with Call 2) logger.info("Generating core SEO metadata") core_metadata_task = self._generate_core_metadata( - blog_content, blog_title, keywords_data, outline, seo_analysis + blog_content, blog_title, keywords_data, outline, seo_analysis, user_id=user_id ) # Call 2: Generate social media and structured data (parallel with Call 1) logger.info("Generating social media and structured data") social_metadata_task = self._generate_social_metadata( - blog_content, blog_title, keywords_data, outline, seo_analysis + blog_content, blog_title, keywords_data, outline, seo_analysis, user_id=user_id ) # Wait for both calls to complete @@ -114,9 +118,12 @@ class BlogSEOMetadataGenerator: blog_title: str, keywords_data: Dict[str, Any], outline: Optional[List[Dict[str, Any]]] = None, - seo_analysis: Optional[Dict[str, Any]] = None + seo_analysis: Optional[Dict[str, Any]] = None, + user_id: str = None ) -> Dict[str, Any]: """Generate core SEO metadata (Call 1)""" + if not user_id: + raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.") try: # Create comprehensive prompt for core metadata prompt = self._create_core_metadata_prompt( @@ -170,7 +177,8 @@ class BlogSEOMetadataGenerator: ai_response_raw = llm_text_gen( prompt=prompt, json_struct=schema, - system_prompt=None + system_prompt=None, + user_id=user_id # Pass user_id for subscription checking ) # Handle response: llm_text_gen may return dict (from structured JSON) or str (needs parsing) @@ -215,9 +223,12 @@ class BlogSEOMetadataGenerator: blog_title: str, keywords_data: Dict[str, Any], outline: Optional[List[Dict[str, Any]]] = None, - seo_analysis: Optional[Dict[str, Any]] = None + seo_analysis: Optional[Dict[str, Any]] = None, + user_id: str = None ) -> Dict[str, Any]: """Generate social media and structured data (Call 2)""" + if not user_id: + raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.") try: # Create comprehensive prompt for social metadata prompt = self._create_social_metadata_prompt( @@ -274,7 +285,8 @@ class BlogSEOMetadataGenerator: ai_response_raw = llm_text_gen( prompt=prompt, json_struct=schema, - system_prompt=None + system_prompt=None, + user_id=user_id # Pass user_id for subscription checking ) # Handle response: llm_text_gen may return dict (from structured JSON) or str (needs parsing) diff --git a/backend/services/blog_writer/seo/blog_seo_recommendation_applier.py b/backend/services/blog_writer/seo/blog_seo_recommendation_applier.py index ed55779e..be7bdfe3 100644 --- a/backend/services/blog_writer/seo/blog_seo_recommendation_applier.py +++ b/backend/services/blog_writer/seo/blog_seo_recommendation_applier.py @@ -20,8 +20,11 @@ class BlogSEORecommendationApplier: def __init__(self): logger.debug("Initialized BlogSEORecommendationApplier") - async def apply_recommendations(self, payload: Dict[str, Any]) -> Dict[str, Any]: + async def apply_recommendations(self, payload: Dict[str, Any], user_id: str = None) -> Dict[str, Any]: """Apply recommendations and return updated content.""" + + if not user_id: + raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.") title = payload.get("title", "Untitled Blog") sections: List[Dict[str, Any]] = payload.get("sections", []) @@ -88,6 +91,7 @@ class BlogSEORecommendationApplier: prompt, None, schema, + user_id, # Pass user_id for subscription checking ) if not result or result.get("error"): diff --git a/backend/services/llm_providers/gemini_grounded_provider.py b/backend/services/llm_providers/gemini_grounded_provider.py index 69ae349e..e6248b32 100644 --- a/backend/services/llm_providers/gemini_grounded_provider.py +++ b/backend/services/llm_providers/gemini_grounded_provider.py @@ -56,7 +56,9 @@ class GeminiGroundedProvider: temperature: float = 0.7, max_tokens: int = 2048, urls: Optional[List[str]] = None, - mode: str = "polished" + mode: str = "polished", + user_id: Optional[str] = None, + validate_subsequent_operations: bool = False ) -> Dict[str, Any]: """ Generate grounded content using native Google Search grounding. @@ -66,12 +68,49 @@ class GeminiGroundedProvider: content_type: Type of content to generate temperature: Creativity level (0.0-1.0) max_tokens: Maximum tokens in response + urls: Optional list of URLs for URL Context tool + mode: Content mode ("draft" or "polished") + user_id: User ID for subscription checking (required if validate_subsequent_operations=True) + validate_subsequent_operations: If True, validates Google Grounding + 3 LLM calls for research workflow Returns: Dictionary containing generated content and grounding metadata """ try: - logger.info(f"Generating grounded content for {content_type} using native Google Search") + # PRE-FLIGHT VALIDATION: If this is part of a research workflow, validate ALL operations + # MUST happen BEFORE any API calls - return immediately if validation fails + if validate_subsequent_operations: + if not user_id: + raise ValueError("user_id is required when validate_subsequent_operations=True") + + from services.database import get_db + from services.subscription import PricingService + from services.subscription.preflight_validator import validate_research_operations + from fastapi import HTTPException + import os + + db = next(get_db()) + try: + pricing_service = PricingService(db) + gpt_provider = os.getenv("GPT_PROVIDER", "google") + + # Validate ALL research operations before making ANY API calls + # This prevents wasteful external API calls if subsequent LLM calls would fail + # Raises HTTPException immediately if validation fails - frontend gets immediate response + validate_research_operations( + pricing_service=pricing_service, + user_id=user_id, + gpt_provider=gpt_provider + ) + except HTTPException as http_ex: + # Re-raise immediately - don't proceed with API call + logger.error(f"[Gemini Grounded] ❌ Pre-flight validation failed - blocking API call") + raise + finally: + db.close() + + logger.info(f"[Gemini Grounded] ✅ Pre-flight validation passed - proceeding with API call") + logger.info(f"[Gemini Grounded] Generating grounded content for {content_type} using native Google Search") # Build the grounded prompt grounded_prompt = self._build_grounded_prompt(prompt, content_type) diff --git a/backend/services/llm_providers/main_image_generation.py b/backend/services/llm_providers/main_image_generation.py index 9b058c65..93fb7a0c 100644 --- a/backend/services/llm_providers/main_image_generation.py +++ b/backend/services/llm_providers/main_image_generation.py @@ -40,7 +40,38 @@ def _get_provider(provider_name: str): raise ValueError(f"Unknown image provider: {provider_name}") -def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None) -> ImageGenerationResult: +def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> ImageGenerationResult: + """Generate image with pre-flight validation. + + Args: + prompt: Image generation prompt + options: Image generation options (provider, model, width, height, etc.) + user_id: User ID for subscription checking (optional, but required for validation) + """ + # PRE-FLIGHT VALIDATION: Validate image generation before API call + # MUST happen BEFORE any API calls - return immediately if validation fails + if user_id: + from services.database import get_db + from services.subscription import PricingService + from services.subscription.preflight_validator import validate_image_generation_operations + from fastapi import HTTPException + + db = next(get_db()) + try: + pricing_service = PricingService(db) + # Raises HTTPException immediately if validation fails - frontend gets immediate response + validate_image_generation_operations( + pricing_service=pricing_service, + user_id=user_id + ) + except HTTPException as http_ex: + # Re-raise immediately - don't proceed with API call + logger.error(f"[Image Generation] ❌ Pre-flight validation failed - blocking API call") + raise + finally: + db.close() + + logger.info(f"[Image Generation] ✅ Pre-flight validation passed - proceeding with image generation") opts = options or {} provider_name = _select_provider(opts.get("provider")) diff --git a/backend/services/llm_providers/main_text_generation.py b/backend/services/llm_providers/main_text_generation.py index 0fe47867..e7fccae6 100644 --- a/backend/services/llm_providers/main_text_generation.py +++ b/backend/services/llm_providers/main_text_generation.py @@ -7,6 +7,7 @@ migrated from the legacy lib/gpt_providers/text_generation/main_text_generation. import os import json from typing import Optional, Dict, Any +from datetime import datetime from loguru import logger from ..onboarding.api_key_manager import APIKeyManager @@ -14,7 +15,7 @@ from .gemini_provider import gemini_text_response, gemini_structured_json_respon from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response -def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: Optional[Dict[str, Any]] = None) -> str: +def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: Optional[Dict[str, Any]] = None, user_id: str = None) -> str: """ Generate text using Language Model (LLM) based on the provided prompt. @@ -22,9 +23,13 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: prompt (str): The prompt to generate text from. system_prompt (str, optional): Custom system prompt to use instead of the default one. json_struct (dict, optional): JSON schema structure for structured responses. + user_id (str): Clerk user ID for subscription checking (required). Returns: str: Generated text based on the prompt. + + Raises: + RuntimeError: If subscription limits are exceeded or user_id is missing. """ try: logger.info("[llm_text_gen] Starting text generation") @@ -93,6 +98,75 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: logger.debug(f"[llm_text_gen] Using provider: {gpt_provider}, model: {model}") + # Map provider name to APIProvider enum (define at function scope for usage tracking) + from models.subscription_models import APIProvider + provider_enum = None + # Store actual provider name for logging (e.g., "huggingface", "gemini") + actual_provider_name = None + if gpt_provider == "google": + provider_enum = APIProvider.GEMINI + actual_provider_name = "gemini" # Use "gemini" for consistency in logs + elif gpt_provider == "huggingface": + provider_enum = APIProvider.MISTRAL # HuggingFace maps to Mistral enum for usage tracking + actual_provider_name = "huggingface" # Keep actual provider name for logs + + if not provider_enum: + raise RuntimeError(f"Unknown provider {gpt_provider} for subscription checking") + + # SUBSCRIPTION CHECK - Required and strict enforcement + if not user_id: + raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.") + + try: + from services.database import get_db + from services.subscription import UsageTrackingService, PricingService + from models.subscription_models import UsageSummary + + db = next(get_db()) + try: + + usage_service = UsageTrackingService(db) + pricing_service = PricingService(db) + + # Estimate tokens from prompt (input tokens) + # Note: We estimate output tokens conservatively (assume response is similar length to prompt) + # This prevents underestimating total token usage + input_tokens = int(len(prompt.split()) * 1.3) + # Conservative estimate: assume output tokens ≈ input tokens * 1.0 (can be up to max_tokens) + estimated_output_tokens = min(input_tokens, max_tokens) if max_tokens else int(input_tokens * 0.8) + estimated_total_tokens = input_tokens + estimated_output_tokens + + # Check limits using sync method from pricing service (strict enforcement) + can_proceed, message, usage_info = pricing_service.check_usage_limits( + user_id=user_id, + provider=provider_enum, + tokens_requested=estimated_total_tokens, + actual_provider_name=actual_provider_name # Pass actual provider name for correct error messages + ) + + if not can_proceed: + logger.warning(f"[llm_text_gen] Subscription limit exceeded for user {user_id}: {message}") + raise RuntimeError(f"Subscription limit exceeded: {message}") + + # Get current usage for limit checking only + current_period = pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m") + usage = db.query(UsageSummary).filter( + UsageSummary.user_id == user_id, + UsageSummary.billing_period == current_period + ).first() + + # No separate log here - we'll create unified log after API call and usage tracking + + finally: + db.close() + except RuntimeError: + # Re-raise subscription limit errors + raise + except Exception as sub_error: + # STRICT: Fail on subscription check errors + logger.error(f"[llm_text_gen] Subscription check failed for user {user_id}: {sub_error}") + raise RuntimeError(f"Subscription check failed: {str(sub_error)}") + # Construct the system prompt if not provided if system_prompt is None: system_instructions = f"""You are a highly skilled content writer with a knack for creating engaging and informative content. @@ -117,10 +191,12 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: system_instructions = system_prompt # Generate response based on provider + response_text = None + actual_provider_used = gpt_provider try: if gpt_provider == "google": if json_struct: - return gemini_structured_json_response( + response_text = gemini_structured_json_response( prompt=prompt, schema=json_struct, temperature=temperature, @@ -130,7 +206,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: system_prompt=system_instructions ) else: - return gemini_text_response( + response_text = gemini_text_response( prompt=prompt, temperature=temperature, top_p=top_p, @@ -140,7 +216,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: ) elif gpt_provider == "huggingface": if json_struct: - return huggingface_structured_json_response( + response_text = huggingface_structured_json_response( prompt=prompt, schema=json_struct, model=model, @@ -149,7 +225,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: system_prompt=system_instructions ) else: - return huggingface_text_response( + response_text = huggingface_text_response( prompt=prompt, model=model, temperature=temperature, @@ -160,6 +236,107 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: else: logger.error(f"[llm_text_gen] Unknown provider: {gpt_provider}") raise RuntimeError("Unknown LLM provider. Supported providers: google, huggingface") + + # TRACK USAGE after successful API call + if response_text: + logger.info(f"[llm_text_gen] ✅ API call successful, tracking usage for user {user_id}, provider {provider_enum.value}") + try: + db_track = next(get_db()) + try: + # Estimate tokens from prompt and response + tokens_input = estimated_tokens # Already calculated above + tokens_output = int(len(str(response_text).split()) * 1.3) # Estimate output tokens + tokens_total = tokens_input + tokens_output + + logger.debug(f"[llm_text_gen] Token estimates: input={tokens_input}, output={tokens_output}, total={tokens_total}") + + # Get or create usage summary + from models.subscription_models import UsageSummary + from services.subscription import PricingService + + pricing = PricingService(db_track) + current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m") + + logger.debug(f"[llm_text_gen] Looking for usage summary: user_id={user_id}, period={current_period}") + + summary = db_track.query(UsageSummary).filter( + UsageSummary.user_id == user_id, + UsageSummary.billing_period == current_period + ).first() + + if not summary: + logger.info(f"[llm_text_gen] Creating new usage summary for user {user_id}, period {current_period}") + summary = UsageSummary( + user_id=user_id, + billing_period=current_period + ) + db_track.add(summary) + db_track.flush() # Ensure summary is persisted before updating + + # Get "before" state for unified log + provider_name = provider_enum.value + current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0 + + # Update provider-specific counters (sync operation) + new_calls = current_calls_before + 1 + setattr(summary, f"{provider_name}_calls", new_calls) + logger.debug(f"[llm_text_gen] Updated {provider_name}_calls: {current_calls_before} -> {new_calls}") + + # Update token usage for LLM providers + if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]: + current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0 + new_tokens = current_tokens_before + tokens_total + setattr(summary, f"{provider_name}_tokens", new_tokens) + logger.debug(f"[llm_text_gen] Updated {provider_name}_tokens: {current_tokens_before} -> {new_tokens}") + else: + current_tokens_before = 0 + new_tokens = 0 + + # Update totals + old_total_calls = summary.total_calls or 0 + old_total_tokens = summary.total_tokens or 0 + summary.total_calls = old_total_calls + 1 + summary.total_tokens = old_total_tokens + tokens_total + logger.debug(f"[llm_text_gen] Updated totals: calls {old_total_calls} -> {summary.total_calls}, tokens {old_total_tokens} -> {summary.total_tokens}") + + # Get plan details for unified log + limits = pricing.get_user_limits(user_id) + plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown' + tier = limits.get('tier', 'unknown') if limits else 'unknown' + call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0 + token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) if limits else 0 + + # Get image stats for unified log + current_images_before = getattr(summary, "stability_calls", 0) or 0 + image_limit = limits['limits'].get("stability_calls", 0) if limits else 0 + + db_track.commit() + logger.info(f"[llm_text_gen] ✅ Successfully tracked usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens") + + # UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message + # Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral") + # Include image stats in the log + print(f""" +[SUBSCRIPTION] LLM Text Generation +├─ User: {user_id} +├─ Plan: {plan_name} ({tier}) +├─ Provider: {actual_provider_name} +├─ Model: {model} +├─ Calls: {current_calls_before} → {new_calls} / {call_limit if call_limit > 0 else '∞'} +├─ Tokens: {current_tokens_before} → {new_tokens} / {token_limit if token_limit > 0 else '∞'} +├─ Images: {current_images_before} / {image_limit if image_limit > 0 else '∞'} +└─ Status: ✅ Allowed & Tracked +""") + except Exception as track_error: + logger.error(f"[llm_text_gen] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True) + db_track.rollback() + finally: + db_track.close() + except Exception as usage_error: + # Non-blocking: log error but don't fail the request + logger.error(f"[llm_text_gen] ❌ Failed to track usage: {usage_error}", exc_info=True) + + return response_text except Exception as provider_error: logger.error(f"[llm_text_gen] Provider {gpt_provider} failed: {str(provider_error)}") @@ -171,9 +348,21 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: fallback_provider = fallback_providers[0] # Only try the first available try: logger.info(f"[llm_text_gen] Trying SINGLE fallback provider: {fallback_provider}") + actual_provider_used = fallback_provider + + # Update provider enum for fallback + if fallback_provider == "google": + provider_enum = APIProvider.GEMINI + actual_provider_name = "gemini" + fallback_model = "gemini-2.0-flash-lite" + elif fallback_provider == "huggingface": + provider_enum = APIProvider.MISTRAL + actual_provider_name = "huggingface" + fallback_model = "openai/gpt-oss-120b:groq" + if fallback_provider == "google": if json_struct: - return gemini_structured_json_response( + response_text = gemini_structured_json_response( prompt=prompt, schema=json_struct, temperature=temperature, @@ -183,7 +372,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: system_prompt=system_instructions ) else: - return gemini_text_response( + response_text = gemini_text_response( prompt=prompt, temperature=temperature, top_p=top_p, @@ -193,7 +382,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: ) elif fallback_provider == "huggingface": if json_struct: - return huggingface_structured_json_response( + response_text = huggingface_structured_json_response( prompt=prompt, schema=json_struct, model="openai/gpt-oss-120b:groq", @@ -202,7 +391,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: system_prompt=system_instructions ) else: - return huggingface_text_response( + response_text = huggingface_text_response( prompt=prompt, model="openai/gpt-oss-120b:groq", temperature=temperature, @@ -210,6 +399,96 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: top_p=top_p, system_prompt=system_instructions ) + + # TRACK USAGE after successful fallback call + if response_text: + logger.info(f"[llm_text_gen] ✅ Fallback API call successful, tracking usage for user {user_id}, provider {provider_enum.value}") + try: + db_track = next(get_db()) + try: + # Estimate tokens from prompt and response + tokens_input = estimated_tokens + tokens_output = int(len(str(response_text).split()) * 1.3) + tokens_total = tokens_input + tokens_output + + # Get or create usage summary + from models.subscription_models import UsageSummary + from services.subscription import PricingService + + pricing = PricingService(db_track) + current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m") + + summary = db_track.query(UsageSummary).filter( + UsageSummary.user_id == user_id, + UsageSummary.billing_period == current_period + ).first() + + if not summary: + summary = UsageSummary( + user_id=user_id, + billing_period=current_period + ) + db_track.add(summary) + db_track.flush() # Ensure summary is persisted before updating + + # Get "before" state for unified log + provider_name = provider_enum.value + current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0 + + # Update provider-specific counters (sync operation) + new_calls = current_calls_before + 1 + setattr(summary, f"{provider_name}_calls", new_calls) + + # Update token usage for LLM providers + if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]: + current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0 + new_tokens = current_tokens_before + tokens_total + setattr(summary, f"{provider_name}_tokens", new_tokens) + else: + current_tokens_before = 0 + new_tokens = 0 + + # Update totals + summary.total_calls = (summary.total_calls or 0) + 1 + summary.total_tokens = (summary.total_tokens or 0) + tokens_total + + # Get plan details for unified log + limits = pricing.get_user_limits(user_id) + plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown' + tier = limits.get('tier', 'unknown') if limits else 'unknown' + call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0 + token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) if limits else 0 + + # Get image stats for unified log + current_images_before = getattr(summary, "stability_calls", 0) or 0 + image_limit = limits['limits'].get("stability_calls", 0) if limits else 0 + + db_track.commit() + logger.info(f"[llm_text_gen] ✅ Successfully tracked fallback usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens") + + # UNIFIED SUBSCRIPTION LOG for fallback + # Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral") + # Include image stats in the log + print(f""" +[SUBSCRIPTION] LLM Text Generation (Fallback) +├─ User: {user_id} +├─ Plan: {plan_name} ({tier}) +├─ Provider: {actual_provider_name} +├─ Model: {fallback_model} +├─ Calls: {current_calls_before} → {new_calls} / {call_limit if call_limit > 0 else '∞'} +├─ Tokens: {current_tokens_before} → {new_tokens} / {token_limit if token_limit > 0 else '∞'} +├─ Images: {current_images_before} / {image_limit if image_limit > 0 else '∞'} +└─ Status: ✅ Allowed & Tracked +""") + except Exception as track_error: + logger.error(f"[llm_text_gen] ❌ Error tracking fallback usage (non-blocking): {track_error}", exc_info=True) + db_track.rollback() + finally: + db_track.close() + except Exception as usage_error: + logger.error(f"[llm_text_gen] ❌ Failed to track fallback usage: {usage_error}", exc_info=True) + + return response_text except Exception as fallback_error: logger.error(f"[llm_text_gen] Fallback provider {fallback_provider} also failed: {str(fallback_error)}") diff --git a/backend/services/monitoring_data_service.py b/backend/services/monitoring_data_service.py index b783a170..a7ccfa30 100644 --- a/backend/services/monitoring_data_service.py +++ b/backend/services/monitoring_data_service.py @@ -55,6 +55,14 @@ class MonitoringDataService: alert_threshold=task_data.get('alertThreshold', ''), status='active' ) + + # Initialize next_execution based on frequency + from services.scheduler.utils.frequency_calculator import calculate_next_execution + task.next_execution = calculate_next_execution( + frequency=task.frequency, + base_time=datetime.utcnow() + ) + self.db.add(task) # Save activation status @@ -357,3 +365,80 @@ class MonitoringDataService: logger.error(f"Error updating performance metrics for strategy {strategy_id}: {e}") self.db.rollback() return False + + def get_user_execution_logs( + self, + user_id: int, + limit: Optional[int] = 50, + offset: Optional[int] = 0, + status_filter: Optional[str] = None + ) -> List[Dict[str, Any]]: + """ + Get execution logs for a specific user. + + Args: + user_id: User ID to filter execution logs + limit: Maximum number of logs to return + offset: Number of logs to skip (for pagination) + status_filter: Optional status filter ('success', 'failed', 'running', 'skipped') + + Returns: + List of execution log dictionaries with task details + """ + try: + logger.info(f"Getting execution logs for user {user_id}") + + # Build query for execution logs filtered by user_id + query = self.db.query(TaskExecutionLog).filter( + TaskExecutionLog.user_id == user_id + ) + + # Apply status filter if provided + if status_filter: + query = query.filter(TaskExecutionLog.status == status_filter) + + # Order by execution date (most recent first) + query = query.order_by(desc(TaskExecutionLog.execution_date)) + + # Apply pagination + if limit: + query = query.limit(limit) + if offset: + query = query.offset(offset) + + logs = query.all() + + # Convert to dictionaries with task details + logs_data = [] + for log in logs: + # Get task details if available + task = self.db.query(MonitoringTask).filter( + MonitoringTask.id == log.task_id + ).first() + + log_data = { + "id": log.id, + "task_id": log.task_id, + "user_id": log.user_id, + "execution_date": log.execution_date.isoformat() if log.execution_date else None, + "status": log.status, + "result_data": log.result_data, + "error_message": log.error_message, + "execution_time_ms": log.execution_time_ms, + "created_at": log.created_at.isoformat() if log.created_at else None, + "task": { + "title": task.task_title if task else None, + "description": task.task_description if task else None, + "assignee": task.assignee if task else None, + "frequency": task.frequency if task else None, + "strategy_id": task.strategy_id if task else None + } if task else None + } + logs_data.append(log_data) + + logger.info(f"Retrieved {len(logs_data)} execution logs for user {user_id}") + return logs_data + + except Exception as e: + logger.error(f"Error getting execution logs for user {user_id}: {e}") + return [] diff --git a/backend/services/scheduler/__init__.py b/backend/services/scheduler/__init__.py new file mode 100644 index 00000000..6628e173 --- /dev/null +++ b/backend/services/scheduler/__init__.py @@ -0,0 +1,59 @@ +""" +Task Scheduler Package +Modular, pluggable scheduler for ALwrity tasks. +""" + +from .core.scheduler import TaskScheduler +from .core.executor_interface import TaskExecutor, TaskExecutionResult +from .core.exception_handler import ( + SchedulerExceptionHandler, SchedulerException, SchedulerErrorType, SchedulerErrorSeverity, + TaskExecutionError, DatabaseError, TaskLoaderError, SchedulerConfigError +) +from .executors.monitoring_task_executor import MonitoringTaskExecutor +from .utils.task_loader import load_due_monitoring_tasks + +# Global scheduler instance (initialized on first access) +_scheduler_instance: TaskScheduler = None + + +def get_scheduler() -> TaskScheduler: + """ + Get global scheduler instance (singleton pattern). + + Returns: + TaskScheduler instance + """ + global _scheduler_instance + if _scheduler_instance is None: + _scheduler_instance = TaskScheduler( + check_interval_minutes=15, + max_concurrent_executions=10 + ) + + # Register monitoring task executor + monitoring_executor = MonitoringTaskExecutor() + _scheduler_instance.register_executor( + 'monitoring_task', + monitoring_executor, + load_due_monitoring_tasks + ) + + return _scheduler_instance + + +__all__ = [ + 'TaskScheduler', + 'TaskExecutor', + 'TaskExecutionResult', + 'MonitoringTaskExecutor', + 'get_scheduler', + # Exception handling + 'SchedulerExceptionHandler', + 'SchedulerException', + 'SchedulerErrorType', + 'SchedulerErrorSeverity', + 'TaskExecutionError', + 'DatabaseError', + 'TaskLoaderError', + 'SchedulerConfigError' +] diff --git a/backend/services/scheduler/core/__init__.py b/backend/services/scheduler/core/__init__.py new file mode 100644 index 00000000..73e680c1 --- /dev/null +++ b/backend/services/scheduler/core/__init__.py @@ -0,0 +1,4 @@ +""" +Core scheduler components. +""" + diff --git a/backend/services/scheduler/core/exception_handler.py b/backend/services/scheduler/core/exception_handler.py new file mode 100644 index 00000000..48349233 --- /dev/null +++ b/backend/services/scheduler/core/exception_handler.py @@ -0,0 +1,395 @@ +""" +Comprehensive Exception Handling and Logging for Task Scheduler +Provides robust error handling, logging, and monitoring for the scheduler system. +""" + +import traceback +import sys +from datetime import datetime +from typing import Dict, Any, Optional, Union +from enum import Enum +from sqlalchemy.orm import Session +from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError + +from utils.logger_utils import get_service_logger + +logger = get_service_logger("scheduler_exception_handler") + + +class SchedulerErrorType(Enum): + """Error types for scheduler system.""" + DATABASE_ERROR = "database_error" + TASK_EXECUTION_ERROR = "task_execution_error" + TASK_LOADER_ERROR = "task_loader_error" + SCHEDULER_CONFIG_ERROR = "scheduler_config_error" + RETRY_ERROR = "retry_error" + CONCURRENCY_ERROR = "concurrency_error" + TIMEOUT_ERROR = "timeout_error" + + +class SchedulerErrorSeverity(Enum): + """Severity levels for scheduler errors.""" + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + CRITICAL = "critical" + + +class SchedulerException(Exception): + """Base exception for scheduler system errors.""" + + def __init__( + self, + message: str, + error_type: SchedulerErrorType, + severity: SchedulerErrorSeverity = SchedulerErrorSeverity.MEDIUM, + user_id: Optional[int] = None, + task_id: Optional[int] = None, + task_type: Optional[str] = None, + context: Optional[Dict[str, Any]] = None, + original_error: Optional[Exception] = None + ): + self.message = message + self.error_type = error_type + self.severity = severity + self.user_id = user_id + self.task_id = task_id + self.task_type = task_type + self.context = context or {} + self.original_error = original_error + self.timestamp = datetime.utcnow() + + # Capture stack trace if original error provided + self.stack_trace = None + if self.original_error: + try: + exc_type, exc_value, exc_traceback = sys.exc_info() + if exc_traceback: + self.stack_trace = ''.join(traceback.format_exception( + exc_type, exc_value, exc_traceback + )) + else: + self.stack_trace = traceback.format_exception( + type(self.original_error), + self.original_error, + self.original_error.__traceback__ + ) + except Exception: + self.stack_trace = str(self.original_error) + + super().__init__(message) + + def to_dict(self) -> Dict[str, Any]: + """Convert exception to dictionary for logging/storage.""" + return { + "message": self.message, + "error_type": self.error_type.value, + "severity": self.severity.value, + "user_id": self.user_id, + "task_id": self.task_id, + "task_type": self.task_type, + "context": self.context, + "timestamp": self.timestamp.isoformat() if isinstance(self.timestamp, datetime) else self.timestamp, + "original_error": str(self.original_error) if self.original_error else None, + "stack_trace": self.stack_trace + } + + def __str__(self): + return f"[{self.error_type.value}] {self.message}" + + +class DatabaseError(SchedulerException): + """Exception raised for database-related errors.""" + + def __init__( + self, + message: str, + user_id: Optional[int] = None, + task_id: Optional[int] = None, + context: Dict[str, Any] = None, + original_error: Exception = None + ): + super().__init__( + message=message, + error_type=SchedulerErrorType.DATABASE_ERROR, + severity=SchedulerErrorSeverity.CRITICAL, + user_id=user_id, + task_id=task_id, + context=context or {}, + original_error=original_error + ) + + +class TaskExecutionError(SchedulerException): + """Exception raised for task execution failures.""" + + def __init__( + self, + message: str, + user_id: Optional[int] = None, + task_id: Optional[int] = None, + task_type: Optional[str] = None, + retry_count: int = 0, + execution_time_ms: Optional[int] = None, + context: Dict[str, Any] = None, + original_error: Exception = None + ): + context = context or {} + context.update({ + "retry_count": retry_count, + "execution_time_ms": execution_time_ms + }) + + super().__init__( + message=message, + error_type=SchedulerErrorType.TASK_EXECUTION_ERROR, + severity=SchedulerErrorSeverity.HIGH, + user_id=user_id, + task_id=task_id, + task_type=task_type, + context=context, + original_error=original_error + ) + + +class TaskLoaderError(SchedulerException): + """Exception raised for task loading failures.""" + + def __init__( + self, + message: str, + task_type: Optional[str] = None, + user_id: Optional[int] = None, + context: Dict[str, Any] = None, + original_error: Exception = None + ): + super().__init__( + message=message, + error_type=SchedulerErrorType.TASK_LOADER_ERROR, + severity=SchedulerErrorSeverity.HIGH, + user_id=user_id, + task_type=task_type, + context=context or {}, + original_error=original_error + ) + + +class SchedulerConfigError(SchedulerException): + """Exception raised for scheduler configuration errors.""" + + def __init__( + self, + message: str, + context: Dict[str, Any] = None, + original_error: Exception = None + ): + super().__init__( + message=message, + error_type=SchedulerErrorType.SCHEDULER_CONFIG_ERROR, + severity=SchedulerErrorSeverity.CRITICAL, + context=context or {}, + original_error=original_error + ) + + +class SchedulerExceptionHandler: + """Comprehensive exception handler for the scheduler system.""" + + def __init__(self, db: Session = None): + self.db = db + self.logger = logger + + def handle_exception( + self, + error: Union[Exception, SchedulerException], + context: Dict[str, Any] = None, + log_level: str = "error" + ) -> Dict[str, Any]: + """Handle and log scheduler exceptions.""" + + context = context or {} + + # Convert regular exceptions to SchedulerException + if not isinstance(error, SchedulerException): + error = SchedulerException( + message=str(error), + error_type=self._classify_error(error), + severity=self._determine_severity(error), + context=context, + original_error=error + ) + + # Log the error + error_data = error.to_dict() + error_data.update(context) + + log_message = f"Scheduler Error: {error.message}" + + if log_level == "critical" or error.severity == SchedulerErrorSeverity.CRITICAL: + self.logger.critical(log_message, extra={"error_data": error_data}) + elif log_level == "error" or error.severity == SchedulerErrorSeverity.HIGH: + self.logger.error(log_message, extra={"error_data": error_data}) + elif log_level == "warning" or error.severity == SchedulerErrorSeverity.MEDIUM: + self.logger.warning(log_message, extra={"error_data": error_data}) + else: + self.logger.info(log_message, extra={"error_data": error_data}) + + # Store critical errors in database for alerting + if error.severity in [SchedulerErrorSeverity.HIGH, SchedulerErrorSeverity.CRITICAL]: + self._store_error_alert(error) + + # Return formatted error response + return self._format_error_response(error) + + def _classify_error(self, error: Exception) -> SchedulerErrorType: + """Classify an exception into a scheduler error type.""" + + error_str = str(error).lower() + error_type_name = type(error).__name__.lower() + + # Database errors + if isinstance(error, (SQLAlchemyError, OperationalError, IntegrityError)): + return SchedulerErrorType.DATABASE_ERROR + if "database" in error_str or "sql" in error_type_name or "connection" in error_str: + return SchedulerErrorType.DATABASE_ERROR + + # Timeout errors + if "timeout" in error_str or "timed out" in error_str: + return SchedulerErrorType.TIMEOUT_ERROR + + # Concurrency errors + if "concurrent" in error_str or "race" in error_str or "lock" in error_str: + return SchedulerErrorType.CONCURRENCY_ERROR + + # Task execution errors + if "task" in error_str and "execut" in error_str: + return SchedulerErrorType.TASK_EXECUTION_ERROR + + # Task loader errors + if "load" in error_str and "task" in error_str: + return SchedulerErrorType.TASK_LOADER_ERROR + + # Retry errors + if "retry" in error_str: + return SchedulerErrorType.RETRY_ERROR + + # Config errors + if "config" in error_str or "scheduler" in error_str and "init" in error_str: + return SchedulerErrorType.SCHEDULER_CONFIG_ERROR + + # Default to task execution error for unknown errors + return SchedulerErrorType.TASK_EXECUTION_ERROR + + def _determine_severity(self, error: Exception) -> SchedulerErrorSeverity: + """Determine the severity of an error.""" + + error_str = str(error).lower() + error_type = type(error) + + # Critical errors + if isinstance(error, (SQLAlchemyError, OperationalError, ConnectionError)): + return SchedulerErrorSeverity.CRITICAL + if "database" in error_str or "connection" in error_str: + return SchedulerErrorSeverity.CRITICAL + + # High severity errors + if "timeout" in error_str or "concurrent" in error_str: + return SchedulerErrorSeverity.HIGH + if isinstance(error, (KeyError, AttributeError)) and "config" in error_str: + return SchedulerErrorSeverity.HIGH + + # Medium severity errors + if "task" in error_str or "execution" in error_str: + return SchedulerErrorSeverity.MEDIUM + + # Default to low + return SchedulerErrorSeverity.LOW + + def _store_error_alert(self, error: SchedulerException): + """Store critical errors in database for alerting.""" + + if not self.db: + return + + try: + # Import here to avoid circular dependencies + from models.monitoring_models import TaskExecutionLog + + # Store as failed execution log if we have task_id (even without user_id for system errors) + if error.task_id: + try: + execution_log = TaskExecutionLog( + task_id=error.task_id, + user_id=error.user_id, # Can be None for system-level errors + execution_date=error.timestamp, + status='failed', + error_message=error.message, + result_data={ + "error_type": error.error_type.value, + "severity": error.severity.value, + "context": error.context, + "stack_trace": error.stack_trace, + "task_type": error.task_type + } + ) + self.db.add(execution_log) + self.db.commit() + self.logger.info(f"Stored error alert in execution log for task {error.task_id}") + except Exception as e: + self.logger.error(f"Failed to store error in execution log: {e}") + self.db.rollback() + # Note: For errors without task_id, we rely on structured logging only + # Future: Could create a separate scheduler_error_logs table for system-level errors + + except Exception as e: + self.logger.error(f"Failed to store error alert: {e}") + + def _format_error_response(self, error: SchedulerException) -> Dict[str, Any]: + """Format error for API response or logging.""" + + response = { + "success": False, + "error": { + "type": error.error_type.value, + "message": error.message, + "severity": error.severity.value, + "timestamp": error.timestamp.isoformat() if isinstance(error.timestamp, datetime) else str(error.timestamp), + "user_id": error.user_id, + "task_id": error.task_id, + "task_type": error.task_type + } + } + + # Add context for debugging (non-sensitive info only) + if error.context: + safe_context = { + k: v for k, v in error.context.items() + if k not in ["password", "token", "key", "secret", "credential"] + } + response["error"]["context"] = safe_context + + # Add user-friendly message based on error type + user_messages = { + SchedulerErrorType.DATABASE_ERROR: + "A database error occurred while processing the task. Please try again later.", + SchedulerErrorType.TASK_EXECUTION_ERROR: + "The task failed to execute. Please check the task configuration and try again.", + SchedulerErrorType.TASK_LOADER_ERROR: + "Failed to load tasks. The scheduler may be experiencing issues.", + SchedulerErrorType.SCHEDULER_CONFIG_ERROR: + "The scheduler configuration is invalid. Contact support.", + SchedulerErrorType.RETRY_ERROR: + "Task retry failed. The task will be rescheduled.", + SchedulerErrorType.CONCURRENCY_ERROR: + "A concurrency issue occurred. The task will be retried.", + SchedulerErrorType.TIMEOUT_ERROR: + "The task execution timed out. The task will be retried." + } + + response["error"]["user_message"] = user_messages.get( + error.error_type, + "An error occurred while processing the task." + ) + + return response + diff --git a/backend/services/scheduler/core/executor_interface.py b/backend/services/scheduler/core/executor_interface.py new file mode 100644 index 00000000..b3742ce0 --- /dev/null +++ b/backend/services/scheduler/core/executor_interface.py @@ -0,0 +1,75 @@ +""" +Task Executor Interface +Abstract base class for all task executors. +""" + +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional +from dataclasses import dataclass +from datetime import datetime +from sqlalchemy.orm import Session + + +@dataclass +class TaskExecutionResult: + """Result of task execution.""" + success: bool + error_message: Optional[str] = None + result_data: Optional[Dict[str, Any]] = None + execution_time_ms: Optional[int] = None + retryable: bool = True + retry_delay: int = 300 # seconds + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary.""" + return { + 'success': self.success, + 'error_message': self.error_message, + 'result_data': self.result_data, + 'execution_time_ms': self.execution_time_ms, + 'retryable': self.retryable, + 'retry_delay': self.retry_delay + } + + +class TaskExecutor(ABC): + """ + Abstract base class for task executors. + + Each task type must implement this interface to be schedulable. + """ + + @abstractmethod + async def execute_task(self, task: Any, db: Session) -> TaskExecutionResult: + """ + Execute a task. + + Args: + task: Task instance from database + db: Database session + + Returns: + TaskExecutionResult with execution details + """ + pass + + @abstractmethod + def calculate_next_execution( + self, + task: Any, + frequency: str, + last_execution: Optional[datetime] = None + ) -> datetime: + """ + Calculate next execution time based on frequency. + + Args: + task: Task instance + frequency: Task frequency (e.g., 'Daily', 'Weekly') + last_execution: Last execution datetime + + Returns: + Next execution datetime + """ + pass + diff --git a/backend/services/scheduler/core/scheduler.py b/backend/services/scheduler/core/scheduler.py new file mode 100644 index 00000000..35e8197b --- /dev/null +++ b/backend/services/scheduler/core/scheduler.py @@ -0,0 +1,628 @@ +""" +Core Task Scheduler Service +Pluggable task scheduler that can work with any task model. +""" + +import asyncio +import logging +from typing import Dict, Any, Optional, List, Callable +from datetime import datetime +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from apscheduler.triggers.cron import CronTrigger +from apscheduler.triggers.interval import IntervalTrigger +from sqlalchemy.orm import Session + +from .executor_interface import TaskExecutor, TaskExecutionResult +from .task_registry import TaskRegistry +from .exception_handler import ( + SchedulerExceptionHandler, SchedulerException, TaskExecutionError, DatabaseError, + TaskLoaderError, SchedulerConfigError +) +from services.database import get_db_session +from utils.logger_utils import get_service_logger + +logger = get_service_logger("task_scheduler") + + +class TaskScheduler: + """ + Pluggable task scheduler that can work with any task model. + + Features: + - Async task execution + - Plugin-based executor system + - Database-backed task persistence + - Configurable check intervals + - Automatic retry logic + """ + + def __init__( + self, + check_interval_minutes: int = 15, + max_concurrent_executions: int = 10, + enable_retries: bool = True, + max_retries: int = 3 + ): + """ + Initialize the task scheduler. + + Args: + check_interval_minutes: How often to check for due tasks + max_concurrent_executions: Maximum concurrent task executions + enable_retries: Whether to retry failed tasks + max_retries: Maximum retry attempts + """ + self.check_interval_minutes = check_interval_minutes + self.max_concurrent_executions = max_concurrent_executions + self.enable_retries = enable_retries + self.max_retries = max_retries + + # Initialize APScheduler + self.scheduler = AsyncIOScheduler( + timezone='UTC', + job_defaults={ + 'coalesce': True, + 'max_instances': 1, + 'misfire_grace_time': 300 # 5 minutes grace period + } + ) + + # Task executor registry + self.registry = TaskRegistry() + + # Track running executions + self.active_executions: Dict[str, asyncio.Task] = {} + + # Exception handler for robust error handling + self.exception_handler = SchedulerExceptionHandler() + + # Intelligent scheduling configuration + self.min_check_interval_minutes = 15 # Check every 15min when active strategies exist + self.max_check_interval_minutes = 60 # Check every 60min when no active strategies + self.current_check_interval_minutes = check_interval_minutes # Current interval + + # Statistics + self.stats = { + 'total_checks': 0, + 'tasks_found': 0, + 'tasks_executed': 0, + 'tasks_failed': 0, + 'tasks_skipped': 0, + 'last_check': None, + 'per_user_stats': {}, # Track metrics per user for user isolation + 'active_strategies_count': 0, # Track active strategies with tasks + 'last_interval_adjustment': None # Track when interval was last adjusted + } + + self._running = False + + def _get_trigger_for_interval(self, interval_minutes: int): + """ + Get the appropriate trigger for the given interval. + + For intervals >= 60 minutes, use IntervalTrigger. + For intervals < 60 minutes, use CronTrigger. + + Args: + interval_minutes: Interval in minutes + + Returns: + Appropriate APScheduler trigger + """ + if interval_minutes >= 60: + # Use IntervalTrigger for intervals >= 60 minutes + return IntervalTrigger(minutes=interval_minutes) + else: + # Use CronTrigger for intervals < 60 minutes (valid range: 0-59) + return CronTrigger(minute=f'*/{interval_minutes}') + + def register_executor( + self, + task_type: str, + executor: TaskExecutor, + task_loader: Callable[[Session], List[Any]] + ): + """ + Register a task executor for a specific task type. + + Args: + task_type: Unique identifier for task type (e.g., 'monitoring_task') + executor: TaskExecutor instance that handles execution + task_loader: Function that loads due tasks from database + """ + self.registry.register(task_type, executor, task_loader) + logger.info(f"Registered executor for task type: {task_type}") + + async def start(self): + """Start the scheduler with intelligent interval adjustment.""" + if self._running: + logger.warning("Scheduler is already running") + return + + try: + # Determine initial check interval based on active strategies + initial_interval = await self._determine_optimal_interval() + self.current_check_interval_minutes = initial_interval + + # Add periodic job to check for due tasks + self.scheduler.add_job( + self._check_and_execute_due_tasks, + trigger=self._get_trigger_for_interval(initial_interval), + id='check_due_tasks', + replace_existing=True + ) + + self.scheduler.start() + self._running = True + + logger.info( + f"Task scheduler started | " + f"check_interval={initial_interval}min | " + f"registered_types={self.registry.get_registered_types()}" + ) + + except Exception as e: + logger.error(f"Failed to start scheduler: {e}") + raise + + async def stop(self): + """Stop the scheduler gracefully.""" + if not self._running: + return + + try: + # Cancel all active executions + for task_id, execution_task in self.active_executions.items(): + execution_task.cancel() + + # Wait for active executions to complete (with timeout) + if self.active_executions: + await asyncio.wait( + self.active_executions.values(), + timeout=30 + ) + + # Shutdown scheduler + self.scheduler.shutdown(wait=True) + self._running = False + + logger.info("Task scheduler stopped gracefully") + + except Exception as e: + logger.error(f"Error stopping scheduler: {e}") + raise + + async def _check_and_execute_due_tasks(self): + """ + Main scheduler loop: check for due tasks and execute them. + This runs periodically with intelligent interval adjustment based on active strategies. + """ + self.stats['total_checks'] += 1 + self.stats['last_check'] = datetime.utcnow().isoformat() + + logger.debug("Checking for due tasks...") + + db = None + try: + db = get_db_session() + if db is None: + logger.error("Failed to get database session") + return + + # Check for active strategies and adjust interval intelligently + await self._adjust_check_interval_if_needed(db) + + # Check each registered task type + for task_type in self.registry.get_registered_types(): + await self._process_task_type(task_type, db) + + except Exception as e: + error = DatabaseError( + message=f"Error checking for due tasks: {str(e)}", + original_error=e + ) + self.exception_handler.handle_exception(error) + finally: + if db: + db.close() + + async def _determine_optimal_interval(self) -> int: + """ + Determine optimal check interval based on active strategies. + + Returns: + Optimal check interval in minutes + """ + db = None + try: + db = get_db_session() + if db: + from services.active_strategy_service import ActiveStrategyService + active_strategy_service = ActiveStrategyService(db_session=db) + active_count = active_strategy_service.count_active_strategies_with_tasks() + self.stats['active_strategies_count'] = active_count + + if active_count > 0: + logger.info(f"Found {active_count} active strategies with tasks - using {self.min_check_interval_minutes}min interval") + return self.min_check_interval_minutes + else: + logger.info(f"No active strategies with tasks - using {self.max_check_interval_minutes}min interval") + return self.max_check_interval_minutes + except Exception as e: + logger.warning(f"Error determining optimal interval: {e}, using default {self.min_check_interval_minutes}min") + finally: + if db: + db.close() + + # Default to shorter interval on error (safer) + return self.min_check_interval_minutes + + async def _adjust_check_interval_if_needed(self, db: Session): + """ + Intelligently adjust check interval based on active strategies. + + If there are active strategies with tasks, check more frequently. + If there are no active strategies, check less frequently. + + Args: + db: Database session + """ + try: + from services.active_strategy_service import ActiveStrategyService + + active_strategy_service = ActiveStrategyService(db_session=db) + active_count = active_strategy_service.count_active_strategies_with_tasks() + self.stats['active_strategies_count'] = active_count + + # Determine optimal interval + if active_count > 0: + optimal_interval = self.min_check_interval_minutes + else: + optimal_interval = self.max_check_interval_minutes + + # Only reschedule if interval needs to change + if optimal_interval != self.current_check_interval_minutes: + logger.info( + f"Adjusting scheduler interval: {self.current_check_interval_minutes}min → {optimal_interval}min | " + f"active_strategies={active_count}" + ) + + # Reschedule the job with new interval + self.scheduler.modify_job( + 'check_due_tasks', + trigger=self._get_trigger_for_interval(optimal_interval) + ) + + self.current_check_interval_minutes = optimal_interval + self.stats['last_interval_adjustment'] = datetime.utcnow().isoformat() + + logger.info(f"Scheduler interval adjusted to {optimal_interval}min") + + except Exception as e: + logger.warning(f"Error adjusting check interval: {e}") + + async def trigger_interval_adjustment(self): + """ + Trigger immediate interval adjustment check. + + This should be called when a strategy is activated or deactivated + to immediately adjust the scheduler interval based on current active strategies. + """ + if not self._running: + logger.debug("Scheduler not running, skipping interval adjustment") + return + + try: + db = get_db_session() + if db: + await self._adjust_check_interval_if_needed(db) + else: + logger.warning("Could not get database session for interval adjustment") + except Exception as e: + logger.warning(f"Error triggering interval adjustment: {e}") + + async def _process_task_type(self, task_type: str, db: Session): + """Process due tasks for a specific task type.""" + try: + # Get task loader for this type + try: + task_loader = self.registry.get_task_loader(task_type) + except Exception as e: + error = TaskLoaderError( + message=f"Failed to get task loader for type {task_type}: {str(e)}", + task_type=task_type, + original_error=e + ) + self.exception_handler.handle_exception(error) + return + + # Load due tasks (with error handling) + try: + due_tasks = task_loader(db) + except Exception as e: + error = TaskLoaderError( + message=f"Failed to load due tasks for type {task_type}: {str(e)}", + task_type=task_type, + original_error=e + ) + self.exception_handler.handle_exception(error) + return + + if not due_tasks: + return + + self.stats['tasks_found'] += len(due_tasks) + logger.info(f"Found {len(due_tasks)} due tasks for type: {task_type}") + + # Execute tasks (with concurrency limit) + execution_tasks = [] + for task in due_tasks: + if len(self.active_executions) >= self.max_concurrent_executions: + logger.warning( + f"Max concurrent executions reached ({self.max_concurrent_executions}), " + f"skipping {len(due_tasks) - len(execution_tasks)} tasks" + ) + break + + # Execute task asynchronously + # Note: Each task gets its own database session to prevent concurrent access issues + execution_task = asyncio.create_task( + self._execute_task_async(task_type, task) + ) + + task_id = f"{task_type}_{getattr(task, 'id', id(task))}" + self.active_executions[task_id] = execution_task + + execution_tasks.append(execution_task) + + # Wait for executions to complete (with timeout per task) + if execution_tasks: + await asyncio.wait(execution_tasks, timeout=300) + + except Exception as e: + error = TaskLoaderError( + message=f"Error processing task type {task_type}: {str(e)}", + task_type=task_type, + original_error=e + ) + self.exception_handler.handle_exception(error) + + async def _execute_task_async(self, task_type: str, task: Any): + """ + Execute a single task asynchronously with user isolation. + + Each task gets its own database session to prevent concurrent access issues, + as SQLAlchemy sessions are not async-safe or concurrent-safe. + + User context is extracted and tracked for user isolation. + + Args: + task_type: Type of task + task: Task instance from database (detached from original session) + """ + task_id = f"{task_type}_{getattr(task, 'id', id(task))}" + db = None + user_id = None + + try: + # Extract user context if available (for user isolation tracking) + try: + if hasattr(task, 'strategy') and task.strategy: + user_id = getattr(task.strategy, 'user_id', None) + elif hasattr(task, 'strategy_id') and task.strategy_id: + # Will query user_id after we have db session + pass + except Exception as e: + logger.debug(f"Could not extract user_id before execution for task {task_id}: {e}") + + logger.info(f"Executing task: {task_id} | user_id: {user_id}") + + # Create a new database session for this async task + # SQLAlchemy sessions are not async-safe and cannot be shared across concurrent tasks + db = get_db_session() + if db is None: + error = DatabaseError( + message=f"Failed to get database session for task {task_id}", + user_id=user_id, + task_id=getattr(task, 'id', None), + task_type=task_type + ) + self.exception_handler.handle_exception(error, log_level="error") + self.stats['tasks_failed'] += 1 + self._update_user_stats(user_id, success=False) + return + + # Set database session for exception handler + self.exception_handler.db = db + + # Merge the detached task object into this session + # The task object was loaded in a different session and is now detached + from sqlalchemy.orm import object_session + if object_session(task) is None: + # Task is detached, need to merge it into this session + task = db.merge(task) + + # Extract user_id after merge if not already available + if user_id is None and hasattr(task, 'strategy'): + try: + if task.strategy: + user_id = getattr(task.strategy, 'user_id', None) + elif hasattr(task, 'strategy_id'): + # Query strategy if relationship not loaded + from models.enhanced_strategy_models import EnhancedContentStrategy + strategy = db.query(EnhancedContentStrategy).filter( + EnhancedContentStrategy.id == task.strategy_id + ).first() + if strategy: + user_id = strategy.user_id + except Exception as e: + logger.debug(f"Could not extract user_id after merge for task {task_id}: {e}") + + # Get executor for this task type + try: + executor = self.registry.get_executor(task_type) + except Exception as e: + from .exception_handler import SchedulerConfigError + error = SchedulerConfigError( + message=f"Failed to get executor for task type {task_type}: {str(e)}", + user_id=user_id, + context={ + "task_id": getattr(task, 'id', None), + "task_type": task_type + }, + original_error=e + ) + self.exception_handler.handle_exception(error) + self.stats['tasks_failed'] += 1 + self._update_user_stats(user_id, success=False) + return + + # Execute task with its own session (with error handling) + try: + result = await executor.execute_task(task, db) + + # Handle result and update statistics + if result.success: + self.stats['tasks_executed'] += 1 + self._update_user_stats(user_id, success=True) + logger.info(f"Task executed successfully: {task_id} | user_id: {user_id}") + else: + self.stats['tasks_failed'] += 1 + self._update_user_stats(user_id, success=False) + + # Create structured error for failed execution + error = TaskExecutionError( + message=result.error_message or "Task execution failed", + user_id=user_id, + task_id=getattr(task, 'id', None), + task_type=task_type, + execution_time_ms=result.execution_time_ms, + context={"result_data": result.result_data} + ) + self.exception_handler.handle_exception(error, log_level="warning") + + # Retry logic if enabled + if self.enable_retries and result.retryable: + await self._schedule_retry(task, result.retry_delay) + + except SchedulerException as e: + # Re-raise scheduler exceptions (they're already handled) + raise + except Exception as e: + # Wrap unexpected exceptions + error = TaskExecutionError( + message=f"Unexpected error during task execution: {str(e)}", + user_id=user_id, + task_id=getattr(task, 'id', None), + task_type=task_type, + original_error=e + ) + self.exception_handler.handle_exception(error) + self.stats['tasks_failed'] += 1 + self._update_user_stats(user_id, success=False) + + except SchedulerException as e: + # Handle scheduler exceptions + self.exception_handler.handle_exception(e) + self.stats['tasks_failed'] += 1 + self._update_user_stats(user_id, success=False) + except Exception as e: + # Handle any other unexpected errors + error = TaskExecutionError( + message=f"Unexpected error in task execution wrapper: {str(e)}", + user_id=user_id, + task_id=getattr(task, 'id', None), + task_type=task_type, + original_error=e + ) + self.exception_handler.handle_exception(error) + self.stats['tasks_failed'] += 1 + self._update_user_stats(user_id, success=False) + finally: + # Clean up database session + if db: + try: + db.close() + except Exception as e: + logger.error(f"Error closing database session for task {task_id}: {e}") + + # Remove from active executions + if task_id in self.active_executions: + del self.active_executions[task_id] + + def _update_user_stats(self, user_id: Optional[int], success: bool): + """ + Update per-user statistics for user isolation tracking. + + Args: + user_id: User ID (None if user context not available) + success: Whether task execution was successful + """ + if user_id is None: + return + + if user_id not in self.stats['per_user_stats']: + self.stats['per_user_stats'][user_id] = { + 'executed': 0, + 'failed': 0, + 'success_rate': 0.0 + } + + user_stats = self.stats['per_user_stats'][user_id] + if success: + user_stats['executed'] += 1 + else: + user_stats['failed'] += 1 + + # Calculate success rate + total = user_stats['executed'] + user_stats['failed'] + if total > 0: + user_stats['success_rate'] = (user_stats['executed'] / total) * 100.0 + + async def _schedule_retry(self, task: Any, delay_seconds: int): + """Schedule a retry for a failed task.""" + # This would update the task's next_execution time + # For now, just log - could be enhanced to update next_execution + logger.debug(f"Scheduling retry for task in {delay_seconds}s") + + def get_stats(self, user_id: Optional[int] = None) -> Dict[str, Any]: + """ + Get scheduler statistics with optional user filtering. + + Args: + user_id: Optional user ID to filter statistics for specific user + + Returns: + Dictionary with scheduler statistics + """ + base_stats = { + **{k: v for k, v in self.stats.items() if k not in ['per_user_stats']}, + 'active_executions': len(self.active_executions), + 'registered_types': self.registry.get_registered_types(), + 'running': self._running, + 'check_interval_minutes': self.current_check_interval_minutes, + 'min_check_interval_minutes': self.min_check_interval_minutes, + 'max_check_interval_minutes': self.max_check_interval_minutes, + 'intelligent_scheduling': True + } + + # Include per-user stats (all users or filtered) + if user_id is not None: + if user_id in self.stats['per_user_stats']: + base_stats['user_stats'] = self.stats['per_user_stats'][user_id] + else: + base_stats['user_stats'] = { + 'executed': 0, + 'failed': 0, + 'success_rate': 0.0 + } + else: + # Include all per-user stats (for admin/debugging) + base_stats['per_user_stats'] = self.stats['per_user_stats'] + + return base_stats + + def is_running(self) -> bool: + """Check if scheduler is running.""" + return self._running + diff --git a/backend/services/scheduler/core/task_registry.py b/backend/services/scheduler/core/task_registry.py new file mode 100644 index 00000000..61abb8b9 --- /dev/null +++ b/backend/services/scheduler/core/task_registry.py @@ -0,0 +1,59 @@ +""" +Task Registry +Manages registration of task executors and loaders. +""" + +import logging +from typing import Dict, Callable, List, Any +from sqlalchemy.orm import Session + +from .executor_interface import TaskExecutor + +logger = logging.getLogger(__name__) + + +class TaskRegistry: + """Registry for task executors and loaders.""" + + def __init__(self): + self.executors: Dict[str, TaskExecutor] = {} + self.task_loaders: Dict[str, Callable[[Session], List[Any]]] = {} + + def register( + self, + task_type: str, + executor: TaskExecutor, + task_loader: Callable[[Session], List[Any]] + ): + """ + Register a task executor and loader. + + Args: + task_type: Unique identifier for task type + executor: TaskExecutor instance + task_loader: Function that loads due tasks from database + """ + if task_type in self.executors: + logger.warning(f"Overwriting existing executor for task type: {task_type}") + + self.executors[task_type] = executor + self.task_loaders[task_type] = task_loader + + logger.info(f"Registered task type: {task_type}") + + def get_executor(self, task_type: str) -> TaskExecutor: + """Get executor for task type.""" + if task_type not in self.executors: + raise ValueError(f"No executor registered for task type: {task_type}") + return self.executors[task_type] + + def get_task_loader(self, task_type: str) -> Callable[[Session], List[Any]]: + """Get task loader for task type.""" + if task_type not in self.task_loaders: + raise ValueError(f"No task loader registered for task type: {task_type}") + return self.task_loaders[task_type] + + def get_registered_types(self) -> List[str]: + """Get list of registered task types.""" + return list(self.executors.keys()) + diff --git a/backend/services/scheduler/executors/__init__.py b/backend/services/scheduler/executors/__init__.py new file mode 100644 index 00000000..b29ace75 --- /dev/null +++ b/backend/services/scheduler/executors/__init__.py @@ -0,0 +1,4 @@ +""" +Task executor implementations. +""" + diff --git a/backend/services/scheduler/executors/monitoring_task_executor.py b/backend/services/scheduler/executors/monitoring_task_executor.py new file mode 100644 index 00000000..493f990e --- /dev/null +++ b/backend/services/scheduler/executors/monitoring_task_executor.py @@ -0,0 +1,266 @@ +""" +Monitoring Task Executor +Handles execution of content strategy monitoring tasks. +""" + +import logging +import time +from datetime import datetime +from typing import Dict, Any, Optional +from sqlalchemy.orm import Session + +from ..core.executor_interface import TaskExecutor, TaskExecutionResult +from ..core.exception_handler import TaskExecutionError, DatabaseError, SchedulerExceptionHandler +from ..utils.frequency_calculator import calculate_next_execution +from models.monitoring_models import MonitoringTask, TaskExecutionLog +from models.enhanced_strategy_models import EnhancedContentStrategy +from utils.logger_utils import get_service_logger + +logger = get_service_logger("monitoring_task_executor") + + +class MonitoringTaskExecutor(TaskExecutor): + """ + Executor for content strategy monitoring tasks. + + Handles: + - ALwrity tasks (automated execution) + - Human tasks (notifications/queuing) + """ + + def __init__(self): + self.logger = logger + self.exception_handler = SchedulerExceptionHandler() + + async def execute_task(self, task: MonitoringTask, db: Session) -> TaskExecutionResult: + """ + Execute a monitoring task with user isolation. + + Args: + task: MonitoringTask instance (with strategy relationship loaded) + db: Database session + + Returns: + TaskExecutionResult + """ + start_time = time.time() + + # Extract user_id from strategy relationship for user isolation + user_id = None + try: + if task.strategy and hasattr(task.strategy, 'user_id'): + user_id = task.strategy.user_id + elif task.strategy_id: + # Fallback: query strategy if relationship not loaded + strategy = db.query(EnhancedContentStrategy).filter( + EnhancedContentStrategy.id == task.strategy_id + ).first() + if strategy: + user_id = strategy.user_id + except Exception as e: + self.logger.warning(f"Could not extract user_id for task {task.id}: {e}") + + try: + self.logger.info( + f"Executing monitoring task: {task.id} | " + f"user_id: {user_id} | " + f"assignee: {task.assignee} | " + f"frequency: {task.frequency}" + ) + + # Create execution log with user_id for user isolation tracking + execution_log = TaskExecutionLog( + task_id=task.id, + user_id=user_id, + execution_date=datetime.utcnow(), + status='running' + ) + db.add(execution_log) + db.flush() + + # Execute based on assignee + if task.assignee == 'ALwrity': + result = await self._execute_alwrity_task(task, db) + else: + result = await self._execute_human_task(task, db) + + # Update execution log + execution_time_ms = int((time.time() - start_time) * 1000) + execution_log.status = 'success' if result.success else 'failed' + execution_log.result_data = result.result_data + execution_log.error_message = result.error_message + execution_log.execution_time_ms = execution_time_ms + + # Update task + task.last_executed = datetime.utcnow() + task.next_execution = self.calculate_next_execution( + task, + task.frequency, + task.last_executed + ) + + if result.success: + task.status = 'completed' + else: + task.status = 'failed' + + db.commit() + + return result + + except Exception as e: + execution_time_ms = int((time.time() - start_time) * 1000) + + # Set database session for exception handler + self.exception_handler.db = db + + # Create structured error + error = TaskExecutionError( + message=f"Error executing monitoring task {task.id}: {str(e)}", + user_id=user_id, + task_id=task.id, + task_type="monitoring_task", + execution_time_ms=execution_time_ms, + context={ + "assignee": task.assignee, + "frequency": task.frequency, + "component": task.component_name + }, + original_error=e + ) + + # Handle exception with structured logging + self.exception_handler.handle_exception(error) + + # Update execution log with error (include user_id for isolation) + try: + execution_log = TaskExecutionLog( + task_id=task.id, + user_id=user_id, + execution_date=datetime.utcnow(), + status='failed', + error_message=str(e), + execution_time_ms=execution_time_ms, + result_data={ + "error_type": error.error_type.value, + "severity": error.severity.value, + "context": error.context + } + ) + db.add(execution_log) + + task.status = 'failed' + task.last_executed = datetime.utcnow() + + db.commit() + except Exception as commit_error: + db_error = DatabaseError( + message=f"Error saving execution log: {str(commit_error)}", + user_id=user_id, + task_id=task.id, + original_error=commit_error + ) + self.exception_handler.handle_exception(db_error) + db.rollback() + + return TaskExecutionResult( + success=False, + error_message=str(e), + execution_time_ms=execution_time_ms, + retryable=True, + retry_delay=300 + ) + + async def _execute_alwrity_task(self, task: MonitoringTask, db: Session) -> TaskExecutionResult: + """ + Execute an ALwrity (automated) monitoring task. + + This is where the actual monitoring logic would go. + For now, we'll implement a placeholder that can be extended. + """ + try: + self.logger.info(f"Executing ALwrity task: {task.task_title}") + + # TODO: Implement actual monitoring logic based on: + # - task.metric + # - task.measurement_method + # - task.success_criteria + # - task.alert_threshold + + # Placeholder: Simulate task execution + result_data = { + 'metric_value': 0, + 'status': 'measured', + 'message': f"Task {task.task_title} executed successfully", + 'timestamp': datetime.utcnow().isoformat() + } + + return TaskExecutionResult( + success=True, + result_data=result_data + ) + + except Exception as e: + self.logger.error(f"Error in ALwrity task execution: {e}") + return TaskExecutionResult( + success=False, + error_message=str(e), + retryable=True + ) + + async def _execute_human_task(self, task: MonitoringTask, db: Session) -> TaskExecutionResult: + """ + Execute a Human monitoring task (notification/queuing). + + For human tasks, we don't execute the task directly, + but rather queue it for human review or send notifications. + """ + try: + self.logger.info(f"Queuing human task: {task.task_title}") + + # TODO: Implement notification/queuing system: + # - Send email notification + # - Add to user's task queue + # - Create in-app notification + + result_data = { + 'status': 'queued', + 'message': f"Task {task.task_title} queued for human review", + 'timestamp': datetime.utcnow().isoformat() + } + + return TaskExecutionResult( + success=True, + result_data=result_data + ) + + except Exception as e: + self.logger.error(f"Error queuing human task: {e}") + return TaskExecutionResult( + success=False, + error_message=str(e), + retryable=True + ) + + def calculate_next_execution( + self, + task: MonitoringTask, + frequency: str, + last_execution: Optional[datetime] = None + ) -> datetime: + """ + Calculate next execution time based on frequency. + + Args: + task: MonitoringTask instance + frequency: Frequency string (Daily, Weekly, Monthly, Quarterly) + last_execution: Last execution datetime (defaults to now) + + Returns: + Next execution datetime + """ + return calculate_next_execution( + frequency=frequency, + base_time=last_execution or datetime.utcnow() + ) + diff --git a/backend/services/scheduler/utils/__init__.py b/backend/services/scheduler/utils/__init__.py new file mode 100644 index 00000000..056d50fb --- /dev/null +++ b/backend/services/scheduler/utils/__init__.py @@ -0,0 +1,4 @@ +""" +Scheduler utilities. +""" + diff --git a/backend/services/scheduler/utils/frequency_calculator.py b/backend/services/scheduler/utils/frequency_calculator.py new file mode 100644 index 00000000..19885a3d --- /dev/null +++ b/backend/services/scheduler/utils/frequency_calculator.py @@ -0,0 +1,33 @@ +""" +Frequency Calculator Utility +Calculates next execution time based on frequency string. +""" + +from datetime import datetime, timedelta +from typing import Optional + + +def calculate_next_execution(frequency: str, base_time: Optional[datetime] = None) -> datetime: + """ + Calculate next execution time based on frequency. + + Args: + frequency: Frequency string ('Daily', 'Weekly', 'Monthly', 'Quarterly') + base_time: Base time to calculate from (defaults to now if None) + + Returns: + Next execution datetime + """ + if base_time is None: + base_time = datetime.utcnow() + + frequency_map = { + 'Daily': timedelta(days=1), + 'Weekly': timedelta(weeks=1), + 'Monthly': timedelta(days=30), + 'Quarterly': timedelta(days=90) + } + + delta = frequency_map.get(frequency, timedelta(days=1)) + return base_time + delta + diff --git a/backend/services/scheduler/utils/task_loader.py b/backend/services/scheduler/utils/task_loader.py new file mode 100644 index 00000000..ce97227d --- /dev/null +++ b/backend/services/scheduler/utils/task_loader.py @@ -0,0 +1,60 @@ +""" +Task Loader Utilities +Functions to load due tasks from database. +""" + +from datetime import datetime +from typing import List, Optional +from sqlalchemy.orm import Session, joinedload +from sqlalchemy import and_, or_ + +from models.monitoring_models import MonitoringTask +from models.enhanced_strategy_models import EnhancedContentStrategy + + +def load_due_monitoring_tasks( + db: Session, + user_id: Optional[int] = None +) -> List[MonitoringTask]: + """ + Load all monitoring tasks that are due for execution. + + Criteria: + - status == 'active' + - next_execution <= now (or is None for first execution) + - Optional: user_id filter for specific user (for future admin features) + + Note: Strategy relationship is eagerly loaded to ensure user_id is accessible + during task execution for user isolation. + + Args: + db: Database session + user_id: Optional user ID to filter tasks (if None, loads all users' tasks) + + Returns: + List of due MonitoringTask instances with strategy relationship loaded + """ + now = datetime.utcnow() + + # Join with strategy to ensure relationship is loaded and support user filtering + query = db.query(MonitoringTask).join( + EnhancedContentStrategy, + MonitoringTask.strategy_id == EnhancedContentStrategy.id + ).options( + joinedload(MonitoringTask.strategy) # Eagerly load strategy relationship + ).filter( + and_( + MonitoringTask.status == 'active', + or_( + MonitoringTask.next_execution <= now, + MonitoringTask.next_execution.is_(None) + ) + ) + ) + + # Apply user filter if provided + if user_id is not None: + query = query.filter(EnhancedContentStrategy.user_id == user_id) + + return query.all() + diff --git a/backend/services/subscription/preflight_validator.py b/backend/services/subscription/preflight_validator.py new file mode 100644 index 00000000..0526cc3c --- /dev/null +++ b/backend/services/subscription/preflight_validator.py @@ -0,0 +1,189 @@ +""" +Pre-flight Validation Utility for Multi-Operation Workflows + +Provides transparent validation for operations that involve multiple API calls. +Services can use this to validate entire workflows before making any external API calls. +""" + +from typing import Dict, Any, List, Optional, Tuple +from fastapi import HTTPException +from loguru import logger + +from services.subscription.pricing_service import PricingService +from models.subscription_models import APIProvider + + +def validate_research_operations( + pricing_service: PricingService, + user_id: str, + gpt_provider: str = "google" +) -> None: + """ + Validate all operations for a research workflow before making ANY API calls. + + This prevents wasteful external API calls (like Google Grounding) if subsequent + LLM calls would fail due to token or call limits. + + Args: + pricing_service: PricingService instance + user_id: User ID for subscription checking + gpt_provider: GPT provider from env var (defaults to "google") + + Returns: + (can_proceed, error_message, error_details) + If can_proceed is False, raises HTTPException with 429 status + """ + try: + # Determine actual provider for LLM calls based on GPT_PROVIDER env var + gpt_provider_lower = gpt_provider.lower() + if gpt_provider_lower == "huggingface": + llm_provider_enum = APIProvider.MISTRAL # Maps to HuggingFace + llm_provider_name = "huggingface" + else: + llm_provider_enum = APIProvider.GEMINI + llm_provider_name = "gemini" + + # Estimate tokens for each operation in research workflow + # Google Grounding call: ~2000 tokens (input + output) + # Keyword analyzer: ~1000 tokens (input: 3000 chars research, output: structured JSON) + # Competitor analyzer: ~1000 tokens (input: 3000 chars research, output: structured JSON) + # Content angle generator: ~1000 tokens (input: 3000 chars research, output: list of angles) + + operations_to_validate = [ + { + 'provider': APIProvider.GEMINI, # Google Grounding uses Gemini + 'tokens_requested': 2000, + 'actual_provider_name': 'gemini', + 'operation_type': 'google_grounding' + }, + { + 'provider': llm_provider_enum, + 'tokens_requested': 1000, + 'actual_provider_name': llm_provider_name, + 'operation_type': 'keyword_analysis' + }, + { + 'provider': llm_provider_enum, + 'tokens_requested': 1000, + 'actual_provider_name': llm_provider_name, + 'operation_type': 'competitor_analysis' + }, + { + 'provider': llm_provider_enum, + 'tokens_requested': 1000, + 'actual_provider_name': llm_provider_name, + 'operation_type': 'content_angle_generation' + } + ] + + logger.info(f"[Pre-flight Validator] 🚀 Starting Research Workflow Validation") + logger.info(f" ├─ User: {user_id}") + logger.info(f" ├─ LLM Provider: {llm_provider_name} (GPT_PROVIDER={gpt_provider})") + logger.info(f" └─ Operations to validate: {len(operations_to_validate)}") + + can_proceed, message, error_details = pricing_service.check_comprehensive_limits( + user_id=user_id, + operations=operations_to_validate + ) + + if not can_proceed: + usage_info = error_details.get('usage_info', {}) if error_details else {} + provider = usage_info.get('provider', llm_provider_name) if usage_info else llm_provider_name + operation_type = usage_info.get('operation_type', 'unknown') + + logger.error(f"[Pre-flight Validator] ❌ RESEARCH WORKFLOW BLOCKED") + logger.error(f" ├─ User: {user_id}") + logger.error(f" ├─ Blocked at: {operation_type}") + logger.error(f" ├─ Provider: {provider}") + logger.error(f" └─ Reason: {message}") + + # Raise HTTPException immediately - frontend gets immediate response, no API calls made + raise HTTPException( + status_code=429, + detail={ + 'error': message, + 'message': message, + 'provider': provider, + 'usage_info': usage_info if usage_info else error_details + } + ) + + logger.info(f"[Pre-flight Validator] ✅ RESEARCH WORKFLOW APPROVED") + logger.info(f" ├─ User: {user_id}") + logger.info(f" └─ All {len(operations_to_validate)} operations validated - proceeding with API calls") + # Validation passed - no return needed (function raises HTTPException if validation fails) + + except HTTPException: + raise + except Exception as e: + logger.error(f"[Pre-flight Validator] Error validating research operations: {e}", exc_info=True) + raise HTTPException( + status_code=500, + detail={ + 'error': f"Failed to validate operations: {str(e)}", + 'message': f"Failed to validate operations: {str(e)}" + } + ) + + +def validate_image_generation_operations( + pricing_service: PricingService, + user_id: str +) -> None: + """ + Validate image generation operation before making API calls. + + Args: + pricing_service: PricingService instance + user_id: User ID for subscription checking + + Returns: + (can_proceed, error_message, error_details) + If can_proceed is False, raises HTTPException with 429 status + """ + try: + operations_to_validate = [ + { + 'provider': APIProvider.STABILITY, + 'tokens_requested': 0, + 'actual_provider_name': 'stability', + 'operation_type': 'image_generation' + } + ] + + can_proceed, message, error_details = pricing_service.check_comprehensive_limits( + user_id=user_id, + operations=operations_to_validate + ) + + if not can_proceed: + logger.error(f"[Pre-flight Validator] Image generation blocked for user {user_id}: {message}") + + usage_info = error_details.get('usage_info', {}) if error_details else {} + provider = usage_info.get('provider', 'stability') if usage_info else 'stability' + + raise HTTPException( + status_code=429, + detail={ + 'error': message, + 'message': message, + 'provider': provider, + 'usage_info': usage_info if usage_info else error_details + } + ) + + logger.info(f"[Pre-flight Validator] ✅ Image generation validated for user {user_id}") + # Validation passed - no return needed (function raises HTTPException if validation fails) + + except HTTPException: + raise + except Exception as e: + logger.error(f"[Pre-flight Validator] Error validating image generation: {e}", exc_info=True) + raise HTTPException( + status_code=500, + detail={ + 'error': f"Failed to validate image generation: {str(e)}", + 'message': f"Failed to validate image generation: {str(e)}" + } + ) + diff --git a/backend/services/subscription/pricing_service.py b/backend/services/subscription/pricing_service.py index 40a2ab09..629759d5 100644 --- a/backend/services/subscription/pricing_service.py +++ b/backend/services/subscription/pricing_service.py @@ -3,10 +3,11 @@ Pricing Service for API Usage Tracking Manages API pricing, cost calculation, and subscription limits. """ -from typing import Dict, Any, Optional, List, Tuple +from typing import Dict, Any, Optional, List, Tuple, Union from decimal import Decimal, ROUND_HALF_UP from datetime import datetime, timedelta from sqlalchemy.orm import Session +from sqlalchemy import text from loguru import logger from models.subscription_models import ( @@ -17,13 +18,17 @@ from models.subscription_models import ( class PricingService: """Service for managing API pricing and cost calculations.""" + # Class-level cache shared across all instances (critical for cache invalidation on subscription renewal) + # key: f"{user_id}:{provider}", value: { 'result': (bool, str, dict), 'expires_at': datetime } + _limits_cache: Dict[str, Dict[str, Any]] = {} + def __init__(self, db: Session): self.db = db self._pricing_cache = {} self._plans_cache = {} - # Lightweight in-process cache for limit checks - # key: f"{user_id}:{provider}", value: { 'result': (bool, str, dict), 'expires_at': datetime } - self._limits_cache: Dict[str, Dict[str, Any]] = {} + # Cache for schema feature detection (ai_text_generation_calls_limit column) + self._ai_text_gen_col_checked: bool = False + self._ai_text_gen_col_available: bool = False # ------------------- Billing period helpers ------------------- def _compute_next_period_end(self, start: datetime, cycle: str) -> datetime: @@ -68,6 +73,15 @@ class PricingService: self._ensure_subscription_current(subscription) # Continue to use YYYY-MM for summaries return datetime.now().strftime("%Y-%m") + + @classmethod + def clear_user_cache(cls, user_id: str) -> int: + """Clear all cached limit checks for a specific user. Returns number of entries cleared.""" + keys_to_remove = [key for key in cls._limits_cache.keys() if key.startswith(f"{user_id}:")] + for key in keys_to_remove: + del cls._limits_cache[key] + logger.info(f"Cleared {len(keys_to_remove)} cache entries for user {user_id}") + return len(keys_to_remove) def initialize_default_pricing(self): """Initialize default pricing for all API providers.""" @@ -292,7 +306,8 @@ class PricingService: "tier": SubscriptionTier.BASIC, "price_monthly": 29.0, "price_yearly": 290.0, - "gemini_calls_limit": 1000, + "ai_text_generation_calls_limit": 10, # Unified limit for all LLM providers + "gemini_calls_limit": 1000, # Legacy, kept for backwards compatibility (not used for enforcement) "openai_calls_limit": 500, "anthropic_calls_limit": 200, "mistral_calls_limit": 500, @@ -300,11 +315,11 @@ class PricingService: "serper_calls_limit": 200, "metaphor_calls_limit": 100, "firecrawl_calls_limit": 100, - "stability_calls_limit": 50, - "gemini_tokens_limit": 1000000, - "openai_tokens_limit": 500000, - "anthropic_tokens_limit": 200000, - "mistral_tokens_limit": 500000, + "stability_calls_limit": 5, + "gemini_tokens_limit": 2000, + "openai_tokens_limit": 2000, + "anthropic_tokens_limit": 2000, + "mistral_tokens_limit": 2000, "monthly_cost_limit": 50.0, "features": ["full_content_generation", "advanced_research", "basic_analytics"], "description": "Great for individuals and small teams" @@ -426,21 +441,60 @@ class PricingService: self._ensure_subscription_current(subscription) return self._plan_to_limits_dict(subscription.plan) + def _ensure_ai_text_gen_column_detection(self) -> None: + """Detect at runtime whether ai_text_generation_calls_limit column exists and cache the result.""" + if self._ai_text_gen_col_checked: + return + try: + # Try to query the column - if it exists, this will work + self.db.execute(text('SELECT ai_text_generation_calls_limit FROM subscription_plans LIMIT 0')) + self._ai_text_gen_col_available = True + except Exception: + self._ai_text_gen_col_available = False + finally: + self._ai_text_gen_col_checked = True + def _plan_to_limits_dict(self, plan: SubscriptionPlan) -> Dict[str, Any]: """Convert subscription plan to limits dictionary.""" + # Detect if unified AI text generation limit column exists + self._ensure_ai_text_gen_column_detection() + + # Use unified AI text generation limit if column exists and is set + ai_text_gen_limit = None + if self._ai_text_gen_col_available: + try: + ai_text_gen_limit = getattr(plan, 'ai_text_generation_calls_limit', None) + # If 0, treat as not set (unlimited for Enterprise or use fallback) + if ai_text_gen_limit == 0: + ai_text_gen_limit = None + except (AttributeError, Exception): + # Column exists but access failed - use fallback + ai_text_gen_limit = None + return { 'plan_name': plan.name, 'tier': plan.tier.value, 'limits': { + # Unified AI text generation limit (applies to all LLM providers) + # If not set, fall back to first non-zero legacy limit for backwards compatibility + 'ai_text_generation_calls': ai_text_gen_limit if ai_text_gen_limit is not None else ( + plan.gemini_calls_limit if plan.gemini_calls_limit > 0 else + plan.openai_calls_limit if plan.openai_calls_limit > 0 else + plan.anthropic_calls_limit if plan.anthropic_calls_limit > 0 else + plan.mistral_calls_limit if plan.mistral_calls_limit > 0 else 0 + ), + # Legacy per-provider limits (for backwards compatibility and analytics) 'gemini_calls': plan.gemini_calls_limit, 'openai_calls': plan.openai_calls_limit, 'anthropic_calls': plan.anthropic_calls_limit, 'mistral_calls': plan.mistral_calls_limit, + # Other API limits 'tavily_calls': plan.tavily_calls_limit, 'serper_calls': plan.serper_calls_limit, 'metaphor_calls': plan.metaphor_calls_limit, 'firecrawl_calls': plan.firecrawl_calls_limit, 'stability_calls': plan.stability_calls_limit, + # Token limits 'gemini_tokens': plan.gemini_tokens_limit, 'openai_tokens': plan.openai_tokens_limit, 'anthropic_tokens': plan.anthropic_tokens_limit, @@ -451,101 +505,293 @@ class PricingService: } def check_usage_limits(self, user_id: str, provider: APIProvider, - tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]: - """Check if user can make an API call within their limits.""" - # Short TTL cache to reduce DB reads under sustained traffic - cache_key = f"{user_id}:{provider.value}" - now = datetime.utcnow() - cached = self._limits_cache.get(cache_key) - if cached and cached.get('expires_at') and cached['expires_at'] > now: - return tuple(cached['result']) # type: ignore - - # Get user limits - limits = self.get_user_limits(user_id) - if not limits: - return False, "No subscription plan found", {} + tokens_requested: int = 0, actual_provider_name: Optional[str] = None) -> Tuple[bool, str, Dict[str, Any]]: + """Check if user can make an API call within their limits. - # Get current usage for this billing period - current_period = self.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m") - usage = self.db.query(UsageSummary).filter( - UsageSummary.user_id == user_id, - UsageSummary.billing_period == current_period - ).first() - - if not usage: - # First usage this period, create summary - usage = UsageSummary( - user_id=user_id, - billing_period=current_period - ) - self.db.add(usage) - self.db.commit() - - # Check call limits - provider_name = provider.value - current_calls = getattr(usage, f"{provider_name}_calls", 0) - call_limit = limits['limits'].get(f"{provider_name}_calls", 0) - - if call_limit > 0 and current_calls >= call_limit: - result = (False, f"API call limit reached for {provider_name}", { - 'current_calls': current_calls, - 'limit': call_limit, - 'usage_percentage': 100.0 - }) - self._limits_cache[cache_key] = { - 'result': result, - 'expires_at': now + timedelta(seconds=30) - } - return result - - # Check token limits for LLM providers - if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]: - current_tokens = getattr(usage, f"{provider_name}_tokens", 0) - token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) + Args: + user_id: User ID + provider: APIProvider enum (may be MISTRAL for HuggingFace) + tokens_requested: Estimated tokens for the request + actual_provider_name: Optional actual provider name (e.g., "huggingface" when provider is MISTRAL) + """ + try: + # Use actual_provider_name if provided, otherwise use enum value + # This fixes cases where HuggingFace maps to MISTRAL enum but should show as "huggingface" in errors + display_provider_name = actual_provider_name or provider.value - if token_limit > 0 and (current_tokens + tokens_requested) > token_limit: - result = (False, f"Token limit would be exceeded for {provider_name}", { - 'current_tokens': current_tokens, - 'requested_tokens': tokens_requested, - 'limit': token_limit, - 'usage_percentage': ((current_tokens + tokens_requested) / token_limit) * 100 + logger.debug(f"[Subscription Check] Starting limit check for user {user_id}, provider {display_provider_name}, tokens {tokens_requested}") + + # Short TTL cache to reduce DB reads under sustained traffic + cache_key = f"{user_id}:{provider.value}" + now = datetime.utcnow() + cached = self._limits_cache.get(cache_key) + if cached and cached.get('expires_at') and cached['expires_at'] > now: + logger.debug(f"[Subscription Check] Using cached result for {user_id}:{provider.value}") + return tuple(cached['result']) # type: ignore + + # Get user subscription first to check expiration + subscription = self.db.query(UserSubscription).filter( + UserSubscription.user_id == user_id, + UserSubscription.is_active == True + ).first() + + if subscription: + logger.debug(f"[Subscription Check] Found subscription for user {user_id}: plan_id={subscription.plan_id}, period_end={subscription.current_period_end}") + else: + logger.debug(f"[Subscription Check] No active subscription found for user {user_id}") + + # Check subscription expiration (STRICT: deny if expired) + if subscription: + if subscription.current_period_end < now: + logger.warning(f"[Subscription Check] Subscription expired for user {user_id}: period_end={subscription.current_period_end}, now={now}") + # Subscription expired - check if auto_renew is enabled + if not getattr(subscription, 'auto_renew', False): + # Expired and no auto-renew - deny access + logger.warning(f"[Subscription Check] Subscription expired for user {user_id}, auto_renew=False, denying access") + result = (False, "Subscription expired. Please renew your subscription to continue using the service.", { + 'expired': True, + 'period_end': subscription.current_period_end.isoformat() + }) + self._limits_cache[cache_key] = { + 'result': result, + 'expires_at': now + timedelta(seconds=30) + } + return result + else: + # Try to auto-renew + if not self._ensure_subscription_current(subscription): + # Auto-renew failed - deny access + result = (False, "Subscription expired and auto-renewal failed. Please renew manually.", { + 'expired': True, + 'auto_renew_failed': True + }) + self._limits_cache[cache_key] = { + 'result': result, + 'expires_at': now + timedelta(seconds=30) + } + return result + + # Get user limits with error handling (STRICT: fail on errors) + try: + limits = self.get_user_limits(user_id) + if limits: + logger.debug(f"[Subscription Check] Retrieved limits for user {user_id}: plan={limits.get('plan_name')}, tier={limits.get('tier')}") + else: + logger.debug(f"[Subscription Check] No limits found for user {user_id}, checking free tier") + except Exception as e: + logger.error(f"[Subscription Check] Error getting user limits for {user_id}: {e}", exc_info=True) + # STRICT: Fail closed - deny request if we can't check limits + return False, f"Failed to retrieve subscription limits: {str(e)}", {} + + if not limits: + # No subscription found - check for free tier + free_plan = self.db.query(SubscriptionPlan).filter( + SubscriptionPlan.tier == SubscriptionTier.FREE, + SubscriptionPlan.is_active == True + ).first() + if free_plan: + logger.info(f"[Subscription Check] Assigning free tier to user {user_id}") + limits = self._plan_to_limits_dict(free_plan) + else: + # No subscription and no free tier - deny access + logger.warning(f"[Subscription Check] No subscription or free tier found for user {user_id}, denying access") + return False, "No subscription plan found. Please subscribe to a plan.", {} + + # Get current usage for this billing period with error handling + try: + current_period = self.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m") + usage = self.db.query(UsageSummary).filter( + UsageSummary.user_id == user_id, + UsageSummary.billing_period == current_period + ).first() + + if not usage: + # First usage this period, create summary + try: + usage = UsageSummary( + user_id=user_id, + billing_period=current_period + ) + self.db.add(usage) + self.db.commit() + except Exception as create_error: + logger.error(f"Error creating usage summary: {create_error}") + self.db.rollback() + # STRICT: Fail closed on DB error + return False, f"Failed to create usage summary: {str(create_error)}", {} + except Exception as e: + logger.error(f"Error getting usage summary for {user_id}: {e}") + self.db.rollback() + # STRICT: Fail closed on DB error + return False, f"Failed to retrieve usage summary: {str(e)}", {} + + # Check call limits with error handling + # NOTE: call_limit = 0 means UNLIMITED (Enterprise plans) + try: + # Use display_provider_name for error messages, but provider.value for DB queries + provider_name = provider.value # For DB field names (e.g., "mistral_calls", "mistral_tokens") + + # For LLM text generation providers, check against unified total_calls limit + llm_providers = ['gemini', 'openai', 'anthropic', 'mistral'] + is_llm_provider = provider_name in llm_providers + + if is_llm_provider: + # Use unified AI text generation limit (total_calls across all LLM providers) + ai_text_gen_limit = limits['limits'].get('ai_text_generation_calls', 0) or 0 + + # If unified limit not set, fall back to provider-specific limit for backwards compatibility + if ai_text_gen_limit == 0: + ai_text_gen_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0 + + # Calculate total LLM provider calls (sum of gemini + openai + anthropic + mistral) + current_total_llm_calls = ( + (usage.gemini_calls or 0) + + (usage.openai_calls or 0) + + (usage.anthropic_calls or 0) + + (usage.mistral_calls or 0) + ) + + # Only enforce limit if limit > 0 (0 means unlimited for Enterprise) + if ai_text_gen_limit > 0 and current_total_llm_calls >= ai_text_gen_limit: + logger.error(f"[Subscription Check] AI text generation call limit exceeded for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit} (provider: {display_provider_name})") + result = (False, f"AI text generation call limit reached. Used {current_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls this billing period.", { + 'current_calls': current_total_llm_calls, + 'limit': ai_text_gen_limit, + 'usage_percentage': (current_total_llm_calls / ai_text_gen_limit) * 100 if ai_text_gen_limit > 0 else 0, + 'provider': display_provider_name, # Use display name for consistency + 'usage_info': { + 'provider': display_provider_name, # Use display name for user-facing info + 'current_calls': current_total_llm_calls, + 'limit': ai_text_gen_limit, + 'type': 'ai_text_generation', + 'breakdown': { + 'gemini': usage.gemini_calls or 0, + 'openai': usage.openai_calls or 0, + 'anthropic': usage.anthropic_calls or 0, + 'mistral': usage.mistral_calls or 0 # DB field name (not display name) + } + } + }) + self._limits_cache[cache_key] = { + 'result': result, + 'expires_at': now + timedelta(seconds=30) + } + return result + else: + logger.debug(f"[Subscription Check] AI text generation limit check passed for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit if ai_text_gen_limit > 0 else 'unlimited'} (provider: {display_provider_name})") + else: + # For non-LLM providers, check provider-specific limit + current_calls = getattr(usage, f"{provider_name}_calls", 0) or 0 + call_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0 + + # Only enforce limit if limit > 0 (0 means unlimited for Enterprise) + if call_limit > 0 and current_calls >= call_limit: + logger.error(f"[Subscription Check] Call limit exceeded for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit}") + result = (False, f"API call limit reached for {display_provider_name}. Used {current_calls} of {call_limit} calls this billing period.", { + 'current_calls': current_calls, + 'limit': call_limit, + 'usage_percentage': 100.0, + 'provider': display_provider_name # Use display name for consistency + }) + self._limits_cache[cache_key] = { + 'result': result, + 'expires_at': now + timedelta(seconds=30) + } + return result + else: + logger.debug(f"[Subscription Check] Call limit check passed for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit if call_limit > 0 else 'unlimited'}") + except Exception as e: + logger.error(f"Error checking call limits: {e}") + # Continue to next check + + # Check token limits for LLM providers with error handling + # NOTE: token_limit = 0 means UNLIMITED (Enterprise plans) + try: + if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]: + current_tokens = getattr(usage, f"{provider_name}_tokens", 0) or 0 + token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0 + + # Only enforce limit if limit > 0 (0 means unlimited for Enterprise) + if token_limit > 0 and (current_tokens + tokens_requested) > token_limit: + result = (False, f"Token limit would be exceeded for {display_provider_name}. Current: {current_tokens}, Requested: {tokens_requested}, Limit: {token_limit}", { + 'current_tokens': current_tokens, + 'requested_tokens': tokens_requested, + 'limit': token_limit, + 'usage_percentage': ((current_tokens + tokens_requested) / token_limit) * 100, + 'provider': display_provider_name, # Use display name in error details + 'usage_info': { + 'provider': display_provider_name, + 'current_tokens': current_tokens, + 'requested_tokens': tokens_requested, + 'limit': token_limit, + 'type': 'tokens' + } + }) + self._limits_cache[cache_key] = { + 'result': result, + 'expires_at': now + timedelta(seconds=30) + } + return result + except Exception as e: + logger.error(f"Error checking token limits: {e}") + # Continue to next check + + # Check cost limits with error handling + # NOTE: cost_limit = 0 means UNLIMITED (Enterprise plans) + try: + cost_limit = limits['limits'].get('monthly_cost', 0) or 0 + # Only enforce limit if limit > 0 (0 means unlimited for Enterprise) + if cost_limit > 0 and usage.total_cost >= cost_limit: + result = (False, f"Monthly cost limit reached. Current cost: ${usage.total_cost:.2f}, Limit: ${cost_limit:.2f}", { + 'current_cost': usage.total_cost, + 'limit': cost_limit, + 'usage_percentage': 100.0 + }) + self._limits_cache[cache_key] = { + 'result': result, + 'expires_at': now + timedelta(seconds=30) + } + return result + except Exception as e: + logger.error(f"Error checking cost limits: {e}") + # Continue to success case + + # Calculate usage percentages for warnings + try: + # Determine which call variables to use based on provider type + if is_llm_provider: + # Use unified LLM call tracking + current_call_count = current_total_llm_calls + call_limit_value = ai_text_gen_limit + else: + # Use provider-specific call tracking + current_call_count = current_calls + call_limit_value = call_limit + + call_usage_pct = (current_call_count / max(call_limit_value, 1)) * 100 if call_limit_value > 0 else 0 + cost_usage_pct = (usage.total_cost / max(cost_limit, 1)) * 100 if cost_limit > 0 else 0 + result = (True, "Within limits", { + 'current_calls': current_call_count, + 'call_limit': call_limit_value, + 'call_usage_percentage': call_usage_pct, + 'current_cost': usage.total_cost, + 'cost_limit': cost_limit, + 'cost_usage_percentage': cost_usage_pct }) self._limits_cache[cache_key] = { 'result': result, 'expires_at': now + timedelta(seconds=30) } return result + except Exception as e: + logger.error(f"Error calculating usage percentages: {e}") + # Return basic success + return True, "Within limits", {} - # Check cost limits - cost_limit = limits['limits'].get('monthly_cost', 0) - if cost_limit > 0 and usage.total_cost >= cost_limit: - result = (False, "Monthly cost limit reached", { - 'current_cost': usage.total_cost, - 'limit': cost_limit, - 'usage_percentage': 100.0 - }) - self._limits_cache[cache_key] = { - 'result': result, - 'expires_at': now + timedelta(seconds=30) - } - return result - - # Calculate usage percentages for warnings - call_usage_pct = (current_calls / max(call_limit, 1)) * 100 if call_limit > 0 else 0 - cost_usage_pct = (usage.total_cost / max(cost_limit, 1)) * 100 if cost_limit > 0 else 0 - result = (True, "Within limits", { - 'current_calls': current_calls, - 'call_limit': call_limit, - 'call_usage_percentage': call_usage_pct, - 'current_cost': usage.total_cost, - 'cost_limit': cost_limit, - 'cost_usage_percentage': cost_usage_pct - }) - self._limits_cache[cache_key] = { - 'result': result, - 'expires_at': now + timedelta(seconds=30) - } - return result + except Exception as e: + logger.error(f"Unexpected error in check_usage_limits for {user_id}: {e}") + # STRICT: Fail closed - deny requests if subscription system fails + return False, f"Subscription check error: {str(e)}", {} def estimate_tokens(self, text: str, provider: APIProvider) -> int: """Estimate token count for text based on provider.""" @@ -581,6 +827,236 @@ class PricingService: if not pricing: return None + def check_comprehensive_limits( + self, + user_id: str, + operations: List[Dict[str, Any]] + ) -> Tuple[bool, Optional[str], Optional[Dict[str, Any]]]: + """ + Comprehensive pre-flight validation that checks ALL limits before making ANY API calls. + + This prevents wasteful API calls by validating that ALL subsequent operations will succeed + before making the first external API call. + + Args: + user_id: User ID + operations: List of operations to validate, each with: + - 'provider': APIProvider enum + - 'tokens_requested': int (estimated tokens for LLM calls, 0 for non-LLM) + - 'actual_provider_name': Optional[str] (e.g., "huggingface" when provider is MISTRAL) + - 'operation_type': str (e.g., "google_grounding", "llm_call", "image_generation") + + Returns: + (can_proceed, error_message, error_details) + If can_proceed is False, error_message explains which limit would be exceeded + """ + try: + logger.info(f"[Pre-flight Check] 🔍 Starting comprehensive validation for user {user_id}") + logger.info(f"[Pre-flight Check] 📋 Validating {len(operations)} operation(s) before making any API calls") + + # Get current usage and limits once + current_period = self.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m") + usage = self.db.query(UsageSummary).filter( + UsageSummary.user_id == user_id, + UsageSummary.billing_period == current_period + ).first() + + if not usage: + # First usage this period, create summary + try: + usage = UsageSummary( + user_id=user_id, + billing_period=current_period + ) + self.db.add(usage) + self.db.commit() + except Exception as create_error: + logger.error(f"Error creating usage summary: {create_error}") + self.db.rollback() + return False, f"Failed to create usage summary: {str(create_error)}", {} + + # Get user limits + limits_dict = self.get_user_limits(user_id) + if not limits_dict: + # No subscription found - check for free tier + free_plan = self.db.query(SubscriptionPlan).filter( + SubscriptionPlan.tier == SubscriptionTier.FREE, + SubscriptionPlan.is_active == True + ).first() + if free_plan: + limits_dict = self._plan_to_limits_dict(free_plan) + else: + return False, "No subscription plan found. Please subscribe to a plan.", {} + + limits = limits_dict.get('limits', {}) + + # Track cumulative usage across all operations + total_llm_calls = ( + (usage.gemini_calls or 0) + + (usage.openai_calls or 0) + + (usage.anthropic_calls or 0) + + (usage.mistral_calls or 0) + ) + total_llm_tokens = {} + total_images = usage.stability_calls or 0 + + # Log current usage summary + logger.info(f"[Pre-flight Check] 📊 Current Usage Summary:") + logger.info(f" └─ Total LLM Calls: {total_llm_calls}") + logger.info(f" └─ Gemini Tokens: {usage.gemini_tokens or 0}, Mistral/HF Tokens: {usage.mistral_tokens or 0}") + logger.info(f" └─ Image Calls: {total_images}") + + # Validate each operation + for op_idx, operation in enumerate(operations): + provider = operation.get('provider') + provider_name = provider.value if hasattr(provider, 'value') else str(provider) + tokens_requested = operation.get('tokens_requested', 0) + actual_provider_name = operation.get('actual_provider_name') + operation_type = operation.get('operation_type', 'unknown') + + display_provider_name = actual_provider_name or provider_name + + logger.info(f"[Pre-flight Check] ✅ Operation {op_idx + 1}/{len(operations)}: {operation_type}") + logger.info(f" ├─ Provider: {display_provider_name} (enum: {provider_name})") + logger.info(f" └─ Estimated Tokens: {tokens_requested}") + + # Check if this is an LLM provider + llm_providers = ['gemini', 'openai', 'anthropic', 'mistral'] + is_llm_provider = provider_name in llm_providers + + # Check unified AI text generation limit for LLM providers + if is_llm_provider: + ai_text_gen_limit = limits.get('ai_text_generation_calls', 0) or 0 + if ai_text_gen_limit == 0: + # Fallback to provider-specific limit + ai_text_gen_limit = limits.get(f"{provider_name}_calls", 0) or 0 + + # Count this operation as an LLM call + projected_total_llm_calls = total_llm_calls + 1 + + if ai_text_gen_limit > 0 and projected_total_llm_calls > ai_text_gen_limit: + error_info = { + 'current_calls': total_llm_calls, + 'limit': ai_text_gen_limit, + 'provider': display_provider_name, + 'operation_type': operation_type, + 'operation_index': op_idx + } + return False, f"AI text generation call limit would be exceeded. Would use {projected_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls.", { + 'error_type': 'call_limit', + 'usage_info': error_info + } + + # Check token limits for this provider + # Use cumulative projected tokens from previous operations, or current from DB if first operation + provider_tokens_key = f"{provider_name}_tokens" + if provider_tokens_key in total_llm_tokens: + # Use cumulative projected tokens from previous operations + current_provider_tokens = total_llm_tokens[provider_tokens_key] + logger.info(f" └─ Using cumulative projected tokens: {current_provider_tokens}") + else: + # First operation for this provider - get current from database + current_provider_tokens = getattr(usage, provider_tokens_key, 0) or 0 + total_llm_tokens[provider_tokens_key] = current_provider_tokens + logger.info(f" └─ Current tokens from DB: {current_provider_tokens}") + + token_limit = limits.get(provider_tokens_key, 0) or 0 + + if token_limit > 0 and tokens_requested > 0: + projected_tokens = current_provider_tokens + tokens_requested + logger.info(f" └─ Token Check: {current_provider_tokens} (current) + {tokens_requested} (requested) = {projected_tokens} (total) / {token_limit} (limit)") + + if projected_tokens > token_limit: + usage_percentage = (projected_tokens / token_limit) * 100 if token_limit > 0 else 0 + error_info = { + 'current_tokens': current_provider_tokens, + 'requested_tokens': tokens_requested, + 'limit': token_limit, + 'provider': display_provider_name, + 'operation_type': operation_type, + 'operation_index': op_idx + } + error_msg = ( + f"Token limit exceeded for {display_provider_name} " + f"({operation_type}). " + f"Current: {current_provider_tokens}/{token_limit}, " + f"Requested: {tokens_requested}, " + f"Would exceed by: {projected_tokens - token_limit} tokens " + f"({usage_percentage:.1f}% of limit)" + ) + logger.error(f"[Pre-flight Check] ❌ BLOCKED: {error_msg}") + return False, error_msg, { + 'error_type': 'token_limit', + 'usage_info': error_info + } + else: + logger.info(f" └─ ✅ Token limit check passed: {projected_tokens} <= {token_limit}") + + # Update cumulative counts for next operation + total_llm_calls = projected_total_llm_calls + total_llm_tokens[provider_tokens_key] += tokens_requested + logger.info(f" └─ Updated cumulative tokens for {display_provider_name}: {total_llm_tokens[provider_tokens_key]}") + + # Check image generation limits + elif provider == APIProvider.STABILITY: + image_limit = limits.get('stability_calls', 0) or 0 + projected_images = total_images + 1 + + if image_limit > 0 and projected_images > image_limit: + error_info = { + 'current_images': total_images, + 'limit': image_limit, + 'provider': 'stability', + 'operation_type': operation_type, + 'operation_index': op_idx + } + return False, f"Image generation limit would be exceeded. Would use {projected_images} of {image_limit} images this billing period.", { + 'error_type': 'image_limit', + 'usage_info': error_info + } + + total_images = projected_images + + # Check other provider-specific limits + else: + provider_calls_key = f"{provider_name}_calls" + current_provider_calls = getattr(usage, provider_calls_key, 0) or 0 + call_limit = limits.get(provider_calls_key, 0) or 0 + + if call_limit > 0: + projected_calls = current_provider_calls + 1 + if projected_calls > call_limit: + error_info = { + 'current_calls': current_provider_calls, + 'limit': call_limit, + 'provider': display_provider_name, + 'operation_type': operation_type, + 'operation_index': op_idx + } + return False, f"API call limit would be exceeded for {display_provider_name}. Would use {projected_calls} of {call_limit} calls this billing period.", { + 'error_type': 'call_limit', + 'usage_info': error_info + } + + # All checks passed + logger.info(f"[Pre-flight Check] ✅ All {len(operations)} operation(s) validated successfully") + logger.info(f"[Pre-flight Check] ✅ User {user_id} is cleared to proceed with API calls") + return True, None, None + + except Exception as e: + logger.error(f"[Pre-flight Check] Error during comprehensive limit check: {e}", exc_info=True) + return False, f"Failed to validate limits: {str(e)}", {} + + def get_pricing_for_provider_model(self, provider: APIProvider, model_name: str) -> Optional[Dict[str, Any]]: + """Get pricing configuration for a specific provider and model.""" + pricing = self.db.query(APIProviderPricing).filter( + APIProviderPricing.provider == provider, + APIProviderPricing.model_name == model_name + ).first() + + if not pricing: + return None + return { 'provider': pricing.provider.value, 'model_name': pricing.model_name, diff --git a/backend/services/subscription/usage_tracking_service.py b/backend/services/subscription/usage_tracking_service.py index de8e64f3..94efd731 100644 --- a/backend/services/subscription/usage_tracking_service.py +++ b/backend/services/subscription/usage_tracking_service.py @@ -502,7 +502,7 @@ class UsageTrackingService: return result async def reset_current_billing_period(self, user_id: str) -> Dict[str, Any]: - """Reset usage status for the current billing period (after plan change).""" + """Reset usage status and counters for the current billing period (after plan renewal/change).""" try: billing_period = datetime.now().strftime("%Y-%m") summary = self.db.query(UsageSummary).filter( @@ -514,11 +514,52 @@ class UsageTrackingService: # Nothing to reset return {"reset": False, "reason": "no_summary"} - # Clear LIMIT_REACHED so the user can resume; keep counters intact + # CRITICAL: Reset ALL usage counters to 0 so user gets fresh limits with new/renewed plan + # Clear LIMIT_REACHED status summary.usage_status = UsageStatus.ACTIVE + + # Reset all LLM provider call counters + summary.gemini_calls = 0 + summary.openai_calls = 0 + summary.anthropic_calls = 0 + summary.mistral_calls = 0 + + # Reset all LLM provider token counters + summary.gemini_tokens = 0 + summary.openai_tokens = 0 + summary.anthropic_tokens = 0 + summary.mistral_tokens = 0 + + # Reset search/research provider counters + summary.tavily_calls = 0 + summary.serper_calls = 0 + summary.metaphor_calls = 0 + summary.firecrawl_calls = 0 + + # Reset image generation counters + summary.stability_calls = 0 + + # Reset cost counters + summary.gemini_cost = 0.0 + summary.openai_cost = 0.0 + summary.anthropic_cost = 0.0 + summary.mistral_cost = 0.0 + summary.tavily_cost = 0.0 + summary.serper_cost = 0.0 + summary.metaphor_cost = 0.0 + summary.firecrawl_cost = 0.0 + summary.stability_cost = 0.0 + + # Reset totals + summary.total_calls = 0 + summary.total_tokens = 0 + summary.total_cost = 0.0 + summary.updated_at = datetime.utcnow() self.db.commit() - return {"reset": True} + + logger.info(f"Reset usage counters for user {user_id} in billing period {billing_period} after renewal") + return {"reset": True, "counters_reset": True} except Exception as e: self.db.rollback() logger.error(f"Error resetting usage status: {e}") diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index f31761b8..53746948 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -58,19 +58,25 @@ const InitialRouteHandler: React.FC = () => { error: null, }); - // Check subscription on mount + // Check subscription on mount (non-blocking - don't wait for it to route) useEffect(() => { - checkSubscription().catch((err) => { - console.error('Error checking subscription:', err); - - // Check if it's a connection error - handle it locally - if (err instanceof Error && (err.name === 'NetworkError' || err.name === 'ConnectionError')) { - setConnectionError({ - hasError: true, - error: err, - }); - } - }); + // Delay subscription check slightly to allow auth token getter to be installed first + const timeoutId = setTimeout(() => { + checkSubscription().catch((err) => { + console.error('Error checking subscription (non-blocking):', err); + + // Check if it's a connection error - handle it locally + if (err instanceof Error && (err.name === 'NetworkError' || err.name === 'ConnectionError')) { + setConnectionError({ + hasError: true, + error: err, + }); + } + // Don't block routing on subscription check errors - allow graceful degradation + }); + }, 100); // Small delay to ensure TokenInstaller has run + + return () => clearTimeout(timeoutId); }, []); // Remove checkSubscription dependency to prevent loop // Initialize onboarding only after subscription is confirmed @@ -125,9 +131,10 @@ const InitialRouteHandler: React.FC = () => { ); } - // Loading state - ensure we wait for onboarding init after subscription is confirmed - const waitingForOnboardingInit = !!subscription && subscription.active && !subscriptionLoading && (loading || !data); - if (subscriptionLoading || loading || waitingForOnboardingInit) { + // Loading state - only wait for onboarding init, not subscription check + // Subscription check is non-blocking and happens in background + const waitingForOnboardingInit = loading || !data; + if (loading || waitingForOnboardingInit) { return ( { ); } - if (!subscription) { - return null; // Should not happen, but just in case + // Decision tree for SIGNED-IN users: + // Priority: Subscription → Onboarding → Dashboard (as per user flow: Landing → Subscription → Onboarding → Dashboard) + + // 1. If subscription is still loading, show loading state + if (subscriptionLoading) { + return ( + + + + Checking subscription... + + + ); } - // Decision tree for SIGNED-IN users: - // Priority: Subscription → Onboarding → Dashboard - - // Check if user is new (no subscription record at all) + // 2. No subscription data yet - handle gracefully + // If onboarding is complete, allow access to dashboard (user already went through flow) + // If onboarding not complete, check if subscription check is still loading or failed + if (!subscription) { + if (isOnboardingComplete) { + console.log('InitialRouteHandler: Onboarding complete but no subscription data → Dashboard (allow access)'); + return ; + } + + // Onboarding not complete and no subscription data + // If subscription check is still loading, show loading state + if (subscriptionLoading) { + return ( + + + + Checking subscription... + + + ); + } + + // Subscription check completed but returned null/undefined + // This likely means no subscription - redirect to pricing + console.log('InitialRouteHandler: No subscription data after check → Pricing page'); + return ; + } + + // 3. Check subscription status first const isNewUser = !subscription || subscription.plan === 'none'; - // 1. No active subscription? → Must subscribe first (even if onboarding is complete) + // No active subscription → Must subscribe first if (isNewUser || !subscription.active) { console.log('InitialRouteHandler: No active subscription → Pricing page'); return ; } - // 2. Has active subscription, check onboarding status + // 4. Has active subscription, check onboarding status if (!isOnboardingComplete) { console.log('InitialRouteHandler: Subscription active but onboarding incomplete → Onboarding'); return ; } - // 3. Has subscription AND completed onboarding → Dashboard + // 5. Has subscription AND completed onboarding → Dashboard console.log('InitialRouteHandler: All set (subscription + onboarding) → Dashboard'); return ; }; diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index bd130ce8..0bde5b6d 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -7,6 +7,24 @@ export const setGlobalSubscriptionErrorHandler = (handler: (error: any) => boole globalSubscriptionErrorHandler = handler; }; +// Export a function to trigger subscription error handler from outside axios interceptors +export const triggerSubscriptionError = (error: any) => { + const status = error?.response?.status; + console.log('triggerSubscriptionError: Received error', { + hasHandler: !!globalSubscriptionErrorHandler, + status, + dataKeys: error?.response?.data ? Object.keys(error.response.data) : null + }); + + if (globalSubscriptionErrorHandler) { + console.log('triggerSubscriptionError: Calling global subscription error handler'); + return globalSubscriptionErrorHandler(error); + } + + console.warn('triggerSubscriptionError: No global subscription error handler registered'); + return false; +}; + // Optional token getter installed from within the app after Clerk is available let authTokenGetter: (() => Promise) | null = null; @@ -64,13 +82,27 @@ apiClient.interceptors.request.use( async (config) => { console.log(`Making ${config.method?.toUpperCase()} request to ${config.url}`); try { - const token = authTokenGetter ? await authTokenGetter() : null; + if (!authTokenGetter) { + console.warn(`[apiClient] ⚠️ authTokenGetter not set for ${config.url} - request may fail authentication`); + console.warn(`[apiClient] This usually means TokenInstaller hasn't run yet. Request will likely fail with 401.`); + } else { + try { + const token = await authTokenGetter(); if (token) { config.headers = config.headers || {}; (config.headers as any)['Authorization'] = `Bearer ${token}`; + console.log(`[apiClient] ✅ Added auth token to request: ${config.url}`); + } else { + console.warn(`[apiClient] ⚠️ authTokenGetter returned null for ${config.url} - user may not be signed in`); + console.warn(`[apiClient] User ID from localStorage: ${localStorage.getItem('user_id') || 'none'}`); + } + } catch (tokenError) { + console.error(`[apiClient] ❌ Error getting auth token for ${config.url}:`, tokenError); + } } } catch (e) { - // non-fatal + console.error(`[apiClient] ❌ Unexpected error in request interceptor for ${config.url}:`, e); + // non-fatal - let the request proceed, backend will return 401 if needed } return config; }, @@ -138,13 +170,17 @@ apiClient.interceptors.response.use( console.error('Token refresh failed:', retryError); } - // If retry failed and not in onboarding, redirect - const isOnboardingRoute = window.location.pathname.includes('/onboarding') || - window.location.pathname === '/'; - if (!isOnboardingRoute) { + // If retry failed, don't redirect during app initialization (root route) + // Only redirect if we're on a protected route and definitely authenticated + const isOnboardingRoute = window.location.pathname.includes('/onboarding'); + const isRootRoute = window.location.pathname === '/'; + + // Don't redirect from root route during app initialization - allow InitialRouteHandler to work + if (!isRootRoute && !isOnboardingRoute) { + // Only redirect if we're definitely not just initializing try { window.location.assign('/'); } catch {} } else { - console.warn('401 Unauthorized - token refresh failed'); + console.warn('401 Unauthorized - token refresh failed (during initialization, not redirecting)'); } } @@ -204,12 +240,14 @@ aiApiClient.interceptors.response.use( console.error('Token refresh failed:', retryError); } - const isOnboardingRoute = window.location.pathname.includes('/onboarding') || - window.location.pathname === '/'; - if (!isOnboardingRoute) { + const isOnboardingRoute = window.location.pathname.includes('/onboarding'); + const isRootRoute = window.location.pathname === '/'; + + // Don't redirect from root route during app initialization + if (!isRootRoute && !isOnboardingRoute) { try { window.location.assign('/'); } catch {} } else { - console.warn('401 Unauthorized - token refresh failed'); + console.warn('401 Unauthorized - token refresh failed (during initialization, not redirecting)'); } } @@ -254,13 +292,15 @@ longRunningApiClient.interceptors.response.use( }, (error) => { if (error?.response?.status === 401) { - // Only redirect on 401 if we're not in onboarding flow - const isOnboardingRoute = window.location.pathname.includes('/onboarding') || - window.location.pathname === '/'; - if (!isOnboardingRoute) { + // Only redirect on 401 if we're not in onboarding flow or root route + const isOnboardingRoute = window.location.pathname.includes('/onboarding'); + const isRootRoute = window.location.pathname === '/'; + + // Don't redirect from root route during app initialization + if (!isRootRoute && !isOnboardingRoute) { try { window.location.assign('/'); } catch {} } else { - console.warn('401 Unauthorized during onboarding - token may need refresh'); + console.warn('401 Unauthorized during initialization - token may need refresh (not redirecting)'); } } // Check if it's a subscription-related error and handle it globally @@ -304,13 +344,15 @@ pollingApiClient.interceptors.response.use( }, (error) => { if (error?.response?.status === 401) { - // Only redirect on 401 if we're not in onboarding flow - const isOnboardingRoute = window.location.pathname.includes('/onboarding') || - window.location.pathname === '/'; - if (!isOnboardingRoute) { + // Only redirect on 401 if we're not in onboarding flow or root route + const isOnboardingRoute = window.location.pathname.includes('/onboarding'); + const isRootRoute = window.location.pathname === '/'; + + // Don't redirect from root route during app initialization + if (!isRootRoute && !isOnboardingRoute) { try { window.location.assign('/'); } catch {} } else { - console.warn('401 Unauthorized during onboarding - token may need refresh'); + console.warn('401 Unauthorized during initialization - token may need refresh (not redirecting)'); } } // Check if it's a subscription-related error and handle it globally diff --git a/frontend/src/components/BlogWriter/BlogWriter.tsx b/frontend/src/components/BlogWriter/BlogWriter.tsx index 7b0af2e1..bf68d89d 100644 --- a/frontend/src/components/BlogWriter/BlogWriter.tsx +++ b/frontend/src/components/BlogWriter/BlogWriter.tsx @@ -66,6 +66,7 @@ export const BlogWriter: React.FC = () => { contentConfirmed, flowAnalysisCompleted, flowAnalysisResults, + sectionImages, setOutline, setTitleOptions, setSelectedTitle, @@ -78,6 +79,7 @@ export const BlogWriter: React.FC = () => { setContentConfirmed, setFlowAnalysisCompleted, setFlowAnalysisResults, + setSectionImages, handleResearchComplete, handleOutlineComplete, handleOutlineError, @@ -670,6 +672,8 @@ export const BlogWriter: React.FC = () => { flowAnalysisResults={flowAnalysisResults} outlineGenRef={outlineGenRef} blogWriterApi={blogWriterApi} + sectionImages={sectionImages} + setSectionImages={setSectionImages} contentConfirmed={contentConfirmed} seoAnalysis={seoAnalysis} seoMetadata={seoMetadata} diff --git a/frontend/src/components/BlogWriter/BlogWriterUtils/PhaseContent.tsx b/frontend/src/components/BlogWriter/BlogWriterUtils/PhaseContent.tsx index 4247d788..71b3c126 100644 --- a/frontend/src/components/BlogWriter/BlogWriterUtils/PhaseContent.tsx +++ b/frontend/src/components/BlogWriter/BlogWriterUtils/PhaseContent.tsx @@ -31,6 +31,8 @@ interface PhaseContentProps { seoMetadata: any; onTitleSelect: any; onCustomTitle: any; + sectionImages?: Record; + setSectionImages?: (images: Record | ((prev: Record) => Record)) => void; } export const PhaseContent: React.FC = ({ @@ -58,7 +60,9 @@ export const PhaseContent: React.FC = ({ seoAnalysis, seoMetadata, onTitleSelect, - onCustomTitle + onCustomTitle, + sectionImages, + setSectionImages }) => { return (
@@ -100,6 +104,8 @@ export const PhaseContent: React.FC = ({ optimizationResults={optimizationResults} researchCoverage={researchCoverage} onRefine={(op: any, id: any, payload: any) => blogWriterApi.refineOutline({ outline, operation: op, section_id: id, payload }).then((res: any) => setOutline(res.outline))} + sectionImages={sectionImages} + setSectionImages={setSectionImages} /> ) : ( @@ -126,6 +132,7 @@ export const PhaseContent: React.FC = ({ onSave={handleContentSave} continuityRefresh={continuityRefresh || undefined} flowAnalysisResults={flowAnalysisResults} + sectionImages={sectionImages} /> ) : (
@@ -151,6 +158,7 @@ export const PhaseContent: React.FC = ({ onSave={handleContentSave} continuityRefresh={continuityRefresh || undefined} flowAnalysisResults={flowAnalysisResults} + sectionImages={sectionImages} /> ) : (
diff --git a/frontend/src/components/BlogWriter/BlogWriterUtils/WixConnectModal.tsx b/frontend/src/components/BlogWriter/BlogWriterUtils/WixConnectModal.tsx new file mode 100644 index 00000000..adf7c225 --- /dev/null +++ b/frontend/src/components/BlogWriter/BlogWriterUtils/WixConnectModal.tsx @@ -0,0 +1,168 @@ +import React, { useState, useEffect } from 'react'; +import { + Dialog, + DialogTitle, + DialogContent, + DialogActions, + Button, + Typography, + Box, + CircularProgress, + Alert +} from '@mui/material'; +import { usePlatformConnections } from '../../../components/OnboardingWizard/common/usePlatformConnections'; + +interface WixConnectModalProps { + isOpen: boolean; + onClose: () => void; + onConnectionSuccess?: () => void; +} + +export const WixConnectModal: React.FC = ({ + isOpen, + onClose, + onConnectionSuccess +}) => { + const { handleConnect, isLoading } = usePlatformConnections(); + const [error, setError] = useState(null); + const [isConnecting, setIsConnecting] = useState(false); + + // Handle OAuth success via postMessage (same pattern as onboarding) + useEffect(() => { + if (!isOpen) return; + + const handler = (event: MessageEvent) => { + const trusted = [window.location.origin, 'https://littery-sonny-unscrutinisingly.ngrok-free.dev']; + if (!trusted.includes(event.origin)) return; + if (!event.data || typeof event.data !== 'object') return; + + if (event.data.type === 'WIX_OAUTH_SUCCESS') { + console.log('Wix OAuth success in modal'); + setIsConnecting(false); + setError(null); + // Close modal and notify parent + if (onConnectionSuccess) { + onConnectionSuccess(); + } + onClose(); + } + + if (event.data.type === 'WIX_OAUTH_ERROR') { + console.error('Wix OAuth error in modal:', event.data.error); + setIsConnecting(false); + setError(event.data.error || 'Wix connection failed. Please try again.'); + } + }; + + window.addEventListener('message', handler); + return () => window.removeEventListener('message', handler); + }, [isOpen, onClose, onConnectionSuccess]); + + // Also check for URL param (fallback for same-tab redirect) + useEffect(() => { + if (!isOpen) return; + + const params = new URLSearchParams(window.location.search); + if (params.get('wix_connected') === 'true') { + console.log('Wix connected via URL param in modal'); + setIsConnecting(false); + setError(null); + if (onConnectionSuccess) { + onConnectionSuccess(); + } + onClose(); + // Clean URL + const clean = window.location.pathname + window.location.hash; + window.history.replaceState({}, document.title, clean || '/'); + } + }, [isOpen, onClose, onConnectionSuccess]); + + const handleConnectClick = async () => { + try { + setIsConnecting(true); + setError(null); + await handleConnect('wix'); + // OAuth will redirect, so we don't need to do anything else here + // The postMessage handler or URL param handler will close the modal + } catch (err: any) { + console.error('Error connecting to Wix:', err); + setIsConnecting(false); + setError(err?.message || 'Failed to start Wix connection. Please try again.'); + } + }; + + return ( + + + + Connect Your Wix Account + + + + + + + Connect your Wix account to publish blog posts directly to your website. + + + {error && ( + + {error} + + )} + + {isConnecting && ( + + + + Opening Wix authorization page... + + + )} + + + + What happens next: + + +
    +
  1. You'll be redirected to Wix to authorize ALwrity
  2. +
  3. Grant permissions for blog creation and publishing
  4. +
  5. You'll be redirected back to ALwrity
  6. +
  7. Your blog post will be published automatically
  8. +
+
+
+
+
+ + + + + +
+ ); +}; + +export default WixConnectModal; + diff --git a/frontend/src/components/BlogWriter/EnhancedOutlineEditor.tsx b/frontend/src/components/BlogWriter/EnhancedOutlineEditor.tsx index 5ff66b4b..aa20e22e 100644 --- a/frontend/src/components/BlogWriter/EnhancedOutlineEditor.tsx +++ b/frontend/src/components/BlogWriter/EnhancedOutlineEditor.tsx @@ -12,6 +12,8 @@ interface Props { groundingInsights?: GroundingInsights | null; optimizationResults?: OptimizationResults | null; researchCoverage?: ResearchCoverage | null; + sectionImages?: Record; + setSectionImages?: (images: Record | ((prev: Record) => Record)) => void; } const EnhancedOutlineEditor: React.FC = ({ @@ -21,14 +23,15 @@ const EnhancedOutlineEditor: React.FC = ({ sourceMappingStats, groundingInsights, optimizationResults, - researchCoverage + researchCoverage, + sectionImages = {}, + setSectionImages }) => { const [editingSection, setEditingSection] = useState(null); const [expandedSections, setExpandedSections] = useState>(new Set()); const [hoveredSection, setHoveredSection] = useState(null); const [showAddSection, setShowAddSection] = useState(false); const [imageModalState, setImageModalState] = useState<{ open: boolean; sectionId?: string }>(() => ({ open: false })); - const [sectionImages, setSectionImages] = useState>({}); const [newSectionData, setNewSectionData] = useState({ heading: '', subheadings: '', @@ -117,8 +120,8 @@ const EnhancedOutlineEditor: React.FC = ({ }; })()} onImageGenerated={(imageBase64, sectionId) => { - if (sectionId) { - setSectionImages(prev => ({ ...prev, [sectionId]: imageBase64 })); + if (sectionId && setSectionImages) { + setSectionImages((prev: Record) => ({ ...prev, [sectionId]: imageBase64 })); } }} /> diff --git a/frontend/src/components/BlogWriter/Publisher.tsx b/frontend/src/components/BlogWriter/Publisher.tsx index d2b0b59c..f54a541b 100644 --- a/frontend/src/components/BlogWriter/Publisher.tsx +++ b/frontend/src/components/BlogWriter/Publisher.tsx @@ -1,7 +1,10 @@ import React, { useState, useEffect } from 'react'; import { useCopilotAction } from '@copilotkit/react-core'; -import { blogWriterApi, BlogSEOMetadataResponse } from '../../services/blogWriterApi'; +import { BlogSEOMetadataResponse } from '../../services/blogWriterApi'; import { apiClient } from '../../api/client'; +import { wordpressAPI, WordPressSite, WordPressPublishRequest } from '../../api/wordpress'; +import { validateAndRefreshWixTokens } from '../../utils/wixTokenUtils'; +import WixConnectModal from './BlogWriterUtils/WixConnectModal'; interface PublisherProps { buildFullMarkdown: () => string; @@ -26,10 +29,15 @@ export const Publisher: React.FC = ({ }) => { const [wixConnectionStatus, setWixConnectionStatus] = useState(null); const [checkingWixStatus, setCheckingWixStatus] = useState(false); + const [wordpressSites, setWordpressSites] = useState([]); + const [checkingWordPressStatus, setCheckingWordPressStatus] = useState(false); + const [showWixConnectModal, setShowWixConnectModal] = useState(false); + const [pendingWixPublish, setPendingWixPublish] = useState<(() => Promise) | null>(null); - // Check Wix connection status on component mount + // Check platform connection statuses on component mount useEffect(() => { checkWixConnectionStatus(); + checkWordPressConnectionStatus(); }, []); const checkWixConnectionStatus = async () => { @@ -48,6 +56,137 @@ export const Publisher: React.FC = ({ setCheckingWixStatus(false); } }; + + const checkWordPressConnectionStatus = async () => { + setCheckingWordPressStatus(true); + try { + const status = await wordpressAPI.getStatus(); + setWordpressSites(status.sites || []); + } catch (error) { + console.error('Failed to check WordPress connection status:', error); + setWordpressSites([]); + } finally { + setCheckingWordPressStatus(false); + } + }; + + // Helper function to publish to Wix + const publishToWix = async (md: string, metadata: BlogSEOMetadataResponse | null, accessToken?: string): Promise => { + // Get access token if not provided + if (!accessToken) { + const tokenResult = await validateAndRefreshWixTokens(); + if (!tokenResult.accessToken) { + return { + success: false, + message: 'Wix tokens not available. Please connect your Wix account.', + action_required: 'connect_wix' + }; + } + accessToken = tokenResult.accessToken; + } + + // Extract title from SEO metadata or markdown + const title = metadata?.seo_title || (() => { + const titleMatch = md.match(/^#\s+(.+)$/m); + return titleMatch ? titleMatch[1] : 'Blog Post from ALwrity'; + })(); + + // Extract cover image URL, skip if base64 (Wix needs HTTP URL) + let coverImageUrl: string | undefined = undefined; + if (metadata?.open_graph?.image) { + const imageUrl = metadata.open_graph.image; + // Skip base64 images - Wix import_image needs HTTP/HTTPS URL + if (typeof imageUrl === 'string' && (imageUrl.startsWith('http://') || imageUrl.startsWith('https://'))) { + coverImageUrl = imageUrl; + } else { + console.warn('Skipping cover image - Wix requires HTTP/HTTPS URL, received:', imageUrl?.substring(0, 50)); + } + } + + try { + // Publish using same endpoint as WixTestPage + // Note: Wix requires category/tag IDs (UUIDs), not names + // For now, skip categories/tags until we implement ID lookup/creation + const response = await apiClient.post('/api/wix/test/publish/real', { + title: title, + content: md, // Use markdown, backend converts it + cover_image_url: coverImageUrl, + // TODO: Lookup/create category IDs from metadata?.blog_categories + // TODO: Lookup/create tag IDs from metadata?.blog_tags + category_ids: undefined, + tag_ids: undefined, + publish: true, + access_token: accessToken, + member_id: undefined // Let backend derive from token + }); + + if (response.data.success) { + return { + success: true, + url: response.data.url, + post_id: response.data.post_id, + message: 'Blog post published successfully to Wix!' + }; + } else { + return { + success: false, + message: response.data.error || 'Failed to publish to Wix' + }; + } + } catch (error: any) { + // If auth error, token may be invalid - try refreshing or reconnect + if (error.response?.status === 401 || error.response?.status === 403) { + // Try to refresh one more time + const tokenResult = await validateAndRefreshWixTokens(); + if (tokenResult.needsReconnect) { + const publishFunction = async () => { + return await publishToWix(md, metadata); + }; + setPendingWixPublish(() => publishFunction); + setShowWixConnectModal(true); + return { + success: false, + message: 'Wix tokens expired. Please reconnect your Wix account.', + action_required: 'reconnect_wix' + }; + } + // If refresh worked, retry once + if (tokenResult.accessToken) { + return await publishToWix(md, metadata, tokenResult.accessToken); + } + } + + return { + success: false, + message: `Failed to publish to Wix: ${error.response?.data?.detail || error.message}` + }; + } + }; + + // Handle Wix connection success - retry publish + const handleWixConnectionSuccess = async () => { + if (pendingWixPublish) { + const publishFn = pendingWixPublish; + setPendingWixPublish(null); + // Small delay to ensure tokens are saved in sessionStorage + setTimeout(async () => { + try { + // Retry the publish - this will be executed and return result + // Note: The result won't show in CopilotKit UI since we're outside the action handler + // But the publish will succeed and user will see their blog on Wix + const result = await publishFn(); + console.log('Wix publish after connection:', result); + // Optionally show a success notification + if (result.success) { + // Publish succeeded - user's blog is now on Wix + console.log('Blog published to Wix successfully after connection'); + } + } catch (error) { + console.error('Error retrying publish after connection:', error); + } + }, 500); + } + }; // Enhanced publish action with Wix support useCopilotActionTyped({ name: 'publishToPlatform', @@ -61,58 +200,101 @@ export const Publisher: React.FC = ({ const html = convertMarkdownToHTML(md); if (platform === 'wix') { - // Check Wix connection status first - if (!wixConnectionStatus?.connected) { - return { - success: false, - message: 'Wix account not connected. Please connect your Wix account first using the Wix Test Page.', + // Proactively validate and refresh tokens + const tokenResult = await validateAndRefreshWixTokens(); + + if (tokenResult.needsReconnect || !tokenResult.accessToken) { + // Store the publish function to retry after connection + const publishFunction = async () => { + return await publishToWix(md, seoMetadata); + }; + setPendingWixPublish(() => publishFunction); + setShowWixConnectModal(true); + return { + success: false, + message: 'Wix account not connected. Please connect your Wix account to publish.', action_required: 'connect_wix' }; } - - if (!wixConnectionStatus?.has_permissions) { + + // We have a valid access token, proceed with publishing + return await publishToWix(md, seoMetadata, tokenResult.accessToken); + } else if (platform === 'wordpress') { + // WordPress publishing + if (!seoMetadata) { return { success: false, - message: 'Insufficient Wix permissions. Please reconnect your Wix account.', - action_required: 'reconnect_wix' + message: 'Generate SEO metadata first. Use the "Next: Generate SEO Metadata" suggestion to create metadata before publishing.' }; } - - // Extract title from markdown (first heading or use default) - const titleMatch = md.match(/^#\s+(.+)$/m); - const title = titleMatch ? titleMatch[1] : 'Blog Post from ALwrity'; - + + // Check if user has connected WordPress sites + if (wordpressSites.length === 0) { + return { + success: false, + message: 'No WordPress sites connected. Please connect a WordPress site first. Go to Settings > Integrations to add your WordPress site.', + action_required: 'connect_wordpress' + }; + } + + // Find first active site, or use first site if none are active + const activeSite = wordpressSites.find(site => site.is_active) || wordpressSites[0]; + if (!activeSite) { + return { + success: false, + message: 'No active WordPress sites found. Please activate a WordPress site connection.', + action_required: 'activate_wordpress' + }; + } + + // Extract title from SEO metadata or markdown + const title = seoMetadata.seo_title || (() => { + const titleMatch = md.match(/^#\s+(.+)$/m); + return titleMatch ? titleMatch[1] : 'Blog Post from ALwrity'; + })(); + + // Extract excerpt from SEO metadata + const excerpt = seoMetadata.meta_description || ''; + + // Build WordPress publish request + const publishRequest: WordPressPublishRequest = { + site_id: activeSite.id, + title: title, + content: html, + excerpt: excerpt, + status: 'publish', + meta_description: seoMetadata.meta_description || excerpt, + tags: seoMetadata.blog_tags || [], + categories: seoMetadata.blog_categories || [] + }; + try { - const response = await apiClient.post('/api/wix/publish', { - title: title, - content: md, - publish: true - }); + const result = await wordpressAPI.publishContent(publishRequest); - if (response.data.success) { - return { - success: true, - url: response.data.url, - post_id: response.data.post_id, - message: 'Blog post published successfully to Wix!' + if (result.success) { + return { + success: true, + url: result.post_url || `${activeSite.site_url}/?p=${result.post_id}`, + post_id: result.post_id, + message: `Blog post published successfully to WordPress site "${activeSite.site_name}"!` }; } else { - return { - success: false, - message: response.data.error || 'Failed to publish to Wix' + return { + success: false, + message: result.error || 'Failed to publish to WordPress' }; } } catch (error: any) { - return { - success: false, - message: `Failed to publish to Wix: ${error.response?.data?.detail || error.message}` + return { + success: false, + message: `Failed to publish to WordPress: ${error.response?.data?.detail || error.message || 'Unknown error'}` }; } } else { - // WordPress or other platforms - if (!seoMetadata) return { success: false, message: 'Generate SEO metadata first' }; - const res = await blogWriterApi.publish({ platform, html, metadata: seoMetadata, schedule_time }); - return { success: true, url: res.url }; + return { + success: false, + message: `Unsupported platform: ${platform}. Supported platforms are 'wix' and 'wordpress'.` + }; } }, render: ({ status, result }: any) => { @@ -153,6 +335,13 @@ export const Publisher: React.FC = ({
)} + {(result?.action_required === 'connect_wordpress' || result?.action_required === 'activate_wordpress') && ( + + )}
); } @@ -161,7 +350,18 @@ export const Publisher: React.FC = ({ } }); - return null; // This component only provides the copilot action + return ( + <> + { + setShowWixConnectModal(false); + setPendingWixPublish(null); + }} + onConnectionSuccess={handleWixConnectionSuccess} + /> + + ); }; export default Publisher; diff --git a/frontend/src/components/BlogWriter/SuggestionsGenerator.tsx b/frontend/src/components/BlogWriter/SuggestionsGenerator.tsx index a99ca3f7..873e3a62 100644 --- a/frontend/src/components/BlogWriter/SuggestionsGenerator.tsx +++ b/frontend/src/components/BlogWriter/SuggestionsGenerator.tsx @@ -145,11 +145,7 @@ export const useSuggestions = ({ priority: 'high' }); items.push({ - title: 'Content Analysis', - message: 'Analyze the flow and quality of my blog content to get improvement suggestions' - }); - items.push({ - title: 'Content Analysis', + title: '📊 Content Analysis', message: 'Analyze the flow and quality of my blog content to get improvement suggestions' }); } else if (seoAnalysis && !seoRecommendationsApplied) { @@ -160,7 +156,7 @@ export const useSuggestions = ({ priority: 'high' }); items.push({ - title: 'Content Analysis', + title: '📊 Content Analysis', message: 'Run analyzeContentQuality to review narrative flow and get final improvement suggestions before publishing.' }); items.push({ @@ -175,33 +171,21 @@ export const useSuggestions = ({ message: 'SEO recommendations are applied. Execute generateSEOMetadata immediately so we can prepare titles, descriptions, and schema without further prompts.', priority: 'high' }); - } else { items.push({ - title: 'Next: Publish', - message: 'The blog is SEO-optimized. Use publishToPlatform with your preferred destination (wix|wordpress) right away—no additional confirmation needed.', - priority: 'high' + title: '📊 Content Analysis', + message: 'Run analyzeContentQuality to validate flow, consistency, and progression before publishing.' }); - } - - items.push({ - title: 'Content Analysis', - message: 'Run analyzeContentQuality to validate flow, consistency, and progression before publishing.' - }); - items.push({ - title: 'Publish', - message: seoMetadata - ? 'Publish my blog to your preferred platform using publishToPlatform.' - : 'Generate SEO metadata first, then publish your blog.' - }); - - if (seoMetadata) { + } else { + // SEO metadata is ready - show publishing options items.push({ title: '🚀 Publish to Wix', - message: 'Publish my blog to Wix using publishToPlatform with platform "wix".' + message: 'Publish my blog to Wix using publishToPlatform with platform "wix".', + priority: 'high' }); items.push({ title: '🌐 Publish to WordPress', - message: 'Publish my blog to WordPress using publishToPlatform with platform "wordpress".' + message: 'Publish my blog to WordPress using publishToPlatform with platform "wordpress".', + priority: 'high' }); } } diff --git a/frontend/src/components/BlogWriter/WYSIWYG/BlogEditor.tsx b/frontend/src/components/BlogWriter/WYSIWYG/BlogEditor.tsx index df178d05..71bfa2d1 100644 --- a/frontend/src/components/BlogWriter/WYSIWYG/BlogEditor.tsx +++ b/frontend/src/components/BlogWriter/WYSIWYG/BlogEditor.tsx @@ -30,6 +30,7 @@ interface BlogEditorProps { onSave?: (content: any) => void; continuityRefresh?: number; flowAnalysisResults?: any; + sectionImages?: Record; } const BlogEditor: React.FC = ({ @@ -43,7 +44,8 @@ const BlogEditor: React.FC = ({ onContentUpdate, onSave, continuityRefresh, - flowAnalysisResults + flowAnalysisResults, + sectionImages = {} }) => { const [blogTitle, setBlogTitle] = useState(initialTitle || 'Your Amazing Blog Title'); const [sections, setSections] = useState([]); @@ -143,17 +145,25 @@ const BlogEditor: React.FC = ({
- {sections.map((section) => ( - - ))} + {sections.map((section, index) => { + // Robust image mapping: prefer outline index id (order is consistent across phases) + const imageIdByIndex = outline[index]?.id; + const outlineSection = outline.find(s => (s.id === section.id) || (s.heading === section.title)); + const imageId = imageIdByIndex || outlineSection?.id || section.id; + const sectionImage = sectionImages?.[imageId] || null; + return ( + + ); + })}
diff --git a/frontend/src/components/BlogWriter/WYSIWYG/BlogSection.tsx b/frontend/src/components/BlogWriter/WYSIWYG/BlogSection.tsx index e0fe8b51..a1c8b1e7 100644 --- a/frontend/src/components/BlogWriter/WYSIWYG/BlogSection.tsx +++ b/frontend/src/components/BlogWriter/WYSIWYG/BlogSection.tsx @@ -40,6 +40,7 @@ interface BlogSectionProps { toggleSectionExpansion: (sectionId: any) => void; refreshToken?: number; flowAnalysisResults?: any; + sectionImage?: string; } const BlogSection: React.FC = ({ @@ -53,7 +54,8 @@ const BlogSection: React.FC = ({ expandedSections, toggleSectionExpansion, refreshToken, - flowAnalysisResults + flowAnalysisResults, + sectionImage }) => { const [isEditing, setIsEditing] = useState(false); const [sectionTitle, setSectionTitle] = useState(title); @@ -181,6 +183,31 @@ const BlogSection: React.FC = ({ )} + + {/* Section Image Display */} + {sectionImage && ( +
+
+ {`Cover +
+
+ )}
= ({ className const fetchDetailedStats = async () => { try { const response = await apiClient.get('/api/content-planning/monitoring/api-stats'); - const result = response.data; - if (result.status === 'success') { - setDetailedStats(result.data); - if (result.data?.cache_performance) { - setCachePerf(result.data.cache_performance); + const result = response?.data; + + // Validate response structure + if (!result || result.status !== 'success' || !result.data) { + console.warn('Invalid response structure from api-stats endpoint:', result); + setChartData([]); + return; + } + + const data = result.data; + setDetailedStats(data); + + if (data?.cache_performance) { + setCachePerf(data.cache_performance); } - // Generate chart data - const chartData = result.data.top_endpoints.slice(0, 5).map((endpoint: any, index: number) => ({ - name: endpoint.endpoint.split(' ')[1].split('/').pop() || 'API', - requests: endpoint.count, - avgTime: endpoint.avg_time, - errors: endpoint.errors, - hitRate: endpoint.cache_hit_rate + // Generate chart data - safely handle missing top_endpoints + if (data?.top_endpoints && Array.isArray(data.top_endpoints) && data.top_endpoints.length > 0) { + try { + const chartData = data.top_endpoints.slice(0, 5).map((endpoint: any) => ({ + name: endpoint?.endpoint?.split(' ')[1]?.split('/').pop() || 'API', + requests: endpoint?.count || 0, + avgTime: endpoint?.avg_time || 0, + errors: endpoint?.errors || 0, + hitRate: endpoint?.cache_hit_rate || 0 })); setChartData(chartData); + } catch (mapError) { + console.error('Error mapping chart data:', mapError); + setChartData([]); + } + } else { + // If top_endpoints is missing or not an array, set empty chart data + setChartData([]); } } catch (err) { console.error('Error fetching detailed stats:', err); + setChartData([]); } }; @@ -353,7 +372,7 @@ const SystemStatusIndicator: React.FC = ({ className )} {/* Recent Errors Section */} - {detailedStats?.recent_errors && detailedStats.recent_errors.length > 0 && ( + {detailedStats?.recent_errors && Array.isArray(detailedStats.recent_errors) && detailedStats.recent_errors.length > 0 && ( = ({ className > Close + + + + diff --git a/frontend/src/components/ImageGen/useImageGeneration.ts b/frontend/src/components/ImageGen/useImageGeneration.ts index ac0ace7f..e9302ccb 100644 --- a/frontend/src/components/ImageGen/useImageGeneration.ts +++ b/frontend/src/components/ImageGen/useImageGeneration.ts @@ -56,18 +56,10 @@ export interface PromptSuggestion { } export async function fetchPromptSuggestions(payload: any): Promise { - const res = await fetch('/api/images/suggest-prompts', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - credentials: 'include', - body: JSON.stringify(payload) - }); - if (!res.ok) { - const text = await res.text(); - throw new Error(text || 'Failed to fetch prompt suggestions'); - } - const data = await res.json(); - return data.suggestions || []; + // Use apiClient directly (same pattern as SEO analysis in SEOAnalysisModal.tsx) + // The apiClient interceptor will handle auth token injection automatically + const response = await apiClient.post('/api/images/suggest-prompts', payload); + return response.data.suggestions || []; } diff --git a/frontend/src/components/Pricing/PricingPage.tsx b/frontend/src/components/Pricing/PricingPage.tsx index 302314f0..895157c3 100644 --- a/frontend/src/components/Pricing/PricingPage.tsx +++ b/frontend/src/components/Pricing/PricingPage.tsx @@ -28,6 +28,7 @@ import { Modal, Fade, Backdrop, + Snackbar, } from '@mui/material'; import { Check as CheckIcon, @@ -35,6 +36,7 @@ import { Star as StarIcon, WorkspacePremium as PremiumIcon, Info as InfoIcon, + Warning, Psychology, Search, FactCheck, @@ -83,6 +85,7 @@ const PricingPage: React.FC = () => { const [subscribing, setSubscribing] = useState(false); const [paymentModalOpen, setPaymentModalOpen] = useState(false); const [showSignInPrompt, setShowSignInPrompt] = useState(false); + const [successSnackbar, setSuccessSnackbar] = useState({ open: false, message: '', countdown: 3 }); const [knowMoreModal, setKnowMoreModal] = useState<{ open: boolean; title: string; content: React.ReactNode }>({ open: false, title: '', @@ -172,27 +175,70 @@ const PricingPage: React.FC = () => { setSubscribing(true); const userId = localStorage.getItem('user_id') || 'anonymous'; - await apiClient.post(`/api/subscription/subscribe/${userId}`, { + const response = await apiClient.post(`/api/subscription/subscribe/${userId}`, { plan_id: selectedPlan, billing_cycle: yearlyBilling ? 'yearly' : 'monthly' }); - // Refresh subscription status + console.log('Subscription renewed successfully:', response.data); + + // Refresh subscription status immediately window.dispatchEvent(new CustomEvent('subscription-updated')); + + // Also trigger user authenticated event to refresh subscription context + window.dispatchEvent(new CustomEvent('user-authenticated')); setPaymentModalOpen(false); - // After subscription, check if onboarding is complete - // If not complete, redirect to onboarding; otherwise to dashboard - const onboardingComplete = localStorage.getItem('onboarding_complete') === 'true'; - if (onboardingComplete) { - navigate('/dashboard'); - } else { - navigate('/onboarding'); - } + // Get plan name for success message + const planName = plans.find(p => p.id === selectedPlan)?.name || 'subscription'; + + // Show success message with countdown + setSuccessSnackbar({ + open: true, + message: `🎉 ${planName} plan activated! Your usage limits have been reset. Returning to your work in 3 seconds...`, + countdown: 3 + }); + + // Countdown timer + let countdown = 3; + const countdownInterval = setInterval(() => { + countdown -= 1; + if (countdown > 0) { + setSuccessSnackbar(prev => ({ + ...prev, + message: `🎉 ${planName} plan activated! Your usage limits have been reset. Returning to your work in ${countdown} second${countdown !== 1 ? 's' : ''}...`, + countdown + })); + } else { + clearInterval(countdownInterval); + } + }, 1000); + + // Auto-redirect after 3 seconds + setTimeout(() => { + clearInterval(countdownInterval); + + // After subscription, check if onboarding is complete + // If not complete, redirect to onboarding; otherwise to dashboard + const onboardingComplete = localStorage.getItem('onboarding_complete') === 'true'; + if (onboardingComplete) { + // Try to go back to where the user was (e.g., blog writer) + // If no history, go to dashboard + const referrer = sessionStorage.getItem('subscription_referrer'); + if (referrer && referrer !== '/pricing') { + navigate(referrer); + } else { + navigate('/dashboard'); + } + } else { + navigate('/onboarding'); + } + }, 3000); } catch (err) { console.error('Error subscribing:', err); setError('Failed to process subscription'); + setSuccessSnackbar({ open: false, message: '', countdown: 0 }); } finally { setSubscribing(false); } @@ -900,32 +946,71 @@ const PricingPage: React.FC = () => { top: '50%', left: '50%', transform: 'translate(-50%, -50%)', - width: 400, + width: 450, bgcolor: 'background.paper', border: '2px solid #000', boxShadow: 24, p: 4, borderRadius: 2, }}> - + + Alpha Testing Subscription - - Thank you for participating in our alpha testing! For the Basic plan, we're crediting $29 to your account. - - - In production, this would integrate with Stripe/Paddle for real payment processing. + + {/* Alpha Testing Notice */} + + + ⚠️ Alpha Testing Mode - No Payment Required + + + Payment integration is coming soon. For now, subscriptions are activated without charge. + + + + + Thank you for participating in our alpha testing! We're crediting the Basic plan ($29 value) to your account. + + {/* TODO: Payment Integration Notice */} + + + Coming in Production: + + + • Secure Stripe/PayPal payment processing
+ • Automatic renewal management
+ • Payment verification & receipts
+ • Upgrade/downgrade options +
+
+ + {/* Note: Current behavior allows renewal without payment verification */} + {/* This is intentional for alpha testing but will be secured in production */} + - @@ -981,6 +1066,37 @@ const PricingPage: React.FC = () => { + + {/* Success Snackbar */} + setSuccessSnackbar({ open: false, message: '', countdown: 0 })} + anchorOrigin={{ vertical: 'top', horizontal: 'center' }} + sx={{ + top: { xs: 16, sm: 24 }, + '& .MuiSnackbarContent-root': { + minWidth: { xs: '90vw', sm: '500px' } + } + }} + > + setSuccessSnackbar({ open: false, message: '', countdown: 0 })} + sx={{ + width: '100%', + fontSize: '1rem', + alignItems: 'center', + boxShadow: '0 8px 24px rgba(76, 175, 80, 0.4)', + '& .MuiAlert-icon': { + fontSize: '2rem' + } + }} + > + {successSnackbar.message} + + ); }; diff --git a/frontend/src/components/SubscriptionExpiredModal.tsx b/frontend/src/components/SubscriptionExpiredModal.tsx index baf49031..620b7064 100644 --- a/frontend/src/components/SubscriptionExpiredModal.tsx +++ b/frontend/src/components/SubscriptionExpiredModal.tsx @@ -39,6 +39,25 @@ const SubscriptionExpiredModal: React.FC = ({ subscriptionData, errorData }) => { + // Debug logging to verify modal state + React.useEffect(() => { + if (open) { + console.log('SubscriptionExpiredModal: Modal opened', { + open, + errorData, + hasUsageInfo: !!errorData?.usage_info + }); + } + }, [open, errorData]); + + const handleDialogClose = (_event: object, reason?: string) => { + if (reason === 'backdropClick') { + console.log('SubscriptionExpiredModal: Ignoring backdrop click close'); + return; + } + onClose(); + }; + const handleRenewClick = () => { onRenewSubscription(); onClose(); @@ -47,16 +66,21 @@ const SubscriptionExpiredModal: React.FC = ({ return ( @@ -93,56 +117,156 @@ const SubscriptionExpiredModal: React.FC = ({ borderRadius: 2 }} > - + {/* Main error message */} + {errorData?.message || (errorData?.usage_info ? 'You\'ve reached your monthly usage limit for this plan. Upgrade your plan to get higher limits.' : 'To continue using Alwrity and access all features, you need to renew your subscription.' )} + {/* Detailed usage information */} {errorData?.usage_info && ( - - + + + Usage Information: - {errorData.usage_info.call_usage_percentage && ( - - You've used {errorData.usage_info.call_usage_percentage.toFixed(1)}% of your monthly limit - + + {/* Provider and operation type */} + + {errorData.provider && ( + + + Provider: + + + {errorData.provider} + + + )} + + {errorData.usage_info.operation_type && ( + + + Operation: + + + {errorData.usage_info.operation_type.replace(/_/g, ' ')} + + + )} + + + {/* Token usage details (if available) */} + {(errorData.usage_info.current_tokens !== undefined || errorData.usage_info.current_calls !== undefined) && ( + + {errorData.usage_info.current_tokens !== undefined && ( + <> + + Token Usage: + + + + {errorData.usage_info.current_tokens?.toLocaleString() || 0} + + + / {errorData.usage_info.limit?.toLocaleString() || 0} + + + ({((errorData.usage_info.current_tokens / errorData.usage_info.limit) * 100).toFixed(1)}% used) + + + + {errorData.usage_info.requested_tokens && ( + + Requested: {errorData.usage_info.requested_tokens.toLocaleString()} tokens + {errorData.usage_info.current_tokens + errorData.usage_info.requested_tokens > errorData.usage_info.limit && ( + + (Would exceed by: {((errorData.usage_info.current_tokens + errorData.usage_info.requested_tokens) - errorData.usage_info.limit).toLocaleString()} tokens) + + )} + + )} + + )} + + {errorData.usage_info.current_calls !== undefined && ( + <> + + API Call Usage: + + + + {errorData.usage_info.current_calls?.toLocaleString() || 0} + + + / {errorData.usage_info.call_limit?.toLocaleString() || 0} + + + ({((errorData.usage_info.current_calls / errorData.usage_info.call_limit) * 100).toFixed(1)}% used) + + + + )} + )} - {errorData.provider && ( - - Provider: {errorData.provider} - + + {/* Error type badge */} + {errorData.usage_info.error_type && ( + + + + {errorData.usage_info.error_type.replace(/_/g, ' ')} + + + )} )} + {/* Current plan information */} {subscriptionData && ( {subscriptionData.plan && ( - - Current Plan: {subscriptionData.plan} + + Current Plan: - - )} - {subscriptionData.tier && subscriptionData.tier !== subscriptionData.plan && ( - - - Tier: {subscriptionData.tier} + + {subscriptionData.plan} )} diff --git a/frontend/src/components/shared/DashboardHeader.tsx b/frontend/src/components/shared/DashboardHeader.tsx index c45beaa1..d3cc7e04 100644 --- a/frontend/src/components/shared/DashboardHeader.tsx +++ b/frontend/src/components/shared/DashboardHeader.tsx @@ -105,12 +105,13 @@ const DashboardHeader: React.FC = ({ /* Enhanced Start Button with Phase 1 Improvements */ - + {isFirstVisit ? '🚀 Start Journey' : 'Start'} + + = ({ chil // New: Grace window after plan changes to avoid noisy UX const [graceUntil, setGraceUntil] = useState(0); const [planSignature, setPlanSignature] = useState(""); + // Flag to track if current modal is a usage limit modal (should never be auto-closed) + const [isUsageLimitModal, setIsUsageLimitModal] = useState(false); const checkSubscription = useCallback(async () => { // Throttle subscription checks to prevent excessive API calls @@ -86,6 +88,10 @@ export const SubscriptionProvider: React.FC = ({ chil return; } + // Wait a moment to ensure auth token getter is installed + // This prevents 401 errors during app initialization + await new Promise(resolve => setTimeout(resolve, 200)); + console.log('SubscriptionContext: Checking subscription for user:', userId); const response = await apiClient.get(`/api/subscription/status/${userId}`); const subscriptionData = response.data.data; @@ -101,29 +107,42 @@ export const SubscriptionProvider: React.FC = ({ chil setPlanSignature(newSignature); setGraceUntil(Date.now() + 5 * 60 * 1000); // Close any existing modal as plan just changed - if (showModal) { + // BUT: Don't close usage limit modals - they're important even after plan changes + if (showModal && !isUsageLimitModal) { + console.log('SubscriptionContext: Plan changed, closing non-usage-limit modal'); setShowModal(false); setModalErrorData(null); + } else if (showModal && isUsageLimitModal) { + console.log('SubscriptionContext: Plan changed but usage limit modal is open, keeping it open'); } } } catch (_e) {} // If we have a valid subscription and the modal is open, close it + // BUT: NEVER close usage limit modals - user needs to see they hit a limit even with active subscription if (subscriptionData && subscriptionData.active && showModal) { - console.log('SubscriptionContext: Valid subscription detected, closing modal'); - setShowModal(false); - setModalErrorData(null); - setLastModalShowTime(0); // Reset the cooldown timer - } - - // Also check if this is a usage limit error that should be suppressed - if (subscriptionData && subscriptionData.active && modalErrorData) { - const now = Date.now(); - const timeSinceLastModal = now - lastModalShowTime; - - // If it's been less than 10 minutes since modal was shown for usage limits, keep it closed - if (timeSinceLastModal < 600000 && modalErrorData.usage_info) { - console.log('SubscriptionContext: Recent usage limit modal, keeping it closed'); + // Check if this is a usage limit modal (using flag or checking error data) + const hasUsageInfo = modalErrorData?.usage_info || + (modalErrorData?.current_tokens !== undefined) || + (modalErrorData?.current_calls !== undefined) || + (modalErrorData?.limit !== undefined) || + (modalErrorData?.requested_tokens !== undefined); + + const isUsageLimit = isUsageLimitModal || hasUsageInfo; + + if (isUsageLimit) { + console.log('SubscriptionContext: Usage limit modal detected - KEEPING OPEN (never auto-close usage limit modals)', { + isUsageLimitModal, + hasUsageInfo, + modalErrorDataKeys: modalErrorData ? Object.keys(modalErrorData) : [] + }); + // Do NOT close - usage limit modals should stay open until user dismisses them + } else { + console.log('SubscriptionContext: Non-usage-limit modal detected, closing since subscription is active'); + setShowModal(false); + setModalErrorData(null); + setIsUsageLimitModal(false); + setLastModalShowTime(0); // Reset the cooldown timer } } @@ -156,7 +175,7 @@ export const SubscriptionProvider: React.FC = ({ chil setLastModalShowTime(now); } } - } catch (err) { + } catch (err: any) { console.error('Error checking subscription:', err); // Check if it's a connection error that should be handled at the app level @@ -165,6 +184,16 @@ export const SubscriptionProvider: React.FC = ({ chil throw err; } + // Handle 401 errors gracefully during initialization - don't block routing + // 401 might happen if auth token getter isn't ready yet + if (err?.response?.status === 401) { + console.warn('Subscription check failed with 401 - auth may not be ready yet, will retry later'); + setError(null); // Don't set error for 401 during init + setLoading(false); + // Don't throw - allow routing to proceed, subscription check will retry later + return; + } + setError(err instanceof Error ? err.message : 'Failed to check subscription'); // Don't default to free tier on error - preserve existing subscription or leave null @@ -173,21 +202,30 @@ export const SubscriptionProvider: React.FC = ({ chil } finally { setLoading(false); } - }, [lastCheckTime, planSignature, showModal, modalErrorData, lastModalShowTime, graceUntil]); + }, [lastCheckTime, planSignature, showModal, modalErrorData, lastModalShowTime, graceUntil, isUsageLimitModal]); const refreshSubscription = useCallback(async () => { await checkSubscription(); }, [checkSubscription]); const showExpiredModal = useCallback(() => { + setIsUsageLimitModal(false); setShowModal(true); }, []); const hideExpiredModal = useCallback(() => { + console.log('SubscriptionExpiredModal: User manually closed modal'); setShowModal(false); + setIsUsageLimitModal(false); // Reset flag when user closes modal + setModalErrorData(null); }, []); const handleRenewSubscription = useCallback(() => { + // Save current location so we can return after renewal + const currentPath = window.location.pathname; + sessionStorage.setItem('subscription_referrer', currentPath); + + console.log('SubscriptionContext: Navigating to pricing page, saved referrer:', currentPath); window.location.href = '/pricing'; }, []); @@ -203,42 +241,131 @@ export const SubscriptionProvider: React.FC = ({ chil const now = Date.now(); - // If we have subscription data and it's active, always suppress modal for usage limits - if (subscription && subscription.active) { - console.log('SubscriptionContext: Active subscription; suppressing usage-limit modal'); - return true; // Do not show modal for active plan usage limits + // Check if this is a usage limit error (status 429) vs subscription expired (402) + let errorData = error.response?.data || {}; + + // DEBUG: Log the raw error data structure + console.log('SubscriptionContext: Raw error data', { + type: typeof errorData, + isArray: Array.isArray(errorData), + data: errorData, + stringified: JSON.stringify(errorData) + }); + + // If errorData is an array, extract the first element (common FastAPI response format) + if (Array.isArray(errorData)) { + console.log('SubscriptionContext: errorData is array, extracting first element'); + errorData = errorData[0] || {}; } - - // If we don't have subscription data yet, defer the decision - if (!subscription) { - console.log('SubscriptionContext: No subscription data yet, deferring modal decision'); - setDeferredError(error); - return true; // Handle the error but don't show modal yet - } - - // If subscription is not active, show modal immediately - if (!subscription.active) { - console.log('SubscriptionContext: Inactive subscription, showing modal immediately'); - const errorData = error.response?.data || {}; - setModalErrorData({ - provider: errorData.provider, - usage_info: errorData.usage_info, - message: errorData.message || errorData.error + + // Check for usage_info in various possible locations + const usageInfo = errorData.usage_info || + (errorData.current_calls !== undefined ? errorData : null) || + null; + + // Usage limit error: 429 status with usage info OR 429 status without explicit expiration + const isUsageLimitError = status === 429 && (usageInfo || errorData.provider || errorData.message); + const isSubscriptionExpired = status === 402 || (status === 429 && !isUsageLimitError); + + console.log('SubscriptionContext: Error analysis', { + status, + isUsageLimitError, + isSubscriptionExpired, + hasUsageInfo: !!usageInfo, + errorDataType: typeof errorData, + errorDataKeys: typeof errorData === 'object' && !Array.isArray(errorData) ? Object.keys(errorData) : 'not-an-object', + errorData: errorData + }); + + // For usage limit errors (429 with usage_info), always show modal - even for active subscriptions + // Ignore grace window and cooldown for usage limit errors (user needs to know immediately) + if (isUsageLimitError) { + const modalData = { + provider: errorData.provider || usageInfo?.provider || 'unknown', + usage_info: usageInfo || errorData, + message: errorData.message || errorData.error || 'You have reached your usage limit.' + }; + + console.log('SubscriptionContext: Usage limit exceeded, showing modal (ignoring grace window/cooldown)', { + modalData, + errorData: Object.keys(errorData), + usageInfo: usageInfo ? Object.keys(usageInfo) : null }); + + // Set flag to mark this as a usage limit modal (should never be auto-closed) + setIsUsageLimitModal(true); + setModalErrorData(modalData); setShowModal(true); setLastModalShowTime(now); + + console.log('SubscriptionContext: Modal state updated - showModal should be true, isUsageLimitModal = true'); return true; } + + // For subscription expired errors, handle based on subscription status + if (isSubscriptionExpired) { + // If we have subscription data and it's active, this shouldn't happen but suppress anyway + if (subscription && subscription.active) { + console.log('SubscriptionContext: Active subscription but got expired error, suppressing modal'); + return true; + } + + // If we don't have subscription data yet, defer the decision + if (!subscription) { + console.log('SubscriptionContext: No subscription data yet, deferring modal decision'); + setDeferredError(error); + return true; // Handle the error but don't show modal yet + } + + // If subscription is not active, show modal immediately + if (!subscription.active) { + console.log('SubscriptionContext: Inactive subscription, showing modal immediately'); + setIsUsageLimitModal(false); + setModalErrorData({ + provider: errorData.provider, + usage_info: errorData.usage_info, + message: errorData.message || errorData.error + }); + setShowModal(true); + setLastModalShowTime(now); + return true; + } + } } return false; // Not a subscription error }, [subscription]); // Register the global error handler with the API client + // Use a ref to ensure the latest handler is always used + const handlerRef = useRef(globalSubscriptionErrorHandler); + useEffect(() => { + handlerRef.current = globalSubscriptionErrorHandler; + }, [globalSubscriptionErrorHandler]); + useEffect(() => { console.log('SubscriptionContext: Registering global subscription error handler'); - setGlobalSubscriptionErrorHandler(globalSubscriptionErrorHandler); - }, [globalSubscriptionErrorHandler]); + setGlobalSubscriptionErrorHandler((error: any) => { + // Always use the latest handler from ref + return handlerRef.current(error); + }); + + // Cleanup: Don't remove the handler on unmount - it should persist + // This ensures errors can still be caught even during component transitions + }, []); // Empty deps - only register once, but handler ref updates automatically + + useEffect(() => { + const eventHandler = (event: Event) => { + const customEvent = event as CustomEvent; + console.log('SubscriptionContext: Received subscription-error event fallback', customEvent.detail); + handlerRef.current(customEvent.detail); + }; + + window.addEventListener('subscription-error', eventHandler as EventListener); + return () => { + window.removeEventListener('subscription-error', eventHandler as EventListener); + }; + }, []); useEffect(() => { // Check subscription on mount diff --git a/frontend/src/hooks/useBlogWriterState.ts b/frontend/src/hooks/useBlogWriterState.ts index 808c709d..73b06356 100644 --- a/frontend/src/hooks/useBlogWriterState.ts +++ b/frontend/src/hooks/useBlogWriterState.ts @@ -33,6 +33,9 @@ export const useBlogWriterState = () => { // Content confirmation state const [contentConfirmed, setContentConfirmed] = useState(false); + // Section images state - persists images generated in outline phase to content phase + const [sectionImages, setSectionImages] = useState>({}); + // Cache recovery - restore most recent research on page load useEffect(() => { const cachedEntries = researchCache.getAllCachedEntries(); @@ -211,6 +214,7 @@ export const useBlogWriterState = () => { contentConfirmed, flowAnalysisCompleted, flowAnalysisResults, + sectionImages, // Setters setResearch, @@ -233,6 +237,7 @@ export const useBlogWriterState = () => { setContentConfirmed, setFlowAnalysisCompleted, setFlowAnalysisResults, + setSectionImages, // Handlers handleResearchComplete, diff --git a/frontend/src/hooks/usePolling.ts b/frontend/src/hooks/usePolling.ts index 524713f8..c45b5663 100644 --- a/frontend/src/hooks/usePolling.ts +++ b/frontend/src/hooks/usePolling.ts @@ -1,5 +1,6 @@ import { useState, useEffect, useCallback, useRef } from 'react'; import { blogWriterApi, TaskStatusResponse } from '../services/blogWriterApi'; +import { triggerSubscriptionError } from '../api/client'; export interface UsePollingOptions { interval?: number; // Polling interval in milliseconds @@ -108,6 +109,43 @@ export function usePolling( console.log('❌ Task failed - stopping polling immediately'); setError(status.error || 'Task failed'); onError?.(status.error || 'Task failed'); + + // Check if this is a subscription error and trigger modal + if (status.error_status === 429 || status.error_status === 402) { + console.log('usePolling: Detected subscription error in task status', { + error_status: status.error_status, + error_data: status.error_data, + error: status.error + }); + + // Create a mock error object with the subscription error data + const errorData = status.error_data || {}; + + // Ensure usage_info is properly nested - it might be at the top level or nested + const usageInfo = errorData.usage_info || + (errorData.current_calls !== undefined ? errorData : null) || + errorData; + + const mockError = { + response: { + status: status.error_status, + data: { + error: errorData.error || status.error || 'Subscription limit exceeded', + message: errorData.message || errorData.error || status.error || 'You have reached your usage limit.', + provider: errorData.provider || usageInfo?.provider || 'unknown', + usage_info: usageInfo + } + } + }; + + console.log('usePolling: Triggering subscription error handler with:', mockError); + const handled = triggerSubscriptionError(mockError); + + if (!handled) { + console.warn('usePolling: Subscription error handler did not handle the error'); + } + } + stopPolling(); return; // Exit early to prevent further processing } @@ -117,6 +155,38 @@ export function usePolling( const errorMessage = err instanceof Error ? err.message : 'Unknown error occurred'; console.error('Polling error:', errorMessage); + // Check if this is an axios error with subscription limit status + // This is a fallback in case the interceptor doesn't catch it + const axiosError = err as any; + if (axiosError?.response?.status === 429 || axiosError?.response?.status === 402) { + console.log('usePolling: Detected subscription error in axios error response', { + status: axiosError.response.status, + data: axiosError.response.data + }); + + // Trigger subscription error handler (modal will show) + const handled = triggerSubscriptionError(axiosError); + console.log('usePolling: triggerSubscriptionError returned', handled); + + if (handled) { + console.log('usePolling: Subscription error handled, stopping polling'); + const errorMsg = axiosError.response?.data?.message || + axiosError.response?.data?.error || + 'Subscription limit exceeded'; + setError(errorMsg); + onError?.(errorMsg); + stopPolling(); + return; // Exit early - don't continue processing + } else { + console.warn('usePolling: Subscription error not handled by global handler, dispatching fallback event'); + try { + window.dispatchEvent(new CustomEvent('subscription-error', { detail: axiosError })); + } catch (eventError) { + console.error('usePolling: Failed to dispatch subscription-error event', eventError); + } + } + } + // Stop polling for task failures and rate limiting if (errorMessage.includes('404') || errorMessage.includes('Task not found')) { setError('Task not found - it may have expired or been cleaned up'); diff --git a/frontend/src/services/blogWriterApi.ts b/frontend/src/services/blogWriterApi.ts index e56d6fbc..6d7860b3 100644 --- a/frontend/src/services/blogWriterApi.ts +++ b/frontend/src/services/blogWriterApi.ts @@ -219,9 +219,22 @@ export interface BlogSEOMetadataResponse { success: boolean; title_options: string[]; meta_descriptions: string[]; + seo_title?: string; + meta_description?: string; + url_slug?: string; + blog_tags: string[]; + blog_categories: string[]; + social_hashtags: string[]; open_graph: Record; twitter_card: Record; - schema: Record; + json_ld_schema?: Record; + schema?: Record; // Legacy field name + canonical_url?: string; + reading_time?: number; + focus_keyword?: string; + generated_at?: string; + optimization_score?: number; + error?: string; } export interface BlogPublishResponse { @@ -241,6 +254,26 @@ export interface TaskStatusResponse { }>; result?: BlogResearchResponse; error?: string; + // Subscription error details (set by backend when subscription limit is exceeded) + error_status?: number; // HTTP status code (429 for usage limit, 402 for subscription expired) + error_data?: { + error?: string; + message?: string; + provider?: string; + usage_info?: { + provider?: string; + current_calls?: number; + limit?: number; + type?: string; + breakdown?: { + gemini?: number; + openai?: number; + anthropic?: number; + mistral?: number; + }; + }; + [key: string]: any; // Allow additional fields + }; } export const blogWriterApi = { diff --git a/frontend/src/utils/wixTokenUtils.ts b/frontend/src/utils/wixTokenUtils.ts new file mode 100644 index 00000000..902a0dfe --- /dev/null +++ b/frontend/src/utils/wixTokenUtils.ts @@ -0,0 +1,198 @@ +/** + * Wix Token Utilities + * Functions for validating and refreshing Wix OAuth tokens + */ + +import { apiClient } from '../api/client'; + +interface WixTokens { + accessToken?: { + value: string; + expiresAt?: string; + }; + refreshToken?: { + value: string; + }; + access_token?: string; + refresh_token?: string; + expires_in?: number; +} + +interface TokenValidationResult { + valid: boolean; + accessToken: string | null; + needsRefresh: boolean; + needsReconnect: boolean; +} + +/** + * Get Wix tokens from sessionStorage + */ +export function getWixTokens(): WixTokens | null { + try { + const tokensRaw = sessionStorage.getItem('wix_tokens'); + if (!tokensRaw) return null; + return JSON.parse(tokensRaw); + } catch (error) { + console.error('Error parsing Wix tokens:', error); + return null; + } +} + +/** + * Extract access token from token structure + */ +export function extractAccessToken(tokens: WixTokens | null): string | null { + if (!tokens) return null; + return tokens.accessToken?.value || tokens.access_token || null; +} + +/** + * Extract refresh token from token structure + */ +export function extractRefreshToken(tokens: WixTokens | null): string | null { + if (!tokens) return null; + return tokens.refreshToken?.value || tokens.refresh_token || null; +} + +/** + * Refresh Wix access token using refresh token + */ +export async function refreshWixToken(refreshToken: string): Promise { + try { + const response = await apiClient.post('/api/wix/refresh-token', { + refresh_token: refreshToken + }); + + if (response.data.success) { + // Create new token structure matching Wix SDK format + const newTokens: WixTokens = { + accessToken: { + value: response.data.access_token + }, + refreshToken: { + value: response.data.refresh_token || refreshToken // Keep old refresh token if new one not provided + }, + access_token: response.data.access_token, + refresh_token: response.data.refresh_token || refreshToken + }; + + // Update sessionStorage + try { + sessionStorage.setItem('wix_tokens', JSON.stringify(newTokens)); + sessionStorage.setItem('wix_connected', 'true'); + } catch (e) { + console.error('Error saving refreshed tokens:', e); + } + + return newTokens; + } + + return null; + } catch (error: any) { + console.error('Error refreshing Wix token:', error); + return null; + } +} + +/** + * Check if token is expired based on expiresAt timestamp + */ +function isTokenExpired(tokens: WixTokens): boolean { + if (tokens.accessToken?.expiresAt) { + try { + const expiresAt = new Date(tokens.accessToken.expiresAt); + return expiresAt < new Date(); + } catch (e) { + // If we can't parse, assume not expired (will validate during publish) + return false; + } + } + // If no expiration info, we can't tell - assume valid for now + // Real validation happens during actual API call + return false; +} + +/** + * Validate and refresh Wix tokens proactively + * Returns access token if valid, or null if needs reconnection + * + * Strategy: + * 1. Check if tokens exist + * 2. Check if token is expired (if expiration info available) + * 3. If expired, attempt refresh + * 4. If refresh fails or no refresh token, needs reconnection + * 5. Real validation happens during actual publish (we catch 401/403 errors) + */ +export async function validateAndRefreshWixTokens(): Promise { + const tokens = getWixTokens(); + + if (!tokens) { + return { + valid: false, + accessToken: null, + needsRefresh: false, + needsReconnect: true + }; + } + + const accessToken = extractAccessToken(tokens); + const refreshToken = extractRefreshToken(tokens); + + if (!accessToken) { + return { + valid: false, + accessToken: null, + needsRefresh: false, + needsReconnect: true + }; + } + + // Check if token is expired (if we have expiration info) + const expired = isTokenExpired(tokens); + + if (!expired) { + // Token appears valid (not expired or no expiration info) + // We'll do real validation during publish + return { + valid: true, + accessToken: accessToken, + needsRefresh: false, + needsReconnect: false + }; + } + + // Token is expired, try to refresh + if (!refreshToken) { + return { + valid: false, + accessToken: null, + needsRefresh: false, + needsReconnect: true + }; + } + + // Attempt to refresh token + const refreshedTokens = await refreshWixToken(refreshToken); + + if (refreshedTokens) { + const newAccessToken = extractAccessToken(refreshedTokens); + if (newAccessToken) { + return { + valid: true, + accessToken: newAccessToken, + needsRefresh: true, + needsReconnect: false + }; + } + } + + // Refresh failed, needs reconnection + return { + valid: false, + accessToken: null, + needsRefresh: false, + needsReconnect: true + }; +} +