Subscription dashboard improvements, AI text generation limit, and other fixes.

This commit is contained in:
ajaysi
2025-11-01 18:01:14 +05:30
parent cdb41aec1b
commit de4328175d
64 changed files with 5809 additions and 444 deletions

View File

@@ -5,10 +5,11 @@ Main router for blog writing operations including research, outline generation,
content creation, SEO analysis, and publishing. content creation, SEO analysis, and publishing.
""" """
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException, Depends
from typing import Any, Dict, List from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from loguru import logger from loguru import logger
from middleware.auth_middleware import get_current_user
from models.blog_models import ( from models.blog_models import (
BlogResearchRequest, BlogResearchRequest,
@@ -64,10 +65,21 @@ class SEOApplyRecommendationsRequest(BaseModel):
@router.post("/seo/apply-recommendations") @router.post("/seo/apply-recommendations")
async def apply_seo_recommendations(request: SEOApplyRecommendationsRequest) -> Dict[str, Any]: async def apply_seo_recommendations(
request: SEOApplyRecommendationsRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""Apply actionable SEO recommendations and return updated content.""" """Apply actionable SEO recommendations and return updated content."""
try: try:
result = await recommendation_applier.apply_recommendations(request.dict()) # Extract Clerk user ID (required)
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
result = await recommendation_applier.apply_recommendations(request.dict(), user_id=user_id)
if not result.get("success"): if not result.get("success"):
raise HTTPException(status_code=500, detail=result.get("error", "Failed to apply recommendations")) raise HTTPException(status_code=500, detail=result.get("error", "Failed to apply recommendations"))
return result return result
@@ -87,13 +99,24 @@ async def health() -> Dict[str, Any]:
# Research Endpoints # Research Endpoints
@router.post("/research/start") @router.post("/research/start")
async def start_research(request: BlogResearchRequest) -> Dict[str, Any]: async def start_research(
request: BlogResearchRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""Start a research operation and return a task ID for polling.""" """Start a research operation and return a task ID for polling."""
try: try:
# TODO: Get user_id from authentication context # Extract Clerk user ID (required)
user_id = "anonymous" # This should come from auth middleware if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
task_id = await task_manager.start_research_task(request, user_id) task_id = await task_manager.start_research_task(request, user_id)
return {"task_id": task_id, "status": "started"} return {"task_id": task_id, "status": "started"}
except HTTPException:
raise
except Exception as e: except Exception as e:
logger.error(f"Failed to start research: {e}") logger.error(f"Failed to start research: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@@ -107,6 +130,50 @@ async def get_research_status(task_id: str) -> Dict[str, Any]:
if status is None: if status is None:
raise HTTPException(status_code=404, detail="Task not found") raise HTTPException(status_code=404, detail="Task not found")
# If task failed with subscription error, return HTTP error so frontend interceptor can catch it
if status.get('status') == 'failed' and status.get('error_status') in [429, 402]:
error_data = status.get('error_data', {}) or {}
error_status = status.get('error_status', 429)
if not isinstance(error_data, dict):
logger.warning(f"Research task {task_id} error_data not dict: {error_data}")
error_data = {'error': str(error_data)}
# Determine provider and usage info
stored_error_message = status.get('error', error_data.get('error'))
provider = error_data.get('provider', 'unknown')
usage_info = error_data.get('usage_info')
if not usage_info:
usage_info = {
'provider': provider,
'message': stored_error_message,
'error_type': error_data.get('error_type', 'unknown')
}
# Include any known fields from error_data
for key in ['current_tokens', 'requested_tokens', 'limit', 'current_calls']:
if key in error_data:
usage_info[key] = error_data[key]
# Build error message for detail
error_msg = error_data.get('message', stored_error_message or 'Subscription limit exceeded')
# Log the subscription error with all context
logger.warning(f"Research task {task_id} failed with subscription error {error_status}: {error_msg}")
logger.warning(f" Provider: {provider}, Usage Info: {usage_info}")
# Use JSONResponse to ensure detail is returned as-is, not wrapped in an array
from fastapi.responses import JSONResponse
return JSONResponse(
status_code=error_status,
content={
'error': error_data.get('error', stored_error_message or 'Subscription limit exceeded'),
'message': error_msg,
'provider': provider,
'usage_info': usage_info
}
)
logger.info(f"Research status request for {task_id}: {status['status']} with {len(status.get('progress_messages', []))} progress messages") logger.info(f"Research status request for {task_id}: {status['status']} with {len(status.get('progress_messages', []))} progress messages")
return status return status
except HTTPException: except HTTPException:
@@ -310,20 +377,46 @@ async def hallucination_check(request: HallucinationCheckRequest) -> Hallucinati
# SEO Endpoints # SEO Endpoints
@router.post("/seo/analyze", response_model=BlogSEOAnalyzeResponse) @router.post("/seo/analyze", response_model=BlogSEOAnalyzeResponse)
async def seo_analyze(request: BlogSEOAnalyzeRequest) -> BlogSEOAnalyzeResponse: async def seo_analyze(
request: BlogSEOAnalyzeRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> BlogSEOAnalyzeResponse:
"""Analyze content for SEO optimization opportunities.""" """Analyze content for SEO optimization opportunities."""
try: try:
return await service.seo_analyze(request) # Extract Clerk user ID (required)
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
return await service.seo_analyze(request, user_id=user_id)
except HTTPException:
raise
except Exception as e: except Exception as e:
logger.error(f"Failed to perform SEO analysis: {e}") logger.error(f"Failed to perform SEO analysis: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.post("/seo/metadata", response_model=BlogSEOMetadataResponse) @router.post("/seo/metadata", response_model=BlogSEOMetadataResponse)
async def seo_metadata(request: BlogSEOMetadataRequest) -> BlogSEOMetadataResponse: async def seo_metadata(
request: BlogSEOMetadataRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> BlogSEOMetadataResponse:
"""Generate SEO metadata for the blog post.""" """Generate SEO metadata for the blog post."""
try: try:
return await service.seo_metadata(request) # Extract Clerk user ID (required)
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
return await service.seo_metadata(request, user_id=user_id)
except HTTPException:
raise
except Exception as e: except Exception as e:
logger.error(f"Failed to generate SEO metadata: {e}") logger.error(f"Failed to generate SEO metadata: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))

View File

@@ -10,6 +10,7 @@ import asyncio
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List from typing import Any, Dict, List
from fastapi import HTTPException
from loguru import logger from loguru import logger
from models.blog_models import ( from models.blog_models import (
@@ -85,6 +86,10 @@ class TaskManager:
response["result"] = task["result"] response["result"] = task["result"]
elif task["status"] == "failed": elif task["status"] == "failed":
response["error"] = task["error"] response["error"] = task["error"]
if "error_status" in task:
response["error_status"] = task["error_status"]
if "error_data" in task:
response["error_data"] = task["error_data"]
return response return response
@@ -109,14 +114,17 @@ class TaskManager:
logger.info(f"Progress update for task {task_id}: {message}") logger.info(f"Progress update for task {task_id}: {message}")
async def start_research_task(self, request: BlogResearchRequest, user_id: str = "anonymous") -> str: async def start_research_task(self, request: BlogResearchRequest, user_id: str) -> str:
"""Start a research operation and return a task ID.""" """Start a research operation and return a task ID."""
if self.use_database: if self.use_database:
return await self.db_manager.start_research_task(request, user_id) return await self.db_manager.start_research_task(request, user_id)
else: else:
task_id = self.create_task("research") task_id = self.create_task("research")
# Store user_id in task for subscription checks
if task_id in self.task_storage:
self.task_storage[task_id]["user_id"] = user_id
# Start the research operation in the background # Start the research operation in the background
asyncio.create_task(self._run_research_task(task_id, request)) asyncio.create_task(self._run_research_task(task_id, request, user_id))
return task_id return task_id
def start_outline_task(self, request: BlogOutlineRequest) -> str: def start_outline_task(self, request: BlogOutlineRequest) -> str:
@@ -144,7 +152,7 @@ class TaskManager:
asyncio.create_task(self._run_medium_generation_task(task_id, request)) asyncio.create_task(self._run_medium_generation_task(task_id, request))
return task_id return task_id
async def _run_research_task(self, task_id: str, request: BlogResearchRequest): async def _run_research_task(self, task_id: str, request: BlogResearchRequest, user_id: str):
"""Background task to run research and update status with progress messages.""" """Background task to run research and update status with progress messages."""
try: try:
# Update status to running # Update status to running
@@ -157,8 +165,8 @@ class TaskManager:
# Check cache first # Check cache first
await self.update_progress(task_id, "📋 Checking cache for existing research...") await self.update_progress(task_id, "📋 Checking cache for existing research...")
# Run the actual research with progress updates # Run the actual research with progress updates (pass user_id for subscription checks)
result = await self.service.research_with_progress(request, task_id) result = await self.service.research_with_progress(request, task_id, user_id)
# Check if research failed gracefully # Check if research failed gracefully
if not result.success: if not result.success:
@@ -171,6 +179,16 @@ class TaskManager:
self.task_storage[task_id]["status"] = "completed" self.task_storage[task_id]["status"] = "completed"
self.task_storage[task_id]["result"] = result.dict() self.task_storage[task_id]["result"] = result.dict()
except HTTPException as http_error:
# Handle HTTPException (e.g., 429 subscription limit) - preserve error details for frontend
error_detail = http_error.detail
error_message = error_detail.get('message', str(error_detail)) if isinstance(error_detail, dict) else str(error_detail)
await self.update_progress(task_id, f"{error_message}")
self.task_storage[task_id]["status"] = "failed"
self.task_storage[task_id]["error"] = error_message
# Store HTTP error details for frontend modal
self.task_storage[task_id]["error_status"] = http_error.status_code
self.task_storage[task_id]["error_data"] = error_detail if isinstance(error_detail, dict) else {"error": str(error_detail)}
except Exception as e: except Exception as e:
await self.update_progress(task_id, f"❌ Research failed with error: {str(e)}") await self.update_progress(task_id, f"❌ Research failed with error: {str(e)}")
# Update status to failed # Update status to failed

View File

@@ -1,5 +1,5 @@
from fastapi import APIRouter, HTTPException, Depends, Query, Body from fastapi import APIRouter, HTTPException, Depends, Query, Body
from typing import Dict, Any from typing import Dict, Any, Optional
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -64,6 +64,15 @@ async def activate_strategy_with_monitoring(
if not monitoring_success: if not monitoring_success:
logger.warning(f"Failed to save monitoring data for strategy {strategy_id}") logger.warning(f"Failed to save monitoring data for strategy {strategy_id}")
# Trigger scheduler interval adjustment (scheduler will check more frequently now)
try:
from services.scheduler import get_scheduler
scheduler = get_scheduler()
await scheduler.trigger_interval_adjustment()
logger.info(f"Triggered scheduler interval adjustment after strategy {strategy_id} activation")
except Exception as e:
logger.warning(f"Could not trigger scheduler interval adjustment: {e}")
logger.info(f"Successfully activated strategy {strategy_id} with monitoring") logger.info(f"Successfully activated strategy {strategy_id} with monitoring")
return { return {
"success": True, "success": True,
@@ -396,6 +405,150 @@ async def get_monitoring_tasks(
logger.error(f"Error retrieving monitoring tasks: {str(e)}") logger.error(f"Error retrieving monitoring tasks: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error") raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/user/{user_id}/monitoring-tasks")
async def get_user_monitoring_tasks(
user_id: int,
db: Session = Depends(get_db),
status: Optional[str] = Query(None, description="Filter by task status"),
limit: int = Query(50, description="Maximum number of tasks to return"),
offset: int = Query(0, description="Number of tasks to skip")
):
"""
Get all monitoring tasks for a specific user with their execution status.
Uses the scheduler's task loader to get tasks filtered by user_id for proper user isolation.
"""
try:
logger.info(f"Getting monitoring tasks for user {user_id}")
# Use scheduler task loader for user-specific tasks
from services.scheduler.utils.task_loader import load_due_monitoring_tasks
# Load all tasks for user (not just due tasks - we want all user tasks)
# Join with strategy to filter by user
tasks_query = db.query(MonitoringTask).join(
EnhancedContentStrategy,
MonitoringTask.strategy_id == EnhancedContentStrategy.id
).filter(
EnhancedContentStrategy.user_id == user_id
)
# Apply status filter if provided
if status:
tasks_query = tasks_query.filter(MonitoringTask.status == status)
# Get tasks with pagination
tasks = tasks_query.order_by(desc(MonitoringTask.created_at)).offset(offset).limit(limit).all()
tasks_data = []
for task in tasks:
# Get latest execution log
latest_log = db.query(TaskExecutionLog).filter(
TaskExecutionLog.task_id == task.id
).order_by(desc(TaskExecutionLog.execution_date)).first()
# Get strategy info
strategy = db.query(EnhancedContentStrategy).filter(
EnhancedContentStrategy.id == task.strategy_id
).first()
task_data = {
"id": task.id,
"strategy_id": task.strategy_id,
"strategy_name": strategy.name if strategy else None,
"title": task.task_title,
"description": task.task_description,
"assignee": task.assignee,
"frequency": task.frequency,
"metric": task.metric,
"measurementMethod": task.measurement_method,
"successCriteria": task.success_criteria,
"alertThreshold": task.alert_threshold,
"status": task.status,
"lastExecuted": latest_log.execution_date.isoformat() if latest_log else None,
"nextExecution": task.next_execution.isoformat() if task.next_execution else None,
"executionCount": db.query(TaskExecutionLog).filter(
TaskExecutionLog.task_id == task.id
).count(),
"created_at": task.created_at.isoformat() if task.created_at else None
}
tasks_data.append(task_data)
# Get total count for pagination
total_count = db.query(MonitoringTask).join(
EnhancedContentStrategy,
MonitoringTask.strategy_id == EnhancedContentStrategy.id
).filter(
EnhancedContentStrategy.user_id == user_id
)
if status:
total_count = total_count.filter(MonitoringTask.status == status)
total_count = total_count.count()
return {
"success": True,
"data": tasks_data,
"pagination": {
"total": total_count,
"limit": limit,
"offset": offset,
"has_more": (offset + len(tasks_data)) < total_count
},
"message": f"Retrieved {len(tasks_data)} monitoring tasks for user {user_id}"
}
except Exception as e:
logger.error(f"Error retrieving user monitoring tasks: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to retrieve monitoring tasks: {str(e)}")
@router.get("/user/{user_id}/execution-logs")
async def get_user_execution_logs(
user_id: int,
db: Session = Depends(get_db),
status: Optional[str] = Query(None, description="Filter by execution status"),
limit: int = Query(50, description="Maximum number of logs to return"),
offset: int = Query(0, description="Number of logs to skip")
):
"""
Get execution logs for a specific user.
Provides user isolation by filtering execution logs by user_id.
"""
try:
logger.info(f"Getting execution logs for user {user_id}")
monitoring_service = MonitoringDataService(db)
logs_data = monitoring_service.get_user_execution_logs(
user_id=user_id,
limit=limit,
offset=offset,
status_filter=status
)
# Get total count for pagination
count_query = db.query(TaskExecutionLog).filter(
TaskExecutionLog.user_id == user_id
)
if status:
count_query = count_query.filter(TaskExecutionLog.status == status)
total_count = count_query.count()
return {
"success": True,
"data": logs_data,
"pagination": {
"total": total_count,
"limit": limit,
"offset": offset,
"has_more": (offset + len(logs_data)) < total_count
},
"message": f"Retrieved {len(logs_data)} execution logs for user {user_id}"
}
except Exception as e:
logger.error(f"Error retrieving execution logs for user {user_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to retrieve execution logs: {str(e)}")
@router.get("/{strategy_id}/data-freshness") @router.get("/{strategy_id}/data-freshness")
async def get_data_freshness( async def get_data_freshness(
strategy_id: int, strategy_id: int,

View File

@@ -3,13 +3,18 @@ from __future__ import annotations
import base64 import base64
import os import os
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
from datetime import datetime
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from services.llm_providers.main_image_generation import generate_image from services.llm_providers.main_image_generation import generate_image
from services.llm_providers.main_text_generation import llm_text_gen from services.llm_providers.main_text_generation import llm_text_gen
from utils.logger_utils import get_service_logger from utils.logger_utils import get_service_logger
from middleware.auth_middleware import get_current_user
from services.database import get_db
from services.subscription import UsageTrackingService, PricingService
from models.subscription_models import APIProvider, UsageSummary
router = APIRouter(prefix="/api/images", tags=["images"]) router = APIRouter(prefix="/api/images", tags=["images"])
@@ -39,9 +44,23 @@ class ImageGenerateResponse(BaseModel):
@router.post("/generate", response_model=ImageGenerateResponse) @router.post("/generate", response_model=ImageGenerateResponse)
def generate(req: ImageGenerateRequest) -> ImageGenerateResponse: def generate(
req: ImageGenerateRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> ImageGenerateResponse:
"""Generate image with subscription checking."""
try: try:
# Extract Clerk user ID (required)
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
# Validation is now handled inside generate_image function
last_error: Optional[Exception] = None last_error: Optional[Exception] = None
result = None
for attempt in range(2): # simple single retry for attempt in range(2): # simple single retry
try: try:
result = generate_image( result = generate_image(
@@ -56,8 +75,79 @@ def generate(req: ImageGenerateRequest) -> ImageGenerateResponse:
"steps": req.steps, "steps": req.steps,
"seed": req.seed, "seed": req.seed,
}, },
user_id=user_id, # Pass user_id for validation inside generate_image
) )
image_b64 = base64.b64encode(result.image_bytes).decode("utf-8") image_b64 = base64.b64encode(result.image_bytes).decode("utf-8")
# TRACK USAGE after successful image generation
if result:
logger.info(f"[images.generate] ✅ Image generation successful, tracking usage for user {user_id}")
try:
db_track = next(get_db())
try:
# Get or create usage summary
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
logger.debug(f"[images.generate] Looking for usage summary: user_id={user_id}, period={current_period}")
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
logger.info(f"[images.generate] Creating new usage summary for user {user_id}, period {current_period}")
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush() # Ensure summary is persisted before updating
# Get "before" state for unified log
current_calls_before = getattr(summary, "stability_calls", 0) or 0
# Update provider-specific counters (stability for image generation)
# Note: All image generation goes through STABILITY provider enum regardless of actual provider
new_calls = current_calls_before + 1
setattr(summary, "stability_calls", new_calls)
logger.debug(f"[images.generate] Updated stability_calls: {current_calls_before} -> {new_calls}")
# Update totals
old_total_calls = summary.total_calls or 0
summary.total_calls = old_total_calls + 1
logger.debug(f"[images.generate] Updated totals: calls {old_total_calls} -> {summary.total_calls}")
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
call_limit = limits['limits'].get("stability_calls", 0) if limits else 0
db_track.commit()
logger.info(f"[images.generate] ✅ Successfully tracked usage: user {user_id} -> stability -> {new_calls} calls")
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
print(f"""
[SUBSCRIPTION] Image Generation
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: stability
├─ Actual Provider: {result.provider}
├─ Model: {result.model or 'default'}
├─ Calls: {current_calls_before}{new_calls} / {call_limit if call_limit > 0 else ''}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:
logger.error(f"[images.generate] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
# Non-blocking: log error but don't fail the request
logger.error(f"[images.generate] ❌ Failed to track usage: {usage_error}", exc_info=True)
return ImageGenerateResponse( return ImageGenerateResponse(
image_base64=image_b64, image_base64=image_b64,
width=result.width, width=result.width,
@@ -106,7 +196,10 @@ class ImagePromptSuggestResponse(BaseModel):
@router.post("/suggest-prompts", response_model=ImagePromptSuggestResponse) @router.post("/suggest-prompts", response_model=ImagePromptSuggestResponse)
def suggest_prompts(req: ImagePromptSuggestRequest) -> ImagePromptSuggestResponse: def suggest_prompts(
req: ImagePromptSuggestRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> ImagePromptSuggestResponse:
try: try:
provider = (req.provider or ("gemini" if (os.getenv("GPT_PROVIDER") or "").lower().startswith("gemini") else "huggingface")).lower() provider = (req.provider or ("gemini" if (os.getenv("GPT_PROVIDER") or "").lower().startswith("gemini") else "huggingface")).lower()
section = req.section or {} section = req.section or {}
@@ -203,7 +296,15 @@ def suggest_prompts(req: ImagePromptSuggestRequest) -> ImagePromptSuggestRespons
If including on-image text, return it in overlay_text (short: <= 8 words). If including on-image text, return it in overlay_text (short: <= 8 words).
""" """
raw = llm_text_gen(prompt=prompt, system_prompt=system, json_struct=schema) # Get user_id for llm_text_gen subscription check (required)
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id_for_llm = str(current_user.get('id', ''))
if not user_id_for_llm:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
raw = llm_text_gen(prompt=prompt, system_prompt=system, json_struct=schema, user_id=user_id_for_llm)
data = raw if isinstance(raw, dict) else {} data = raw if isinstance(raw, dict) else {}
suggestions = data.get("suggestions") or [] suggestions = data.get("suggestions") or []
# basic fallback if provider returns string # basic fallback if provider returns string

View File

@@ -94,6 +94,7 @@ async def get_subscription_plans(
"description": plan.description, "description": plan.description,
"features": plan.features or [], "features": plan.features or [],
"limits": { "limits": {
"ai_text_generation_calls": getattr(plan, 'ai_text_generation_calls_limit', None) or 0,
"gemini_calls": plan.gemini_calls_limit, "gemini_calls": plan.gemini_calls_limit,
"openai_calls": plan.openai_calls_limit, "openai_calls": plan.openai_calls_limit,
"anthropic_calls": plan.anthropic_calls_limit, "anthropic_calls": plan.anthropic_calls_limit,
@@ -162,6 +163,7 @@ async def get_user_subscription(
}, },
"status": "free", "status": "free",
"limits": { "limits": {
"ai_text_generation_calls": getattr(free_plan, 'ai_text_generation_calls_limit', None) or 0,
"gemini_calls": free_plan.gemini_calls_limit, "gemini_calls": free_plan.gemini_calls_limit,
"openai_calls": free_plan.openai_calls_limit, "openai_calls": free_plan.openai_calls_limit,
"anthropic_calls": free_plan.anthropic_calls_limit, "anthropic_calls": free_plan.anthropic_calls_limit,
@@ -200,6 +202,7 @@ async def get_user_subscription(
"is_free": False "is_free": False
}, },
"limits": { "limits": {
"ai_text_generation_calls": getattr(subscription.plan, 'ai_text_generation_calls_limit', None) or 0,
"gemini_calls": subscription.plan.gemini_calls_limit, "gemini_calls": subscription.plan.gemini_calls_limit,
"openai_calls": subscription.plan.openai_calls_limit, "openai_calls": subscription.plan.openai_calls_limit,
"anthropic_calls": subscription.plan.anthropic_calls_limit, "anthropic_calls": subscription.plan.anthropic_calls_limit,
@@ -252,6 +255,7 @@ async def get_subscription_status(
"tier": "free", "tier": "free",
"can_use_api": True, "can_use_api": True,
"limits": { "limits": {
"ai_text_generation_calls": getattr(free_plan, 'ai_text_generation_calls_limit', None) or 0,
"gemini_calls": free_plan.gemini_calls_limit, "gemini_calls": free_plan.gemini_calls_limit,
"openai_calls": free_plan.openai_calls_limit, "openai_calls": free_plan.openai_calls_limit,
"anthropic_calls": free_plan.anthropic_calls_limit, "anthropic_calls": free_plan.anthropic_calls_limit,
@@ -309,6 +313,7 @@ async def get_subscription_status(
"tier": subscription.plan.tier.value, "tier": subscription.plan.tier.value,
"can_use_api": True, "can_use_api": True,
"limits": { "limits": {
"ai_text_generation_calls": getattr(subscription.plan, 'ai_text_generation_calls_limit', None) or 0,
"gemini_calls": subscription.plan.gemini_calls_limit, "gemini_calls": subscription.plan.gemini_calls_limit,
"openai_calls": subscription.plan.openai_calls_limit, "openai_calls": subscription.plan.openai_calls_limit,
"anthropic_calls": subscription.plan.anthropic_calls_limit, "anthropic_calls": subscription.plan.anthropic_calls_limit,
@@ -331,9 +336,14 @@ async def get_subscription_status(
async def subscribe_to_plan( async def subscribe_to_plan(
user_id: str, user_id: str,
subscription_data: dict, subscription_data: dict,
db: Session = Depends(get_db) db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Create or update a user's subscription.""" """Create or update a user's subscription (renewal)."""
# Verify user can only subscribe/renew their own subscription
if current_user.get('id') != user_id:
raise HTTPException(status_code=403, detail="Access denied")
try: try:
plan_id = subscription_data.get('plan_id') plan_id = subscription_data.get('plan_id')
@@ -388,12 +398,75 @@ async def subscribe_to_plan(
db.commit() db.commit()
# Get current usage BEFORE reset for logging
current_period = datetime.utcnow().strftime("%Y-%m")
usage_before = db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
# Log renewal request details
logger.info("=" * 80)
logger.info(f"[SUBSCRIPTION RENEWAL] 🔄 Processing renewal request")
logger.info(f" ├─ User: {user_id}")
logger.info(f" ├─ Plan: {plan.name} (ID: {plan_id}, Tier: {plan.tier.value})")
logger.info(f" ├─ Billing Cycle: {billing_cycle}")
logger.info(f" ├─ Period Start: {now.strftime('%Y-%m-%d %H:%M:%S')}")
logger.info(f" └─ Period End: {subscription.current_period_end.strftime('%Y-%m-%d %H:%M:%S')}")
if usage_before:
logger.info(f" 📊 Current Usage BEFORE Reset (Period: {current_period}):")
logger.info(f" ├─ Gemini: {usage_before.gemini_tokens or 0} tokens / {usage_before.gemini_calls or 0} calls")
logger.info(f" ├─ Mistral/HF: {usage_before.mistral_tokens or 0} tokens / {usage_before.mistral_calls or 0} calls")
logger.info(f" ├─ OpenAI: {usage_before.openai_tokens or 0} tokens / {usage_before.openai_calls or 0} calls")
logger.info(f" ├─ Stability (Images): {usage_before.stability_calls or 0} calls")
logger.info(f" ├─ Total Tokens: {usage_before.total_tokens or 0}")
logger.info(f" ├─ Total Calls: {usage_before.total_calls or 0}")
logger.info(f" └─ Usage Status: {usage_before.usage_status.value}")
else:
logger.info(f" 📊 No usage summary found for period {current_period} (will be created on reset)")
# Clear subscription limits cache to force refresh on next check
try:
from services.subscription import PricingService
# Clear cache for this specific user (class-level cache shared across all instances)
cleared_count = PricingService.clear_user_cache(user_id)
logger.info(f" 🗑️ Cleared {cleared_count} subscription cache entries for user {user_id}")
except Exception as cache_err:
logger.error(f" ❌ Failed to clear cache after subscribe: {cache_err}")
# Reset usage status for current billing period so new plan takes effect immediately # Reset usage status for current billing period so new plan takes effect immediately
reset_result = None
try: try:
usage_service = UsageTrackingService(db) usage_service = UsageTrackingService(db)
await usage_service.reset_current_billing_period(user_id) reset_result = await usage_service.reset_current_billing_period(user_id)
# Re-query usage summary from DB after reset to get fresh data
usage_after = db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if reset_result.get('reset'):
logger.info(f" ✅ Usage counters RESET successfully")
if usage_after:
logger.info(f" 📊 New Usage AFTER Reset:")
logger.info(f" ├─ Gemini: {usage_after.gemini_tokens or 0} tokens / {usage_after.gemini_calls or 0} calls")
logger.info(f" ├─ Mistral/HF: {usage_after.mistral_tokens or 0} tokens / {usage_after.mistral_calls or 0} calls")
logger.info(f" ├─ OpenAI: {usage_after.openai_tokens or 0} tokens / {usage_after.openai_calls or 0} calls")
logger.info(f" ├─ Stability (Images): {usage_after.stability_calls or 0} calls")
logger.info(f" ├─ Total Tokens: {usage_after.total_tokens or 0}")
logger.info(f" ├─ Total Calls: {usage_after.total_calls or 0}")
logger.info(f" └─ Usage Status: {usage_after.usage_status.value}")
else:
logger.warning(f" ⚠️ Usage summary not found after reset - may need to be created on next API call")
else:
logger.warning(f" ⚠️ Reset returned: {reset_result.get('reason', 'unknown')}")
except Exception as reset_err: except Exception as reset_err:
logger.error(f"Failed to reset usage after subscribe: {reset_err}") logger.error(f"Failed to reset usage after subscribe: {reset_err}", exc_info=True)
logger.info(f" ✅ Renewal completed: User {user_id}{plan.name} ({billing_cycle})")
logger.info("=" * 80)
return { return {
"success": True, "success": True,
@@ -404,7 +477,20 @@ async def subscribe_to_plan(
"billing_cycle": billing_cycle, "billing_cycle": billing_cycle,
"current_period_start": subscription.current_period_start.isoformat(), "current_period_start": subscription.current_period_start.isoformat(),
"current_period_end": subscription.current_period_end.isoformat(), "current_period_end": subscription.current_period_end.isoformat(),
"status": subscription.status.value "status": subscription.status.value,
"limits": {
"ai_text_generation_calls": getattr(plan, 'ai_text_generation_calls_limit', None) or 0,
"gemini_calls": plan.gemini_calls_limit,
"openai_calls": plan.openai_calls_limit,
"anthropic_calls": plan.anthropic_calls_limit,
"mistral_calls": plan.mistral_calls_limit,
"tavily_calls": plan.tavily_calls_limit,
"serper_calls": plan.serper_calls_limit,
"metaphor_calls": plan.metaphor_calls_limit,
"firecrawl_calls": plan.firecrawl_calls_limit,
"stability_calls": plan.stability_calls_limit,
"monthly_cost": plan.monthly_cost_limit
}
} }
} }

View File

@@ -477,6 +477,39 @@ async def test_publish_to_wix(request: WixPublishRequest) -> Dict[str, Any]:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.post("/refresh-token")
async def refresh_wix_token(request: Dict[str, Any]) -> Dict[str, Any]:
"""
Refresh Wix access token using refresh token
Args:
request: Dict containing refresh_token
Returns:
New token information with access_token, refresh_token, expires_in
"""
try:
refresh_token = request.get("refresh_token")
if not refresh_token:
raise HTTPException(status_code=400, detail="Missing refresh_token")
# Refresh the token
new_tokens = wix_service.refresh_access_token(refresh_token)
return {
"success": True,
"access_token": new_tokens.get("access_token"),
"refresh_token": new_tokens.get("refresh_token"),
"expires_in": new_tokens.get("expires_in"),
"token_type": new_tokens.get("token_type", "Bearer")
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to refresh Wix token: {e}")
raise HTTPException(status_code=500, detail=f"Failed to refresh token: {str(e)}")
@router.post("/test/publish/real") @router.post("/test/publish/real")
async def test_publish_real(payload: Dict[str, Any]) -> Dict[str, Any]: async def test_publish_real(payload: Dict[str, Any]) -> Dict[str, Any]:
""" """

View File

@@ -298,6 +298,11 @@ async def startup_event():
try: try:
# Initialize database # Initialize database
init_database() init_database()
# Start task scheduler
from services.scheduler import get_scheduler
await get_scheduler().start()
logger.info("ALwrity backend started successfully") logger.info("ALwrity backend started successfully")
except Exception as e: except Exception as e:
logger.error(f"Error during startup: {e}") logger.error(f"Error during startup: {e}")
@@ -307,6 +312,10 @@ async def startup_event():
async def shutdown_event(): async def shutdown_event():
"""Cleanup on shutdown.""" """Cleanup on shutdown."""
try: try:
# Stop task scheduler
from services.scheduler import get_scheduler
await get_scheduler().stop()
# Close database connections # Close database connections
close_database() close_database()
logger.info("ALwrity backend shutdown successfully") logger.info("ALwrity backend shutdown successfully")

View File

@@ -0,0 +1,20 @@
-- Migration: Add user_id column to task_execution_logs for user isolation
-- Date: 2025-01-XX
-- Purpose: Enable user isolation tracking in scheduler task execution logs
-- Add user_id column (nullable for backward compatibility with existing records)
ALTER TABLE task_execution_logs
ADD COLUMN user_id INTEGER NULL;
-- Create index for efficient user filtering and queries
CREATE INDEX IF NOT EXISTS idx_task_execution_logs_user_id
ON task_execution_logs(user_id);
-- Create composite index for common query patterns (user_id + status + execution_date)
CREATE INDEX IF NOT EXISTS idx_task_execution_logs_user_status_date
ON task_execution_logs(user_id, status, execution_date);
-- Note: Backfilling existing records would require joining with monitoring_tasks
-- and enhanced_content_strategies tables. This can be done in a separate migration
-- or during a maintenance window. For now, existing records will have user_id = NULL.

View File

@@ -65,7 +65,7 @@ class LinkedInPostRequest(BaseModel):
persona_override: Optional[Dict[str, Any]] = Field(default=None, description="Session-only persona overrides to apply without saving") persona_override: Optional[Dict[str, Any]] = Field(default=None, description="Session-only persona overrides to apply without saving")
class Config: class Config:
schema_extra = { json_schema_extra = {
"example": { "example": {
"topic": "AI in healthcare transformation", "topic": "AI in healthcare transformation",
"industry": "Healthcare", "industry": "Healthcare",
@@ -102,7 +102,7 @@ class LinkedInArticleRequest(BaseModel):
persona_override: Optional[Dict[str, Any]] = Field(default=None, description="Session-only persona overrides to apply without saving") persona_override: Optional[Dict[str, Any]] = Field(default=None, description="Session-only persona overrides to apply without saving")
class Config: class Config:
schema_extra = { json_schema_extra = {
"example": { "example": {
"topic": "Digital transformation in manufacturing", "topic": "Digital transformation in manufacturing",
"industry": "Manufacturing", "industry": "Manufacturing",
@@ -135,7 +135,7 @@ class LinkedInCarouselRequest(BaseModel):
include_citations: bool = Field(default=True, description="Whether to include inline citations") include_citations: bool = Field(default=True, description="Whether to include inline citations")
class Config: class Config:
schema_extra = { json_schema_extra = {
"example": { "example": {
"topic": "Future of remote work", "topic": "Future of remote work",
"industry": "Technology", "industry": "Technology",
@@ -167,7 +167,7 @@ class LinkedInVideoScriptRequest(BaseModel):
include_citations: bool = Field(default=True, description="Whether to include inline citations") include_citations: bool = Field(default=True, description="Whether to include inline citations")
class Config: class Config:
schema_extra = { json_schema_extra = {
"example": { "example": {
"topic": "Cybersecurity best practices", "topic": "Cybersecurity best practices",
"industry": "Technology", "industry": "Technology",
@@ -197,7 +197,7 @@ class LinkedInCommentResponseRequest(BaseModel):
grounding_level: GroundingLevel = Field(default=GroundingLevel.BASIC, description="Level of content grounding") grounding_level: GroundingLevel = Field(default=GroundingLevel.BASIC, description="Level of content grounding")
class Config: class Config:
schema_extra = { json_schema_extra = {
"example": { "example": {
"original_comment": "Great insights on AI implementation!", "original_comment": "Great insights on AI implementation!",
"post_context": "Post about AI transformation in healthcare", "post_context": "Post about AI transformation in healthcare",
@@ -353,7 +353,7 @@ class LinkedInPostResponse(BaseModel):
grounding_status: Optional[Dict[str, Any]] = Field(None, description="Grounding operation status") grounding_status: Optional[Dict[str, Any]] = Field(None, description="Grounding operation status")
class Config: class Config:
schema_extra = { json_schema_extra = {
"example": { "example": {
"success": True, "success": True,
"data": { "data": {

View File

@@ -48,8 +48,9 @@ class TaskExecutionLog(Base):
id = Column(Integer, primary_key=True, index=True) id = Column(Integer, primary_key=True, index=True)
task_id = Column(Integer, ForeignKey("monitoring_tasks.id"), nullable=False) task_id = Column(Integer, ForeignKey("monitoring_tasks.id"), nullable=False)
user_id = Column(Integer, nullable=True) # User ID for user isolation (nullable for backward compatibility)
execution_date = Column(DateTime, default=datetime.utcnow) execution_date = Column(DateTime, default=datetime.utcnow)
status = Column(String(50), nullable=False) # 'success', 'failed', 'skipped' status = Column(String(50), nullable=False) # 'success', 'failed', 'skipped', 'running'
result_data = Column(JSON, nullable=True) result_data = Column(JSON, nullable=True)
error_message = Column(Text, nullable=True) error_message = Column(Text, nullable=True)
execution_time_ms = Column(Integer, nullable=True) execution_time_ms = Column(Integer, nullable=True)

View File

@@ -50,16 +50,22 @@ class SubscriptionPlan(Base):
price_monthly = Column(Float, nullable=False, default=0.0) price_monthly = Column(Float, nullable=False, default=0.0)
price_yearly = Column(Float, nullable=False, default=0.0) price_yearly = Column(Float, nullable=False, default=0.0)
# API Call Limits # Unified AI Text Generation Call Limit (applies to all LLM providers: gemini, openai, anthropic, mistral)
gemini_calls_limit = Column(Integer, default=0) # 0 = unlimited # Note: This column may not exist in older databases - use getattr() when accessing
openai_calls_limit = Column(Integer, default=0) ai_text_generation_calls_limit = Column(Integer, default=0, nullable=True) # 0 = unlimited, None if column doesn't exist
anthropic_calls_limit = Column(Integer, default=0)
mistral_calls_limit = Column(Integer, default=0) # Legacy per-provider limits (kept for backwards compatibility and analytics)
gemini_calls_limit = Column(Integer, default=0) # 0 = unlimited (deprecated, use ai_text_generation_calls_limit)
openai_calls_limit = Column(Integer, default=0) # (deprecated, use ai_text_generation_calls_limit)
anthropic_calls_limit = Column(Integer, default=0) # (deprecated, use ai_text_generation_calls_limit)
mistral_calls_limit = Column(Integer, default=0) # (deprecated, use ai_text_generation_calls_limit)
# Other API Call Limits (non-LLM)
tavily_calls_limit = Column(Integer, default=0) tavily_calls_limit = Column(Integer, default=0)
serper_calls_limit = Column(Integer, default=0) serper_calls_limit = Column(Integer, default=0)
metaphor_calls_limit = Column(Integer, default=0) metaphor_calls_limit = Column(Integer, default=0)
firecrawl_calls_limit = Column(Integer, default=0) firecrawl_calls_limit = Column(Integer, default=0)
stability_calls_limit = Column(Integer, default=0) stability_calls_limit = Column(Integer, default=0) # Image generation
# Token Limits (for LLM providers) # Token Limits (for LLM providers)
gemini_tokens_limit = Column(Integer, default=0) gemini_tokens_limit = Column(Integer, default=0)

View File

@@ -63,6 +63,9 @@ pytest-asyncio>=0.21.0
pydantic>=2.5.2,<3.0.0 pydantic>=2.5.2,<3.0.0
typing-extensions>=4.8.0 typing-extensions>=4.8.0
# Task scheduling
apscheduler>=3.10.0
# Optional dependencies (for enhanced features) # Optional dependencies (for enhanced features)
redis>=5.0.0 redis>=5.0.0
schedule>=1.2.0 schedule>=1.2.0

View File

@@ -0,0 +1,146 @@
"""
Migration Script: Add ai_text_generation_calls_limit column to subscription_plans table.
This adds the unified AI text generation limit column that applies to all LLM providers
(gemini, openai, anthropic, mistral) instead of per-provider limits.
"""
import sys
import os
from pathlib import Path
from datetime import datetime, timezone
# Add the backend directory to Python path
backend_dir = Path(__file__).parent.parent
sys.path.insert(0, str(backend_dir))
from sqlalchemy import create_engine, text, inspect
from sqlalchemy.orm import sessionmaker
from loguru import logger
from models.subscription_models import SubscriptionPlan, SubscriptionTier
from services.database import DATABASE_URL
def add_ai_text_generation_limit_column():
"""Add ai_text_generation_calls_limit column to subscription_plans table."""
try:
engine = create_engine(DATABASE_URL, echo=False)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = SessionLocal()
try:
# Check if column already exists
inspector = inspect(engine)
columns = [col['name'] for col in inspector.get_columns('subscription_plans')]
if 'ai_text_generation_calls_limit' in columns:
logger.info("✅ Column 'ai_text_generation_calls_limit' already exists. Skipping migration.")
return True
logger.info("📋 Adding 'ai_text_generation_calls_limit' column to subscription_plans table...")
# Add the column (SQLite compatible)
alter_query = text("""
ALTER TABLE subscription_plans
ADD COLUMN ai_text_generation_calls_limit INTEGER DEFAULT 0
""")
db.execute(alter_query)
db.commit()
logger.info("✅ Column added successfully!")
# Update existing plans with unified limits based on their current limits
logger.info("\n🔄 Updating existing subscription plans with unified limits...")
plans = db.query(SubscriptionPlan).all()
updated_count = 0
for plan in plans:
# Use the first non-zero LLM provider limit as the unified limit
# Or use gemini_calls_limit as default
unified_limit = (
plan.ai_text_generation_calls_limit or
plan.gemini_calls_limit or
plan.openai_calls_limit or
plan.anthropic_calls_limit or
plan.mistral_calls_limit or
0
)
# For Basic plan, ensure it's set to 10 (from our recent update)
if plan.tier == SubscriptionTier.BASIC:
unified_limit = 10
if plan.ai_text_generation_calls_limit != unified_limit:
plan.ai_text_generation_calls_limit = unified_limit
plan.updated_at = datetime.now(timezone.utc)
updated_count += 1
logger.info(f" ✅ Updated {plan.name} ({plan.tier.value}): ai_text_generation_calls_limit = {unified_limit}")
else:
logger.info(f" {plan.name} ({plan.tier.value}): already set to {unified_limit}")
if updated_count > 0:
db.commit()
logger.info(f"\n✅ Updated {updated_count} subscription plan(s)")
else:
logger.info("\n No plans needed updating")
# Display summary
logger.info("\n" + "="*60)
logger.info("MIGRATION SUMMARY")
logger.info("="*60)
all_plans = db.query(SubscriptionPlan).all()
for plan in all_plans:
logger.info(f"\n{plan.name} ({plan.tier.value}):")
logger.info(f" Unified AI Text Gen Limit: {plan.ai_text_generation_calls_limit if plan.ai_text_generation_calls_limit else 'Not set'}")
logger.info(f" Legacy Limits: gemini={plan.gemini_calls_limit}, mistral={plan.mistral_calls_limit}")
logger.info("\n" + "="*60)
logger.info("✅ Migration completed successfully!")
return True
except Exception as e:
db.rollback()
logger.error(f"❌ Error during migration: {e}")
import traceback
logger.error(traceback.format_exc())
raise
finally:
db.close()
except Exception as e:
logger.error(f"❌ Failed to connect to database: {e}")
import traceback
logger.error(traceback.format_exc())
return False
if __name__ == "__main__":
logger.info("🚀 Starting ai_text_generation_calls_limit column migration...")
logger.info("="*60)
logger.info("This will add the unified AI text generation limit column")
logger.info("and update existing plans with appropriate values.")
logger.info("="*60)
try:
success = add_ai_text_generation_limit_column()
if success:
logger.info("\n✅ Script completed successfully!")
sys.exit(0)
else:
logger.error("\n❌ Script failed!")
sys.exit(1)
except KeyboardInterrupt:
logger.info("\n⚠️ Script cancelled by user")
sys.exit(1)
except Exception as e:
logger.error(f"\n❌ Unexpected error: {e}")
sys.exit(1)

View File

@@ -0,0 +1,210 @@
"""
Standalone script to cap usage counters at new Basic plan limits.
This preserves historical usage data but caps it at the new limits so users
can continue making new calls within their limits.
"""
import sys
import os
from pathlib import Path
from datetime import datetime, timezone
# Add the backend directory to Python path
backend_dir = Path(__file__).parent.parent
sys.path.insert(0, str(backend_dir))
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from loguru import logger
from models.subscription_models import SubscriptionPlan, SubscriptionTier, UserSubscription, UsageSummary, UsageStatus
from services.database import DATABASE_URL
from services.subscription import PricingService
def cap_basic_plan_usage():
"""Cap usage counters at new Basic plan limits."""
try:
engine = create_engine(DATABASE_URL, echo=False)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = SessionLocal()
try:
# Find Basic plan
basic_plan = db.query(SubscriptionPlan).filter(
SubscriptionPlan.tier == SubscriptionTier.BASIC
).first()
if not basic_plan:
logger.error("❌ Basic plan not found in database!")
return False
# New limits
new_call_limit = basic_plan.gemini_calls_limit # Should be 10
new_token_limit = basic_plan.gemini_tokens_limit # Should be 2000
new_image_limit = basic_plan.stability_calls_limit # Should be 5
logger.info(f"📋 Basic Plan Limits:")
logger.info(f" Calls: {new_call_limit}")
logger.info(f" Tokens: {new_token_limit}")
logger.info(f" Images: {new_image_limit}")
# Get all Basic plan users
user_subscriptions = db.query(UserSubscription).filter(
UserSubscription.plan_id == basic_plan.id,
UserSubscription.is_active == True
).all()
logger.info(f"\n👥 Found {len(user_subscriptions)} Basic plan user(s)")
pricing_service = PricingService(db)
capped_count = 0
for sub in user_subscriptions:
try:
# Get current billing period for this user
current_period = pricing_service.get_current_billing_period(sub.user_id) or datetime.now(timezone.utc).strftime("%Y-%m")
# Find usage summary for current period
usage_summary = db.query(UsageSummary).filter(
UsageSummary.user_id == sub.user_id,
UsageSummary.billing_period == current_period
).first()
if usage_summary:
# Store old values for logging
old_gemini = usage_summary.gemini_calls or 0
old_mistral = usage_summary.mistral_calls or 0
old_openai = usage_summary.openai_calls or 0
old_anthropic = usage_summary.anthropic_calls or 0
old_tokens = max(
usage_summary.gemini_tokens or 0,
usage_summary.openai_tokens or 0,
usage_summary.anthropic_tokens or 0,
usage_summary.mistral_tokens or 0
)
old_images = usage_summary.stability_calls or 0
# Check if capping is needed
needs_cap = (
old_gemini > new_call_limit or
old_mistral > new_call_limit or
old_openai > new_call_limit or
old_anthropic > new_call_limit or
old_images > new_image_limit or
old_tokens > new_token_limit
)
if needs_cap:
# Cap LLM provider counters at new limits
usage_summary.gemini_calls = min(old_gemini, new_call_limit)
usage_summary.mistral_calls = min(old_mistral, new_call_limit)
usage_summary.openai_calls = min(old_openai, new_call_limit)
usage_summary.anthropic_calls = min(old_anthropic, new_call_limit)
# Cap token counters at new limits
usage_summary.gemini_tokens = min(usage_summary.gemini_tokens or 0, new_token_limit)
usage_summary.openai_tokens = min(usage_summary.openai_tokens or 0, new_token_limit)
usage_summary.anthropic_tokens = min(usage_summary.anthropic_tokens or 0, new_token_limit)
usage_summary.mistral_tokens = min(usage_summary.mistral_tokens or 0, new_token_limit)
# Cap image counter at new limit
usage_summary.stability_calls = min(old_images, new_image_limit)
# Recalculate totals based on capped values
total_capped_calls = (
usage_summary.gemini_calls +
usage_summary.mistral_calls +
usage_summary.openai_calls +
usage_summary.anthropic_calls +
usage_summary.stability_calls
)
total_capped_tokens = (
usage_summary.gemini_tokens +
usage_summary.mistral_tokens +
usage_summary.openai_tokens +
usage_summary.anthropic_tokens
)
usage_summary.total_calls = total_capped_calls
usage_summary.total_tokens = total_capped_tokens
# Reset status to active to allow new calls
usage_summary.usage_status = UsageStatus.ACTIVE
usage_summary.updated_at = datetime.now(timezone.utc)
db.commit()
capped_count += 1
logger.info(f"\n✅ Capped usage for user {sub.user_id} (period {current_period}):")
logger.info(f" Gemini Calls: {old_gemini}{usage_summary.gemini_calls} (limit: {new_call_limit})")
logger.info(f" Mistral Calls: {old_mistral}{usage_summary.mistral_calls} (limit: {new_call_limit})")
logger.info(f" OpenAI Calls: {old_openai}{usage_summary.openai_calls} (limit: {new_call_limit})")
logger.info(f" Anthropic Calls: {old_anthropic}{usage_summary.anthropic_calls} (limit: {new_call_limit})")
logger.info(f" Tokens: {old_tokens}{max(usage_summary.gemini_tokens, usage_summary.mistral_tokens)} (limit: {new_token_limit})")
logger.info(f" Images: {old_images}{usage_summary.stability_calls} (limit: {new_image_limit})")
else:
logger.info(f" User {sub.user_id} usage is within limits - no capping needed")
else:
logger.info(f" No usage summary found for user {sub.user_id} (period {current_period})")
except Exception as cap_error:
logger.error(f" ❌ Error capping usage for user {sub.user_id}: {cap_error}")
import traceback
logger.error(traceback.format_exc())
db.rollback()
if capped_count > 0:
logger.info(f"\n✅ Successfully capped usage for {capped_count} user(s)")
logger.info(" Historical usage preserved, but capped at new limits")
logger.info(" Users can now make new calls within their limits")
else:
logger.info("\n No usage counters needed capping")
logger.info("\n" + "="*60)
logger.info("CAPPING COMPLETE")
logger.info("="*60)
return True
except Exception as e:
db.rollback()
logger.error(f"❌ Error capping usage: {e}")
import traceback
logger.error(traceback.format_exc())
raise
finally:
db.close()
except Exception as e:
logger.error(f"❌ Failed to connect to database: {e}")
import traceback
logger.error(traceback.format_exc())
return False
if __name__ == "__main__":
logger.info("🚀 Starting Basic plan usage capping...")
logger.info("="*60)
logger.info("This will cap usage counters at new Basic plan limits")
logger.info("while preserving historical usage data.")
logger.info("="*60)
try:
success = cap_basic_plan_usage()
if success:
logger.info("\n✅ Script completed successfully!")
sys.exit(0)
else:
logger.error("\n❌ Script failed!")
sys.exit(1)
except KeyboardInterrupt:
logger.info("\n⚠️ Script cancelled by user")
sys.exit(1)
except Exception as e:
logger.error(f"\n❌ Unexpected error: {e}")
sys.exit(1)

View File

@@ -0,0 +1,168 @@
"""
Quick script to reset usage counters for Basic plan users.
This fixes the issue where plan limits were updated but old usage data remained.
Resets all usage counters (calls, tokens, images) to 0 for the current billing period.
"""
import sys
import os
from pathlib import Path
from datetime import datetime, timezone
# Add the backend directory to Python path
backend_dir = Path(__file__).parent.parent
sys.path.insert(0, str(backend_dir))
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from loguru import logger
from models.subscription_models import SubscriptionPlan, SubscriptionTier, UserSubscription, UsageSummary, UsageStatus
from services.database import DATABASE_URL
from services.subscription import PricingService
def reset_basic_plan_usage():
"""Reset usage counters for all Basic plan users."""
try:
engine = create_engine(DATABASE_URL, echo=False)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = SessionLocal()
try:
# Find Basic plan
basic_plan = db.query(SubscriptionPlan).filter(
SubscriptionPlan.tier == SubscriptionTier.BASIC
).first()
if not basic_plan:
logger.error("❌ Basic plan not found in database!")
return False
# Get all Basic plan users
user_subscriptions = db.query(UserSubscription).filter(
UserSubscription.plan_id == basic_plan.id,
UserSubscription.is_active == True
).all()
logger.info(f"Found {len(user_subscriptions)} Basic plan user(s)")
pricing_service = PricingService(db)
reset_count = 0
for sub in user_subscriptions:
try:
# Get current billing period for this user
current_period = pricing_service.get_current_billing_period(sub.user_id) or datetime.now(timezone.utc).strftime("%Y-%m")
# Find usage summary for current period
usage_summary = db.query(UsageSummary).filter(
UsageSummary.user_id == sub.user_id,
UsageSummary.billing_period == current_period
).first()
if usage_summary:
# Store old values for logging
old_gemini = usage_summary.gemini_calls or 0
old_mistral = usage_summary.mistral_calls or 0
old_tokens = (usage_summary.mistral_tokens or 0) + (usage_summary.gemini_tokens or 0)
old_images = usage_summary.stability_calls or 0
old_total_calls = usage_summary.total_calls or 0
old_total_tokens = usage_summary.total_tokens or 0
# Reset all LLM provider counters
usage_summary.gemini_calls = 0
usage_summary.openai_calls = 0
usage_summary.anthropic_calls = 0
usage_summary.mistral_calls = 0
# Reset all token counters
usage_summary.gemini_tokens = 0
usage_summary.openai_tokens = 0
usage_summary.anthropic_tokens = 0
usage_summary.mistral_tokens = 0
# Reset image counter
usage_summary.stability_calls = 0
# Reset totals
usage_summary.total_calls = 0
usage_summary.total_tokens = 0
usage_summary.total_cost = 0.0
# Reset status to active
usage_summary.usage_status = UsageStatus.ACTIVE
usage_summary.updated_at = datetime.now(timezone.utc)
db.commit()
reset_count += 1
logger.info(f"\n✅ Reset usage for user {sub.user_id} (period {current_period}):")
logger.info(f" Calls: {old_gemini + old_mistral} (gemini: {old_gemini}, mistral: {old_mistral}) → 0")
logger.info(f" Tokens: {old_tokens} → 0")
logger.info(f" Images: {old_images} → 0")
logger.info(f" Total Calls: {old_total_calls} → 0")
logger.info(f" Total Tokens: {old_total_tokens} → 0")
else:
logger.info(f" No usage summary found for user {sub.user_id} (period {current_period}) - nothing to reset")
except Exception as reset_error:
logger.error(f" ❌ Error resetting usage for user {sub.user_id}: {reset_error}")
import traceback
logger.error(traceback.format_exc())
db.rollback()
if reset_count > 0:
logger.info(f"\n✅ Successfully reset usage counters for {reset_count} user(s)")
else:
logger.info("\n No usage counters to reset")
logger.info("\n" + "="*60)
logger.info("RESET COMPLETE")
logger.info("="*60)
logger.info("\n💡 Usage counters have been reset. Users can now use their new limits.")
logger.info(" Next API call will start counting from 0.")
return True
except Exception as e:
db.rollback()
logger.error(f"❌ Error resetting usage: {e}")
import traceback
logger.error(traceback.format_exc())
raise
finally:
db.close()
except Exception as e:
logger.error(f"❌ Failed to connect to database: {e}")
import traceback
logger.error(traceback.format_exc())
return False
if __name__ == "__main__":
logger.info("🚀 Starting Basic plan usage counter reset...")
logger.info("="*60)
logger.info("This will reset all usage counters (calls, tokens, images) to 0")
logger.info("for all Basic plan users in their current billing period.")
logger.info("="*60)
try:
success = reset_basic_plan_usage()
if success:
logger.info("\n✅ Script completed successfully!")
sys.exit(0)
else:
logger.error("\n❌ Script failed!")
sys.exit(1)
except KeyboardInterrupt:
logger.info("\n⚠️ Script cancelled by user")
sys.exit(1)
except Exception as e:
logger.error(f"\n❌ Unexpected error: {e}")
sys.exit(1)

View File

@@ -0,0 +1,279 @@
"""
Script to update Basic plan subscription limits for testing rate limits and renewal flows.
Updates:
- LLM Calls (all providers): 10 calls (was 500-1000)
- LLM Tokens (all providers): 2000 tokens (was 200k-1M)
- Images: 5 images (was 50)
This script updates the SubscriptionPlan table, which automatically applies to all users
who have a Basic plan subscription via the plan_id foreign key.
"""
import sys
import os
from pathlib import Path
from datetime import datetime, timezone
# Add the backend directory to Python path
backend_dir = Path(__file__).parent.parent
sys.path.insert(0, str(backend_dir))
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from loguru import logger
from models.subscription_models import SubscriptionPlan, SubscriptionTier, UserSubscription, UsageStatus
from services.database import DATABASE_URL
def update_basic_plan_limits():
"""Update Basic plan limits for testing rate limits and renewal."""
try:
engine = create_engine(DATABASE_URL, echo=False)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = SessionLocal()
try:
# Find Basic plan
basic_plan = db.query(SubscriptionPlan).filter(
SubscriptionPlan.tier == SubscriptionTier.BASIC
).first()
if not basic_plan:
logger.error("❌ Basic plan not found in database!")
return False
# Store old values for logging
old_limits = {
'gemini_calls': basic_plan.gemini_calls_limit,
'mistral_calls': basic_plan.mistral_calls_limit,
'gemini_tokens': basic_plan.gemini_tokens_limit,
'mistral_tokens': basic_plan.mistral_tokens_limit,
'stability_calls': basic_plan.stability_calls_limit,
}
logger.info(f"📋 Current Basic plan limits:")
logger.info(f" Gemini Calls: {old_limits['gemini_calls']}")
logger.info(f" Mistral Calls: {old_limits['mistral_calls']}")
logger.info(f" Gemini Tokens: {old_limits['gemini_tokens']}")
logger.info(f" Mistral Tokens: {old_limits['mistral_tokens']}")
logger.info(f" Images (Stability): {old_limits['stability_calls']}")
# Update unified AI text generation limit to 10
basic_plan.ai_text_generation_calls_limit = 10
# Legacy per-provider limits (kept for backwards compatibility, but not used for enforcement)
basic_plan.gemini_calls_limit = 1000
basic_plan.openai_calls_limit = 500
basic_plan.anthropic_calls_limit = 200
basic_plan.mistral_calls_limit = 500
# Update all LLM provider token limits to 2000
basic_plan.gemini_tokens_limit = 2000
basic_plan.openai_tokens_limit = 2000
basic_plan.anthropic_tokens_limit = 2000
basic_plan.mistral_tokens_limit = 2000
# Update image generation limit to 5
basic_plan.stability_calls_limit = 5
# Update timestamp
basic_plan.updated_at = datetime.now(timezone.utc)
logger.info("\n📝 New Basic plan limits:")
logger.info(f" LLM Calls (all providers): 10")
logger.info(f" LLM Tokens (all providers): 2000")
logger.info(f" Images: 5")
# Count and get affected users
user_subscriptions = db.query(UserSubscription).filter(
UserSubscription.plan_id == basic_plan.id,
UserSubscription.is_active == True
).all()
affected_users = len(user_subscriptions)
logger.info(f"\n👥 Users affected: {affected_users}")
if affected_users > 0:
logger.info("\n📋 Affected user IDs:")
for sub in user_subscriptions:
logger.info(f" - {sub.user_id}")
else:
logger.info(" (No active Basic plan subscriptions found)")
# Commit plan limit changes first
db.commit()
logger.info("\n✅ Basic plan limits updated successfully!")
# Cap usage at new limits for all affected users (preserve historical data, but cap enforcement)
logger.info("\n🔄 Capping usage counters at new limits for Basic plan users...")
logger.info(" (Historical usage preserved, but capped to allow new calls within limits)")
from models.subscription_models import UsageSummary
from services.subscription import PricingService
pricing_service = PricingService(db)
capped_count = 0
# New limits - use unified AI text generation limit if available
new_call_limit = getattr(basic_plan, 'ai_text_generation_calls_limit', None) or basic_plan.gemini_calls_limit
new_token_limit = basic_plan.gemini_tokens_limit # 2000
new_image_limit = basic_plan.stability_calls_limit # 5
for sub in user_subscriptions:
try:
# Get current billing period for this user
current_period = pricing_service.get_current_billing_period(sub.user_id) or datetime.now(timezone.utc).strftime("%Y-%m")
# Find usage summary for current period
usage_summary = db.query(UsageSummary).filter(
UsageSummary.user_id == sub.user_id,
UsageSummary.billing_period == current_period
).first()
if usage_summary:
# Store old values for logging
old_gemini = usage_summary.gemini_calls or 0
old_mistral = usage_summary.mistral_calls or 0
old_openai = usage_summary.openai_calls or 0
old_anthropic = usage_summary.anthropic_calls or 0
old_tokens = max(
usage_summary.gemini_tokens or 0,
usage_summary.openai_tokens or 0,
usage_summary.anthropic_tokens or 0,
usage_summary.mistral_tokens or 0
)
old_images = usage_summary.stability_calls or 0
# Cap LLM provider counters at new limits (don't reset, just cap)
# This allows historical data to remain but prevents blocking from old usage
usage_summary.gemini_calls = min(old_gemini, new_call_limit)
usage_summary.mistral_calls = min(old_mistral, new_call_limit)
usage_summary.openai_calls = min(old_openai, new_call_limit)
usage_summary.anthropic_calls = min(old_anthropic, new_call_limit)
# Cap token counters at new limits
usage_summary.gemini_tokens = min(usage_summary.gemini_tokens or 0, new_token_limit)
usage_summary.openai_tokens = min(usage_summary.openai_tokens or 0, new_token_limit)
usage_summary.anthropic_tokens = min(usage_summary.anthropic_tokens or 0, new_token_limit)
usage_summary.mistral_tokens = min(usage_summary.mistral_tokens or 0, new_token_limit)
# Cap image counter at new limit
usage_summary.stability_calls = min(old_images, new_image_limit)
# Update totals based on capped values (approximate)
# Recalculate total_calls and total_tokens based on capped provider values
total_capped_calls = (
usage_summary.gemini_calls +
usage_summary.mistral_calls +
usage_summary.openai_calls +
usage_summary.anthropic_calls +
usage_summary.stability_calls
)
total_capped_tokens = (
usage_summary.gemini_tokens +
usage_summary.mistral_tokens +
usage_summary.openai_tokens +
usage_summary.anthropic_tokens
)
usage_summary.total_calls = total_capped_calls
usage_summary.total_tokens = total_capped_tokens
# Reset status to active to allow new calls
usage_summary.usage_status = UsageStatus.ACTIVE
usage_summary.updated_at = datetime.now(timezone.utc)
db.commit()
capped_count += 1
logger.info(f" ✅ Capped usage for user {sub.user_id}:")
logger.info(f" Gemini Calls: {old_gemini}{usage_summary.gemini_calls} (limit: {new_call_limit})")
logger.info(f" Mistral Calls: {old_mistral}{usage_summary.mistral_calls} (limit: {new_call_limit})")
logger.info(f" Tokens: {old_tokens}{max(usage_summary.gemini_tokens, usage_summary.mistral_tokens)} (limit: {new_token_limit})")
logger.info(f" Images: {old_images}{usage_summary.stability_calls} (limit: {new_image_limit})")
else:
logger.info(f" No usage summary found for user {sub.user_id} (period {current_period})")
except Exception as cap_error:
logger.error(f" ❌ Error capping usage for user {sub.user_id}: {cap_error}")
import traceback
logger.error(traceback.format_exc())
db.rollback()
if capped_count > 0:
logger.info(f"\n✅ Capped usage counters for {capped_count} user(s)")
logger.info(" Historical usage preserved, but capped at new limits to allow new calls")
else:
logger.info("\n No usage counters to cap")
# Note about cache clearing
logger.info("\n🔄 Cache Information:")
logger.info(" The subscription limits cache is per-instance and will refresh on next request.")
logger.info(" No manual cache clearing needed - limits will be read from database on next check.")
# Display final summary
logger.info("\n" + "="*60)
logger.info("BASIC PLAN UPDATE SUMMARY")
logger.info("="*60)
logger.info(f"\nPlan: {basic_plan.name} ({basic_plan.tier.value})")
logger.info(f"Price: ${basic_plan.price_monthly}/mo, ${basic_plan.price_yearly}/yr")
logger.info(f"\nUpdated Limits:")
logger.info(f" LLM Calls (gemini/openai/anthropic/mistral): {basic_plan.gemini_calls_limit}")
logger.info(f" LLM Tokens (gemini/openai/anthropic/mistral): {basic_plan.gemini_tokens_limit}")
logger.info(f" Images (stability): {basic_plan.stability_calls_limit}")
logger.info(f"\nUsers Affected: {affected_users}")
logger.info("\n" + "="*60)
logger.info("\n💡 Note: These limits apply immediately to all Basic plan users.")
logger.info(" Historical usage has been preserved but capped at new limits.")
logger.info(" Users can continue making new calls up to the new limits.")
logger.info(" Users will hit rate limits faster for testing purposes.")
return True
except Exception as e:
db.rollback()
logger.error(f"❌ Error updating Basic plan: {e}")
import traceback
logger.error(traceback.format_exc())
raise
finally:
db.close()
except Exception as e:
logger.error(f"❌ Failed to connect to database: {e}")
import traceback
logger.error(traceback.format_exc())
return False
if __name__ == "__main__":
logger.info("🚀 Starting Basic plan limits update...")
logger.info("="*60)
logger.info("This will update Basic plan limits for testing rate limits:")
logger.info(" - LLM Calls: 10 (all providers)")
logger.info(" - LLM Tokens: 2000 (all providers)")
logger.info(" - Images: 5")
logger.info("="*60)
# Ask for confirmation in non-interactive mode, proceed directly
# In interactive mode, you can add: input("\nPress Enter to continue or Ctrl+C to cancel...")
try:
success = update_basic_plan_limits()
if success:
logger.info("\n✅ Script completed successfully!")
sys.exit(0)
else:
logger.error("\n❌ Script failed!")
sys.exit(1)
except KeyboardInterrupt:
logger.info("\n⚠️ Script cancelled by user")
sys.exit(1)
except Exception as e:
logger.error(f"\n❌ Unexpected error: {e}")
sys.exit(1)

View File

@@ -295,3 +295,55 @@ class ActiveStrategyService:
'cached_users': list(self._memory_cache.keys()), 'cached_users': list(self._memory_cache.keys()),
'last_updates': {k: v.isoformat() for k, v in self._last_cache_update.items()} 'last_updates': {k: v.isoformat() for k, v in self._last_cache_update.items()}
} }
def count_active_strategies_with_tasks(self) -> int:
"""
Count how many active strategies have monitoring tasks.
This is used for intelligent scheduling - if there are no active strategies
with tasks, the scheduler can check less frequently.
Returns:
Number of active strategies that have at least one active monitoring task
"""
try:
if not self.db_session:
logger.warning("Database session not available")
return 0
from sqlalchemy import func, and_
from models.monitoring_models import MonitoringTask
# Count distinct strategies that:
# 1. Have activation status = 'active'
# 2. Have at least one active monitoring task
count = self.db_session.query(
func.count(func.distinct(EnhancedContentStrategy.id))
).join(
StrategyActivationStatus,
EnhancedContentStrategy.id == StrategyActivationStatus.strategy_id
).join(
MonitoringTask,
EnhancedContentStrategy.id == MonitoringTask.strategy_id
).filter(
and_(
StrategyActivationStatus.status == 'active',
MonitoringTask.status == 'active'
)
).scalar()
return count or 0
except Exception as e:
logger.error(f"Error counting active strategies with tasks: {e}")
# On error, assume there are active strategies (safer to check more frequently)
return 1
def has_active_strategies_with_tasks(self) -> bool:
"""
Check if there are any active strategies with monitoring tasks.
Returns:
True if there are active strategies with tasks, False otherwise
"""
return self.count_active_strategies_with_tasks() > 0

View File

@@ -96,13 +96,13 @@ class BlogWriterService:
self.blog_rewriter = BlogRewriter(self.task_manager) self.blog_rewriter = BlogRewriter(self.task_manager)
# Research Methods # Research Methods
async def research(self, request: BlogResearchRequest) -> BlogResearchResponse: async def research(self, request: BlogResearchRequest, user_id: str) -> BlogResearchResponse:
"""Conduct comprehensive research using Google Search grounding.""" """Conduct comprehensive research using Google Search grounding."""
return await self.research_service.research(request) return await self.research_service.research(request, user_id)
async def research_with_progress(self, request: BlogResearchRequest, task_id: str) -> BlogResearchResponse: async def research_with_progress(self, request: BlogResearchRequest, task_id: str, user_id: str) -> BlogResearchResponse:
"""Conduct research with real-time progress updates.""" """Conduct research with real-time progress updates."""
return await self.research_service.research_with_progress(request, task_id) return await self.research_service.research_with_progress(request, task_id, user_id)
# Outline Methods # Outline Methods
async def generate_outline(self, request: BlogOutlineRequest) -> BlogOutlineResponse: async def generate_outline(self, request: BlogOutlineRequest) -> BlogOutlineResponse:
@@ -204,11 +204,14 @@ class BlogWriterService:
except Exception as e: except Exception as e:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
async def seo_analyze(self, request: BlogSEOAnalyzeRequest) -> BlogSEOAnalyzeResponse: async def seo_analyze(self, request: BlogSEOAnalyzeRequest, user_id: str = None) -> BlogSEOAnalyzeResponse:
"""Analyze content for SEO optimization using comprehensive blog-specific analyzer.""" """Analyze content for SEO optimization using comprehensive blog-specific analyzer."""
try: try:
from services.blog_writer.seo.blog_content_seo_analyzer import BlogContentSEOAnalyzer from services.blog_writer.seo.blog_content_seo_analyzer import BlogContentSEOAnalyzer
if not user_id:
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
content = request.content or "" content = request.content or ""
target_keywords = request.keywords or [] target_keywords = request.keywords or []
@@ -231,7 +234,7 @@ class BlogWriterService:
# Use our comprehensive SEO analyzer # Use our comprehensive SEO analyzer
analyzer = BlogContentSEOAnalyzer() analyzer = BlogContentSEOAnalyzer()
analysis_results = await analyzer.analyze_blog_content(content, research_data) analysis_results = await analyzer.analyze_blog_content(content, research_data, user_id=user_id)
# Convert results to response format # Convert results to response format
recommendations = analysis_results.get('actionable_recommendations', []) recommendations = analysis_results.get('actionable_recommendations', [])
@@ -267,11 +270,14 @@ class BlogWriterService:
recommendations=[f"SEO analysis failed: {str(e)}"] recommendations=[f"SEO analysis failed: {str(e)}"]
) )
async def seo_metadata(self, request: BlogSEOMetadataRequest) -> BlogSEOMetadataResponse: async def seo_metadata(self, request: BlogSEOMetadataRequest, user_id: str = None) -> BlogSEOMetadataResponse:
"""Generate comprehensive SEO metadata for content.""" """Generate comprehensive SEO metadata for content."""
try: try:
from services.blog_writer.seo.blog_seo_metadata_generator import BlogSEOMetadataGenerator from services.blog_writer.seo.blog_seo_metadata_generator import BlogSEOMetadataGenerator
if not user_id:
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
# Initialize metadata generator # Initialize metadata generator
metadata_generator = BlogSEOMetadataGenerator() metadata_generator = BlogSEOMetadataGenerator()
@@ -285,7 +291,8 @@ class BlogWriterService:
blog_title=request.title or "Untitled Blog Post", blog_title=request.title or "Untitled Blog Post",
research_data=request.research_data or {}, research_data=request.research_data or {},
outline=outline, outline=outline,
seo_analysis=seo_analysis seo_analysis=seo_analysis,
user_id=user_id
) )
# Convert to BlogSEOMetadataResponse format # Convert to BlogSEOMetadataResponse format

View File

@@ -163,13 +163,18 @@ class BlogWriterLogger:
context: Optional[Dict[str, Any]] = None context: Optional[Dict[str, Any]] = None
): ):
"""Log error with full context.""" """Log error with full context."""
# Safely format error message to avoid KeyError on format strings in error messages
error_str = str(error)
# Replace any curly braces that might be in the error message to avoid format string issues
safe_error_str = error_str.replace('{', '{{').replace('}', '}}')
logger.error( logger.error(
f"Error in {operation}: {str(error)}", f"Error in {operation}: {safe_error_str}",
extra={ extra={
"event_type": "error", "event_type": "error",
"operation": operation, "operation": operation,
"error_type": type(error).__name__, "error_type": type(error).__name__,
"error_message": str(error), "error_message": error_str, # Keep original in extra, but use safe version in format string
"context": context or {} "context": context or {}
}, },
exc_info=True exc_info=True

View File

@@ -11,7 +11,7 @@ from loguru import logger
class CompetitorAnalyzer: class CompetitorAnalyzer:
"""Analyzes competitors and market intelligence from research content.""" """Analyzes competitors and market intelligence from research content."""
def analyze(self, content: str) -> Dict[str, Any]: def analyze(self, content: str, user_id: str = None) -> Dict[str, Any]:
"""Parse comprehensive competitor analysis from the research content using AI.""" """Parse comprehensive competitor analysis from the research content using AI."""
competitor_prompt = f""" competitor_prompt = f"""
Analyze the following research content and extract competitor insights: Analyze the following research content and extract competitor insights:
@@ -57,7 +57,8 @@ class CompetitorAnalyzer:
competitor_analysis = llm_text_gen( competitor_analysis = llm_text_gen(
prompt=competitor_prompt, prompt=competitor_prompt,
json_struct=competitor_schema json_struct=competitor_schema,
user_id=user_id
) )
if isinstance(competitor_analysis, dict) and 'error' not in competitor_analysis: if isinstance(competitor_analysis, dict) and 'error' not in competitor_analysis:

View File

@@ -11,7 +11,7 @@ from loguru import logger
class ContentAngleGenerator: class ContentAngleGenerator:
"""Generates strategic content angles from research content.""" """Generates strategic content angles from research content."""
def generate(self, content: str, topic: str, industry: str) -> List[str]: def generate(self, content: str, topic: str, industry: str, user_id: str = None) -> List[str]:
"""Parse strategic content angles from the research content using AI.""" """Parse strategic content angles from the research content using AI."""
angles_prompt = f""" angles_prompt = f"""
Analyze the following research content and create strategic content angles for: {topic} in {industry} Analyze the following research content and create strategic content angles for: {topic} in {industry}
@@ -65,7 +65,8 @@ class ContentAngleGenerator:
angles_result = llm_text_gen( angles_result = llm_text_gen(
prompt=angles_prompt, prompt=angles_prompt,
json_struct=angles_schema json_struct=angles_schema,
user_id=user_id
) )
if isinstance(angles_result, dict) and 'content_angles' in angles_result: if isinstance(angles_result, dict) and 'content_angles' in angles_result:

View File

@@ -11,7 +11,7 @@ from loguru import logger
class KeywordAnalyzer: class KeywordAnalyzer:
"""Analyzes keywords from research content using AI-powered extraction.""" """Analyzes keywords from research content using AI-powered extraction."""
def analyze(self, content: str, original_keywords: List[str]) -> Dict[str, Any]: def analyze(self, content: str, original_keywords: List[str], user_id: str = None) -> Dict[str, Any]:
"""Parse comprehensive keyword analysis from the research content using AI.""" """Parse comprehensive keyword analysis from the research content using AI."""
# Use AI to extract and analyze keywords from the rich research content # Use AI to extract and analyze keywords from the rich research content
keyword_prompt = f""" keyword_prompt = f"""
@@ -64,7 +64,8 @@ class KeywordAnalyzer:
keyword_analysis = llm_text_gen( keyword_analysis = llm_text_gen(
prompt=keyword_prompt, prompt=keyword_prompt,
json_struct=keyword_schema json_struct=keyword_schema,
user_id=user_id
) )
if isinstance(keyword_analysis, dict) and 'error' not in keyword_analysis: if isinstance(keyword_analysis, dict) and 'error' not in keyword_analysis:

View File

@@ -4,7 +4,8 @@ Research Service - Core research functionality for AI Blog Writer.
Handles Google Search grounding, caching, and research orchestration. Handles Google Search grounding, caching, and research orchestration.
""" """
from typing import Dict, Any, List from typing import Dict, Any, List, Optional
from datetime import datetime
from loguru import logger from loguru import logger
from models.blog_models import ( from models.blog_models import (
@@ -17,6 +18,7 @@ from models.blog_models import (
Citation, Citation,
) )
from services.blog_writer.logger_config import blog_writer_logger, log_function_call from services.blog_writer.logger_config import blog_writer_logger, log_function_call
from fastapi import HTTPException
from .keyword_analyzer import KeywordAnalyzer from .keyword_analyzer import KeywordAnalyzer
from .competitor_analyzer import CompetitorAnalyzer from .competitor_analyzer import CompetitorAnalyzer
@@ -34,7 +36,7 @@ class ResearchService:
self.data_filter = ResearchDataFilter() self.data_filter = ResearchDataFilter()
@log_function_call("research_operation") @log_function_call("research_operation")
async def research(self, request: BlogResearchRequest) -> BlogResearchResponse: async def research(self, request: BlogResearchRequest, user_id: str) -> BlogResearchResponse:
""" """
Stage 1: Research & Strategy (AI Orchestration) Stage 1: Research & Strategy (AI Orchestration)
Uses ONLY Gemini's native Google Search grounding - ONE API call for everything. Uses ONLY Gemini's native Google Search grounding - ONE API call for everything.
@@ -71,6 +73,10 @@ class ResearchService:
blog_writer_logger.log_operation_end("research", 0, success=True, cache_hit=True) blog_writer_logger.log_operation_end("research", 0, success=True, cache_hit=True)
return BlogResearchResponse(**cached_result) return BlogResearchResponse(**cached_result)
# User ID validation (validation logic is now in Google Grounding provider)
if not user_id:
raise ValueError("user_id is required for research operation. Please provide Clerk user ID.")
# Cache miss - proceed with API call # Cache miss - proceed with API call
logger.info(f"Cache miss - making API call for keywords: {request.keywords}") logger.info(f"Cache miss - making API call for keywords: {request.keywords}")
blog_writer_logger.log_operation_start("gemini_api_call", api_name="gemini_grounded", operation="research") blog_writer_logger.log_operation_start("gemini_api_call", api_name="gemini_grounded", operation="research")
@@ -96,12 +102,15 @@ class ResearchService:
""" """
# Single Gemini call with native Google Search grounding - no fallbacks # Single Gemini call with native Google Search grounding - no fallbacks
# Validation is handled inside generate_grounded_content when validate_subsequent_operations=True
import time import time
api_start_time = time.time() api_start_time = time.time()
gemini_result = await gemini.generate_grounded_content( gemini_result = await gemini.generate_grounded_content(
prompt=research_prompt, prompt=research_prompt,
content_type="research", content_type="research",
max_tokens=2000 max_tokens=2000,
user_id=user_id,
validate_subsequent_operations=True # Validates Google Grounding + 3 LLM calls
) )
api_duration_ms = (time.time() - api_start_time) * 1000 api_duration_ms = (time.time() - api_start_time) * 1000
@@ -126,9 +135,9 @@ class ResearchService:
# Parse the comprehensive response for different analysis components # Parse the comprehensive response for different analysis components
content = gemini_result.get("content", "") content = gemini_result.get("content", "")
keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords) keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords, user_id=user_id)
competitor_analysis = self.competitor_analyzer.analyze(content) competitor_analysis = self.competitor_analyzer.analyze(content, user_id=user_id)
suggested_angles = self.content_angle_generator.generate(content, topic, industry) suggested_angles = self.content_angle_generator.generate(content, topic, industry, user_id=user_id)
logger.info(f"Research completed successfully with {len(sources)} sources and {len(search_queries)} search queries") logger.info(f"Research completed successfully with {len(sources)} sources and {len(search_queries)} search queries")
@@ -179,6 +188,9 @@ class ResearchService:
return filtered_response return filtered_response
except HTTPException:
# Re-raise HTTPException (subscription errors) - let task manager handle it
raise
except Exception as e: except Exception as e:
error_message = str(e) error_message = str(e)
logger.error(f"Research failed: {error_message}") logger.error(f"Research failed: {error_message}")
@@ -244,7 +256,7 @@ class ResearchService:
) )
@log_function_call("research_with_progress") @log_function_call("research_with_progress")
async def research_with_progress(self, request: BlogResearchRequest, task_id: str) -> BlogResearchResponse: async def research_with_progress(self, request: BlogResearchRequest, task_id: str, user_id: str) -> BlogResearchResponse:
""" """
Research method with progress updates for real-time feedback. Research method with progress updates for real-time feedback.
""" """
@@ -281,6 +293,11 @@ class ResearchService:
logger.info(f"Returning cached research result for keywords: {request.keywords}") logger.info(f"Returning cached research result for keywords: {request.keywords}")
return BlogResearchResponse(**cached_result) return BlogResearchResponse(**cached_result)
# User ID validation (validation logic is now in Google Grounding provider)
if not user_id:
await task_manager.update_progress(task_id, "❌ Error: User ID is required for research operation")
raise ValueError("user_id is required for research operation. Please provide Clerk user ID.")
# Cache miss - proceed with API call # Cache miss - proceed with API call
await task_manager.update_progress(task_id, "🌐 Cache miss - connecting to Google Search grounding...") await task_manager.update_progress(task_id, "🌐 Cache miss - connecting to Google Search grounding...")
logger.info(f"Cache miss - making API call for keywords: {request.keywords}") logger.info(f"Cache miss - making API call for keywords: {request.keywords}")
@@ -307,11 +324,20 @@ class ResearchService:
await task_manager.update_progress(task_id, "🤖 Making AI request to Gemini with Google Search grounding...") await task_manager.update_progress(task_id, "🤖 Making AI request to Gemini with Google Search grounding...")
# Single Gemini call with native Google Search grounding - no fallbacks # Single Gemini call with native Google Search grounding - no fallbacks
gemini_result = await gemini.generate_grounded_content( # Validation is handled inside generate_grounded_content when validate_subsequent_operations=True
prompt=research_prompt, try:
content_type="research", gemini_result = await gemini.generate_grounded_content(
max_tokens=2000 prompt=research_prompt,
) content_type="research",
max_tokens=2000,
user_id=user_id,
validate_subsequent_operations=True # Validates Google Grounding + 3 LLM calls
)
except HTTPException as http_error:
# Re-raise HTTPException so it can be properly handled by task manager
logger.error(f"Subscription limit exceeded for research: {http_error.detail}")
await task_manager.update_progress(task_id, f"❌ Subscription limit exceeded: {http_error.detail.get('message', str(http_error.detail)) if isinstance(http_error.detail, dict) else str(http_error.detail)}")
raise # Re-raise HTTPException to preserve status code and error details
await task_manager.update_progress(task_id, "📊 Processing research results and extracting insights...") await task_manager.update_progress(task_id, "📊 Processing research results and extracting insights...")
# Extract sources from grounding metadata # Extract sources from grounding metadata
@@ -327,9 +353,9 @@ class ResearchService:
await task_manager.update_progress(task_id, "🔍 Analyzing keywords and content angles...") await task_manager.update_progress(task_id, "🔍 Analyzing keywords and content angles...")
# Parse the comprehensive response for different analysis components # Parse the comprehensive response for different analysis components
content = gemini_result.get("content", "") content = gemini_result.get("content", "")
keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords) keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords, user_id=user_id)
competitor_analysis = self.competitor_analyzer.analyze(content) competitor_analysis = self.competitor_analyzer.analyze(content, user_id=user_id)
suggested_angles = self.content_angle_generator.generate(content, topic, industry) suggested_angles = self.content_angle_generator.generate(content, topic, industry, user_id=user_id)
await task_manager.update_progress(task_id, "💾 Caching results for future use...") await task_manager.update_progress(task_id, "💾 Caching results for future use...")
logger.info(f"Research completed successfully with {len(sources)} sources and {len(search_queries)} search queries") logger.info(f"Research completed successfully with {len(sources)} sources and {len(search_queries)} search queries")
@@ -373,6 +399,9 @@ class ResearchService:
return filtered_response return filtered_response
except HTTPException:
# Re-raise HTTPException (subscription errors) - let task manager handle it
raise
except Exception as e: except Exception as e:
error_message = str(e) error_message = str(e)
logger.error(f"Research failed: {error_message}") logger.error(f"Research failed: {error_message}")

View File

@@ -34,17 +34,21 @@ class BlogContentSEOAnalyzer:
logger.info("BlogContentSEOAnalyzer initialized") logger.info("BlogContentSEOAnalyzer initialized")
async def analyze_blog_content(self, blog_content: str, research_data: Dict[str, Any], blog_title: Optional[str] = None) -> Dict[str, Any]: async def analyze_blog_content(self, blog_content: str, research_data: Dict[str, Any], blog_title: Optional[str] = None, user_id: str = None) -> Dict[str, Any]:
""" """
Main analysis method with parallel processing Main analysis method with parallel processing
Args: Args:
blog_content: The blog content to analyze blog_content: The blog content to analyze
research_data: Research data containing keywords and other insights research_data: Research data containing keywords and other insights
blog_title: Optional blog title
user_id: Clerk user ID for subscription checking (required)
Returns: Returns:
Comprehensive SEO analysis results Comprehensive SEO analysis results
""" """
if not user_id:
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
try: try:
logger.info("Starting blog content SEO analysis") logger.info("Starting blog content SEO analysis")
@@ -58,7 +62,7 @@ class BlogContentSEOAnalyzer:
# Phase 2: Single AI analysis for structured insights # Phase 2: Single AI analysis for structured insights
logger.info("Running AI analysis") logger.info("Running AI analysis")
ai_insights = await self._run_ai_analysis(blog_content, keywords_data, non_ai_results) ai_insights = await self._run_ai_analysis(blog_content, keywords_data, non_ai_results, user_id=user_id)
# Phase 3: Compile and format results # Phase 3: Compile and format results
logger.info("Compiling results") logger.info("Compiling results")
@@ -599,8 +603,10 @@ class BlogContentSEOAnalyzer:
return recommendations return recommendations
async def _run_ai_analysis(self, blog_content: str, keywords_data: Dict[str, Any], non_ai_results: Dict[str, Any]) -> Dict[str, Any]: async def _run_ai_analysis(self, blog_content: str, keywords_data: Dict[str, Any], non_ai_results: Dict[str, Any], user_id: str = None) -> Dict[str, Any]:
"""Run single AI analysis for structured insights (provider-agnostic)""" """Run single AI analysis for structured insights (provider-agnostic)"""
if not user_id:
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
try: try:
# Prepare context for AI analysis # Prepare context for AI analysis
context = { context = {
@@ -658,7 +664,8 @@ class BlogContentSEOAnalyzer:
ai_response = llm_text_gen( ai_response = llm_text_gen(
prompt=prompt, prompt=prompt,
json_struct=schema, json_struct=schema,
system_prompt=None system_prompt=None,
user_id=user_id # Pass user_id for subscription checking
) )
return ai_response return ai_response

View File

@@ -28,7 +28,8 @@ class BlogSEOMetadataGenerator:
blog_title: str, blog_title: str,
research_data: Dict[str, Any], research_data: Dict[str, Any],
outline: Optional[List[Dict[str, Any]]] = None, outline: Optional[List[Dict[str, Any]]] = None,
seo_analysis: Optional[Dict[str, Any]] = None seo_analysis: Optional[Dict[str, Any]] = None,
user_id: str = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Generate comprehensive SEO metadata using maximum 2 AI calls Generate comprehensive SEO metadata using maximum 2 AI calls
@@ -39,10 +40,13 @@ class BlogSEOMetadataGenerator:
research_data: Research data containing keywords and insights research_data: Research data containing keywords and insights
outline: Outline structure with sections and headings outline: Outline structure with sections and headings
seo_analysis: SEO analysis results from previous phase seo_analysis: SEO analysis results from previous phase
user_id: Clerk user ID for subscription checking (required)
Returns: Returns:
Comprehensive metadata including all SEO elements Comprehensive metadata including all SEO elements
""" """
if not user_id:
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
try: try:
logger.info("Starting comprehensive SEO metadata generation") logger.info("Starting comprehensive SEO metadata generation")
@@ -53,13 +57,13 @@ class BlogSEOMetadataGenerator:
# Call 1: Generate core SEO metadata (parallel with Call 2) # Call 1: Generate core SEO metadata (parallel with Call 2)
logger.info("Generating core SEO metadata") logger.info("Generating core SEO metadata")
core_metadata_task = self._generate_core_metadata( core_metadata_task = self._generate_core_metadata(
blog_content, blog_title, keywords_data, outline, seo_analysis blog_content, blog_title, keywords_data, outline, seo_analysis, user_id=user_id
) )
# Call 2: Generate social media and structured data (parallel with Call 1) # Call 2: Generate social media and structured data (parallel with Call 1)
logger.info("Generating social media and structured data") logger.info("Generating social media and structured data")
social_metadata_task = self._generate_social_metadata( social_metadata_task = self._generate_social_metadata(
blog_content, blog_title, keywords_data, outline, seo_analysis blog_content, blog_title, keywords_data, outline, seo_analysis, user_id=user_id
) )
# Wait for both calls to complete # Wait for both calls to complete
@@ -114,9 +118,12 @@ class BlogSEOMetadataGenerator:
blog_title: str, blog_title: str,
keywords_data: Dict[str, Any], keywords_data: Dict[str, Any],
outline: Optional[List[Dict[str, Any]]] = None, outline: Optional[List[Dict[str, Any]]] = None,
seo_analysis: Optional[Dict[str, Any]] = None seo_analysis: Optional[Dict[str, Any]] = None,
user_id: str = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Generate core SEO metadata (Call 1)""" """Generate core SEO metadata (Call 1)"""
if not user_id:
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
try: try:
# Create comprehensive prompt for core metadata # Create comprehensive prompt for core metadata
prompt = self._create_core_metadata_prompt( prompt = self._create_core_metadata_prompt(
@@ -170,7 +177,8 @@ class BlogSEOMetadataGenerator:
ai_response_raw = llm_text_gen( ai_response_raw = llm_text_gen(
prompt=prompt, prompt=prompt,
json_struct=schema, json_struct=schema,
system_prompt=None system_prompt=None,
user_id=user_id # Pass user_id for subscription checking
) )
# Handle response: llm_text_gen may return dict (from structured JSON) or str (needs parsing) # Handle response: llm_text_gen may return dict (from structured JSON) or str (needs parsing)
@@ -215,9 +223,12 @@ class BlogSEOMetadataGenerator:
blog_title: str, blog_title: str,
keywords_data: Dict[str, Any], keywords_data: Dict[str, Any],
outline: Optional[List[Dict[str, Any]]] = None, outline: Optional[List[Dict[str, Any]]] = None,
seo_analysis: Optional[Dict[str, Any]] = None seo_analysis: Optional[Dict[str, Any]] = None,
user_id: str = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Generate social media and structured data (Call 2)""" """Generate social media and structured data (Call 2)"""
if not user_id:
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
try: try:
# Create comprehensive prompt for social metadata # Create comprehensive prompt for social metadata
prompt = self._create_social_metadata_prompt( prompt = self._create_social_metadata_prompt(
@@ -274,7 +285,8 @@ class BlogSEOMetadataGenerator:
ai_response_raw = llm_text_gen( ai_response_raw = llm_text_gen(
prompt=prompt, prompt=prompt,
json_struct=schema, json_struct=schema,
system_prompt=None system_prompt=None,
user_id=user_id # Pass user_id for subscription checking
) )
# Handle response: llm_text_gen may return dict (from structured JSON) or str (needs parsing) # Handle response: llm_text_gen may return dict (from structured JSON) or str (needs parsing)

View File

@@ -20,8 +20,11 @@ class BlogSEORecommendationApplier:
def __init__(self): def __init__(self):
logger.debug("Initialized BlogSEORecommendationApplier") logger.debug("Initialized BlogSEORecommendationApplier")
async def apply_recommendations(self, payload: Dict[str, Any]) -> Dict[str, Any]: async def apply_recommendations(self, payload: Dict[str, Any], user_id: str = None) -> Dict[str, Any]:
"""Apply recommendations and return updated content.""" """Apply recommendations and return updated content."""
if not user_id:
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
title = payload.get("title", "Untitled Blog") title = payload.get("title", "Untitled Blog")
sections: List[Dict[str, Any]] = payload.get("sections", []) sections: List[Dict[str, Any]] = payload.get("sections", [])
@@ -88,6 +91,7 @@ class BlogSEORecommendationApplier:
prompt, prompt,
None, None,
schema, schema,
user_id, # Pass user_id for subscription checking
) )
if not result or result.get("error"): if not result or result.get("error"):

View File

@@ -56,7 +56,9 @@ class GeminiGroundedProvider:
temperature: float = 0.7, temperature: float = 0.7,
max_tokens: int = 2048, max_tokens: int = 2048,
urls: Optional[List[str]] = None, urls: Optional[List[str]] = None,
mode: str = "polished" mode: str = "polished",
user_id: Optional[str] = None,
validate_subsequent_operations: bool = False
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Generate grounded content using native Google Search grounding. Generate grounded content using native Google Search grounding.
@@ -66,12 +68,49 @@ class GeminiGroundedProvider:
content_type: Type of content to generate content_type: Type of content to generate
temperature: Creativity level (0.0-1.0) temperature: Creativity level (0.0-1.0)
max_tokens: Maximum tokens in response max_tokens: Maximum tokens in response
urls: Optional list of URLs for URL Context tool
mode: Content mode ("draft" or "polished")
user_id: User ID for subscription checking (required if validate_subsequent_operations=True)
validate_subsequent_operations: If True, validates Google Grounding + 3 LLM calls for research workflow
Returns: Returns:
Dictionary containing generated content and grounding metadata Dictionary containing generated content and grounding metadata
""" """
try: try:
logger.info(f"Generating grounded content for {content_type} using native Google Search") # PRE-FLIGHT VALIDATION: If this is part of a research workflow, validate ALL operations
# MUST happen BEFORE any API calls - return immediately if validation fails
if validate_subsequent_operations:
if not user_id:
raise ValueError("user_id is required when validate_subsequent_operations=True")
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_research_operations
from fastapi import HTTPException
import os
db = next(get_db())
try:
pricing_service = PricingService(db)
gpt_provider = os.getenv("GPT_PROVIDER", "google")
# Validate ALL research operations before making ANY API calls
# This prevents wasteful external API calls if subsequent LLM calls would fail
# Raises HTTPException immediately if validation fails - frontend gets immediate response
validate_research_operations(
pricing_service=pricing_service,
user_id=user_id,
gpt_provider=gpt_provider
)
except HTTPException as http_ex:
# Re-raise immediately - don't proceed with API call
logger.error(f"[Gemini Grounded] ❌ Pre-flight validation failed - blocking API call")
raise
finally:
db.close()
logger.info(f"[Gemini Grounded] ✅ Pre-flight validation passed - proceeding with API call")
logger.info(f"[Gemini Grounded] Generating grounded content for {content_type} using native Google Search")
# Build the grounded prompt # Build the grounded prompt
grounded_prompt = self._build_grounded_prompt(prompt, content_type) grounded_prompt = self._build_grounded_prompt(prompt, content_type)

View File

@@ -40,7 +40,38 @@ def _get_provider(provider_name: str):
raise ValueError(f"Unknown image provider: {provider_name}") raise ValueError(f"Unknown image provider: {provider_name}")
def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None) -> ImageGenerationResult: def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> ImageGenerationResult:
"""Generate image with pre-flight validation.
Args:
prompt: Image generation prompt
options: Image generation options (provider, model, width, height, etc.)
user_id: User ID for subscription checking (optional, but required for validation)
"""
# PRE-FLIGHT VALIDATION: Validate image generation before API call
# MUST happen BEFORE any API calls - return immediately if validation fails
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_generation_operations
from fastapi import HTTPException
db = next(get_db())
try:
pricing_service = PricingService(db)
# Raises HTTPException immediately if validation fails - frontend gets immediate response
validate_image_generation_operations(
pricing_service=pricing_service,
user_id=user_id
)
except HTTPException as http_ex:
# Re-raise immediately - don't proceed with API call
logger.error(f"[Image Generation] ❌ Pre-flight validation failed - blocking API call")
raise
finally:
db.close()
logger.info(f"[Image Generation] ✅ Pre-flight validation passed - proceeding with image generation")
opts = options or {} opts = options or {}
provider_name = _select_provider(opts.get("provider")) provider_name = _select_provider(opts.get("provider"))

View File

@@ -7,6 +7,7 @@ migrated from the legacy lib/gpt_providers/text_generation/main_text_generation.
import os import os
import json import json
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
from datetime import datetime
from loguru import logger from loguru import logger
from ..onboarding.api_key_manager import APIKeyManager from ..onboarding.api_key_manager import APIKeyManager
@@ -14,7 +15,7 @@ from .gemini_provider import gemini_text_response, gemini_structured_json_respon
from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response
def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: Optional[Dict[str, Any]] = None) -> str: def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: Optional[Dict[str, Any]] = None, user_id: str = None) -> str:
""" """
Generate text using Language Model (LLM) based on the provided prompt. Generate text using Language Model (LLM) based on the provided prompt.
@@ -22,9 +23,13 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
prompt (str): The prompt to generate text from. prompt (str): The prompt to generate text from.
system_prompt (str, optional): Custom system prompt to use instead of the default one. system_prompt (str, optional): Custom system prompt to use instead of the default one.
json_struct (dict, optional): JSON schema structure for structured responses. json_struct (dict, optional): JSON schema structure for structured responses.
user_id (str): Clerk user ID for subscription checking (required).
Returns: Returns:
str: Generated text based on the prompt. str: Generated text based on the prompt.
Raises:
RuntimeError: If subscription limits are exceeded or user_id is missing.
""" """
try: try:
logger.info("[llm_text_gen] Starting text generation") logger.info("[llm_text_gen] Starting text generation")
@@ -93,6 +98,75 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
logger.debug(f"[llm_text_gen] Using provider: {gpt_provider}, model: {model}") logger.debug(f"[llm_text_gen] Using provider: {gpt_provider}, model: {model}")
# Map provider name to APIProvider enum (define at function scope for usage tracking)
from models.subscription_models import APIProvider
provider_enum = None
# Store actual provider name for logging (e.g., "huggingface", "gemini")
actual_provider_name = None
if gpt_provider == "google":
provider_enum = APIProvider.GEMINI
actual_provider_name = "gemini" # Use "gemini" for consistency in logs
elif gpt_provider == "huggingface":
provider_enum = APIProvider.MISTRAL # HuggingFace maps to Mistral enum for usage tracking
actual_provider_name = "huggingface" # Keep actual provider name for logs
if not provider_enum:
raise RuntimeError(f"Unknown provider {gpt_provider} for subscription checking")
# SUBSCRIPTION CHECK - Required and strict enforcement
if not user_id:
raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.")
try:
from services.database import get_db
from services.subscription import UsageTrackingService, PricingService
from models.subscription_models import UsageSummary
db = next(get_db())
try:
usage_service = UsageTrackingService(db)
pricing_service = PricingService(db)
# Estimate tokens from prompt (input tokens)
# Note: We estimate output tokens conservatively (assume response is similar length to prompt)
# This prevents underestimating total token usage
input_tokens = int(len(prompt.split()) * 1.3)
# Conservative estimate: assume output tokens ≈ input tokens * 1.0 (can be up to max_tokens)
estimated_output_tokens = min(input_tokens, max_tokens) if max_tokens else int(input_tokens * 0.8)
estimated_total_tokens = input_tokens + estimated_output_tokens
# Check limits using sync method from pricing service (strict enforcement)
can_proceed, message, usage_info = pricing_service.check_usage_limits(
user_id=user_id,
provider=provider_enum,
tokens_requested=estimated_total_tokens,
actual_provider_name=actual_provider_name # Pass actual provider name for correct error messages
)
if not can_proceed:
logger.warning(f"[llm_text_gen] Subscription limit exceeded for user {user_id}: {message}")
raise RuntimeError(f"Subscription limit exceeded: {message}")
# Get current usage for limit checking only
current_period = pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
usage = db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
# No separate log here - we'll create unified log after API call and usage tracking
finally:
db.close()
except RuntimeError:
# Re-raise subscription limit errors
raise
except Exception as sub_error:
# STRICT: Fail on subscription check errors
logger.error(f"[llm_text_gen] Subscription check failed for user {user_id}: {sub_error}")
raise RuntimeError(f"Subscription check failed: {str(sub_error)}")
# Construct the system prompt if not provided # Construct the system prompt if not provided
if system_prompt is None: if system_prompt is None:
system_instructions = f"""You are a highly skilled content writer with a knack for creating engaging and informative content. system_instructions = f"""You are a highly skilled content writer with a knack for creating engaging and informative content.
@@ -117,10 +191,12 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
system_instructions = system_prompt system_instructions = system_prompt
# Generate response based on provider # Generate response based on provider
response_text = None
actual_provider_used = gpt_provider
try: try:
if gpt_provider == "google": if gpt_provider == "google":
if json_struct: if json_struct:
return gemini_structured_json_response( response_text = gemini_structured_json_response(
prompt=prompt, prompt=prompt,
schema=json_struct, schema=json_struct,
temperature=temperature, temperature=temperature,
@@ -130,7 +206,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
system_prompt=system_instructions system_prompt=system_instructions
) )
else: else:
return gemini_text_response( response_text = gemini_text_response(
prompt=prompt, prompt=prompt,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
@@ -140,7 +216,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
) )
elif gpt_provider == "huggingface": elif gpt_provider == "huggingface":
if json_struct: if json_struct:
return huggingface_structured_json_response( response_text = huggingface_structured_json_response(
prompt=prompt, prompt=prompt,
schema=json_struct, schema=json_struct,
model=model, model=model,
@@ -149,7 +225,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
system_prompt=system_instructions system_prompt=system_instructions
) )
else: else:
return huggingface_text_response( response_text = huggingface_text_response(
prompt=prompt, prompt=prompt,
model=model, model=model,
temperature=temperature, temperature=temperature,
@@ -160,6 +236,107 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
else: else:
logger.error(f"[llm_text_gen] Unknown provider: {gpt_provider}") logger.error(f"[llm_text_gen] Unknown provider: {gpt_provider}")
raise RuntimeError("Unknown LLM provider. Supported providers: google, huggingface") raise RuntimeError("Unknown LLM provider. Supported providers: google, huggingface")
# TRACK USAGE after successful API call
if response_text:
logger.info(f"[llm_text_gen] ✅ API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
try:
db_track = next(get_db())
try:
# Estimate tokens from prompt and response
tokens_input = estimated_tokens # Already calculated above
tokens_output = int(len(str(response_text).split()) * 1.3) # Estimate output tokens
tokens_total = tokens_input + tokens_output
logger.debug(f"[llm_text_gen] Token estimates: input={tokens_input}, output={tokens_output}, total={tokens_total}")
# Get or create usage summary
from models.subscription_models import UsageSummary
from services.subscription import PricingService
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
logger.debug(f"[llm_text_gen] Looking for usage summary: user_id={user_id}, period={current_period}")
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
logger.info(f"[llm_text_gen] Creating new usage summary for user {user_id}, period {current_period}")
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush() # Ensure summary is persisted before updating
# Get "before" state for unified log
provider_name = provider_enum.value
current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0
# Update provider-specific counters (sync operation)
new_calls = current_calls_before + 1
setattr(summary, f"{provider_name}_calls", new_calls)
logger.debug(f"[llm_text_gen] Updated {provider_name}_calls: {current_calls_before} -> {new_calls}")
# Update token usage for LLM providers
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0
new_tokens = current_tokens_before + tokens_total
setattr(summary, f"{provider_name}_tokens", new_tokens)
logger.debug(f"[llm_text_gen] Updated {provider_name}_tokens: {current_tokens_before} -> {new_tokens}")
else:
current_tokens_before = 0
new_tokens = 0
# Update totals
old_total_calls = summary.total_calls or 0
old_total_tokens = summary.total_tokens or 0
summary.total_calls = old_total_calls + 1
summary.total_tokens = old_total_tokens + tokens_total
logger.debug(f"[llm_text_gen] Updated totals: calls {old_total_calls} -> {summary.total_calls}, tokens {old_total_tokens} -> {summary.total_tokens}")
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) if limits else 0
# Get image stats for unified log
current_images_before = getattr(summary, "stability_calls", 0) or 0
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
db_track.commit()
logger.info(f"[llm_text_gen] ✅ Successfully tracked usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens")
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")
# Include image stats in the log
print(f"""
[SUBSCRIPTION] LLM Text Generation
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: {actual_provider_name}
├─ Model: {model}
├─ Calls: {current_calls_before}{new_calls} / {call_limit if call_limit > 0 else ''}
├─ Tokens: {current_tokens_before}{new_tokens} / {token_limit if token_limit > 0 else ''}
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else ''}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:
logger.error(f"[llm_text_gen] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
# Non-blocking: log error but don't fail the request
logger.error(f"[llm_text_gen] ❌ Failed to track usage: {usage_error}", exc_info=True)
return response_text
except Exception as provider_error: except Exception as provider_error:
logger.error(f"[llm_text_gen] Provider {gpt_provider} failed: {str(provider_error)}") logger.error(f"[llm_text_gen] Provider {gpt_provider} failed: {str(provider_error)}")
@@ -171,9 +348,21 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
fallback_provider = fallback_providers[0] # Only try the first available fallback_provider = fallback_providers[0] # Only try the first available
try: try:
logger.info(f"[llm_text_gen] Trying SINGLE fallback provider: {fallback_provider}") logger.info(f"[llm_text_gen] Trying SINGLE fallback provider: {fallback_provider}")
actual_provider_used = fallback_provider
# Update provider enum for fallback
if fallback_provider == "google":
provider_enum = APIProvider.GEMINI
actual_provider_name = "gemini"
fallback_model = "gemini-2.0-flash-lite"
elif fallback_provider == "huggingface":
provider_enum = APIProvider.MISTRAL
actual_provider_name = "huggingface"
fallback_model = "openai/gpt-oss-120b:groq"
if fallback_provider == "google": if fallback_provider == "google":
if json_struct: if json_struct:
return gemini_structured_json_response( response_text = gemini_structured_json_response(
prompt=prompt, prompt=prompt,
schema=json_struct, schema=json_struct,
temperature=temperature, temperature=temperature,
@@ -183,7 +372,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
system_prompt=system_instructions system_prompt=system_instructions
) )
else: else:
return gemini_text_response( response_text = gemini_text_response(
prompt=prompt, prompt=prompt,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
@@ -193,7 +382,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
) )
elif fallback_provider == "huggingface": elif fallback_provider == "huggingface":
if json_struct: if json_struct:
return huggingface_structured_json_response( response_text = huggingface_structured_json_response(
prompt=prompt, prompt=prompt,
schema=json_struct, schema=json_struct,
model="openai/gpt-oss-120b:groq", model="openai/gpt-oss-120b:groq",
@@ -202,7 +391,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
system_prompt=system_instructions system_prompt=system_instructions
) )
else: else:
return huggingface_text_response( response_text = huggingface_text_response(
prompt=prompt, prompt=prompt,
model="openai/gpt-oss-120b:groq", model="openai/gpt-oss-120b:groq",
temperature=temperature, temperature=temperature,
@@ -210,6 +399,96 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
top_p=top_p, top_p=top_p,
system_prompt=system_instructions system_prompt=system_instructions
) )
# TRACK USAGE after successful fallback call
if response_text:
logger.info(f"[llm_text_gen] ✅ Fallback API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
try:
db_track = next(get_db())
try:
# Estimate tokens from prompt and response
tokens_input = estimated_tokens
tokens_output = int(len(str(response_text).split()) * 1.3)
tokens_total = tokens_input + tokens_output
# Get or create usage summary
from models.subscription_models import UsageSummary
from services.subscription import PricingService
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush() # Ensure summary is persisted before updating
# Get "before" state for unified log
provider_name = provider_enum.value
current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0
# Update provider-specific counters (sync operation)
new_calls = current_calls_before + 1
setattr(summary, f"{provider_name}_calls", new_calls)
# Update token usage for LLM providers
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0
new_tokens = current_tokens_before + tokens_total
setattr(summary, f"{provider_name}_tokens", new_tokens)
else:
current_tokens_before = 0
new_tokens = 0
# Update totals
summary.total_calls = (summary.total_calls or 0) + 1
summary.total_tokens = (summary.total_tokens or 0) + tokens_total
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) if limits else 0
# Get image stats for unified log
current_images_before = getattr(summary, "stability_calls", 0) or 0
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
db_track.commit()
logger.info(f"[llm_text_gen] ✅ Successfully tracked fallback usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens")
# UNIFIED SUBSCRIPTION LOG for fallback
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")
# Include image stats in the log
print(f"""
[SUBSCRIPTION] LLM Text Generation (Fallback)
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: {actual_provider_name}
├─ Model: {fallback_model}
├─ Calls: {current_calls_before}{new_calls} / {call_limit if call_limit > 0 else ''}
├─ Tokens: {current_tokens_before}{new_tokens} / {token_limit if token_limit > 0 else ''}
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else ''}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:
logger.error(f"[llm_text_gen] ❌ Error tracking fallback usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
logger.error(f"[llm_text_gen] ❌ Failed to track fallback usage: {usage_error}", exc_info=True)
return response_text
except Exception as fallback_error: except Exception as fallback_error:
logger.error(f"[llm_text_gen] Fallback provider {fallback_provider} also failed: {str(fallback_error)}") logger.error(f"[llm_text_gen] Fallback provider {fallback_provider} also failed: {str(fallback_error)}")

View File

@@ -55,6 +55,14 @@ class MonitoringDataService:
alert_threshold=task_data.get('alertThreshold', ''), alert_threshold=task_data.get('alertThreshold', ''),
status='active' status='active'
) )
# Initialize next_execution based on frequency
from services.scheduler.utils.frequency_calculator import calculate_next_execution
task.next_execution = calculate_next_execution(
frequency=task.frequency,
base_time=datetime.utcnow()
)
self.db.add(task) self.db.add(task)
# Save activation status # Save activation status
@@ -357,3 +365,80 @@ class MonitoringDataService:
logger.error(f"Error updating performance metrics for strategy {strategy_id}: {e}") logger.error(f"Error updating performance metrics for strategy {strategy_id}: {e}")
self.db.rollback() self.db.rollback()
return False return False
def get_user_execution_logs(
self,
user_id: int,
limit: Optional[int] = 50,
offset: Optional[int] = 0,
status_filter: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
Get execution logs for a specific user.
Args:
user_id: User ID to filter execution logs
limit: Maximum number of logs to return
offset: Number of logs to skip (for pagination)
status_filter: Optional status filter ('success', 'failed', 'running', 'skipped')
Returns:
List of execution log dictionaries with task details
"""
try:
logger.info(f"Getting execution logs for user {user_id}")
# Build query for execution logs filtered by user_id
query = self.db.query(TaskExecutionLog).filter(
TaskExecutionLog.user_id == user_id
)
# Apply status filter if provided
if status_filter:
query = query.filter(TaskExecutionLog.status == status_filter)
# Order by execution date (most recent first)
query = query.order_by(desc(TaskExecutionLog.execution_date))
# Apply pagination
if limit:
query = query.limit(limit)
if offset:
query = query.offset(offset)
logs = query.all()
# Convert to dictionaries with task details
logs_data = []
for log in logs:
# Get task details if available
task = self.db.query(MonitoringTask).filter(
MonitoringTask.id == log.task_id
).first()
log_data = {
"id": log.id,
"task_id": log.task_id,
"user_id": log.user_id,
"execution_date": log.execution_date.isoformat() if log.execution_date else None,
"status": log.status,
"result_data": log.result_data,
"error_message": log.error_message,
"execution_time_ms": log.execution_time_ms,
"created_at": log.created_at.isoformat() if log.created_at else None,
"task": {
"title": task.task_title if task else None,
"description": task.task_description if task else None,
"assignee": task.assignee if task else None,
"frequency": task.frequency if task else None,
"strategy_id": task.strategy_id if task else None
} if task else None
}
logs_data.append(log_data)
logger.info(f"Retrieved {len(logs_data)} execution logs for user {user_id}")
return logs_data
except Exception as e:
logger.error(f"Error getting execution logs for user {user_id}: {e}")
return []

View File

@@ -0,0 +1,59 @@
"""
Task Scheduler Package
Modular, pluggable scheduler for ALwrity tasks.
"""
from .core.scheduler import TaskScheduler
from .core.executor_interface import TaskExecutor, TaskExecutionResult
from .core.exception_handler import (
SchedulerExceptionHandler, SchedulerException, SchedulerErrorType, SchedulerErrorSeverity,
TaskExecutionError, DatabaseError, TaskLoaderError, SchedulerConfigError
)
from .executors.monitoring_task_executor import MonitoringTaskExecutor
from .utils.task_loader import load_due_monitoring_tasks
# Global scheduler instance (initialized on first access)
_scheduler_instance: TaskScheduler = None
def get_scheduler() -> TaskScheduler:
"""
Get global scheduler instance (singleton pattern).
Returns:
TaskScheduler instance
"""
global _scheduler_instance
if _scheduler_instance is None:
_scheduler_instance = TaskScheduler(
check_interval_minutes=15,
max_concurrent_executions=10
)
# Register monitoring task executor
monitoring_executor = MonitoringTaskExecutor()
_scheduler_instance.register_executor(
'monitoring_task',
monitoring_executor,
load_due_monitoring_tasks
)
return _scheduler_instance
__all__ = [
'TaskScheduler',
'TaskExecutor',
'TaskExecutionResult',
'MonitoringTaskExecutor',
'get_scheduler',
# Exception handling
'SchedulerExceptionHandler',
'SchedulerException',
'SchedulerErrorType',
'SchedulerErrorSeverity',
'TaskExecutionError',
'DatabaseError',
'TaskLoaderError',
'SchedulerConfigError'
]

View File

@@ -0,0 +1,4 @@
"""
Core scheduler components.
"""

View File

@@ -0,0 +1,395 @@
"""
Comprehensive Exception Handling and Logging for Task Scheduler
Provides robust error handling, logging, and monitoring for the scheduler system.
"""
import traceback
import sys
from datetime import datetime
from typing import Dict, Any, Optional, Union
from enum import Enum
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError
from utils.logger_utils import get_service_logger
logger = get_service_logger("scheduler_exception_handler")
class SchedulerErrorType(Enum):
"""Error types for scheduler system."""
DATABASE_ERROR = "database_error"
TASK_EXECUTION_ERROR = "task_execution_error"
TASK_LOADER_ERROR = "task_loader_error"
SCHEDULER_CONFIG_ERROR = "scheduler_config_error"
RETRY_ERROR = "retry_error"
CONCURRENCY_ERROR = "concurrency_error"
TIMEOUT_ERROR = "timeout_error"
class SchedulerErrorSeverity(Enum):
"""Severity levels for scheduler errors."""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class SchedulerException(Exception):
"""Base exception for scheduler system errors."""
def __init__(
self,
message: str,
error_type: SchedulerErrorType,
severity: SchedulerErrorSeverity = SchedulerErrorSeverity.MEDIUM,
user_id: Optional[int] = None,
task_id: Optional[int] = None,
task_type: Optional[str] = None,
context: Optional[Dict[str, Any]] = None,
original_error: Optional[Exception] = None
):
self.message = message
self.error_type = error_type
self.severity = severity
self.user_id = user_id
self.task_id = task_id
self.task_type = task_type
self.context = context or {}
self.original_error = original_error
self.timestamp = datetime.utcnow()
# Capture stack trace if original error provided
self.stack_trace = None
if self.original_error:
try:
exc_type, exc_value, exc_traceback = sys.exc_info()
if exc_traceback:
self.stack_trace = ''.join(traceback.format_exception(
exc_type, exc_value, exc_traceback
))
else:
self.stack_trace = traceback.format_exception(
type(self.original_error),
self.original_error,
self.original_error.__traceback__
)
except Exception:
self.stack_trace = str(self.original_error)
super().__init__(message)
def to_dict(self) -> Dict[str, Any]:
"""Convert exception to dictionary for logging/storage."""
return {
"message": self.message,
"error_type": self.error_type.value,
"severity": self.severity.value,
"user_id": self.user_id,
"task_id": self.task_id,
"task_type": self.task_type,
"context": self.context,
"timestamp": self.timestamp.isoformat() if isinstance(self.timestamp, datetime) else self.timestamp,
"original_error": str(self.original_error) if self.original_error else None,
"stack_trace": self.stack_trace
}
def __str__(self):
return f"[{self.error_type.value}] {self.message}"
class DatabaseError(SchedulerException):
"""Exception raised for database-related errors."""
def __init__(
self,
message: str,
user_id: Optional[int] = None,
task_id: Optional[int] = None,
context: Dict[str, Any] = None,
original_error: Exception = None
):
super().__init__(
message=message,
error_type=SchedulerErrorType.DATABASE_ERROR,
severity=SchedulerErrorSeverity.CRITICAL,
user_id=user_id,
task_id=task_id,
context=context or {},
original_error=original_error
)
class TaskExecutionError(SchedulerException):
"""Exception raised for task execution failures."""
def __init__(
self,
message: str,
user_id: Optional[int] = None,
task_id: Optional[int] = None,
task_type: Optional[str] = None,
retry_count: int = 0,
execution_time_ms: Optional[int] = None,
context: Dict[str, Any] = None,
original_error: Exception = None
):
context = context or {}
context.update({
"retry_count": retry_count,
"execution_time_ms": execution_time_ms
})
super().__init__(
message=message,
error_type=SchedulerErrorType.TASK_EXECUTION_ERROR,
severity=SchedulerErrorSeverity.HIGH,
user_id=user_id,
task_id=task_id,
task_type=task_type,
context=context,
original_error=original_error
)
class TaskLoaderError(SchedulerException):
"""Exception raised for task loading failures."""
def __init__(
self,
message: str,
task_type: Optional[str] = None,
user_id: Optional[int] = None,
context: Dict[str, Any] = None,
original_error: Exception = None
):
super().__init__(
message=message,
error_type=SchedulerErrorType.TASK_LOADER_ERROR,
severity=SchedulerErrorSeverity.HIGH,
user_id=user_id,
task_type=task_type,
context=context or {},
original_error=original_error
)
class SchedulerConfigError(SchedulerException):
"""Exception raised for scheduler configuration errors."""
def __init__(
self,
message: str,
context: Dict[str, Any] = None,
original_error: Exception = None
):
super().__init__(
message=message,
error_type=SchedulerErrorType.SCHEDULER_CONFIG_ERROR,
severity=SchedulerErrorSeverity.CRITICAL,
context=context or {},
original_error=original_error
)
class SchedulerExceptionHandler:
"""Comprehensive exception handler for the scheduler system."""
def __init__(self, db: Session = None):
self.db = db
self.logger = logger
def handle_exception(
self,
error: Union[Exception, SchedulerException],
context: Dict[str, Any] = None,
log_level: str = "error"
) -> Dict[str, Any]:
"""Handle and log scheduler exceptions."""
context = context or {}
# Convert regular exceptions to SchedulerException
if not isinstance(error, SchedulerException):
error = SchedulerException(
message=str(error),
error_type=self._classify_error(error),
severity=self._determine_severity(error),
context=context,
original_error=error
)
# Log the error
error_data = error.to_dict()
error_data.update(context)
log_message = f"Scheduler Error: {error.message}"
if log_level == "critical" or error.severity == SchedulerErrorSeverity.CRITICAL:
self.logger.critical(log_message, extra={"error_data": error_data})
elif log_level == "error" or error.severity == SchedulerErrorSeverity.HIGH:
self.logger.error(log_message, extra={"error_data": error_data})
elif log_level == "warning" or error.severity == SchedulerErrorSeverity.MEDIUM:
self.logger.warning(log_message, extra={"error_data": error_data})
else:
self.logger.info(log_message, extra={"error_data": error_data})
# Store critical errors in database for alerting
if error.severity in [SchedulerErrorSeverity.HIGH, SchedulerErrorSeverity.CRITICAL]:
self._store_error_alert(error)
# Return formatted error response
return self._format_error_response(error)
def _classify_error(self, error: Exception) -> SchedulerErrorType:
"""Classify an exception into a scheduler error type."""
error_str = str(error).lower()
error_type_name = type(error).__name__.lower()
# Database errors
if isinstance(error, (SQLAlchemyError, OperationalError, IntegrityError)):
return SchedulerErrorType.DATABASE_ERROR
if "database" in error_str or "sql" in error_type_name or "connection" in error_str:
return SchedulerErrorType.DATABASE_ERROR
# Timeout errors
if "timeout" in error_str or "timed out" in error_str:
return SchedulerErrorType.TIMEOUT_ERROR
# Concurrency errors
if "concurrent" in error_str or "race" in error_str or "lock" in error_str:
return SchedulerErrorType.CONCURRENCY_ERROR
# Task execution errors
if "task" in error_str and "execut" in error_str:
return SchedulerErrorType.TASK_EXECUTION_ERROR
# Task loader errors
if "load" in error_str and "task" in error_str:
return SchedulerErrorType.TASK_LOADER_ERROR
# Retry errors
if "retry" in error_str:
return SchedulerErrorType.RETRY_ERROR
# Config errors
if "config" in error_str or "scheduler" in error_str and "init" in error_str:
return SchedulerErrorType.SCHEDULER_CONFIG_ERROR
# Default to task execution error for unknown errors
return SchedulerErrorType.TASK_EXECUTION_ERROR
def _determine_severity(self, error: Exception) -> SchedulerErrorSeverity:
"""Determine the severity of an error."""
error_str = str(error).lower()
error_type = type(error)
# Critical errors
if isinstance(error, (SQLAlchemyError, OperationalError, ConnectionError)):
return SchedulerErrorSeverity.CRITICAL
if "database" in error_str or "connection" in error_str:
return SchedulerErrorSeverity.CRITICAL
# High severity errors
if "timeout" in error_str or "concurrent" in error_str:
return SchedulerErrorSeverity.HIGH
if isinstance(error, (KeyError, AttributeError)) and "config" in error_str:
return SchedulerErrorSeverity.HIGH
# Medium severity errors
if "task" in error_str or "execution" in error_str:
return SchedulerErrorSeverity.MEDIUM
# Default to low
return SchedulerErrorSeverity.LOW
def _store_error_alert(self, error: SchedulerException):
"""Store critical errors in database for alerting."""
if not self.db:
return
try:
# Import here to avoid circular dependencies
from models.monitoring_models import TaskExecutionLog
# Store as failed execution log if we have task_id (even without user_id for system errors)
if error.task_id:
try:
execution_log = TaskExecutionLog(
task_id=error.task_id,
user_id=error.user_id, # Can be None for system-level errors
execution_date=error.timestamp,
status='failed',
error_message=error.message,
result_data={
"error_type": error.error_type.value,
"severity": error.severity.value,
"context": error.context,
"stack_trace": error.stack_trace,
"task_type": error.task_type
}
)
self.db.add(execution_log)
self.db.commit()
self.logger.info(f"Stored error alert in execution log for task {error.task_id}")
except Exception as e:
self.logger.error(f"Failed to store error in execution log: {e}")
self.db.rollback()
# Note: For errors without task_id, we rely on structured logging only
# Future: Could create a separate scheduler_error_logs table for system-level errors
except Exception as e:
self.logger.error(f"Failed to store error alert: {e}")
def _format_error_response(self, error: SchedulerException) -> Dict[str, Any]:
"""Format error for API response or logging."""
response = {
"success": False,
"error": {
"type": error.error_type.value,
"message": error.message,
"severity": error.severity.value,
"timestamp": error.timestamp.isoformat() if isinstance(error.timestamp, datetime) else str(error.timestamp),
"user_id": error.user_id,
"task_id": error.task_id,
"task_type": error.task_type
}
}
# Add context for debugging (non-sensitive info only)
if error.context:
safe_context = {
k: v for k, v in error.context.items()
if k not in ["password", "token", "key", "secret", "credential"]
}
response["error"]["context"] = safe_context
# Add user-friendly message based on error type
user_messages = {
SchedulerErrorType.DATABASE_ERROR:
"A database error occurred while processing the task. Please try again later.",
SchedulerErrorType.TASK_EXECUTION_ERROR:
"The task failed to execute. Please check the task configuration and try again.",
SchedulerErrorType.TASK_LOADER_ERROR:
"Failed to load tasks. The scheduler may be experiencing issues.",
SchedulerErrorType.SCHEDULER_CONFIG_ERROR:
"The scheduler configuration is invalid. Contact support.",
SchedulerErrorType.RETRY_ERROR:
"Task retry failed. The task will be rescheduled.",
SchedulerErrorType.CONCURRENCY_ERROR:
"A concurrency issue occurred. The task will be retried.",
SchedulerErrorType.TIMEOUT_ERROR:
"The task execution timed out. The task will be retried."
}
response["error"]["user_message"] = user_messages.get(
error.error_type,
"An error occurred while processing the task."
)
return response

View File

@@ -0,0 +1,75 @@
"""
Task Executor Interface
Abstract base class for all task executors.
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
from dataclasses import dataclass
from datetime import datetime
from sqlalchemy.orm import Session
@dataclass
class TaskExecutionResult:
"""Result of task execution."""
success: bool
error_message: Optional[str] = None
result_data: Optional[Dict[str, Any]] = None
execution_time_ms: Optional[int] = None
retryable: bool = True
retry_delay: int = 300 # seconds
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
'success': self.success,
'error_message': self.error_message,
'result_data': self.result_data,
'execution_time_ms': self.execution_time_ms,
'retryable': self.retryable,
'retry_delay': self.retry_delay
}
class TaskExecutor(ABC):
"""
Abstract base class for task executors.
Each task type must implement this interface to be schedulable.
"""
@abstractmethod
async def execute_task(self, task: Any, db: Session) -> TaskExecutionResult:
"""
Execute a task.
Args:
task: Task instance from database
db: Database session
Returns:
TaskExecutionResult with execution details
"""
pass
@abstractmethod
def calculate_next_execution(
self,
task: Any,
frequency: str,
last_execution: Optional[datetime] = None
) -> datetime:
"""
Calculate next execution time based on frequency.
Args:
task: Task instance
frequency: Task frequency (e.g., 'Daily', 'Weekly')
last_execution: Last execution datetime
Returns:
Next execution datetime
"""
pass

View File

@@ -0,0 +1,628 @@
"""
Core Task Scheduler Service
Pluggable task scheduler that can work with any task model.
"""
import asyncio
import logging
from typing import Dict, Any, Optional, List, Callable
from datetime import datetime
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
from apscheduler.triggers.interval import IntervalTrigger
from sqlalchemy.orm import Session
from .executor_interface import TaskExecutor, TaskExecutionResult
from .task_registry import TaskRegistry
from .exception_handler import (
SchedulerExceptionHandler, SchedulerException, TaskExecutionError, DatabaseError,
TaskLoaderError, SchedulerConfigError
)
from services.database import get_db_session
from utils.logger_utils import get_service_logger
logger = get_service_logger("task_scheduler")
class TaskScheduler:
"""
Pluggable task scheduler that can work with any task model.
Features:
- Async task execution
- Plugin-based executor system
- Database-backed task persistence
- Configurable check intervals
- Automatic retry logic
"""
def __init__(
self,
check_interval_minutes: int = 15,
max_concurrent_executions: int = 10,
enable_retries: bool = True,
max_retries: int = 3
):
"""
Initialize the task scheduler.
Args:
check_interval_minutes: How often to check for due tasks
max_concurrent_executions: Maximum concurrent task executions
enable_retries: Whether to retry failed tasks
max_retries: Maximum retry attempts
"""
self.check_interval_minutes = check_interval_minutes
self.max_concurrent_executions = max_concurrent_executions
self.enable_retries = enable_retries
self.max_retries = max_retries
# Initialize APScheduler
self.scheduler = AsyncIOScheduler(
timezone='UTC',
job_defaults={
'coalesce': True,
'max_instances': 1,
'misfire_grace_time': 300 # 5 minutes grace period
}
)
# Task executor registry
self.registry = TaskRegistry()
# Track running executions
self.active_executions: Dict[str, asyncio.Task] = {}
# Exception handler for robust error handling
self.exception_handler = SchedulerExceptionHandler()
# Intelligent scheduling configuration
self.min_check_interval_minutes = 15 # Check every 15min when active strategies exist
self.max_check_interval_minutes = 60 # Check every 60min when no active strategies
self.current_check_interval_minutes = check_interval_minutes # Current interval
# Statistics
self.stats = {
'total_checks': 0,
'tasks_found': 0,
'tasks_executed': 0,
'tasks_failed': 0,
'tasks_skipped': 0,
'last_check': None,
'per_user_stats': {}, # Track metrics per user for user isolation
'active_strategies_count': 0, # Track active strategies with tasks
'last_interval_adjustment': None # Track when interval was last adjusted
}
self._running = False
def _get_trigger_for_interval(self, interval_minutes: int):
"""
Get the appropriate trigger for the given interval.
For intervals >= 60 minutes, use IntervalTrigger.
For intervals < 60 minutes, use CronTrigger.
Args:
interval_minutes: Interval in minutes
Returns:
Appropriate APScheduler trigger
"""
if interval_minutes >= 60:
# Use IntervalTrigger for intervals >= 60 minutes
return IntervalTrigger(minutes=interval_minutes)
else:
# Use CronTrigger for intervals < 60 minutes (valid range: 0-59)
return CronTrigger(minute=f'*/{interval_minutes}')
def register_executor(
self,
task_type: str,
executor: TaskExecutor,
task_loader: Callable[[Session], List[Any]]
):
"""
Register a task executor for a specific task type.
Args:
task_type: Unique identifier for task type (e.g., 'monitoring_task')
executor: TaskExecutor instance that handles execution
task_loader: Function that loads due tasks from database
"""
self.registry.register(task_type, executor, task_loader)
logger.info(f"Registered executor for task type: {task_type}")
async def start(self):
"""Start the scheduler with intelligent interval adjustment."""
if self._running:
logger.warning("Scheduler is already running")
return
try:
# Determine initial check interval based on active strategies
initial_interval = await self._determine_optimal_interval()
self.current_check_interval_minutes = initial_interval
# Add periodic job to check for due tasks
self.scheduler.add_job(
self._check_and_execute_due_tasks,
trigger=self._get_trigger_for_interval(initial_interval),
id='check_due_tasks',
replace_existing=True
)
self.scheduler.start()
self._running = True
logger.info(
f"Task scheduler started | "
f"check_interval={initial_interval}min | "
f"registered_types={self.registry.get_registered_types()}"
)
except Exception as e:
logger.error(f"Failed to start scheduler: {e}")
raise
async def stop(self):
"""Stop the scheduler gracefully."""
if not self._running:
return
try:
# Cancel all active executions
for task_id, execution_task in self.active_executions.items():
execution_task.cancel()
# Wait for active executions to complete (with timeout)
if self.active_executions:
await asyncio.wait(
self.active_executions.values(),
timeout=30
)
# Shutdown scheduler
self.scheduler.shutdown(wait=True)
self._running = False
logger.info("Task scheduler stopped gracefully")
except Exception as e:
logger.error(f"Error stopping scheduler: {e}")
raise
async def _check_and_execute_due_tasks(self):
"""
Main scheduler loop: check for due tasks and execute them.
This runs periodically with intelligent interval adjustment based on active strategies.
"""
self.stats['total_checks'] += 1
self.stats['last_check'] = datetime.utcnow().isoformat()
logger.debug("Checking for due tasks...")
db = None
try:
db = get_db_session()
if db is None:
logger.error("Failed to get database session")
return
# Check for active strategies and adjust interval intelligently
await self._adjust_check_interval_if_needed(db)
# Check each registered task type
for task_type in self.registry.get_registered_types():
await self._process_task_type(task_type, db)
except Exception as e:
error = DatabaseError(
message=f"Error checking for due tasks: {str(e)}",
original_error=e
)
self.exception_handler.handle_exception(error)
finally:
if db:
db.close()
async def _determine_optimal_interval(self) -> int:
"""
Determine optimal check interval based on active strategies.
Returns:
Optimal check interval in minutes
"""
db = None
try:
db = get_db_session()
if db:
from services.active_strategy_service import ActiveStrategyService
active_strategy_service = ActiveStrategyService(db_session=db)
active_count = active_strategy_service.count_active_strategies_with_tasks()
self.stats['active_strategies_count'] = active_count
if active_count > 0:
logger.info(f"Found {active_count} active strategies with tasks - using {self.min_check_interval_minutes}min interval")
return self.min_check_interval_minutes
else:
logger.info(f"No active strategies with tasks - using {self.max_check_interval_minutes}min interval")
return self.max_check_interval_minutes
except Exception as e:
logger.warning(f"Error determining optimal interval: {e}, using default {self.min_check_interval_minutes}min")
finally:
if db:
db.close()
# Default to shorter interval on error (safer)
return self.min_check_interval_minutes
async def _adjust_check_interval_if_needed(self, db: Session):
"""
Intelligently adjust check interval based on active strategies.
If there are active strategies with tasks, check more frequently.
If there are no active strategies, check less frequently.
Args:
db: Database session
"""
try:
from services.active_strategy_service import ActiveStrategyService
active_strategy_service = ActiveStrategyService(db_session=db)
active_count = active_strategy_service.count_active_strategies_with_tasks()
self.stats['active_strategies_count'] = active_count
# Determine optimal interval
if active_count > 0:
optimal_interval = self.min_check_interval_minutes
else:
optimal_interval = self.max_check_interval_minutes
# Only reschedule if interval needs to change
if optimal_interval != self.current_check_interval_minutes:
logger.info(
f"Adjusting scheduler interval: {self.current_check_interval_minutes}min → {optimal_interval}min | "
f"active_strategies={active_count}"
)
# Reschedule the job with new interval
self.scheduler.modify_job(
'check_due_tasks',
trigger=self._get_trigger_for_interval(optimal_interval)
)
self.current_check_interval_minutes = optimal_interval
self.stats['last_interval_adjustment'] = datetime.utcnow().isoformat()
logger.info(f"Scheduler interval adjusted to {optimal_interval}min")
except Exception as e:
logger.warning(f"Error adjusting check interval: {e}")
async def trigger_interval_adjustment(self):
"""
Trigger immediate interval adjustment check.
This should be called when a strategy is activated or deactivated
to immediately adjust the scheduler interval based on current active strategies.
"""
if not self._running:
logger.debug("Scheduler not running, skipping interval adjustment")
return
try:
db = get_db_session()
if db:
await self._adjust_check_interval_if_needed(db)
else:
logger.warning("Could not get database session for interval adjustment")
except Exception as e:
logger.warning(f"Error triggering interval adjustment: {e}")
async def _process_task_type(self, task_type: str, db: Session):
"""Process due tasks for a specific task type."""
try:
# Get task loader for this type
try:
task_loader = self.registry.get_task_loader(task_type)
except Exception as e:
error = TaskLoaderError(
message=f"Failed to get task loader for type {task_type}: {str(e)}",
task_type=task_type,
original_error=e
)
self.exception_handler.handle_exception(error)
return
# Load due tasks (with error handling)
try:
due_tasks = task_loader(db)
except Exception as e:
error = TaskLoaderError(
message=f"Failed to load due tasks for type {task_type}: {str(e)}",
task_type=task_type,
original_error=e
)
self.exception_handler.handle_exception(error)
return
if not due_tasks:
return
self.stats['tasks_found'] += len(due_tasks)
logger.info(f"Found {len(due_tasks)} due tasks for type: {task_type}")
# Execute tasks (with concurrency limit)
execution_tasks = []
for task in due_tasks:
if len(self.active_executions) >= self.max_concurrent_executions:
logger.warning(
f"Max concurrent executions reached ({self.max_concurrent_executions}), "
f"skipping {len(due_tasks) - len(execution_tasks)} tasks"
)
break
# Execute task asynchronously
# Note: Each task gets its own database session to prevent concurrent access issues
execution_task = asyncio.create_task(
self._execute_task_async(task_type, task)
)
task_id = f"{task_type}_{getattr(task, 'id', id(task))}"
self.active_executions[task_id] = execution_task
execution_tasks.append(execution_task)
# Wait for executions to complete (with timeout per task)
if execution_tasks:
await asyncio.wait(execution_tasks, timeout=300)
except Exception as e:
error = TaskLoaderError(
message=f"Error processing task type {task_type}: {str(e)}",
task_type=task_type,
original_error=e
)
self.exception_handler.handle_exception(error)
async def _execute_task_async(self, task_type: str, task: Any):
"""
Execute a single task asynchronously with user isolation.
Each task gets its own database session to prevent concurrent access issues,
as SQLAlchemy sessions are not async-safe or concurrent-safe.
User context is extracted and tracked for user isolation.
Args:
task_type: Type of task
task: Task instance from database (detached from original session)
"""
task_id = f"{task_type}_{getattr(task, 'id', id(task))}"
db = None
user_id = None
try:
# Extract user context if available (for user isolation tracking)
try:
if hasattr(task, 'strategy') and task.strategy:
user_id = getattr(task.strategy, 'user_id', None)
elif hasattr(task, 'strategy_id') and task.strategy_id:
# Will query user_id after we have db session
pass
except Exception as e:
logger.debug(f"Could not extract user_id before execution for task {task_id}: {e}")
logger.info(f"Executing task: {task_id} | user_id: {user_id}")
# Create a new database session for this async task
# SQLAlchemy sessions are not async-safe and cannot be shared across concurrent tasks
db = get_db_session()
if db is None:
error = DatabaseError(
message=f"Failed to get database session for task {task_id}",
user_id=user_id,
task_id=getattr(task, 'id', None),
task_type=task_type
)
self.exception_handler.handle_exception(error, log_level="error")
self.stats['tasks_failed'] += 1
self._update_user_stats(user_id, success=False)
return
# Set database session for exception handler
self.exception_handler.db = db
# Merge the detached task object into this session
# The task object was loaded in a different session and is now detached
from sqlalchemy.orm import object_session
if object_session(task) is None:
# Task is detached, need to merge it into this session
task = db.merge(task)
# Extract user_id after merge if not already available
if user_id is None and hasattr(task, 'strategy'):
try:
if task.strategy:
user_id = getattr(task.strategy, 'user_id', None)
elif hasattr(task, 'strategy_id'):
# Query strategy if relationship not loaded
from models.enhanced_strategy_models import EnhancedContentStrategy
strategy = db.query(EnhancedContentStrategy).filter(
EnhancedContentStrategy.id == task.strategy_id
).first()
if strategy:
user_id = strategy.user_id
except Exception as e:
logger.debug(f"Could not extract user_id after merge for task {task_id}: {e}")
# Get executor for this task type
try:
executor = self.registry.get_executor(task_type)
except Exception as e:
from .exception_handler import SchedulerConfigError
error = SchedulerConfigError(
message=f"Failed to get executor for task type {task_type}: {str(e)}",
user_id=user_id,
context={
"task_id": getattr(task, 'id', None),
"task_type": task_type
},
original_error=e
)
self.exception_handler.handle_exception(error)
self.stats['tasks_failed'] += 1
self._update_user_stats(user_id, success=False)
return
# Execute task with its own session (with error handling)
try:
result = await executor.execute_task(task, db)
# Handle result and update statistics
if result.success:
self.stats['tasks_executed'] += 1
self._update_user_stats(user_id, success=True)
logger.info(f"Task executed successfully: {task_id} | user_id: {user_id}")
else:
self.stats['tasks_failed'] += 1
self._update_user_stats(user_id, success=False)
# Create structured error for failed execution
error = TaskExecutionError(
message=result.error_message or "Task execution failed",
user_id=user_id,
task_id=getattr(task, 'id', None),
task_type=task_type,
execution_time_ms=result.execution_time_ms,
context={"result_data": result.result_data}
)
self.exception_handler.handle_exception(error, log_level="warning")
# Retry logic if enabled
if self.enable_retries and result.retryable:
await self._schedule_retry(task, result.retry_delay)
except SchedulerException as e:
# Re-raise scheduler exceptions (they're already handled)
raise
except Exception as e:
# Wrap unexpected exceptions
error = TaskExecutionError(
message=f"Unexpected error during task execution: {str(e)}",
user_id=user_id,
task_id=getattr(task, 'id', None),
task_type=task_type,
original_error=e
)
self.exception_handler.handle_exception(error)
self.stats['tasks_failed'] += 1
self._update_user_stats(user_id, success=False)
except SchedulerException as e:
# Handle scheduler exceptions
self.exception_handler.handle_exception(e)
self.stats['tasks_failed'] += 1
self._update_user_stats(user_id, success=False)
except Exception as e:
# Handle any other unexpected errors
error = TaskExecutionError(
message=f"Unexpected error in task execution wrapper: {str(e)}",
user_id=user_id,
task_id=getattr(task, 'id', None),
task_type=task_type,
original_error=e
)
self.exception_handler.handle_exception(error)
self.stats['tasks_failed'] += 1
self._update_user_stats(user_id, success=False)
finally:
# Clean up database session
if db:
try:
db.close()
except Exception as e:
logger.error(f"Error closing database session for task {task_id}: {e}")
# Remove from active executions
if task_id in self.active_executions:
del self.active_executions[task_id]
def _update_user_stats(self, user_id: Optional[int], success: bool):
"""
Update per-user statistics for user isolation tracking.
Args:
user_id: User ID (None if user context not available)
success: Whether task execution was successful
"""
if user_id is None:
return
if user_id not in self.stats['per_user_stats']:
self.stats['per_user_stats'][user_id] = {
'executed': 0,
'failed': 0,
'success_rate': 0.0
}
user_stats = self.stats['per_user_stats'][user_id]
if success:
user_stats['executed'] += 1
else:
user_stats['failed'] += 1
# Calculate success rate
total = user_stats['executed'] + user_stats['failed']
if total > 0:
user_stats['success_rate'] = (user_stats['executed'] / total) * 100.0
async def _schedule_retry(self, task: Any, delay_seconds: int):
"""Schedule a retry for a failed task."""
# This would update the task's next_execution time
# For now, just log - could be enhanced to update next_execution
logger.debug(f"Scheduling retry for task in {delay_seconds}s")
def get_stats(self, user_id: Optional[int] = None) -> Dict[str, Any]:
"""
Get scheduler statistics with optional user filtering.
Args:
user_id: Optional user ID to filter statistics for specific user
Returns:
Dictionary with scheduler statistics
"""
base_stats = {
**{k: v for k, v in self.stats.items() if k not in ['per_user_stats']},
'active_executions': len(self.active_executions),
'registered_types': self.registry.get_registered_types(),
'running': self._running,
'check_interval_minutes': self.current_check_interval_minutes,
'min_check_interval_minutes': self.min_check_interval_minutes,
'max_check_interval_minutes': self.max_check_interval_minutes,
'intelligent_scheduling': True
}
# Include per-user stats (all users or filtered)
if user_id is not None:
if user_id in self.stats['per_user_stats']:
base_stats['user_stats'] = self.stats['per_user_stats'][user_id]
else:
base_stats['user_stats'] = {
'executed': 0,
'failed': 0,
'success_rate': 0.0
}
else:
# Include all per-user stats (for admin/debugging)
base_stats['per_user_stats'] = self.stats['per_user_stats']
return base_stats
def is_running(self) -> bool:
"""Check if scheduler is running."""
return self._running

View File

@@ -0,0 +1,59 @@
"""
Task Registry
Manages registration of task executors and loaders.
"""
import logging
from typing import Dict, Callable, List, Any
from sqlalchemy.orm import Session
from .executor_interface import TaskExecutor
logger = logging.getLogger(__name__)
class TaskRegistry:
"""Registry for task executors and loaders."""
def __init__(self):
self.executors: Dict[str, TaskExecutor] = {}
self.task_loaders: Dict[str, Callable[[Session], List[Any]]] = {}
def register(
self,
task_type: str,
executor: TaskExecutor,
task_loader: Callable[[Session], List[Any]]
):
"""
Register a task executor and loader.
Args:
task_type: Unique identifier for task type
executor: TaskExecutor instance
task_loader: Function that loads due tasks from database
"""
if task_type in self.executors:
logger.warning(f"Overwriting existing executor for task type: {task_type}")
self.executors[task_type] = executor
self.task_loaders[task_type] = task_loader
logger.info(f"Registered task type: {task_type}")
def get_executor(self, task_type: str) -> TaskExecutor:
"""Get executor for task type."""
if task_type not in self.executors:
raise ValueError(f"No executor registered for task type: {task_type}")
return self.executors[task_type]
def get_task_loader(self, task_type: str) -> Callable[[Session], List[Any]]:
"""Get task loader for task type."""
if task_type not in self.task_loaders:
raise ValueError(f"No task loader registered for task type: {task_type}")
return self.task_loaders[task_type]
def get_registered_types(self) -> List[str]:
"""Get list of registered task types."""
return list(self.executors.keys())

View File

@@ -0,0 +1,4 @@
"""
Task executor implementations.
"""

View File

@@ -0,0 +1,266 @@
"""
Monitoring Task Executor
Handles execution of content strategy monitoring tasks.
"""
import logging
import time
from datetime import datetime
from typing import Dict, Any, Optional
from sqlalchemy.orm import Session
from ..core.executor_interface import TaskExecutor, TaskExecutionResult
from ..core.exception_handler import TaskExecutionError, DatabaseError, SchedulerExceptionHandler
from ..utils.frequency_calculator import calculate_next_execution
from models.monitoring_models import MonitoringTask, TaskExecutionLog
from models.enhanced_strategy_models import EnhancedContentStrategy
from utils.logger_utils import get_service_logger
logger = get_service_logger("monitoring_task_executor")
class MonitoringTaskExecutor(TaskExecutor):
"""
Executor for content strategy monitoring tasks.
Handles:
- ALwrity tasks (automated execution)
- Human tasks (notifications/queuing)
"""
def __init__(self):
self.logger = logger
self.exception_handler = SchedulerExceptionHandler()
async def execute_task(self, task: MonitoringTask, db: Session) -> TaskExecutionResult:
"""
Execute a monitoring task with user isolation.
Args:
task: MonitoringTask instance (with strategy relationship loaded)
db: Database session
Returns:
TaskExecutionResult
"""
start_time = time.time()
# Extract user_id from strategy relationship for user isolation
user_id = None
try:
if task.strategy and hasattr(task.strategy, 'user_id'):
user_id = task.strategy.user_id
elif task.strategy_id:
# Fallback: query strategy if relationship not loaded
strategy = db.query(EnhancedContentStrategy).filter(
EnhancedContentStrategy.id == task.strategy_id
).first()
if strategy:
user_id = strategy.user_id
except Exception as e:
self.logger.warning(f"Could not extract user_id for task {task.id}: {e}")
try:
self.logger.info(
f"Executing monitoring task: {task.id} | "
f"user_id: {user_id} | "
f"assignee: {task.assignee} | "
f"frequency: {task.frequency}"
)
# Create execution log with user_id for user isolation tracking
execution_log = TaskExecutionLog(
task_id=task.id,
user_id=user_id,
execution_date=datetime.utcnow(),
status='running'
)
db.add(execution_log)
db.flush()
# Execute based on assignee
if task.assignee == 'ALwrity':
result = await self._execute_alwrity_task(task, db)
else:
result = await self._execute_human_task(task, db)
# Update execution log
execution_time_ms = int((time.time() - start_time) * 1000)
execution_log.status = 'success' if result.success else 'failed'
execution_log.result_data = result.result_data
execution_log.error_message = result.error_message
execution_log.execution_time_ms = execution_time_ms
# Update task
task.last_executed = datetime.utcnow()
task.next_execution = self.calculate_next_execution(
task,
task.frequency,
task.last_executed
)
if result.success:
task.status = 'completed'
else:
task.status = 'failed'
db.commit()
return result
except Exception as e:
execution_time_ms = int((time.time() - start_time) * 1000)
# Set database session for exception handler
self.exception_handler.db = db
# Create structured error
error = TaskExecutionError(
message=f"Error executing monitoring task {task.id}: {str(e)}",
user_id=user_id,
task_id=task.id,
task_type="monitoring_task",
execution_time_ms=execution_time_ms,
context={
"assignee": task.assignee,
"frequency": task.frequency,
"component": task.component_name
},
original_error=e
)
# Handle exception with structured logging
self.exception_handler.handle_exception(error)
# Update execution log with error (include user_id for isolation)
try:
execution_log = TaskExecutionLog(
task_id=task.id,
user_id=user_id,
execution_date=datetime.utcnow(),
status='failed',
error_message=str(e),
execution_time_ms=execution_time_ms,
result_data={
"error_type": error.error_type.value,
"severity": error.severity.value,
"context": error.context
}
)
db.add(execution_log)
task.status = 'failed'
task.last_executed = datetime.utcnow()
db.commit()
except Exception as commit_error:
db_error = DatabaseError(
message=f"Error saving execution log: {str(commit_error)}",
user_id=user_id,
task_id=task.id,
original_error=commit_error
)
self.exception_handler.handle_exception(db_error)
db.rollback()
return TaskExecutionResult(
success=False,
error_message=str(e),
execution_time_ms=execution_time_ms,
retryable=True,
retry_delay=300
)
async def _execute_alwrity_task(self, task: MonitoringTask, db: Session) -> TaskExecutionResult:
"""
Execute an ALwrity (automated) monitoring task.
This is where the actual monitoring logic would go.
For now, we'll implement a placeholder that can be extended.
"""
try:
self.logger.info(f"Executing ALwrity task: {task.task_title}")
# TODO: Implement actual monitoring logic based on:
# - task.metric
# - task.measurement_method
# - task.success_criteria
# - task.alert_threshold
# Placeholder: Simulate task execution
result_data = {
'metric_value': 0,
'status': 'measured',
'message': f"Task {task.task_title} executed successfully",
'timestamp': datetime.utcnow().isoformat()
}
return TaskExecutionResult(
success=True,
result_data=result_data
)
except Exception as e:
self.logger.error(f"Error in ALwrity task execution: {e}")
return TaskExecutionResult(
success=False,
error_message=str(e),
retryable=True
)
async def _execute_human_task(self, task: MonitoringTask, db: Session) -> TaskExecutionResult:
"""
Execute a Human monitoring task (notification/queuing).
For human tasks, we don't execute the task directly,
but rather queue it for human review or send notifications.
"""
try:
self.logger.info(f"Queuing human task: {task.task_title}")
# TODO: Implement notification/queuing system:
# - Send email notification
# - Add to user's task queue
# - Create in-app notification
result_data = {
'status': 'queued',
'message': f"Task {task.task_title} queued for human review",
'timestamp': datetime.utcnow().isoformat()
}
return TaskExecutionResult(
success=True,
result_data=result_data
)
except Exception as e:
self.logger.error(f"Error queuing human task: {e}")
return TaskExecutionResult(
success=False,
error_message=str(e),
retryable=True
)
def calculate_next_execution(
self,
task: MonitoringTask,
frequency: str,
last_execution: Optional[datetime] = None
) -> datetime:
"""
Calculate next execution time based on frequency.
Args:
task: MonitoringTask instance
frequency: Frequency string (Daily, Weekly, Monthly, Quarterly)
last_execution: Last execution datetime (defaults to now)
Returns:
Next execution datetime
"""
return calculate_next_execution(
frequency=frequency,
base_time=last_execution or datetime.utcnow()
)

View File

@@ -0,0 +1,4 @@
"""
Scheduler utilities.
"""

View File

@@ -0,0 +1,33 @@
"""
Frequency Calculator Utility
Calculates next execution time based on frequency string.
"""
from datetime import datetime, timedelta
from typing import Optional
def calculate_next_execution(frequency: str, base_time: Optional[datetime] = None) -> datetime:
"""
Calculate next execution time based on frequency.
Args:
frequency: Frequency string ('Daily', 'Weekly', 'Monthly', 'Quarterly')
base_time: Base time to calculate from (defaults to now if None)
Returns:
Next execution datetime
"""
if base_time is None:
base_time = datetime.utcnow()
frequency_map = {
'Daily': timedelta(days=1),
'Weekly': timedelta(weeks=1),
'Monthly': timedelta(days=30),
'Quarterly': timedelta(days=90)
}
delta = frequency_map.get(frequency, timedelta(days=1))
return base_time + delta

View File

@@ -0,0 +1,60 @@
"""
Task Loader Utilities
Functions to load due tasks from database.
"""
from datetime import datetime
from typing import List, Optional
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import and_, or_
from models.monitoring_models import MonitoringTask
from models.enhanced_strategy_models import EnhancedContentStrategy
def load_due_monitoring_tasks(
db: Session,
user_id: Optional[int] = None
) -> List[MonitoringTask]:
"""
Load all monitoring tasks that are due for execution.
Criteria:
- status == 'active'
- next_execution <= now (or is None for first execution)
- Optional: user_id filter for specific user (for future admin features)
Note: Strategy relationship is eagerly loaded to ensure user_id is accessible
during task execution for user isolation.
Args:
db: Database session
user_id: Optional user ID to filter tasks (if None, loads all users' tasks)
Returns:
List of due MonitoringTask instances with strategy relationship loaded
"""
now = datetime.utcnow()
# Join with strategy to ensure relationship is loaded and support user filtering
query = db.query(MonitoringTask).join(
EnhancedContentStrategy,
MonitoringTask.strategy_id == EnhancedContentStrategy.id
).options(
joinedload(MonitoringTask.strategy) # Eagerly load strategy relationship
).filter(
and_(
MonitoringTask.status == 'active',
or_(
MonitoringTask.next_execution <= now,
MonitoringTask.next_execution.is_(None)
)
)
)
# Apply user filter if provided
if user_id is not None:
query = query.filter(EnhancedContentStrategy.user_id == user_id)
return query.all()

View File

@@ -0,0 +1,189 @@
"""
Pre-flight Validation Utility for Multi-Operation Workflows
Provides transparent validation for operations that involve multiple API calls.
Services can use this to validate entire workflows before making any external API calls.
"""
from typing import Dict, Any, List, Optional, Tuple
from fastapi import HTTPException
from loguru import logger
from services.subscription.pricing_service import PricingService
from models.subscription_models import APIProvider
def validate_research_operations(
pricing_service: PricingService,
user_id: str,
gpt_provider: str = "google"
) -> None:
"""
Validate all operations for a research workflow before making ANY API calls.
This prevents wasteful external API calls (like Google Grounding) if subsequent
LLM calls would fail due to token or call limits.
Args:
pricing_service: PricingService instance
user_id: User ID for subscription checking
gpt_provider: GPT provider from env var (defaults to "google")
Returns:
(can_proceed, error_message, error_details)
If can_proceed is False, raises HTTPException with 429 status
"""
try:
# Determine actual provider for LLM calls based on GPT_PROVIDER env var
gpt_provider_lower = gpt_provider.lower()
if gpt_provider_lower == "huggingface":
llm_provider_enum = APIProvider.MISTRAL # Maps to HuggingFace
llm_provider_name = "huggingface"
else:
llm_provider_enum = APIProvider.GEMINI
llm_provider_name = "gemini"
# Estimate tokens for each operation in research workflow
# Google Grounding call: ~2000 tokens (input + output)
# Keyword analyzer: ~1000 tokens (input: 3000 chars research, output: structured JSON)
# Competitor analyzer: ~1000 tokens (input: 3000 chars research, output: structured JSON)
# Content angle generator: ~1000 tokens (input: 3000 chars research, output: list of angles)
operations_to_validate = [
{
'provider': APIProvider.GEMINI, # Google Grounding uses Gemini
'tokens_requested': 2000,
'actual_provider_name': 'gemini',
'operation_type': 'google_grounding'
},
{
'provider': llm_provider_enum,
'tokens_requested': 1000,
'actual_provider_name': llm_provider_name,
'operation_type': 'keyword_analysis'
},
{
'provider': llm_provider_enum,
'tokens_requested': 1000,
'actual_provider_name': llm_provider_name,
'operation_type': 'competitor_analysis'
},
{
'provider': llm_provider_enum,
'tokens_requested': 1000,
'actual_provider_name': llm_provider_name,
'operation_type': 'content_angle_generation'
}
]
logger.info(f"[Pre-flight Validator] 🚀 Starting Research Workflow Validation")
logger.info(f" ├─ User: {user_id}")
logger.info(f" ├─ LLM Provider: {llm_provider_name} (GPT_PROVIDER={gpt_provider})")
logger.info(f" └─ Operations to validate: {len(operations_to_validate)}")
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate
)
if not can_proceed:
usage_info = error_details.get('usage_info', {}) if error_details else {}
provider = usage_info.get('provider', llm_provider_name) if usage_info else llm_provider_name
operation_type = usage_info.get('operation_type', 'unknown')
logger.error(f"[Pre-flight Validator] ❌ RESEARCH WORKFLOW BLOCKED")
logger.error(f" ├─ User: {user_id}")
logger.error(f" ├─ Blocked at: {operation_type}")
logger.error(f" ├─ Provider: {provider}")
logger.error(f" └─ Reason: {message}")
# Raise HTTPException immediately - frontend gets immediate response, no API calls made
raise HTTPException(
status_code=429,
detail={
'error': message,
'message': message,
'provider': provider,
'usage_info': usage_info if usage_info else error_details
}
)
logger.info(f"[Pre-flight Validator] ✅ RESEARCH WORKFLOW APPROVED")
logger.info(f" ├─ User: {user_id}")
logger.info(f" └─ All {len(operations_to_validate)} operations validated - proceeding with API calls")
# Validation passed - no return needed (function raises HTTPException if validation fails)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Pre-flight Validator] Error validating research operations: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
'error': f"Failed to validate operations: {str(e)}",
'message': f"Failed to validate operations: {str(e)}"
}
)
def validate_image_generation_operations(
pricing_service: PricingService,
user_id: str
) -> None:
"""
Validate image generation operation before making API calls.
Args:
pricing_service: PricingService instance
user_id: User ID for subscription checking
Returns:
(can_proceed, error_message, error_details)
If can_proceed is False, raises HTTPException with 429 status
"""
try:
operations_to_validate = [
{
'provider': APIProvider.STABILITY,
'tokens_requested': 0,
'actual_provider_name': 'stability',
'operation_type': 'image_generation'
}
]
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate
)
if not can_proceed:
logger.error(f"[Pre-flight Validator] Image generation blocked for user {user_id}: {message}")
usage_info = error_details.get('usage_info', {}) if error_details else {}
provider = usage_info.get('provider', 'stability') if usage_info else 'stability'
raise HTTPException(
status_code=429,
detail={
'error': message,
'message': message,
'provider': provider,
'usage_info': usage_info if usage_info else error_details
}
)
logger.info(f"[Pre-flight Validator] ✅ Image generation validated for user {user_id}")
# Validation passed - no return needed (function raises HTTPException if validation fails)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Pre-flight Validator] Error validating image generation: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
'error': f"Failed to validate image generation: {str(e)}",
'message': f"Failed to validate image generation: {str(e)}"
}
)

View File

@@ -3,10 +3,11 @@ Pricing Service for API Usage Tracking
Manages API pricing, cost calculation, and subscription limits. Manages API pricing, cost calculation, and subscription limits.
""" """
from typing import Dict, Any, Optional, List, Tuple from typing import Dict, Any, Optional, List, Tuple, Union
from decimal import Decimal, ROUND_HALF_UP from decimal import Decimal, ROUND_HALF_UP
from datetime import datetime, timedelta from datetime import datetime, timedelta
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy import text
from loguru import logger from loguru import logger
from models.subscription_models import ( from models.subscription_models import (
@@ -17,13 +18,17 @@ from models.subscription_models import (
class PricingService: class PricingService:
"""Service for managing API pricing and cost calculations.""" """Service for managing API pricing and cost calculations."""
# Class-level cache shared across all instances (critical for cache invalidation on subscription renewal)
# key: f"{user_id}:{provider}", value: { 'result': (bool, str, dict), 'expires_at': datetime }
_limits_cache: Dict[str, Dict[str, Any]] = {}
def __init__(self, db: Session): def __init__(self, db: Session):
self.db = db self.db = db
self._pricing_cache = {} self._pricing_cache = {}
self._plans_cache = {} self._plans_cache = {}
# Lightweight in-process cache for limit checks # Cache for schema feature detection (ai_text_generation_calls_limit column)
# key: f"{user_id}:{provider}", value: { 'result': (bool, str, dict), 'expires_at': datetime } self._ai_text_gen_col_checked: bool = False
self._limits_cache: Dict[str, Dict[str, Any]] = {} self._ai_text_gen_col_available: bool = False
# ------------------- Billing period helpers ------------------- # ------------------- Billing period helpers -------------------
def _compute_next_period_end(self, start: datetime, cycle: str) -> datetime: def _compute_next_period_end(self, start: datetime, cycle: str) -> datetime:
@@ -68,6 +73,15 @@ class PricingService:
self._ensure_subscription_current(subscription) self._ensure_subscription_current(subscription)
# Continue to use YYYY-MM for summaries # Continue to use YYYY-MM for summaries
return datetime.now().strftime("%Y-%m") return datetime.now().strftime("%Y-%m")
@classmethod
def clear_user_cache(cls, user_id: str) -> int:
"""Clear all cached limit checks for a specific user. Returns number of entries cleared."""
keys_to_remove = [key for key in cls._limits_cache.keys() if key.startswith(f"{user_id}:")]
for key in keys_to_remove:
del cls._limits_cache[key]
logger.info(f"Cleared {len(keys_to_remove)} cache entries for user {user_id}")
return len(keys_to_remove)
def initialize_default_pricing(self): def initialize_default_pricing(self):
"""Initialize default pricing for all API providers.""" """Initialize default pricing for all API providers."""
@@ -292,7 +306,8 @@ class PricingService:
"tier": SubscriptionTier.BASIC, "tier": SubscriptionTier.BASIC,
"price_monthly": 29.0, "price_monthly": 29.0,
"price_yearly": 290.0, "price_yearly": 290.0,
"gemini_calls_limit": 1000, "ai_text_generation_calls_limit": 10, # Unified limit for all LLM providers
"gemini_calls_limit": 1000, # Legacy, kept for backwards compatibility (not used for enforcement)
"openai_calls_limit": 500, "openai_calls_limit": 500,
"anthropic_calls_limit": 200, "anthropic_calls_limit": 200,
"mistral_calls_limit": 500, "mistral_calls_limit": 500,
@@ -300,11 +315,11 @@ class PricingService:
"serper_calls_limit": 200, "serper_calls_limit": 200,
"metaphor_calls_limit": 100, "metaphor_calls_limit": 100,
"firecrawl_calls_limit": 100, "firecrawl_calls_limit": 100,
"stability_calls_limit": 50, "stability_calls_limit": 5,
"gemini_tokens_limit": 1000000, "gemini_tokens_limit": 2000,
"openai_tokens_limit": 500000, "openai_tokens_limit": 2000,
"anthropic_tokens_limit": 200000, "anthropic_tokens_limit": 2000,
"mistral_tokens_limit": 500000, "mistral_tokens_limit": 2000,
"monthly_cost_limit": 50.0, "monthly_cost_limit": 50.0,
"features": ["full_content_generation", "advanced_research", "basic_analytics"], "features": ["full_content_generation", "advanced_research", "basic_analytics"],
"description": "Great for individuals and small teams" "description": "Great for individuals and small teams"
@@ -426,21 +441,60 @@ class PricingService:
self._ensure_subscription_current(subscription) self._ensure_subscription_current(subscription)
return self._plan_to_limits_dict(subscription.plan) return self._plan_to_limits_dict(subscription.plan)
def _ensure_ai_text_gen_column_detection(self) -> None:
"""Detect at runtime whether ai_text_generation_calls_limit column exists and cache the result."""
if self._ai_text_gen_col_checked:
return
try:
# Try to query the column - if it exists, this will work
self.db.execute(text('SELECT ai_text_generation_calls_limit FROM subscription_plans LIMIT 0'))
self._ai_text_gen_col_available = True
except Exception:
self._ai_text_gen_col_available = False
finally:
self._ai_text_gen_col_checked = True
def _plan_to_limits_dict(self, plan: SubscriptionPlan) -> Dict[str, Any]: def _plan_to_limits_dict(self, plan: SubscriptionPlan) -> Dict[str, Any]:
"""Convert subscription plan to limits dictionary.""" """Convert subscription plan to limits dictionary."""
# Detect if unified AI text generation limit column exists
self._ensure_ai_text_gen_column_detection()
# Use unified AI text generation limit if column exists and is set
ai_text_gen_limit = None
if self._ai_text_gen_col_available:
try:
ai_text_gen_limit = getattr(plan, 'ai_text_generation_calls_limit', None)
# If 0, treat as not set (unlimited for Enterprise or use fallback)
if ai_text_gen_limit == 0:
ai_text_gen_limit = None
except (AttributeError, Exception):
# Column exists but access failed - use fallback
ai_text_gen_limit = None
return { return {
'plan_name': plan.name, 'plan_name': plan.name,
'tier': plan.tier.value, 'tier': plan.tier.value,
'limits': { 'limits': {
# Unified AI text generation limit (applies to all LLM providers)
# If not set, fall back to first non-zero legacy limit for backwards compatibility
'ai_text_generation_calls': ai_text_gen_limit if ai_text_gen_limit is not None else (
plan.gemini_calls_limit if plan.gemini_calls_limit > 0 else
plan.openai_calls_limit if plan.openai_calls_limit > 0 else
plan.anthropic_calls_limit if plan.anthropic_calls_limit > 0 else
plan.mistral_calls_limit if plan.mistral_calls_limit > 0 else 0
),
# Legacy per-provider limits (for backwards compatibility and analytics)
'gemini_calls': plan.gemini_calls_limit, 'gemini_calls': plan.gemini_calls_limit,
'openai_calls': plan.openai_calls_limit, 'openai_calls': plan.openai_calls_limit,
'anthropic_calls': plan.anthropic_calls_limit, 'anthropic_calls': plan.anthropic_calls_limit,
'mistral_calls': plan.mistral_calls_limit, 'mistral_calls': plan.mistral_calls_limit,
# Other API limits
'tavily_calls': plan.tavily_calls_limit, 'tavily_calls': plan.tavily_calls_limit,
'serper_calls': plan.serper_calls_limit, 'serper_calls': plan.serper_calls_limit,
'metaphor_calls': plan.metaphor_calls_limit, 'metaphor_calls': plan.metaphor_calls_limit,
'firecrawl_calls': plan.firecrawl_calls_limit, 'firecrawl_calls': plan.firecrawl_calls_limit,
'stability_calls': plan.stability_calls_limit, 'stability_calls': plan.stability_calls_limit,
# Token limits
'gemini_tokens': plan.gemini_tokens_limit, 'gemini_tokens': plan.gemini_tokens_limit,
'openai_tokens': plan.openai_tokens_limit, 'openai_tokens': plan.openai_tokens_limit,
'anthropic_tokens': plan.anthropic_tokens_limit, 'anthropic_tokens': plan.anthropic_tokens_limit,
@@ -451,101 +505,293 @@ class PricingService:
} }
def check_usage_limits(self, user_id: str, provider: APIProvider, def check_usage_limits(self, user_id: str, provider: APIProvider,
tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]: tokens_requested: int = 0, actual_provider_name: Optional[str] = None) -> Tuple[bool, str, Dict[str, Any]]:
"""Check if user can make an API call within their limits.""" """Check if user can make an API call within their limits.
# Short TTL cache to reduce DB reads under sustained traffic
cache_key = f"{user_id}:{provider.value}"
now = datetime.utcnow()
cached = self._limits_cache.get(cache_key)
if cached and cached.get('expires_at') and cached['expires_at'] > now:
return tuple(cached['result']) # type: ignore
# Get user limits
limits = self.get_user_limits(user_id)
if not limits:
return False, "No subscription plan found", {}
# Get current usage for this billing period Args:
current_period = self.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m") user_id: User ID
usage = self.db.query(UsageSummary).filter( provider: APIProvider enum (may be MISTRAL for HuggingFace)
UsageSummary.user_id == user_id, tokens_requested: Estimated tokens for the request
UsageSummary.billing_period == current_period actual_provider_name: Optional actual provider name (e.g., "huggingface" when provider is MISTRAL)
).first() """
try:
if not usage: # Use actual_provider_name if provided, otherwise use enum value
# First usage this period, create summary # This fixes cases where HuggingFace maps to MISTRAL enum but should show as "huggingface" in errors
usage = UsageSummary( display_provider_name = actual_provider_name or provider.value
user_id=user_id,
billing_period=current_period
)
self.db.add(usage)
self.db.commit()
# Check call limits
provider_name = provider.value
current_calls = getattr(usage, f"{provider_name}_calls", 0)
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
if call_limit > 0 and current_calls >= call_limit:
result = (False, f"API call limit reached for {provider_name}", {
'current_calls': current_calls,
'limit': call_limit,
'usage_percentage': 100.0
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
# Check token limits for LLM providers
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
current_tokens = getattr(usage, f"{provider_name}_tokens", 0)
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0)
if token_limit > 0 and (current_tokens + tokens_requested) > token_limit: logger.debug(f"[Subscription Check] Starting limit check for user {user_id}, provider {display_provider_name}, tokens {tokens_requested}")
result = (False, f"Token limit would be exceeded for {provider_name}", {
'current_tokens': current_tokens, # Short TTL cache to reduce DB reads under sustained traffic
'requested_tokens': tokens_requested, cache_key = f"{user_id}:{provider.value}"
'limit': token_limit, now = datetime.utcnow()
'usage_percentage': ((current_tokens + tokens_requested) / token_limit) * 100 cached = self._limits_cache.get(cache_key)
if cached and cached.get('expires_at') and cached['expires_at'] > now:
logger.debug(f"[Subscription Check] Using cached result for {user_id}:{provider.value}")
return tuple(cached['result']) # type: ignore
# Get user subscription first to check expiration
subscription = self.db.query(UserSubscription).filter(
UserSubscription.user_id == user_id,
UserSubscription.is_active == True
).first()
if subscription:
logger.debug(f"[Subscription Check] Found subscription for user {user_id}: plan_id={subscription.plan_id}, period_end={subscription.current_period_end}")
else:
logger.debug(f"[Subscription Check] No active subscription found for user {user_id}")
# Check subscription expiration (STRICT: deny if expired)
if subscription:
if subscription.current_period_end < now:
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}: period_end={subscription.current_period_end}, now={now}")
# Subscription expired - check if auto_renew is enabled
if not getattr(subscription, 'auto_renew', False):
# Expired and no auto-renew - deny access
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}, auto_renew=False, denying access")
result = (False, "Subscription expired. Please renew your subscription to continue using the service.", {
'expired': True,
'period_end': subscription.current_period_end.isoformat()
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
else:
# Try to auto-renew
if not self._ensure_subscription_current(subscription):
# Auto-renew failed - deny access
result = (False, "Subscription expired and auto-renewal failed. Please renew manually.", {
'expired': True,
'auto_renew_failed': True
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
# Get user limits with error handling (STRICT: fail on errors)
try:
limits = self.get_user_limits(user_id)
if limits:
logger.debug(f"[Subscription Check] Retrieved limits for user {user_id}: plan={limits.get('plan_name')}, tier={limits.get('tier')}")
else:
logger.debug(f"[Subscription Check] No limits found for user {user_id}, checking free tier")
except Exception as e:
logger.error(f"[Subscription Check] Error getting user limits for {user_id}: {e}", exc_info=True)
# STRICT: Fail closed - deny request if we can't check limits
return False, f"Failed to retrieve subscription limits: {str(e)}", {}
if not limits:
# No subscription found - check for free tier
free_plan = self.db.query(SubscriptionPlan).filter(
SubscriptionPlan.tier == SubscriptionTier.FREE,
SubscriptionPlan.is_active == True
).first()
if free_plan:
logger.info(f"[Subscription Check] Assigning free tier to user {user_id}")
limits = self._plan_to_limits_dict(free_plan)
else:
# No subscription and no free tier - deny access
logger.warning(f"[Subscription Check] No subscription or free tier found for user {user_id}, denying access")
return False, "No subscription plan found. Please subscribe to a plan.", {}
# Get current usage for this billing period with error handling
try:
current_period = self.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not usage:
# First usage this period, create summary
try:
usage = UsageSummary(
user_id=user_id,
billing_period=current_period
)
self.db.add(usage)
self.db.commit()
except Exception as create_error:
logger.error(f"Error creating usage summary: {create_error}")
self.db.rollback()
# STRICT: Fail closed on DB error
return False, f"Failed to create usage summary: {str(create_error)}", {}
except Exception as e:
logger.error(f"Error getting usage summary for {user_id}: {e}")
self.db.rollback()
# STRICT: Fail closed on DB error
return False, f"Failed to retrieve usage summary: {str(e)}", {}
# Check call limits with error handling
# NOTE: call_limit = 0 means UNLIMITED (Enterprise plans)
try:
# Use display_provider_name for error messages, but provider.value for DB queries
provider_name = provider.value # For DB field names (e.g., "mistral_calls", "mistral_tokens")
# For LLM text generation providers, check against unified total_calls limit
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
is_llm_provider = provider_name in llm_providers
if is_llm_provider:
# Use unified AI text generation limit (total_calls across all LLM providers)
ai_text_gen_limit = limits['limits'].get('ai_text_generation_calls', 0) or 0
# If unified limit not set, fall back to provider-specific limit for backwards compatibility
if ai_text_gen_limit == 0:
ai_text_gen_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
# Calculate total LLM provider calls (sum of gemini + openai + anthropic + mistral)
current_total_llm_calls = (
(usage.gemini_calls or 0) +
(usage.openai_calls or 0) +
(usage.anthropic_calls or 0) +
(usage.mistral_calls or 0)
)
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if ai_text_gen_limit > 0 and current_total_llm_calls >= ai_text_gen_limit:
logger.error(f"[Subscription Check] AI text generation call limit exceeded for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit} (provider: {display_provider_name})")
result = (False, f"AI text generation call limit reached. Used {current_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls this billing period.", {
'current_calls': current_total_llm_calls,
'limit': ai_text_gen_limit,
'usage_percentage': (current_total_llm_calls / ai_text_gen_limit) * 100 if ai_text_gen_limit > 0 else 0,
'provider': display_provider_name, # Use display name for consistency
'usage_info': {
'provider': display_provider_name, # Use display name for user-facing info
'current_calls': current_total_llm_calls,
'limit': ai_text_gen_limit,
'type': 'ai_text_generation',
'breakdown': {
'gemini': usage.gemini_calls or 0,
'openai': usage.openai_calls or 0,
'anthropic': usage.anthropic_calls or 0,
'mistral': usage.mistral_calls or 0 # DB field name (not display name)
}
}
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
else:
logger.debug(f"[Subscription Check] AI text generation limit check passed for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit if ai_text_gen_limit > 0 else 'unlimited'} (provider: {display_provider_name})")
else:
# For non-LLM providers, check provider-specific limit
current_calls = getattr(usage, f"{provider_name}_calls", 0) or 0
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if call_limit > 0 and current_calls >= call_limit:
logger.error(f"[Subscription Check] Call limit exceeded for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit}")
result = (False, f"API call limit reached for {display_provider_name}. Used {current_calls} of {call_limit} calls this billing period.", {
'current_calls': current_calls,
'limit': call_limit,
'usage_percentage': 100.0,
'provider': display_provider_name # Use display name for consistency
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
else:
logger.debug(f"[Subscription Check] Call limit check passed for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit if call_limit > 0 else 'unlimited'}")
except Exception as e:
logger.error(f"Error checking call limits: {e}")
# Continue to next check
# Check token limits for LLM providers with error handling
# NOTE: token_limit = 0 means UNLIMITED (Enterprise plans)
try:
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
current_tokens = getattr(usage, f"{provider_name}_tokens", 0) or 0
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if token_limit > 0 and (current_tokens + tokens_requested) > token_limit:
result = (False, f"Token limit would be exceeded for {display_provider_name}. Current: {current_tokens}, Requested: {tokens_requested}, Limit: {token_limit}", {
'current_tokens': current_tokens,
'requested_tokens': tokens_requested,
'limit': token_limit,
'usage_percentage': ((current_tokens + tokens_requested) / token_limit) * 100,
'provider': display_provider_name, # Use display name in error details
'usage_info': {
'provider': display_provider_name,
'current_tokens': current_tokens,
'requested_tokens': tokens_requested,
'limit': token_limit,
'type': 'tokens'
}
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
except Exception as e:
logger.error(f"Error checking token limits: {e}")
# Continue to next check
# Check cost limits with error handling
# NOTE: cost_limit = 0 means UNLIMITED (Enterprise plans)
try:
cost_limit = limits['limits'].get('monthly_cost', 0) or 0
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if cost_limit > 0 and usage.total_cost >= cost_limit:
result = (False, f"Monthly cost limit reached. Current cost: ${usage.total_cost:.2f}, Limit: ${cost_limit:.2f}", {
'current_cost': usage.total_cost,
'limit': cost_limit,
'usage_percentage': 100.0
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
except Exception as e:
logger.error(f"Error checking cost limits: {e}")
# Continue to success case
# Calculate usage percentages for warnings
try:
# Determine which call variables to use based on provider type
if is_llm_provider:
# Use unified LLM call tracking
current_call_count = current_total_llm_calls
call_limit_value = ai_text_gen_limit
else:
# Use provider-specific call tracking
current_call_count = current_calls
call_limit_value = call_limit
call_usage_pct = (current_call_count / max(call_limit_value, 1)) * 100 if call_limit_value > 0 else 0
cost_usage_pct = (usage.total_cost / max(cost_limit, 1)) * 100 if cost_limit > 0 else 0
result = (True, "Within limits", {
'current_calls': current_call_count,
'call_limit': call_limit_value,
'call_usage_percentage': call_usage_pct,
'current_cost': usage.total_cost,
'cost_limit': cost_limit,
'cost_usage_percentage': cost_usage_pct
}) })
self._limits_cache[cache_key] = { self._limits_cache[cache_key] = {
'result': result, 'result': result,
'expires_at': now + timedelta(seconds=30) 'expires_at': now + timedelta(seconds=30)
} }
return result return result
except Exception as e:
logger.error(f"Error calculating usage percentages: {e}")
# Return basic success
return True, "Within limits", {}
# Check cost limits except Exception as e:
cost_limit = limits['limits'].get('monthly_cost', 0) logger.error(f"Unexpected error in check_usage_limits for {user_id}: {e}")
if cost_limit > 0 and usage.total_cost >= cost_limit: # STRICT: Fail closed - deny requests if subscription system fails
result = (False, "Monthly cost limit reached", { return False, f"Subscription check error: {str(e)}", {}
'current_cost': usage.total_cost,
'limit': cost_limit,
'usage_percentage': 100.0
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
# Calculate usage percentages for warnings
call_usage_pct = (current_calls / max(call_limit, 1)) * 100 if call_limit > 0 else 0
cost_usage_pct = (usage.total_cost / max(cost_limit, 1)) * 100 if cost_limit > 0 else 0
result = (True, "Within limits", {
'current_calls': current_calls,
'call_limit': call_limit,
'call_usage_percentage': call_usage_pct,
'current_cost': usage.total_cost,
'cost_limit': cost_limit,
'cost_usage_percentage': cost_usage_pct
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
def estimate_tokens(self, text: str, provider: APIProvider) -> int: def estimate_tokens(self, text: str, provider: APIProvider) -> int:
"""Estimate token count for text based on provider.""" """Estimate token count for text based on provider."""
@@ -581,6 +827,236 @@ class PricingService:
if not pricing: if not pricing:
return None return None
def check_comprehensive_limits(
self,
user_id: str,
operations: List[Dict[str, Any]]
) -> Tuple[bool, Optional[str], Optional[Dict[str, Any]]]:
"""
Comprehensive pre-flight validation that checks ALL limits before making ANY API calls.
This prevents wasteful API calls by validating that ALL subsequent operations will succeed
before making the first external API call.
Args:
user_id: User ID
operations: List of operations to validate, each with:
- 'provider': APIProvider enum
- 'tokens_requested': int (estimated tokens for LLM calls, 0 for non-LLM)
- 'actual_provider_name': Optional[str] (e.g., "huggingface" when provider is MISTRAL)
- 'operation_type': str (e.g., "google_grounding", "llm_call", "image_generation")
Returns:
(can_proceed, error_message, error_details)
If can_proceed is False, error_message explains which limit would be exceeded
"""
try:
logger.info(f"[Pre-flight Check] 🔍 Starting comprehensive validation for user {user_id}")
logger.info(f"[Pre-flight Check] 📋 Validating {len(operations)} operation(s) before making any API calls")
# Get current usage and limits once
current_period = self.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not usage:
# First usage this period, create summary
try:
usage = UsageSummary(
user_id=user_id,
billing_period=current_period
)
self.db.add(usage)
self.db.commit()
except Exception as create_error:
logger.error(f"Error creating usage summary: {create_error}")
self.db.rollback()
return False, f"Failed to create usage summary: {str(create_error)}", {}
# Get user limits
limits_dict = self.get_user_limits(user_id)
if not limits_dict:
# No subscription found - check for free tier
free_plan = self.db.query(SubscriptionPlan).filter(
SubscriptionPlan.tier == SubscriptionTier.FREE,
SubscriptionPlan.is_active == True
).first()
if free_plan:
limits_dict = self._plan_to_limits_dict(free_plan)
else:
return False, "No subscription plan found. Please subscribe to a plan.", {}
limits = limits_dict.get('limits', {})
# Track cumulative usage across all operations
total_llm_calls = (
(usage.gemini_calls or 0) +
(usage.openai_calls or 0) +
(usage.anthropic_calls or 0) +
(usage.mistral_calls or 0)
)
total_llm_tokens = {}
total_images = usage.stability_calls or 0
# Log current usage summary
logger.info(f"[Pre-flight Check] 📊 Current Usage Summary:")
logger.info(f" └─ Total LLM Calls: {total_llm_calls}")
logger.info(f" └─ Gemini Tokens: {usage.gemini_tokens or 0}, Mistral/HF Tokens: {usage.mistral_tokens or 0}")
logger.info(f" └─ Image Calls: {total_images}")
# Validate each operation
for op_idx, operation in enumerate(operations):
provider = operation.get('provider')
provider_name = provider.value if hasattr(provider, 'value') else str(provider)
tokens_requested = operation.get('tokens_requested', 0)
actual_provider_name = operation.get('actual_provider_name')
operation_type = operation.get('operation_type', 'unknown')
display_provider_name = actual_provider_name or provider_name
logger.info(f"[Pre-flight Check] ✅ Operation {op_idx + 1}/{len(operations)}: {operation_type}")
logger.info(f" ├─ Provider: {display_provider_name} (enum: {provider_name})")
logger.info(f" └─ Estimated Tokens: {tokens_requested}")
# Check if this is an LLM provider
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
is_llm_provider = provider_name in llm_providers
# Check unified AI text generation limit for LLM providers
if is_llm_provider:
ai_text_gen_limit = limits.get('ai_text_generation_calls', 0) or 0
if ai_text_gen_limit == 0:
# Fallback to provider-specific limit
ai_text_gen_limit = limits.get(f"{provider_name}_calls", 0) or 0
# Count this operation as an LLM call
projected_total_llm_calls = total_llm_calls + 1
if ai_text_gen_limit > 0 and projected_total_llm_calls > ai_text_gen_limit:
error_info = {
'current_calls': total_llm_calls,
'limit': ai_text_gen_limit,
'provider': display_provider_name,
'operation_type': operation_type,
'operation_index': op_idx
}
return False, f"AI text generation call limit would be exceeded. Would use {projected_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls.", {
'error_type': 'call_limit',
'usage_info': error_info
}
# Check token limits for this provider
# Use cumulative projected tokens from previous operations, or current from DB if first operation
provider_tokens_key = f"{provider_name}_tokens"
if provider_tokens_key in total_llm_tokens:
# Use cumulative projected tokens from previous operations
current_provider_tokens = total_llm_tokens[provider_tokens_key]
logger.info(f" └─ Using cumulative projected tokens: {current_provider_tokens}")
else:
# First operation for this provider - get current from database
current_provider_tokens = getattr(usage, provider_tokens_key, 0) or 0
total_llm_tokens[provider_tokens_key] = current_provider_tokens
logger.info(f" └─ Current tokens from DB: {current_provider_tokens}")
token_limit = limits.get(provider_tokens_key, 0) or 0
if token_limit > 0 and tokens_requested > 0:
projected_tokens = current_provider_tokens + tokens_requested
logger.info(f" └─ Token Check: {current_provider_tokens} (current) + {tokens_requested} (requested) = {projected_tokens} (total) / {token_limit} (limit)")
if projected_tokens > token_limit:
usage_percentage = (projected_tokens / token_limit) * 100 if token_limit > 0 else 0
error_info = {
'current_tokens': current_provider_tokens,
'requested_tokens': tokens_requested,
'limit': token_limit,
'provider': display_provider_name,
'operation_type': operation_type,
'operation_index': op_idx
}
error_msg = (
f"Token limit exceeded for {display_provider_name} "
f"({operation_type}). "
f"Current: {current_provider_tokens}/{token_limit}, "
f"Requested: {tokens_requested}, "
f"Would exceed by: {projected_tokens - token_limit} tokens "
f"({usage_percentage:.1f}% of limit)"
)
logger.error(f"[Pre-flight Check] ❌ BLOCKED: {error_msg}")
return False, error_msg, {
'error_type': 'token_limit',
'usage_info': error_info
}
else:
logger.info(f" └─ ✅ Token limit check passed: {projected_tokens} <= {token_limit}")
# Update cumulative counts for next operation
total_llm_calls = projected_total_llm_calls
total_llm_tokens[provider_tokens_key] += tokens_requested
logger.info(f" └─ Updated cumulative tokens for {display_provider_name}: {total_llm_tokens[provider_tokens_key]}")
# Check image generation limits
elif provider == APIProvider.STABILITY:
image_limit = limits.get('stability_calls', 0) or 0
projected_images = total_images + 1
if image_limit > 0 and projected_images > image_limit:
error_info = {
'current_images': total_images,
'limit': image_limit,
'provider': 'stability',
'operation_type': operation_type,
'operation_index': op_idx
}
return False, f"Image generation limit would be exceeded. Would use {projected_images} of {image_limit} images this billing period.", {
'error_type': 'image_limit',
'usage_info': error_info
}
total_images = projected_images
# Check other provider-specific limits
else:
provider_calls_key = f"{provider_name}_calls"
current_provider_calls = getattr(usage, provider_calls_key, 0) or 0
call_limit = limits.get(provider_calls_key, 0) or 0
if call_limit > 0:
projected_calls = current_provider_calls + 1
if projected_calls > call_limit:
error_info = {
'current_calls': current_provider_calls,
'limit': call_limit,
'provider': display_provider_name,
'operation_type': operation_type,
'operation_index': op_idx
}
return False, f"API call limit would be exceeded for {display_provider_name}. Would use {projected_calls} of {call_limit} calls this billing period.", {
'error_type': 'call_limit',
'usage_info': error_info
}
# All checks passed
logger.info(f"[Pre-flight Check] ✅ All {len(operations)} operation(s) validated successfully")
logger.info(f"[Pre-flight Check] ✅ User {user_id} is cleared to proceed with API calls")
return True, None, None
except Exception as e:
logger.error(f"[Pre-flight Check] Error during comprehensive limit check: {e}", exc_info=True)
return False, f"Failed to validate limits: {str(e)}", {}
def get_pricing_for_provider_model(self, provider: APIProvider, model_name: str) -> Optional[Dict[str, Any]]:
"""Get pricing configuration for a specific provider and model."""
pricing = self.db.query(APIProviderPricing).filter(
APIProviderPricing.provider == provider,
APIProviderPricing.model_name == model_name
).first()
if not pricing:
return None
return { return {
'provider': pricing.provider.value, 'provider': pricing.provider.value,
'model_name': pricing.model_name, 'model_name': pricing.model_name,

View File

@@ -502,7 +502,7 @@ class UsageTrackingService:
return result return result
async def reset_current_billing_period(self, user_id: str) -> Dict[str, Any]: async def reset_current_billing_period(self, user_id: str) -> Dict[str, Any]:
"""Reset usage status for the current billing period (after plan change).""" """Reset usage status and counters for the current billing period (after plan renewal/change)."""
try: try:
billing_period = datetime.now().strftime("%Y-%m") billing_period = datetime.now().strftime("%Y-%m")
summary = self.db.query(UsageSummary).filter( summary = self.db.query(UsageSummary).filter(
@@ -514,11 +514,52 @@ class UsageTrackingService:
# Nothing to reset # Nothing to reset
return {"reset": False, "reason": "no_summary"} return {"reset": False, "reason": "no_summary"}
# Clear LIMIT_REACHED so the user can resume; keep counters intact # CRITICAL: Reset ALL usage counters to 0 so user gets fresh limits with new/renewed plan
# Clear LIMIT_REACHED status
summary.usage_status = UsageStatus.ACTIVE summary.usage_status = UsageStatus.ACTIVE
# Reset all LLM provider call counters
summary.gemini_calls = 0
summary.openai_calls = 0
summary.anthropic_calls = 0
summary.mistral_calls = 0
# Reset all LLM provider token counters
summary.gemini_tokens = 0
summary.openai_tokens = 0
summary.anthropic_tokens = 0
summary.mistral_tokens = 0
# Reset search/research provider counters
summary.tavily_calls = 0
summary.serper_calls = 0
summary.metaphor_calls = 0
summary.firecrawl_calls = 0
# Reset image generation counters
summary.stability_calls = 0
# Reset cost counters
summary.gemini_cost = 0.0
summary.openai_cost = 0.0
summary.anthropic_cost = 0.0
summary.mistral_cost = 0.0
summary.tavily_cost = 0.0
summary.serper_cost = 0.0
summary.metaphor_cost = 0.0
summary.firecrawl_cost = 0.0
summary.stability_cost = 0.0
# Reset totals
summary.total_calls = 0
summary.total_tokens = 0
summary.total_cost = 0.0
summary.updated_at = datetime.utcnow() summary.updated_at = datetime.utcnow()
self.db.commit() self.db.commit()
return {"reset": True}
logger.info(f"Reset usage counters for user {user_id} in billing period {billing_period} after renewal")
return {"reset": True, "counters_reset": True}
except Exception as e: except Exception as e:
self.db.rollback() self.db.rollback()
logger.error(f"Error resetting usage status: {e}") logger.error(f"Error resetting usage status: {e}")

View File

@@ -58,19 +58,25 @@ const InitialRouteHandler: React.FC = () => {
error: null, error: null,
}); });
// Check subscription on mount // Check subscription on mount (non-blocking - don't wait for it to route)
useEffect(() => { useEffect(() => {
checkSubscription().catch((err) => { // Delay subscription check slightly to allow auth token getter to be installed first
console.error('Error checking subscription:', err); const timeoutId = setTimeout(() => {
checkSubscription().catch((err) => {
// Check if it's a connection error - handle it locally console.error('Error checking subscription (non-blocking):', err);
if (err instanceof Error && (err.name === 'NetworkError' || err.name === 'ConnectionError')) {
setConnectionError({ // Check if it's a connection error - handle it locally
hasError: true, if (err instanceof Error && (err.name === 'NetworkError' || err.name === 'ConnectionError')) {
error: err, setConnectionError({
}); hasError: true,
} error: err,
}); });
}
// Don't block routing on subscription check errors - allow graceful degradation
});
}, 100); // Small delay to ensure TokenInstaller has run
return () => clearTimeout(timeoutId);
}, []); // Remove checkSubscription dependency to prevent loop }, []); // Remove checkSubscription dependency to prevent loop
// Initialize onboarding only after subscription is confirmed // Initialize onboarding only after subscription is confirmed
@@ -125,9 +131,10 @@ const InitialRouteHandler: React.FC = () => {
); );
} }
// Loading state - ensure we wait for onboarding init after subscription is confirmed // Loading state - only wait for onboarding init, not subscription check
const waitingForOnboardingInit = !!subscription && subscription.active && !subscriptionLoading && (loading || !data); // Subscription check is non-blocking and happens in background
if (subscriptionLoading || loading || waitingForOnboardingInit) { const waitingForOnboardingInit = loading || !data;
if (loading || waitingForOnboardingInit) {
return ( return (
<Box <Box
display="flex" display="flex"
@@ -167,29 +174,79 @@ const InitialRouteHandler: React.FC = () => {
); );
} }
if (!subscription) { // Decision tree for SIGNED-IN users:
return null; // Should not happen, but just in case // Priority: Subscription → Onboarding → Dashboard (as per user flow: Landing → Subscription → Onboarding → Dashboard)
// 1. If subscription is still loading, show loading state
if (subscriptionLoading) {
return (
<Box
display="flex"
flexDirection="column"
alignItems="center"
justifyContent="center"
minHeight="100vh"
gap={2}
>
<CircularProgress size={60} />
<Typography variant="h6" color="textSecondary">
Checking subscription...
</Typography>
</Box>
);
} }
// Decision tree for SIGNED-IN users: // 2. No subscription data yet - handle gracefully
// Priority: Subscription → Onboarding → Dashboard // If onboarding is complete, allow access to dashboard (user already went through flow)
// If onboarding not complete, check if subscription check is still loading or failed
// Check if user is new (no subscription record at all) if (!subscription) {
if (isOnboardingComplete) {
console.log('InitialRouteHandler: Onboarding complete but no subscription data → Dashboard (allow access)');
return <Navigate to="/dashboard" replace />;
}
// Onboarding not complete and no subscription data
// If subscription check is still loading, show loading state
if (subscriptionLoading) {
return (
<Box
display="flex"
flexDirection="column"
alignItems="center"
justifyContent="center"
minHeight="100vh"
gap={2}
>
<CircularProgress size={60} />
<Typography variant="h6" color="textSecondary">
Checking subscription...
</Typography>
</Box>
);
}
// Subscription check completed but returned null/undefined
// This likely means no subscription - redirect to pricing
console.log('InitialRouteHandler: No subscription data after check → Pricing page');
return <Navigate to="/pricing" replace />;
}
// 3. Check subscription status first
const isNewUser = !subscription || subscription.plan === 'none'; const isNewUser = !subscription || subscription.plan === 'none';
// 1. No active subscription? → Must subscribe first (even if onboarding is complete) // No active subscription → Must subscribe first
if (isNewUser || !subscription.active) { if (isNewUser || !subscription.active) {
console.log('InitialRouteHandler: No active subscription → Pricing page'); console.log('InitialRouteHandler: No active subscription → Pricing page');
return <Navigate to="/pricing" replace />; return <Navigate to="/pricing" replace />;
} }
// 2. Has active subscription, check onboarding status // 4. Has active subscription, check onboarding status
if (!isOnboardingComplete) { if (!isOnboardingComplete) {
console.log('InitialRouteHandler: Subscription active but onboarding incomplete → Onboarding'); console.log('InitialRouteHandler: Subscription active but onboarding incomplete → Onboarding');
return <Navigate to="/onboarding" replace />; return <Navigate to="/onboarding" replace />;
} }
// 3. Has subscription AND completed onboarding → Dashboard // 5. Has subscription AND completed onboarding → Dashboard
console.log('InitialRouteHandler: All set (subscription + onboarding) → Dashboard'); console.log('InitialRouteHandler: All set (subscription + onboarding) → Dashboard');
return <Navigate to="/dashboard" replace />; return <Navigate to="/dashboard" replace />;
}; };

View File

@@ -7,6 +7,24 @@ export const setGlobalSubscriptionErrorHandler = (handler: (error: any) => boole
globalSubscriptionErrorHandler = handler; globalSubscriptionErrorHandler = handler;
}; };
// Export a function to trigger subscription error handler from outside axios interceptors
export const triggerSubscriptionError = (error: any) => {
const status = error?.response?.status;
console.log('triggerSubscriptionError: Received error', {
hasHandler: !!globalSubscriptionErrorHandler,
status,
dataKeys: error?.response?.data ? Object.keys(error.response.data) : null
});
if (globalSubscriptionErrorHandler) {
console.log('triggerSubscriptionError: Calling global subscription error handler');
return globalSubscriptionErrorHandler(error);
}
console.warn('triggerSubscriptionError: No global subscription error handler registered');
return false;
};
// Optional token getter installed from within the app after Clerk is available // Optional token getter installed from within the app after Clerk is available
let authTokenGetter: (() => Promise<string | null>) | null = null; let authTokenGetter: (() => Promise<string | null>) | null = null;
@@ -64,13 +82,27 @@ apiClient.interceptors.request.use(
async (config) => { async (config) => {
console.log(`Making ${config.method?.toUpperCase()} request to ${config.url}`); console.log(`Making ${config.method?.toUpperCase()} request to ${config.url}`);
try { try {
const token = authTokenGetter ? await authTokenGetter() : null; if (!authTokenGetter) {
console.warn(`[apiClient] ⚠️ authTokenGetter not set for ${config.url} - request may fail authentication`);
console.warn(`[apiClient] This usually means TokenInstaller hasn't run yet. Request will likely fail with 401.`);
} else {
try {
const token = await authTokenGetter();
if (token) { if (token) {
config.headers = config.headers || {}; config.headers = config.headers || {};
(config.headers as any)['Authorization'] = `Bearer ${token}`; (config.headers as any)['Authorization'] = `Bearer ${token}`;
console.log(`[apiClient] ✅ Added auth token to request: ${config.url}`);
} else {
console.warn(`[apiClient] ⚠️ authTokenGetter returned null for ${config.url} - user may not be signed in`);
console.warn(`[apiClient] User ID from localStorage: ${localStorage.getItem('user_id') || 'none'}`);
}
} catch (tokenError) {
console.error(`[apiClient] ❌ Error getting auth token for ${config.url}:`, tokenError);
}
} }
} catch (e) { } catch (e) {
// non-fatal console.error(`[apiClient] ❌ Unexpected error in request interceptor for ${config.url}:`, e);
// non-fatal - let the request proceed, backend will return 401 if needed
} }
return config; return config;
}, },
@@ -138,13 +170,17 @@ apiClient.interceptors.response.use(
console.error('Token refresh failed:', retryError); console.error('Token refresh failed:', retryError);
} }
// If retry failed and not in onboarding, redirect // If retry failed, don't redirect during app initialization (root route)
const isOnboardingRoute = window.location.pathname.includes('/onboarding') || // Only redirect if we're on a protected route and definitely authenticated
window.location.pathname === '/'; const isOnboardingRoute = window.location.pathname.includes('/onboarding');
if (!isOnboardingRoute) { const isRootRoute = window.location.pathname === '/';
// Don't redirect from root route during app initialization - allow InitialRouteHandler to work
if (!isRootRoute && !isOnboardingRoute) {
// Only redirect if we're definitely not just initializing
try { window.location.assign('/'); } catch {} try { window.location.assign('/'); } catch {}
} else { } else {
console.warn('401 Unauthorized - token refresh failed'); console.warn('401 Unauthorized - token refresh failed (during initialization, not redirecting)');
} }
} }
@@ -204,12 +240,14 @@ aiApiClient.interceptors.response.use(
console.error('Token refresh failed:', retryError); console.error('Token refresh failed:', retryError);
} }
const isOnboardingRoute = window.location.pathname.includes('/onboarding') || const isOnboardingRoute = window.location.pathname.includes('/onboarding');
window.location.pathname === '/'; const isRootRoute = window.location.pathname === '/';
if (!isOnboardingRoute) {
// Don't redirect from root route during app initialization
if (!isRootRoute && !isOnboardingRoute) {
try { window.location.assign('/'); } catch {} try { window.location.assign('/'); } catch {}
} else { } else {
console.warn('401 Unauthorized - token refresh failed'); console.warn('401 Unauthorized - token refresh failed (during initialization, not redirecting)');
} }
} }
@@ -254,13 +292,15 @@ longRunningApiClient.interceptors.response.use(
}, },
(error) => { (error) => {
if (error?.response?.status === 401) { if (error?.response?.status === 401) {
// Only redirect on 401 if we're not in onboarding flow // Only redirect on 401 if we're not in onboarding flow or root route
const isOnboardingRoute = window.location.pathname.includes('/onboarding') || const isOnboardingRoute = window.location.pathname.includes('/onboarding');
window.location.pathname === '/'; const isRootRoute = window.location.pathname === '/';
if (!isOnboardingRoute) {
// Don't redirect from root route during app initialization
if (!isRootRoute && !isOnboardingRoute) {
try { window.location.assign('/'); } catch {} try { window.location.assign('/'); } catch {}
} else { } else {
console.warn('401 Unauthorized during onboarding - token may need refresh'); console.warn('401 Unauthorized during initialization - token may need refresh (not redirecting)');
} }
} }
// Check if it's a subscription-related error and handle it globally // Check if it's a subscription-related error and handle it globally
@@ -304,13 +344,15 @@ pollingApiClient.interceptors.response.use(
}, },
(error) => { (error) => {
if (error?.response?.status === 401) { if (error?.response?.status === 401) {
// Only redirect on 401 if we're not in onboarding flow // Only redirect on 401 if we're not in onboarding flow or root route
const isOnboardingRoute = window.location.pathname.includes('/onboarding') || const isOnboardingRoute = window.location.pathname.includes('/onboarding');
window.location.pathname === '/'; const isRootRoute = window.location.pathname === '/';
if (!isOnboardingRoute) {
// Don't redirect from root route during app initialization
if (!isRootRoute && !isOnboardingRoute) {
try { window.location.assign('/'); } catch {} try { window.location.assign('/'); } catch {}
} else { } else {
console.warn('401 Unauthorized during onboarding - token may need refresh'); console.warn('401 Unauthorized during initialization - token may need refresh (not redirecting)');
} }
} }
// Check if it's a subscription-related error and handle it globally // Check if it's a subscription-related error and handle it globally

View File

@@ -66,6 +66,7 @@ export const BlogWriter: React.FC = () => {
contentConfirmed, contentConfirmed,
flowAnalysisCompleted, flowAnalysisCompleted,
flowAnalysisResults, flowAnalysisResults,
sectionImages,
setOutline, setOutline,
setTitleOptions, setTitleOptions,
setSelectedTitle, setSelectedTitle,
@@ -78,6 +79,7 @@ export const BlogWriter: React.FC = () => {
setContentConfirmed, setContentConfirmed,
setFlowAnalysisCompleted, setFlowAnalysisCompleted,
setFlowAnalysisResults, setFlowAnalysisResults,
setSectionImages,
handleResearchComplete, handleResearchComplete,
handleOutlineComplete, handleOutlineComplete,
handleOutlineError, handleOutlineError,
@@ -670,6 +672,8 @@ export const BlogWriter: React.FC = () => {
flowAnalysisResults={flowAnalysisResults} flowAnalysisResults={flowAnalysisResults}
outlineGenRef={outlineGenRef} outlineGenRef={outlineGenRef}
blogWriterApi={blogWriterApi} blogWriterApi={blogWriterApi}
sectionImages={sectionImages}
setSectionImages={setSectionImages}
contentConfirmed={contentConfirmed} contentConfirmed={contentConfirmed}
seoAnalysis={seoAnalysis} seoAnalysis={seoAnalysis}
seoMetadata={seoMetadata} seoMetadata={seoMetadata}

View File

@@ -31,6 +31,8 @@ interface PhaseContentProps {
seoMetadata: any; seoMetadata: any;
onTitleSelect: any; onTitleSelect: any;
onCustomTitle: any; onCustomTitle: any;
sectionImages?: Record<string, string>;
setSectionImages?: (images: Record<string, string> | ((prev: Record<string, string>) => Record<string, string>)) => void;
} }
export const PhaseContent: React.FC<PhaseContentProps> = ({ export const PhaseContent: React.FC<PhaseContentProps> = ({
@@ -58,7 +60,9 @@ export const PhaseContent: React.FC<PhaseContentProps> = ({
seoAnalysis, seoAnalysis,
seoMetadata, seoMetadata,
onTitleSelect, onTitleSelect,
onCustomTitle onCustomTitle,
sectionImages,
setSectionImages
}) => { }) => {
return ( return (
<div style={{ display: 'flex', flex: 1, overflow: 'hidden' }}> <div style={{ display: 'flex', flex: 1, overflow: 'hidden' }}>
@@ -100,6 +104,8 @@ export const PhaseContent: React.FC<PhaseContentProps> = ({
optimizationResults={optimizationResults} optimizationResults={optimizationResults}
researchCoverage={researchCoverage} researchCoverage={researchCoverage}
onRefine={(op: any, id: any, payload: any) => blogWriterApi.refineOutline({ outline, operation: op, section_id: id, payload }).then((res: any) => setOutline(res.outline))} onRefine={(op: any, id: any, payload: any) => blogWriterApi.refineOutline({ outline, operation: op, section_id: id, payload }).then((res: any) => setOutline(res.outline))}
sectionImages={sectionImages}
setSectionImages={setSectionImages}
/> />
</> </>
) : ( ) : (
@@ -126,6 +132,7 @@ export const PhaseContent: React.FC<PhaseContentProps> = ({
onSave={handleContentSave} onSave={handleContentSave}
continuityRefresh={continuityRefresh || undefined} continuityRefresh={continuityRefresh || undefined}
flowAnalysisResults={flowAnalysisResults} flowAnalysisResults={flowAnalysisResults}
sectionImages={sectionImages}
/> />
) : ( ) : (
<div style={{ padding: '20px', textAlign: 'center' }}> <div style={{ padding: '20px', textAlign: 'center' }}>
@@ -151,6 +158,7 @@ export const PhaseContent: React.FC<PhaseContentProps> = ({
onSave={handleContentSave} onSave={handleContentSave}
continuityRefresh={continuityRefresh || undefined} continuityRefresh={continuityRefresh || undefined}
flowAnalysisResults={flowAnalysisResults} flowAnalysisResults={flowAnalysisResults}
sectionImages={sectionImages}
/> />
) : ( ) : (
<div style={{ padding: '20px', textAlign: 'center' }}> <div style={{ padding: '20px', textAlign: 'center' }}>

View File

@@ -0,0 +1,168 @@
import React, { useState, useEffect } from 'react';
import {
Dialog,
DialogTitle,
DialogContent,
DialogActions,
Button,
Typography,
Box,
CircularProgress,
Alert
} from '@mui/material';
import { usePlatformConnections } from '../../../components/OnboardingWizard/common/usePlatformConnections';
interface WixConnectModalProps {
isOpen: boolean;
onClose: () => void;
onConnectionSuccess?: () => void;
}
export const WixConnectModal: React.FC<WixConnectModalProps> = ({
isOpen,
onClose,
onConnectionSuccess
}) => {
const { handleConnect, isLoading } = usePlatformConnections();
const [error, setError] = useState<string | null>(null);
const [isConnecting, setIsConnecting] = useState(false);
// Handle OAuth success via postMessage (same pattern as onboarding)
useEffect(() => {
if (!isOpen) return;
const handler = (event: MessageEvent) => {
const trusted = [window.location.origin, 'https://littery-sonny-unscrutinisingly.ngrok-free.dev'];
if (!trusted.includes(event.origin)) return;
if (!event.data || typeof event.data !== 'object') return;
if (event.data.type === 'WIX_OAUTH_SUCCESS') {
console.log('Wix OAuth success in modal');
setIsConnecting(false);
setError(null);
// Close modal and notify parent
if (onConnectionSuccess) {
onConnectionSuccess();
}
onClose();
}
if (event.data.type === 'WIX_OAUTH_ERROR') {
console.error('Wix OAuth error in modal:', event.data.error);
setIsConnecting(false);
setError(event.data.error || 'Wix connection failed. Please try again.');
}
};
window.addEventListener('message', handler);
return () => window.removeEventListener('message', handler);
}, [isOpen, onClose, onConnectionSuccess]);
// Also check for URL param (fallback for same-tab redirect)
useEffect(() => {
if (!isOpen) return;
const params = new URLSearchParams(window.location.search);
if (params.get('wix_connected') === 'true') {
console.log('Wix connected via URL param in modal');
setIsConnecting(false);
setError(null);
if (onConnectionSuccess) {
onConnectionSuccess();
}
onClose();
// Clean URL
const clean = window.location.pathname + window.location.hash;
window.history.replaceState({}, document.title, clean || '/');
}
}, [isOpen, onClose, onConnectionSuccess]);
const handleConnectClick = async () => {
try {
setIsConnecting(true);
setError(null);
await handleConnect('wix');
// OAuth will redirect, so we don't need to do anything else here
// The postMessage handler or URL param handler will close the modal
} catch (err: any) {
console.error('Error connecting to Wix:', err);
setIsConnecting(false);
setError(err?.message || 'Failed to start Wix connection. Please try again.');
}
};
return (
<Dialog
open={isOpen}
onClose={onClose}
maxWidth="sm"
fullWidth
PaperProps={{
sx: {
borderRadius: 2,
boxShadow: '0 4px 20px rgba(0,0,0,0.15)'
}
}}
>
<DialogTitle sx={{ pb: 1 }}>
<Typography variant="h6" sx={{ fontWeight: 600, color: '#1e293b' }}>
Connect Your Wix Account
</Typography>
</DialogTitle>
<DialogContent>
<Box sx={{ py: 1 }}>
<Typography variant="body2" color="text.secondary" paragraph>
Connect your Wix account to publish blog posts directly to your website.
</Typography>
{error && (
<Alert severity="error" sx={{ mb: 2 }}>
{error}
</Alert>
)}
{isConnecting && (
<Box sx={{ display: 'flex', alignItems: 'center', gap: 2, py: 2 }}>
<CircularProgress size={20} />
<Typography variant="body2" color="text.secondary">
Opening Wix authorization page...
</Typography>
</Box>
)}
<Box sx={{ mt: 2, p: 2, bgcolor: '#f8fafc', borderRadius: 1 }}>
<Typography variant="caption" color="text.secondary">
<strong>What happens next:</strong>
</Typography>
<Typography variant="caption" component="div" color="text.secondary" sx={{ mt: 1, display: 'block' }}>
<ol style={{ margin: '8px 0 0 20px', padding: 0 }}>
<li>You'll be redirected to Wix to authorize ALwrity</li>
<li>Grant permissions for blog creation and publishing</li>
<li>You'll be redirected back to ALwrity</li>
<li>Your blog post will be published automatically</li>
</ol>
</Typography>
</Box>
</Box>
</DialogContent>
<DialogActions sx={{ px: 3, pb: 2 }}>
<Button onClick={onClose} disabled={isConnecting}>
Cancel
</Button>
<Button
variant="contained"
onClick={handleConnectClick}
disabled={isConnecting || isLoading}
startIcon={isConnecting ? <CircularProgress size={16} /> : undefined}
>
{isConnecting ? 'Connecting...' : 'Connect to Wix'}
</Button>
</DialogActions>
</Dialog>
);
};
export default WixConnectModal;

View File

@@ -12,6 +12,8 @@ interface Props {
groundingInsights?: GroundingInsights | null; groundingInsights?: GroundingInsights | null;
optimizationResults?: OptimizationResults | null; optimizationResults?: OptimizationResults | null;
researchCoverage?: ResearchCoverage | null; researchCoverage?: ResearchCoverage | null;
sectionImages?: Record<string, string>;
setSectionImages?: (images: Record<string, string> | ((prev: Record<string, string>) => Record<string, string>)) => void;
} }
const EnhancedOutlineEditor: React.FC<Props> = ({ const EnhancedOutlineEditor: React.FC<Props> = ({
@@ -21,14 +23,15 @@ const EnhancedOutlineEditor: React.FC<Props> = ({
sourceMappingStats, sourceMappingStats,
groundingInsights, groundingInsights,
optimizationResults, optimizationResults,
researchCoverage researchCoverage,
sectionImages = {},
setSectionImages
}) => { }) => {
const [editingSection, setEditingSection] = useState<string | null>(null); const [editingSection, setEditingSection] = useState<string | null>(null);
const [expandedSections, setExpandedSections] = useState<Set<string>>(new Set()); const [expandedSections, setExpandedSections] = useState<Set<string>>(new Set());
const [hoveredSection, setHoveredSection] = useState<string | null>(null); const [hoveredSection, setHoveredSection] = useState<string | null>(null);
const [showAddSection, setShowAddSection] = useState(false); const [showAddSection, setShowAddSection] = useState(false);
const [imageModalState, setImageModalState] = useState<{ open: boolean; sectionId?: string }>(() => ({ open: false })); const [imageModalState, setImageModalState] = useState<{ open: boolean; sectionId?: string }>(() => ({ open: false }));
const [sectionImages, setSectionImages] = useState<Record<string, string>>({});
const [newSectionData, setNewSectionData] = useState({ const [newSectionData, setNewSectionData] = useState({
heading: '', heading: '',
subheadings: '', subheadings: '',
@@ -117,8 +120,8 @@ const EnhancedOutlineEditor: React.FC<Props> = ({
}; };
})()} })()}
onImageGenerated={(imageBase64, sectionId) => { onImageGenerated={(imageBase64, sectionId) => {
if (sectionId) { if (sectionId && setSectionImages) {
setSectionImages(prev => ({ ...prev, [sectionId]: imageBase64 })); setSectionImages((prev: Record<string, string>) => ({ ...prev, [sectionId]: imageBase64 }));
} }
}} }}
/> />

View File

@@ -1,7 +1,10 @@
import React, { useState, useEffect } from 'react'; import React, { useState, useEffect } from 'react';
import { useCopilotAction } from '@copilotkit/react-core'; import { useCopilotAction } from '@copilotkit/react-core';
import { blogWriterApi, BlogSEOMetadataResponse } from '../../services/blogWriterApi'; import { BlogSEOMetadataResponse } from '../../services/blogWriterApi';
import { apiClient } from '../../api/client'; import { apiClient } from '../../api/client';
import { wordpressAPI, WordPressSite, WordPressPublishRequest } from '../../api/wordpress';
import { validateAndRefreshWixTokens } from '../../utils/wixTokenUtils';
import WixConnectModal from './BlogWriterUtils/WixConnectModal';
interface PublisherProps { interface PublisherProps {
buildFullMarkdown: () => string; buildFullMarkdown: () => string;
@@ -26,10 +29,15 @@ export const Publisher: React.FC<PublisherProps> = ({
}) => { }) => {
const [wixConnectionStatus, setWixConnectionStatus] = useState<WixConnectionStatus | null>(null); const [wixConnectionStatus, setWixConnectionStatus] = useState<WixConnectionStatus | null>(null);
const [checkingWixStatus, setCheckingWixStatus] = useState(false); const [checkingWixStatus, setCheckingWixStatus] = useState(false);
const [wordpressSites, setWordpressSites] = useState<WordPressSite[]>([]);
const [checkingWordPressStatus, setCheckingWordPressStatus] = useState(false);
const [showWixConnectModal, setShowWixConnectModal] = useState(false);
const [pendingWixPublish, setPendingWixPublish] = useState<(() => Promise<any>) | null>(null);
// Check Wix connection status on component mount // Check platform connection statuses on component mount
useEffect(() => { useEffect(() => {
checkWixConnectionStatus(); checkWixConnectionStatus();
checkWordPressConnectionStatus();
}, []); }, []);
const checkWixConnectionStatus = async () => { const checkWixConnectionStatus = async () => {
@@ -48,6 +56,137 @@ export const Publisher: React.FC<PublisherProps> = ({
setCheckingWixStatus(false); setCheckingWixStatus(false);
} }
}; };
const checkWordPressConnectionStatus = async () => {
setCheckingWordPressStatus(true);
try {
const status = await wordpressAPI.getStatus();
setWordpressSites(status.sites || []);
} catch (error) {
console.error('Failed to check WordPress connection status:', error);
setWordpressSites([]);
} finally {
setCheckingWordPressStatus(false);
}
};
// Helper function to publish to Wix
const publishToWix = async (md: string, metadata: BlogSEOMetadataResponse | null, accessToken?: string): Promise<any> => {
// Get access token if not provided
if (!accessToken) {
const tokenResult = await validateAndRefreshWixTokens();
if (!tokenResult.accessToken) {
return {
success: false,
message: 'Wix tokens not available. Please connect your Wix account.',
action_required: 'connect_wix'
};
}
accessToken = tokenResult.accessToken;
}
// Extract title from SEO metadata or markdown
const title = metadata?.seo_title || (() => {
const titleMatch = md.match(/^#\s+(.+)$/m);
return titleMatch ? titleMatch[1] : 'Blog Post from ALwrity';
})();
// Extract cover image URL, skip if base64 (Wix needs HTTP URL)
let coverImageUrl: string | undefined = undefined;
if (metadata?.open_graph?.image) {
const imageUrl = metadata.open_graph.image;
// Skip base64 images - Wix import_image needs HTTP/HTTPS URL
if (typeof imageUrl === 'string' && (imageUrl.startsWith('http://') || imageUrl.startsWith('https://'))) {
coverImageUrl = imageUrl;
} else {
console.warn('Skipping cover image - Wix requires HTTP/HTTPS URL, received:', imageUrl?.substring(0, 50));
}
}
try {
// Publish using same endpoint as WixTestPage
// Note: Wix requires category/tag IDs (UUIDs), not names
// For now, skip categories/tags until we implement ID lookup/creation
const response = await apiClient.post('/api/wix/test/publish/real', {
title: title,
content: md, // Use markdown, backend converts it
cover_image_url: coverImageUrl,
// TODO: Lookup/create category IDs from metadata?.blog_categories
// TODO: Lookup/create tag IDs from metadata?.blog_tags
category_ids: undefined,
tag_ids: undefined,
publish: true,
access_token: accessToken,
member_id: undefined // Let backend derive from token
});
if (response.data.success) {
return {
success: true,
url: response.data.url,
post_id: response.data.post_id,
message: 'Blog post published successfully to Wix!'
};
} else {
return {
success: false,
message: response.data.error || 'Failed to publish to Wix'
};
}
} catch (error: any) {
// If auth error, token may be invalid - try refreshing or reconnect
if (error.response?.status === 401 || error.response?.status === 403) {
// Try to refresh one more time
const tokenResult = await validateAndRefreshWixTokens();
if (tokenResult.needsReconnect) {
const publishFunction = async () => {
return await publishToWix(md, metadata);
};
setPendingWixPublish(() => publishFunction);
setShowWixConnectModal(true);
return {
success: false,
message: 'Wix tokens expired. Please reconnect your Wix account.',
action_required: 'reconnect_wix'
};
}
// If refresh worked, retry once
if (tokenResult.accessToken) {
return await publishToWix(md, metadata, tokenResult.accessToken);
}
}
return {
success: false,
message: `Failed to publish to Wix: ${error.response?.data?.detail || error.message}`
};
}
};
// Handle Wix connection success - retry publish
const handleWixConnectionSuccess = async () => {
if (pendingWixPublish) {
const publishFn = pendingWixPublish;
setPendingWixPublish(null);
// Small delay to ensure tokens are saved in sessionStorage
setTimeout(async () => {
try {
// Retry the publish - this will be executed and return result
// Note: The result won't show in CopilotKit UI since we're outside the action handler
// But the publish will succeed and user will see their blog on Wix
const result = await publishFn();
console.log('Wix publish after connection:', result);
// Optionally show a success notification
if (result.success) {
// Publish succeeded - user's blog is now on Wix
console.log('Blog published to Wix successfully after connection');
}
} catch (error) {
console.error('Error retrying publish after connection:', error);
}
}, 500);
}
};
// Enhanced publish action with Wix support // Enhanced publish action with Wix support
useCopilotActionTyped({ useCopilotActionTyped({
name: 'publishToPlatform', name: 'publishToPlatform',
@@ -61,58 +200,101 @@ export const Publisher: React.FC<PublisherProps> = ({
const html = convertMarkdownToHTML(md); const html = convertMarkdownToHTML(md);
if (platform === 'wix') { if (platform === 'wix') {
// Check Wix connection status first // Proactively validate and refresh tokens
if (!wixConnectionStatus?.connected) { const tokenResult = await validateAndRefreshWixTokens();
return {
success: false, if (tokenResult.needsReconnect || !tokenResult.accessToken) {
message: 'Wix account not connected. Please connect your Wix account first using the Wix Test Page.', // Store the publish function to retry after connection
const publishFunction = async () => {
return await publishToWix(md, seoMetadata);
};
setPendingWixPublish(() => publishFunction);
setShowWixConnectModal(true);
return {
success: false,
message: 'Wix account not connected. Please connect your Wix account to publish.',
action_required: 'connect_wix' action_required: 'connect_wix'
}; };
} }
if (!wixConnectionStatus?.has_permissions) { // We have a valid access token, proceed with publishing
return await publishToWix(md, seoMetadata, tokenResult.accessToken);
} else if (platform === 'wordpress') {
// WordPress publishing
if (!seoMetadata) {
return { return {
success: false, success: false,
message: 'Insufficient Wix permissions. Please reconnect your Wix account.', message: 'Generate SEO metadata first. Use the "Next: Generate SEO Metadata" suggestion to create metadata before publishing.'
action_required: 'reconnect_wix'
}; };
} }
// Extract title from markdown (first heading or use default) // Check if user has connected WordPress sites
const titleMatch = md.match(/^#\s+(.+)$/m); if (wordpressSites.length === 0) {
const title = titleMatch ? titleMatch[1] : 'Blog Post from ALwrity'; return {
success: false,
message: 'No WordPress sites connected. Please connect a WordPress site first. Go to Settings > Integrations to add your WordPress site.',
action_required: 'connect_wordpress'
};
}
// Find first active site, or use first site if none are active
const activeSite = wordpressSites.find(site => site.is_active) || wordpressSites[0];
if (!activeSite) {
return {
success: false,
message: 'No active WordPress sites found. Please activate a WordPress site connection.',
action_required: 'activate_wordpress'
};
}
// Extract title from SEO metadata or markdown
const title = seoMetadata.seo_title || (() => {
const titleMatch = md.match(/^#\s+(.+)$/m);
return titleMatch ? titleMatch[1] : 'Blog Post from ALwrity';
})();
// Extract excerpt from SEO metadata
const excerpt = seoMetadata.meta_description || '';
// Build WordPress publish request
const publishRequest: WordPressPublishRequest = {
site_id: activeSite.id,
title: title,
content: html,
excerpt: excerpt,
status: 'publish',
meta_description: seoMetadata.meta_description || excerpt,
tags: seoMetadata.blog_tags || [],
categories: seoMetadata.blog_categories || []
};
try { try {
const response = await apiClient.post('/api/wix/publish', { const result = await wordpressAPI.publishContent(publishRequest);
title: title,
content: md,
publish: true
});
if (response.data.success) { if (result.success) {
return { return {
success: true, success: true,
url: response.data.url, url: result.post_url || `${activeSite.site_url}/?p=${result.post_id}`,
post_id: response.data.post_id, post_id: result.post_id,
message: 'Blog post published successfully to Wix!' message: `Blog post published successfully to WordPress site "${activeSite.site_name}"!`
}; };
} else { } else {
return { return {
success: false, success: false,
message: response.data.error || 'Failed to publish to Wix' message: result.error || 'Failed to publish to WordPress'
}; };
} }
} catch (error: any) { } catch (error: any) {
return { return {
success: false, success: false,
message: `Failed to publish to Wix: ${error.response?.data?.detail || error.message}` message: `Failed to publish to WordPress: ${error.response?.data?.detail || error.message || 'Unknown error'}`
}; };
} }
} else { } else {
// WordPress or other platforms return {
if (!seoMetadata) return { success: false, message: 'Generate SEO metadata first' }; success: false,
const res = await blogWriterApi.publish({ platform, html, metadata: seoMetadata, schedule_time }); message: `Unsupported platform: ${platform}. Supported platforms are 'wix' and 'wordpress'.`
return { success: true, url: res.url }; };
} }
}, },
render: ({ status, result }: any) => { render: ({ status, result }: any) => {
@@ -153,6 +335,13 @@ export const Publisher: React.FC<PublisherProps> = ({
</a> </a>
</div> </div>
)} )}
{(result?.action_required === 'connect_wordpress' || result?.action_required === 'activate_wordpress') && (
<div style={{ marginTop: 8 }}>
<a href="/settings/integrations" target="_blank" rel="noopener noreferrer">
Manage WordPress Connections
</a>
</div>
)}
</div> </div>
); );
} }
@@ -161,7 +350,18 @@ export const Publisher: React.FC<PublisherProps> = ({
} }
}); });
return null; // This component only provides the copilot action return (
<>
<WixConnectModal
isOpen={showWixConnectModal}
onClose={() => {
setShowWixConnectModal(false);
setPendingWixPublish(null);
}}
onConnectionSuccess={handleWixConnectionSuccess}
/>
</>
);
}; };
export default Publisher; export default Publisher;

View File

@@ -145,11 +145,7 @@ export const useSuggestions = ({
priority: 'high' priority: 'high'
}); });
items.push({ items.push({
title: 'Content Analysis', title: '📊 Content Analysis',
message: 'Analyze the flow and quality of my blog content to get improvement suggestions'
});
items.push({
title: 'Content Analysis',
message: 'Analyze the flow and quality of my blog content to get improvement suggestions' message: 'Analyze the flow and quality of my blog content to get improvement suggestions'
}); });
} else if (seoAnalysis && !seoRecommendationsApplied) { } else if (seoAnalysis && !seoRecommendationsApplied) {
@@ -160,7 +156,7 @@ export const useSuggestions = ({
priority: 'high' priority: 'high'
}); });
items.push({ items.push({
title: 'Content Analysis', title: '📊 Content Analysis',
message: 'Run analyzeContentQuality to review narrative flow and get final improvement suggestions before publishing.' message: 'Run analyzeContentQuality to review narrative flow and get final improvement suggestions before publishing.'
}); });
items.push({ items.push({
@@ -175,33 +171,21 @@ export const useSuggestions = ({
message: 'SEO recommendations are applied. Execute generateSEOMetadata immediately so we can prepare titles, descriptions, and schema without further prompts.', message: 'SEO recommendations are applied. Execute generateSEOMetadata immediately so we can prepare titles, descriptions, and schema without further prompts.',
priority: 'high' priority: 'high'
}); });
} else {
items.push({ items.push({
title: 'Next: Publish', title: '📊 Content Analysis',
message: 'The blog is SEO-optimized. Use publishToPlatform with your preferred destination (wix|wordpress) right away—no additional confirmation needed.', message: 'Run analyzeContentQuality to validate flow, consistency, and progression before publishing.'
priority: 'high'
}); });
} } else {
// SEO metadata is ready - show publishing options
items.push({
title: 'Content Analysis',
message: 'Run analyzeContentQuality to validate flow, consistency, and progression before publishing.'
});
items.push({
title: 'Publish',
message: seoMetadata
? 'Publish my blog to your preferred platform using publishToPlatform.'
: 'Generate SEO metadata first, then publish your blog.'
});
if (seoMetadata) {
items.push({ items.push({
title: '🚀 Publish to Wix', title: '🚀 Publish to Wix',
message: 'Publish my blog to Wix using publishToPlatform with platform "wix".' message: 'Publish my blog to Wix using publishToPlatform with platform "wix".',
priority: 'high'
}); });
items.push({ items.push({
title: '🌐 Publish to WordPress', title: '🌐 Publish to WordPress',
message: 'Publish my blog to WordPress using publishToPlatform with platform "wordpress".' message: 'Publish my blog to WordPress using publishToPlatform with platform "wordpress".',
priority: 'high'
}); });
} }
} }

View File

@@ -30,6 +30,7 @@ interface BlogEditorProps {
onSave?: (content: any) => void; onSave?: (content: any) => void;
continuityRefresh?: number; continuityRefresh?: number;
flowAnalysisResults?: any; flowAnalysisResults?: any;
sectionImages?: Record<string, string>;
} }
const BlogEditor: React.FC<BlogEditorProps> = ({ const BlogEditor: React.FC<BlogEditorProps> = ({
@@ -43,7 +44,8 @@ const BlogEditor: React.FC<BlogEditorProps> = ({
onContentUpdate, onContentUpdate,
onSave, onSave,
continuityRefresh, continuityRefresh,
flowAnalysisResults flowAnalysisResults,
sectionImages = {}
}) => { }) => {
const [blogTitle, setBlogTitle] = useState(initialTitle || 'Your Amazing Blog Title'); const [blogTitle, setBlogTitle] = useState(initialTitle || 'Your Amazing Blog Title');
const [sections, setSections] = useState<any[]>([]); const [sections, setSections] = useState<any[]>([]);
@@ -143,17 +145,25 @@ const BlogEditor: React.FC<BlogEditorProps> = ({
<Divider sx={{ mt: 3, opacity: 0.3 }} /> <Divider sx={{ mt: 3, opacity: 0.3 }} />
</div> </div>
<div> <div>
{sections.map((section) => ( {sections.map((section, index) => {
<BlogSection // Robust image mapping: prefer outline index id (order is consistent across phases)
key={section.id} const imageIdByIndex = outline[index]?.id;
{...section} const outlineSection = outline.find(s => (s.id === section.id) || (s.heading === section.title));
onContentUpdate={onContentUpdate} const imageId = imageIdByIndex || outlineSection?.id || section.id;
expandedSections={expandedSections} const sectionImage = sectionImages?.[imageId] || null;
toggleSectionExpansion={toggleSectionExpansion} return (
refreshToken={continuityRefresh} <BlogSection
flowAnalysisResults={flowAnalysisResults} key={section.id}
/> {...section}
))} onContentUpdate={onContentUpdate}
expandedSections={expandedSections}
toggleSectionExpansion={toggleSectionExpansion}
refreshToken={continuityRefresh}
flowAnalysisResults={flowAnalysisResults}
sectionImage={sectionImage}
/>
);
})}
</div> </div>
</Paper> </Paper>
</div> </div>

View File

@@ -40,6 +40,7 @@ interface BlogSectionProps {
toggleSectionExpansion: (sectionId: any) => void; toggleSectionExpansion: (sectionId: any) => void;
refreshToken?: number; refreshToken?: number;
flowAnalysisResults?: any; flowAnalysisResults?: any;
sectionImage?: string;
} }
const BlogSection: React.FC<BlogSectionProps> = ({ const BlogSection: React.FC<BlogSectionProps> = ({
@@ -53,7 +54,8 @@ const BlogSection: React.FC<BlogSectionProps> = ({
expandedSections, expandedSections,
toggleSectionExpansion, toggleSectionExpansion,
refreshToken, refreshToken,
flowAnalysisResults flowAnalysisResults,
sectionImage
}) => { }) => {
const [isEditing, setIsEditing] = useState(false); const [isEditing, setIsEditing] = useState(false);
const [sectionTitle, setSectionTitle] = useState(title); const [sectionTitle, setSectionTitle] = useState(title);
@@ -181,6 +183,31 @@ const BlogSection: React.FC<BlogSectionProps> = ({
)} )}
</div> </div>
{/* Section Image Display */}
{sectionImage && (
<div style={{ marginBottom: '16px', marginTop: '8px' }}>
<div style={{
border: '1px solid #e0e0e0',
borderRadius: '8px',
overflow: 'hidden',
maxWidth: '100%',
backgroundColor: '#fff'
}}>
<img
src={`data:image/png;base64,${sectionImage}`}
alt={`Cover image for ${sectionTitle}`}
style={{
width: '100%',
height: 'auto',
display: 'block',
maxHeight: '400px',
objectFit: 'contain'
}}
/>
</div>
</div>
)}
<div <div
className="relative" className="relative"

View File

@@ -119,25 +119,44 @@ const SystemStatusIndicator: React.FC<SystemStatusIndicatorProps> = ({ className
const fetchDetailedStats = async () => { const fetchDetailedStats = async () => {
try { try {
const response = await apiClient.get('/api/content-planning/monitoring/api-stats'); const response = await apiClient.get('/api/content-planning/monitoring/api-stats');
const result = response.data; const result = response?.data;
if (result.status === 'success') {
setDetailedStats(result.data); // Validate response structure
if (result.data?.cache_performance) { if (!result || result.status !== 'success' || !result.data) {
setCachePerf(result.data.cache_performance); console.warn('Invalid response structure from api-stats endpoint:', result);
setChartData([]);
return;
}
const data = result.data;
setDetailedStats(data);
if (data?.cache_performance) {
setCachePerf(data.cache_performance);
} }
// Generate chart data // Generate chart data - safely handle missing top_endpoints
const chartData = result.data.top_endpoints.slice(0, 5).map((endpoint: any, index: number) => ({ if (data?.top_endpoints && Array.isArray(data.top_endpoints) && data.top_endpoints.length > 0) {
name: endpoint.endpoint.split(' ')[1].split('/').pop() || 'API', try {
requests: endpoint.count, const chartData = data.top_endpoints.slice(0, 5).map((endpoint: any) => ({
avgTime: endpoint.avg_time, name: endpoint?.endpoint?.split(' ')[1]?.split('/').pop() || 'API',
errors: endpoint.errors, requests: endpoint?.count || 0,
hitRate: endpoint.cache_hit_rate avgTime: endpoint?.avg_time || 0,
errors: endpoint?.errors || 0,
hitRate: endpoint?.cache_hit_rate || 0
})); }));
setChartData(chartData); setChartData(chartData);
} catch (mapError) {
console.error('Error mapping chart data:', mapError);
setChartData([]);
}
} else {
// If top_endpoints is missing or not an array, set empty chart data
setChartData([]);
} }
} catch (err) { } catch (err) {
console.error('Error fetching detailed stats:', err); console.error('Error fetching detailed stats:', err);
setChartData([]);
} }
}; };
@@ -353,7 +372,7 @@ const SystemStatusIndicator: React.FC<SystemStatusIndicatorProps> = ({ className
)} )}
{/* Recent Errors Section */} {/* Recent Errors Section */}
{detailedStats?.recent_errors && detailedStats.recent_errors.length > 0 && ( {detailedStats?.recent_errors && Array.isArray(detailedStats.recent_errors) && detailedStats.recent_errors.length > 0 && (
<motion.div <motion.div
initial={{ opacity: 0, y: 20 }} initial={{ opacity: 0, y: 20 }}
animate={{ opacity: 1, y: 0 }} animate={{ opacity: 1, y: 0 }}
@@ -395,6 +414,8 @@ const SystemStatusIndicator: React.FC<SystemStatusIndicatorProps> = ({ className
> >
Close Close
</Button> </Button>
<Tooltip title={loading ? "Refreshing data..." : "Refresh monitoring data"}>
<span>
<Button <Button
onClick={fetchDetailedStats} onClick={fetchDetailedStats}
variant="contained" variant="contained"
@@ -403,6 +424,8 @@ const SystemStatusIndicator: React.FC<SystemStatusIndicatorProps> = ({ className
> >
Refresh Data Refresh Data
</Button> </Button>
</span>
</Tooltip>
</DialogActions> </DialogActions>
</Dialog> </Dialog>
</> </>

View File

@@ -56,18 +56,10 @@ export interface PromptSuggestion {
} }
export async function fetchPromptSuggestions(payload: any): Promise<PromptSuggestion[]> { export async function fetchPromptSuggestions(payload: any): Promise<PromptSuggestion[]> {
const res = await fetch('/api/images/suggest-prompts', { // Use apiClient directly (same pattern as SEO analysis in SEOAnalysisModal.tsx)
method: 'POST', // The apiClient interceptor will handle auth token injection automatically
headers: { 'Content-Type': 'application/json' }, const response = await apiClient.post('/api/images/suggest-prompts', payload);
credentials: 'include', return response.data.suggestions || [];
body: JSON.stringify(payload)
});
if (!res.ok) {
const text = await res.text();
throw new Error(text || 'Failed to fetch prompt suggestions');
}
const data = await res.json();
return data.suggestions || [];
} }

View File

@@ -28,6 +28,7 @@ import {
Modal, Modal,
Fade, Fade,
Backdrop, Backdrop,
Snackbar,
} from '@mui/material'; } from '@mui/material';
import { import {
Check as CheckIcon, Check as CheckIcon,
@@ -35,6 +36,7 @@ import {
Star as StarIcon, Star as StarIcon,
WorkspacePremium as PremiumIcon, WorkspacePremium as PremiumIcon,
Info as InfoIcon, Info as InfoIcon,
Warning,
Psychology, Psychology,
Search, Search,
FactCheck, FactCheck,
@@ -83,6 +85,7 @@ const PricingPage: React.FC = () => {
const [subscribing, setSubscribing] = useState(false); const [subscribing, setSubscribing] = useState(false);
const [paymentModalOpen, setPaymentModalOpen] = useState(false); const [paymentModalOpen, setPaymentModalOpen] = useState(false);
const [showSignInPrompt, setShowSignInPrompt] = useState(false); const [showSignInPrompt, setShowSignInPrompt] = useState(false);
const [successSnackbar, setSuccessSnackbar] = useState({ open: false, message: '', countdown: 3 });
const [knowMoreModal, setKnowMoreModal] = useState<{ open: boolean; title: string; content: React.ReactNode }>({ const [knowMoreModal, setKnowMoreModal] = useState<{ open: boolean; title: string; content: React.ReactNode }>({
open: false, open: false,
title: '', title: '',
@@ -172,27 +175,70 @@ const PricingPage: React.FC = () => {
setSubscribing(true); setSubscribing(true);
const userId = localStorage.getItem('user_id') || 'anonymous'; const userId = localStorage.getItem('user_id') || 'anonymous';
await apiClient.post(`/api/subscription/subscribe/${userId}`, { const response = await apiClient.post(`/api/subscription/subscribe/${userId}`, {
plan_id: selectedPlan, plan_id: selectedPlan,
billing_cycle: yearlyBilling ? 'yearly' : 'monthly' billing_cycle: yearlyBilling ? 'yearly' : 'monthly'
}); });
// Refresh subscription status console.log('Subscription renewed successfully:', response.data);
// Refresh subscription status immediately
window.dispatchEvent(new CustomEvent('subscription-updated')); window.dispatchEvent(new CustomEvent('subscription-updated'));
// Also trigger user authenticated event to refresh subscription context
window.dispatchEvent(new CustomEvent('user-authenticated'));
setPaymentModalOpen(false); setPaymentModalOpen(false);
// After subscription, check if onboarding is complete // Get plan name for success message
// If not complete, redirect to onboarding; otherwise to dashboard const planName = plans.find(p => p.id === selectedPlan)?.name || 'subscription';
const onboardingComplete = localStorage.getItem('onboarding_complete') === 'true';
if (onboardingComplete) { // Show success message with countdown
navigate('/dashboard'); setSuccessSnackbar({
} else { open: true,
navigate('/onboarding'); message: `🎉 ${planName} plan activated! Your usage limits have been reset. Returning to your work in 3 seconds...`,
} countdown: 3
});
// Countdown timer
let countdown = 3;
const countdownInterval = setInterval(() => {
countdown -= 1;
if (countdown > 0) {
setSuccessSnackbar(prev => ({
...prev,
message: `🎉 ${planName} plan activated! Your usage limits have been reset. Returning to your work in ${countdown} second${countdown !== 1 ? 's' : ''}...`,
countdown
}));
} else {
clearInterval(countdownInterval);
}
}, 1000);
// Auto-redirect after 3 seconds
setTimeout(() => {
clearInterval(countdownInterval);
// After subscription, check if onboarding is complete
// If not complete, redirect to onboarding; otherwise to dashboard
const onboardingComplete = localStorage.getItem('onboarding_complete') === 'true';
if (onboardingComplete) {
// Try to go back to where the user was (e.g., blog writer)
// If no history, go to dashboard
const referrer = sessionStorage.getItem('subscription_referrer');
if (referrer && referrer !== '/pricing') {
navigate(referrer);
} else {
navigate('/dashboard');
}
} else {
navigate('/onboarding');
}
}, 3000);
} catch (err) { } catch (err) {
console.error('Error subscribing:', err); console.error('Error subscribing:', err);
setError('Failed to process subscription'); setError('Failed to process subscription');
setSuccessSnackbar({ open: false, message: '', countdown: 0 });
} finally { } finally {
setSubscribing(false); setSubscribing(false);
} }
@@ -900,32 +946,71 @@ const PricingPage: React.FC = () => {
top: '50%', top: '50%',
left: '50%', left: '50%',
transform: 'translate(-50%, -50%)', transform: 'translate(-50%, -50%)',
width: 400, width: 450,
bgcolor: 'background.paper', bgcolor: 'background.paper',
border: '2px solid #000', border: '2px solid #000',
boxShadow: 24, boxShadow: 24,
p: 4, p: 4,
borderRadius: 2, borderRadius: 2,
}}> }}>
<Typography variant="h6" component="h2" gutterBottom> <Typography variant="h6" component="h2" gutterBottom sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>
<Warning sx={{ color: 'warning.main' }} />
Alpha Testing Subscription Alpha Testing Subscription
</Typography> </Typography>
<Typography variant="body1" sx={{ mb: 3 }}>
Thank you for participating in our alpha testing! For the Basic plan, we're crediting $29 to your account. {/* Alpha Testing Notice */}
</Typography> <Alert severity="warning" sx={{ mb: 2 }}>
<Typography variant="body2" color="text.secondary" sx={{ mb: 3 }}> <Typography variant="body2" sx={{ fontWeight: 600, mb: 0.5 }}>
In production, this would integrate with Stripe/Paddle for real payment processing. ⚠️ Alpha Testing Mode - No Payment Required
</Typography>
<Typography variant="caption" sx={{ display: 'block' }}>
Payment integration is coming soon. For now, subscriptions are activated without charge.
</Typography>
</Alert>
<Typography variant="body1" sx={{ mb: 2 }}>
Thank you for participating in our alpha testing! We're crediting the Basic plan ($29 value) to your account.
</Typography> </Typography>
{/* TODO: Payment Integration Notice */}
<Box sx={{
p: 2,
mb: 3,
bgcolor: 'info.lighter',
borderRadius: 1,
border: '1px solid',
borderColor: 'info.light'
}}>
<Typography variant="body2" color="info.dark">
<strong>Coming in Production:</strong>
</Typography>
<Typography variant="caption" color="info.dark" sx={{ display: 'block', mt: 0.5 }}>
Secure Stripe/PayPal payment processing<br />
Automatic renewal management<br />
Payment verification & receipts<br />
Upgrade/downgrade options
</Typography>
</Box>
{/* Note: Current behavior allows renewal without payment verification */}
{/* This is intentional for alpha testing but will be secured in production */}
<Box sx={{ display: 'flex', justifyContent: 'flex-end', gap: 2 }}> <Box sx={{ display: 'flex', justifyContent: 'flex-end', gap: 2 }}>
<Button onClick={() => setPaymentModalOpen(false)}> <Button onClick={() => setPaymentModalOpen(false)} variant="outlined">
Cancel Cancel
</Button> </Button>
<Button <Button
variant="contained" variant="contained"
onClick={handlePaymentConfirm} onClick={handlePaymentConfirm}
disabled={subscribing} disabled={subscribing}
sx={{
background: 'linear-gradient(135deg, #667eea 0%, #764ba2 100%)',
'&:hover': {
background: 'linear-gradient(135deg, #5a6fd8 0%, #6a4190 100%)',
}
}}
> >
{subscribing ? <CircularProgress size={20} /> : 'Confirm Subscription'} {subscribing ? <CircularProgress size={20} sx={{ color: 'white' }} /> : 'Confirm Subscription'}
</Button> </Button>
</Box> </Box>
</Box> </Box>
@@ -981,6 +1066,37 @@ const PricingPage: React.FC = () => {
</Button> </Button>
</DialogActions> </DialogActions>
</Dialog> </Dialog>
{/* Success Snackbar */}
<Snackbar
open={successSnackbar.open}
autoHideDuration={3000}
onClose={() => setSuccessSnackbar({ open: false, message: '', countdown: 0 })}
anchorOrigin={{ vertical: 'top', horizontal: 'center' }}
sx={{
top: { xs: 16, sm: 24 },
'& .MuiSnackbarContent-root': {
minWidth: { xs: '90vw', sm: '500px' }
}
}}
>
<Alert
severity="success"
variant="filled"
onClose={() => setSuccessSnackbar({ open: false, message: '', countdown: 0 })}
sx={{
width: '100%',
fontSize: '1rem',
alignItems: 'center',
boxShadow: '0 8px 24px rgba(76, 175, 80, 0.4)',
'& .MuiAlert-icon': {
fontSize: '2rem'
}
}}
>
{successSnackbar.message}
</Alert>
</Snackbar>
</Container> </Container>
); );
}; };

View File

@@ -39,6 +39,25 @@ const SubscriptionExpiredModal: React.FC<SubscriptionExpiredModalProps> = ({
subscriptionData, subscriptionData,
errorData errorData
}) => { }) => {
// Debug logging to verify modal state
React.useEffect(() => {
if (open) {
console.log('SubscriptionExpiredModal: Modal opened', {
open,
errorData,
hasUsageInfo: !!errorData?.usage_info
});
}
}, [open, errorData]);
const handleDialogClose = (_event: object, reason?: string) => {
if (reason === 'backdropClick') {
console.log('SubscriptionExpiredModal: Ignoring backdrop click close');
return;
}
onClose();
};
const handleRenewClick = () => { const handleRenewClick = () => {
onRenewSubscription(); onRenewSubscription();
onClose(); onClose();
@@ -47,16 +66,21 @@ const SubscriptionExpiredModal: React.FC<SubscriptionExpiredModalProps> = ({
return ( return (
<Dialog <Dialog
open={open} open={open}
onClose={onClose} onClose={handleDialogClose}
maxWidth="sm" maxWidth="sm"
fullWidth fullWidth
disableEscapeKeyDown
PaperProps={{ PaperProps={{
sx: { sx: {
borderRadius: 3, borderRadius: 3,
background: 'linear-gradient(135deg, #fff 0%, #f8fafc 100%)', background: 'linear-gradient(135deg, #fff 0%, #f8fafc 100%)',
boxShadow: '0 25px 50px -12px rgba(0, 0, 0, 0.25)', boxShadow: '0 25px 50px -12px rgba(0, 0, 0, 0.25)',
zIndex: 9999, // Ensure modal appears above everything
} }
}} }}
sx={{
zIndex: 9999, // Ensure modal backdrop appears above everything
}}
> >
<DialogTitle sx={{ textAlign: 'center', pb: 1 }}> <DialogTitle sx={{ textAlign: 'center', pb: 1 }}>
<Box sx={{ display: 'flex', alignItems: 'center', justifyContent: 'center', gap: 2 }}> <Box sx={{ display: 'flex', alignItems: 'center', justifyContent: 'center', gap: 2 }}>
@@ -93,56 +117,156 @@ const SubscriptionExpiredModal: React.FC<SubscriptionExpiredModalProps> = ({
borderRadius: 2 borderRadius: 2
}} }}
> >
<Typography variant="body1" sx={{ mb: 2, color: 'text.secondary' }}> {/* Main error message */}
<Typography variant="body1" sx={{ mb: 2, color: 'text.secondary', lineHeight: 1.6 }}>
{errorData?.message || (errorData?.usage_info {errorData?.message || (errorData?.usage_info
? 'You\'ve reached your monthly usage limit for this plan. Upgrade your plan to get higher limits.' ? 'You\'ve reached your monthly usage limit for this plan. Upgrade your plan to get higher limits.'
: 'To continue using Alwrity and access all features, you need to renew your subscription.' : 'To continue using Alwrity and access all features, you need to renew your subscription.'
)} )}
</Typography> </Typography>
{/* Detailed usage information */}
{errorData?.usage_info && ( {errorData?.usage_info && (
<Box sx={{ mb: 2, p: 2, background: 'rgba(255,255,255,0.7)', borderRadius: 1 }}> <Box sx={{ mb: 2, p: 2.5, background: 'rgba(255,255,255,0.9)', borderRadius: 2, border: '1px solid #e2e8f0' }}>
<Typography variant="body2" sx={{ fontWeight: 600, mb: 1, color: 'text.primary' }}> <Typography variant="subtitle2" sx={{ fontWeight: 700, mb: 2, color: 'text.primary', display: 'flex', alignItems: 'center', gap: 1 }}>
<Warning sx={{ fontSize: 18, color: 'warning.main' }} />
Usage Information: Usage Information:
</Typography> </Typography>
{errorData.usage_info.call_usage_percentage && (
<Typography variant="body2" sx={{ color: 'text.secondary' }}> {/* Provider and operation type */}
You've used {errorData.usage_info.call_usage_percentage.toFixed(1)}% of your monthly limit <Box sx={{ display: 'flex', gap: 2, mb: 2, flexWrap: 'wrap' }}>
</Typography> {errorData.provider && (
<Box sx={{
flex: '1 1 auto',
px: 2,
py: 1.5,
background: 'linear-gradient(135deg, #e0e7ff 0%, #c7d2fe 100%)',
borderRadius: 1.5,
border: '1px solid #a5b4fc'
}}>
<Typography variant="caption" sx={{ color: '#4338ca', fontWeight: 600, display: 'block', mb: 0.5 }}>
Provider:
</Typography>
<Typography variant="body2" sx={{ color: '#312e81', fontWeight: 700 }}>
{errorData.provider}
</Typography>
</Box>
)}
{errorData.usage_info.operation_type && (
<Box sx={{
flex: '1 1 auto',
px: 2,
py: 1.5,
background: 'linear-gradient(135deg, #fef3c7 0%, #fde68a 100%)',
borderRadius: 1.5,
border: '1px solid #fbbf24'
}}>
<Typography variant="caption" sx={{ color: '#92400e', fontWeight: 600, display: 'block', mb: 0.5 }}>
Operation:
</Typography>
<Typography variant="body2" sx={{ color: '#78350f', fontWeight: 700, textTransform: 'capitalize' }}>
{errorData.usage_info.operation_type.replace(/_/g, ' ')}
</Typography>
</Box>
)}
</Box>
{/* Token usage details (if available) */}
{(errorData.usage_info.current_tokens !== undefined || errorData.usage_info.current_calls !== undefined) && (
<Box sx={{
p: 2,
background: 'linear-gradient(135deg, #fee2e2 0%, #fecaca 100%)',
borderRadius: 1.5,
border: '1px solid #f87171',
mb: 2
}}>
{errorData.usage_info.current_tokens !== undefined && (
<>
<Typography variant="body2" sx={{ color: '#7f1d1d', fontWeight: 600, mb: 1 }}>
Token Usage:
</Typography>
<Box sx={{ display: 'flex', alignItems: 'baseline', gap: 1, mb: 0.5 }}>
<Typography variant="h6" sx={{ color: '#991b1b', fontWeight: 700 }}>
{errorData.usage_info.current_tokens?.toLocaleString() || 0}
</Typography>
<Typography variant="body2" sx={{ color: '#7f1d1d' }}>
/ {errorData.usage_info.limit?.toLocaleString() || 0}
</Typography>
<Typography variant="caption" sx={{ color: '#7f1d1d', ml: 'auto' }}>
({((errorData.usage_info.current_tokens / errorData.usage_info.limit) * 100).toFixed(1)}% used)
</Typography>
</Box>
{errorData.usage_info.requested_tokens && (
<Typography variant="caption" sx={{ color: '#7f1d1d', display: 'block', mt: 1 }}>
Requested: {errorData.usage_info.requested_tokens.toLocaleString()} tokens
{errorData.usage_info.current_tokens + errorData.usage_info.requested_tokens > errorData.usage_info.limit && (
<span style={{ fontWeight: 700, marginLeft: 4 }}>
(Would exceed by: {((errorData.usage_info.current_tokens + errorData.usage_info.requested_tokens) - errorData.usage_info.limit).toLocaleString()} tokens)
</span>
)}
</Typography>
)}
</>
)}
{errorData.usage_info.current_calls !== undefined && (
<>
<Typography variant="body2" sx={{ color: '#7f1d1d', fontWeight: 600, mb: 1, mt: errorData.usage_info.current_tokens !== undefined ? 2 : 0 }}>
API Call Usage:
</Typography>
<Box sx={{ display: 'flex', alignItems: 'baseline', gap: 1 }}>
<Typography variant="h6" sx={{ color: '#991b1b', fontWeight: 700 }}>
{errorData.usage_info.current_calls?.toLocaleString() || 0}
</Typography>
<Typography variant="body2" sx={{ color: '#7f1d1d' }}>
/ {errorData.usage_info.call_limit?.toLocaleString() || 0}
</Typography>
<Typography variant="caption" sx={{ color: '#7f1d1d', ml: 'auto' }}>
({((errorData.usage_info.current_calls / errorData.usage_info.call_limit) * 100).toFixed(1)}% used)
</Typography>
</Box>
</>
)}
</Box>
)} )}
{errorData.provider && (
<Typography variant="body2" sx={{ color: 'text.secondary' }}> {/* Error type badge */}
Provider: {errorData.provider} {errorData.usage_info.error_type && (
</Typography> <Box sx={{ display: 'flex', justifyContent: 'center' }}>
<Box sx={{
px: 2,
py: 0.5,
background: '#dc2626',
borderRadius: 1,
display: 'inline-block'
}}>
<Typography variant="caption" sx={{ color: 'white', fontWeight: 700, textTransform: 'uppercase', letterSpacing: 0.5 }}>
{errorData.usage_info.error_type.replace(/_/g, ' ')}
</Typography>
</Box>
</Box>
)} )}
</Box> </Box>
)} )}
{/* Current plan information */}
{subscriptionData && ( {subscriptionData && (
<Box sx={{ display: 'flex', justifyContent: 'center', gap: 2, flexWrap: 'wrap' }}> <Box sx={{ display: 'flex', justifyContent: 'center', gap: 2, flexWrap: 'wrap' }}>
{subscriptionData.plan && ( {subscriptionData.plan && (
<Box sx={{ <Box sx={{
px: 2, px: 3,
py: 1, py: 1.5,
background: 'rgba(255,255,255,0.7)', background: 'rgba(255,255,255,0.9)',
borderRadius: 1, borderRadius: 1.5,
border: '1px solid #e2e8f0' border: '2px solid #e2e8f0'
}}> }}>
<Typography variant="caption" sx={{ color: 'text.secondary', fontWeight: 500 }}> <Typography variant="caption" sx={{ color: 'text.secondary', fontWeight: 600, display: 'block', mb: 0.5 }}>
Current Plan: {subscriptionData.plan} Current Plan:
</Typography> </Typography>
</Box> <Typography variant="body2" sx={{ color: 'text.primary', fontWeight: 700, textTransform: 'capitalize' }}>
)} {subscriptionData.plan}
{subscriptionData.tier && subscriptionData.tier !== subscriptionData.plan && (
<Box sx={{
px: 2,
py: 1,
background: 'rgba(255,255,255,0.7)',
borderRadius: 1,
border: '1px solid #e2e8f0'
}}>
<Typography variant="caption" sx={{ color: 'text.secondary', fontWeight: 500 }}>
Tier: {subscriptionData.tier}
</Typography> </Typography>
</Box> </Box>
)} )}

View File

@@ -105,12 +105,13 @@ const DashboardHeader: React.FC<DashboardHeaderProps> = ({
/* Enhanced Start Button with Phase 1 Improvements */ /* Enhanced Start Button with Phase 1 Improvements */
<Box sx={{ position: 'relative', display: 'inline-flex' }}> <Box sx={{ position: 'relative', display: 'inline-flex' }}>
<Tooltip title={tooltipMessage} arrow placement="bottom"> <Tooltip title={tooltipMessage} arrow placement="bottom">
<Button <span>
variant="contained" <Button
size={isFirstVisit ? "medium" : "small"} variant="contained"
startIcon={<PlayArrow />} size={isFirstVisit ? "medium" : "small"}
onClick={workflowControls.onStartWorkflow} startIcon={<PlayArrow />}
disabled={workflowControls.isLoading} onClick={workflowControls.onStartWorkflow}
disabled={workflowControls.isLoading}
sx={{ sx={{
position: 'relative', position: 'relative',
overflow: 'hidden', overflow: 'hidden',
@@ -180,8 +181,9 @@ const DashboardHeader: React.FC<DashboardHeaderProps> = ({
}, },
}} }}
> >
{isFirstVisit ? '🚀 Start Journey' : 'Start'} {isFirstVisit ? '🚀 Start Journey' : 'Start'}
</Button> </Button>
</span>
</Tooltip> </Tooltip>
<Box <Box
sx={{ sx={{

View File

@@ -1,4 +1,4 @@
import React, { createContext, useContext, useState, useEffect, ReactNode, useCallback } from 'react'; import React, { createContext, useContext, useState, useEffect, ReactNode, useCallback, useRef } from 'react';
import { apiClient, setGlobalSubscriptionErrorHandler } from '../api/client'; import { apiClient, setGlobalSubscriptionErrorHandler } from '../api/client';
import SubscriptionExpiredModal from '../components/SubscriptionExpiredModal'; import SubscriptionExpiredModal from '../components/SubscriptionExpiredModal';
@@ -60,6 +60,8 @@ export const SubscriptionProvider: React.FC<SubscriptionProviderProps> = ({ chil
// New: Grace window after plan changes to avoid noisy UX // New: Grace window after plan changes to avoid noisy UX
const [graceUntil, setGraceUntil] = useState<number>(0); const [graceUntil, setGraceUntil] = useState<number>(0);
const [planSignature, setPlanSignature] = useState<string>(""); const [planSignature, setPlanSignature] = useState<string>("");
// Flag to track if current modal is a usage limit modal (should never be auto-closed)
const [isUsageLimitModal, setIsUsageLimitModal] = useState<boolean>(false);
const checkSubscription = useCallback(async () => { const checkSubscription = useCallback(async () => {
// Throttle subscription checks to prevent excessive API calls // Throttle subscription checks to prevent excessive API calls
@@ -86,6 +88,10 @@ export const SubscriptionProvider: React.FC<SubscriptionProviderProps> = ({ chil
return; return;
} }
// Wait a moment to ensure auth token getter is installed
// This prevents 401 errors during app initialization
await new Promise(resolve => setTimeout(resolve, 200));
console.log('SubscriptionContext: Checking subscription for user:', userId); console.log('SubscriptionContext: Checking subscription for user:', userId);
const response = await apiClient.get(`/api/subscription/status/${userId}`); const response = await apiClient.get(`/api/subscription/status/${userId}`);
const subscriptionData = response.data.data; const subscriptionData = response.data.data;
@@ -101,29 +107,42 @@ export const SubscriptionProvider: React.FC<SubscriptionProviderProps> = ({ chil
setPlanSignature(newSignature); setPlanSignature(newSignature);
setGraceUntil(Date.now() + 5 * 60 * 1000); setGraceUntil(Date.now() + 5 * 60 * 1000);
// Close any existing modal as plan just changed // Close any existing modal as plan just changed
if (showModal) { // BUT: Don't close usage limit modals - they're important even after plan changes
if (showModal && !isUsageLimitModal) {
console.log('SubscriptionContext: Plan changed, closing non-usage-limit modal');
setShowModal(false); setShowModal(false);
setModalErrorData(null); setModalErrorData(null);
} else if (showModal && isUsageLimitModal) {
console.log('SubscriptionContext: Plan changed but usage limit modal is open, keeping it open');
} }
} }
} catch (_e) {} } catch (_e) {}
// If we have a valid subscription and the modal is open, close it // If we have a valid subscription and the modal is open, close it
// BUT: NEVER close usage limit modals - user needs to see they hit a limit even with active subscription
if (subscriptionData && subscriptionData.active && showModal) { if (subscriptionData && subscriptionData.active && showModal) {
console.log('SubscriptionContext: Valid subscription detected, closing modal'); // Check if this is a usage limit modal (using flag or checking error data)
setShowModal(false); const hasUsageInfo = modalErrorData?.usage_info ||
setModalErrorData(null); (modalErrorData?.current_tokens !== undefined) ||
setLastModalShowTime(0); // Reset the cooldown timer (modalErrorData?.current_calls !== undefined) ||
} (modalErrorData?.limit !== undefined) ||
(modalErrorData?.requested_tokens !== undefined);
// Also check if this is a usage limit error that should be suppressed
if (subscriptionData && subscriptionData.active && modalErrorData) { const isUsageLimit = isUsageLimitModal || hasUsageInfo;
const now = Date.now();
const timeSinceLastModal = now - lastModalShowTime; if (isUsageLimit) {
console.log('SubscriptionContext: Usage limit modal detected - KEEPING OPEN (never auto-close usage limit modals)', {
// If it's been less than 10 minutes since modal was shown for usage limits, keep it closed isUsageLimitModal,
if (timeSinceLastModal < 600000 && modalErrorData.usage_info) { hasUsageInfo,
console.log('SubscriptionContext: Recent usage limit modal, keeping it closed'); modalErrorDataKeys: modalErrorData ? Object.keys(modalErrorData) : []
});
// Do NOT close - usage limit modals should stay open until user dismisses them
} else {
console.log('SubscriptionContext: Non-usage-limit modal detected, closing since subscription is active');
setShowModal(false);
setModalErrorData(null);
setIsUsageLimitModal(false);
setLastModalShowTime(0); // Reset the cooldown timer
} }
} }
@@ -156,7 +175,7 @@ export const SubscriptionProvider: React.FC<SubscriptionProviderProps> = ({ chil
setLastModalShowTime(now); setLastModalShowTime(now);
} }
} }
} catch (err) { } catch (err: any) {
console.error('Error checking subscription:', err); console.error('Error checking subscription:', err);
// Check if it's a connection error that should be handled at the app level // Check if it's a connection error that should be handled at the app level
@@ -165,6 +184,16 @@ export const SubscriptionProvider: React.FC<SubscriptionProviderProps> = ({ chil
throw err; throw err;
} }
// Handle 401 errors gracefully during initialization - don't block routing
// 401 might happen if auth token getter isn't ready yet
if (err?.response?.status === 401) {
console.warn('Subscription check failed with 401 - auth may not be ready yet, will retry later');
setError(null); // Don't set error for 401 during init
setLoading(false);
// Don't throw - allow routing to proceed, subscription check will retry later
return;
}
setError(err instanceof Error ? err.message : 'Failed to check subscription'); setError(err instanceof Error ? err.message : 'Failed to check subscription');
// Don't default to free tier on error - preserve existing subscription or leave null // Don't default to free tier on error - preserve existing subscription or leave null
@@ -173,21 +202,30 @@ export const SubscriptionProvider: React.FC<SubscriptionProviderProps> = ({ chil
} finally { } finally {
setLoading(false); setLoading(false);
} }
}, [lastCheckTime, planSignature, showModal, modalErrorData, lastModalShowTime, graceUntil]); }, [lastCheckTime, planSignature, showModal, modalErrorData, lastModalShowTime, graceUntil, isUsageLimitModal]);
const refreshSubscription = useCallback(async () => { const refreshSubscription = useCallback(async () => {
await checkSubscription(); await checkSubscription();
}, [checkSubscription]); }, [checkSubscription]);
const showExpiredModal = useCallback(() => { const showExpiredModal = useCallback(() => {
setIsUsageLimitModal(false);
setShowModal(true); setShowModal(true);
}, []); }, []);
const hideExpiredModal = useCallback(() => { const hideExpiredModal = useCallback(() => {
console.log('SubscriptionExpiredModal: User manually closed modal');
setShowModal(false); setShowModal(false);
setIsUsageLimitModal(false); // Reset flag when user closes modal
setModalErrorData(null);
}, []); }, []);
const handleRenewSubscription = useCallback(() => { const handleRenewSubscription = useCallback(() => {
// Save current location so we can return after renewal
const currentPath = window.location.pathname;
sessionStorage.setItem('subscription_referrer', currentPath);
console.log('SubscriptionContext: Navigating to pricing page, saved referrer:', currentPath);
window.location.href = '/pricing'; window.location.href = '/pricing';
}, []); }, []);
@@ -203,42 +241,131 @@ export const SubscriptionProvider: React.FC<SubscriptionProviderProps> = ({ chil
const now = Date.now(); const now = Date.now();
// If we have subscription data and it's active, always suppress modal for usage limits // Check if this is a usage limit error (status 429) vs subscription expired (402)
if (subscription && subscription.active) { let errorData = error.response?.data || {};
console.log('SubscriptionContext: Active subscription; suppressing usage-limit modal');
return true; // Do not show modal for active plan usage limits // DEBUG: Log the raw error data structure
console.log('SubscriptionContext: Raw error data', {
type: typeof errorData,
isArray: Array.isArray(errorData),
data: errorData,
stringified: JSON.stringify(errorData)
});
// If errorData is an array, extract the first element (common FastAPI response format)
if (Array.isArray(errorData)) {
console.log('SubscriptionContext: errorData is array, extracting first element');
errorData = errorData[0] || {};
} }
// If we don't have subscription data yet, defer the decision // Check for usage_info in various possible locations
if (!subscription) { const usageInfo = errorData.usage_info ||
console.log('SubscriptionContext: No subscription data yet, deferring modal decision'); (errorData.current_calls !== undefined ? errorData : null) ||
setDeferredError(error); null;
return true; // Handle the error but don't show modal yet
} // Usage limit error: 429 status with usage info OR 429 status without explicit expiration
const isUsageLimitError = status === 429 && (usageInfo || errorData.provider || errorData.message);
// If subscription is not active, show modal immediately const isSubscriptionExpired = status === 402 || (status === 429 && !isUsageLimitError);
if (!subscription.active) {
console.log('SubscriptionContext: Inactive subscription, showing modal immediately'); console.log('SubscriptionContext: Error analysis', {
const errorData = error.response?.data || {}; status,
setModalErrorData({ isUsageLimitError,
provider: errorData.provider, isSubscriptionExpired,
usage_info: errorData.usage_info, hasUsageInfo: !!usageInfo,
message: errorData.message || errorData.error errorDataType: typeof errorData,
errorDataKeys: typeof errorData === 'object' && !Array.isArray(errorData) ? Object.keys(errorData) : 'not-an-object',
errorData: errorData
});
// For usage limit errors (429 with usage_info), always show modal - even for active subscriptions
// Ignore grace window and cooldown for usage limit errors (user needs to know immediately)
if (isUsageLimitError) {
const modalData = {
provider: errorData.provider || usageInfo?.provider || 'unknown',
usage_info: usageInfo || errorData,
message: errorData.message || errorData.error || 'You have reached your usage limit.'
};
console.log('SubscriptionContext: Usage limit exceeded, showing modal (ignoring grace window/cooldown)', {
modalData,
errorData: Object.keys(errorData),
usageInfo: usageInfo ? Object.keys(usageInfo) : null
}); });
// Set flag to mark this as a usage limit modal (should never be auto-closed)
setIsUsageLimitModal(true);
setModalErrorData(modalData);
setShowModal(true); setShowModal(true);
setLastModalShowTime(now); setLastModalShowTime(now);
console.log('SubscriptionContext: Modal state updated - showModal should be true, isUsageLimitModal = true');
return true; return true;
} }
// For subscription expired errors, handle based on subscription status
if (isSubscriptionExpired) {
// If we have subscription data and it's active, this shouldn't happen but suppress anyway
if (subscription && subscription.active) {
console.log('SubscriptionContext: Active subscription but got expired error, suppressing modal');
return true;
}
// If we don't have subscription data yet, defer the decision
if (!subscription) {
console.log('SubscriptionContext: No subscription data yet, deferring modal decision');
setDeferredError(error);
return true; // Handle the error but don't show modal yet
}
// If subscription is not active, show modal immediately
if (!subscription.active) {
console.log('SubscriptionContext: Inactive subscription, showing modal immediately');
setIsUsageLimitModal(false);
setModalErrorData({
provider: errorData.provider,
usage_info: errorData.usage_info,
message: errorData.message || errorData.error
});
setShowModal(true);
setLastModalShowTime(now);
return true;
}
}
} }
return false; // Not a subscription error return false; // Not a subscription error
}, [subscription]); }, [subscription]);
// Register the global error handler with the API client // Register the global error handler with the API client
// Use a ref to ensure the latest handler is always used
const handlerRef = useRef(globalSubscriptionErrorHandler);
useEffect(() => {
handlerRef.current = globalSubscriptionErrorHandler;
}, [globalSubscriptionErrorHandler]);
useEffect(() => { useEffect(() => {
console.log('SubscriptionContext: Registering global subscription error handler'); console.log('SubscriptionContext: Registering global subscription error handler');
setGlobalSubscriptionErrorHandler(globalSubscriptionErrorHandler); setGlobalSubscriptionErrorHandler((error: any) => {
}, [globalSubscriptionErrorHandler]); // Always use the latest handler from ref
return handlerRef.current(error);
});
// Cleanup: Don't remove the handler on unmount - it should persist
// This ensures errors can still be caught even during component transitions
}, []); // Empty deps - only register once, but handler ref updates automatically
useEffect(() => {
const eventHandler = (event: Event) => {
const customEvent = event as CustomEvent;
console.log('SubscriptionContext: Received subscription-error event fallback', customEvent.detail);
handlerRef.current(customEvent.detail);
};
window.addEventListener('subscription-error', eventHandler as EventListener);
return () => {
window.removeEventListener('subscription-error', eventHandler as EventListener);
};
}, []);
useEffect(() => { useEffect(() => {
// Check subscription on mount // Check subscription on mount

View File

@@ -33,6 +33,9 @@ export const useBlogWriterState = () => {
// Content confirmation state // Content confirmation state
const [contentConfirmed, setContentConfirmed] = useState<boolean>(false); const [contentConfirmed, setContentConfirmed] = useState<boolean>(false);
// Section images state - persists images generated in outline phase to content phase
const [sectionImages, setSectionImages] = useState<Record<string, string>>({});
// Cache recovery - restore most recent research on page load // Cache recovery - restore most recent research on page load
useEffect(() => { useEffect(() => {
const cachedEntries = researchCache.getAllCachedEntries(); const cachedEntries = researchCache.getAllCachedEntries();
@@ -211,6 +214,7 @@ export const useBlogWriterState = () => {
contentConfirmed, contentConfirmed,
flowAnalysisCompleted, flowAnalysisCompleted,
flowAnalysisResults, flowAnalysisResults,
sectionImages,
// Setters // Setters
setResearch, setResearch,
@@ -233,6 +237,7 @@ export const useBlogWriterState = () => {
setContentConfirmed, setContentConfirmed,
setFlowAnalysisCompleted, setFlowAnalysisCompleted,
setFlowAnalysisResults, setFlowAnalysisResults,
setSectionImages,
// Handlers // Handlers
handleResearchComplete, handleResearchComplete,

View File

@@ -1,5 +1,6 @@
import { useState, useEffect, useCallback, useRef } from 'react'; import { useState, useEffect, useCallback, useRef } from 'react';
import { blogWriterApi, TaskStatusResponse } from '../services/blogWriterApi'; import { blogWriterApi, TaskStatusResponse } from '../services/blogWriterApi';
import { triggerSubscriptionError } from '../api/client';
export interface UsePollingOptions { export interface UsePollingOptions {
interval?: number; // Polling interval in milliseconds interval?: number; // Polling interval in milliseconds
@@ -108,6 +109,43 @@ export function usePolling(
console.log('❌ Task failed - stopping polling immediately'); console.log('❌ Task failed - stopping polling immediately');
setError(status.error || 'Task failed'); setError(status.error || 'Task failed');
onError?.(status.error || 'Task failed'); onError?.(status.error || 'Task failed');
// Check if this is a subscription error and trigger modal
if (status.error_status === 429 || status.error_status === 402) {
console.log('usePolling: Detected subscription error in task status', {
error_status: status.error_status,
error_data: status.error_data,
error: status.error
});
// Create a mock error object with the subscription error data
const errorData = status.error_data || {};
// Ensure usage_info is properly nested - it might be at the top level or nested
const usageInfo = errorData.usage_info ||
(errorData.current_calls !== undefined ? errorData : null) ||
errorData;
const mockError = {
response: {
status: status.error_status,
data: {
error: errorData.error || status.error || 'Subscription limit exceeded',
message: errorData.message || errorData.error || status.error || 'You have reached your usage limit.',
provider: errorData.provider || usageInfo?.provider || 'unknown',
usage_info: usageInfo
}
}
};
console.log('usePolling: Triggering subscription error handler with:', mockError);
const handled = triggerSubscriptionError(mockError);
if (!handled) {
console.warn('usePolling: Subscription error handler did not handle the error');
}
}
stopPolling(); stopPolling();
return; // Exit early to prevent further processing return; // Exit early to prevent further processing
} }
@@ -117,6 +155,38 @@ export function usePolling(
const errorMessage = err instanceof Error ? err.message : 'Unknown error occurred'; const errorMessage = err instanceof Error ? err.message : 'Unknown error occurred';
console.error('Polling error:', errorMessage); console.error('Polling error:', errorMessage);
// Check if this is an axios error with subscription limit status
// This is a fallback in case the interceptor doesn't catch it
const axiosError = err as any;
if (axiosError?.response?.status === 429 || axiosError?.response?.status === 402) {
console.log('usePolling: Detected subscription error in axios error response', {
status: axiosError.response.status,
data: axiosError.response.data
});
// Trigger subscription error handler (modal will show)
const handled = triggerSubscriptionError(axiosError);
console.log('usePolling: triggerSubscriptionError returned', handled);
if (handled) {
console.log('usePolling: Subscription error handled, stopping polling');
const errorMsg = axiosError.response?.data?.message ||
axiosError.response?.data?.error ||
'Subscription limit exceeded';
setError(errorMsg);
onError?.(errorMsg);
stopPolling();
return; // Exit early - don't continue processing
} else {
console.warn('usePolling: Subscription error not handled by global handler, dispatching fallback event');
try {
window.dispatchEvent(new CustomEvent('subscription-error', { detail: axiosError }));
} catch (eventError) {
console.error('usePolling: Failed to dispatch subscription-error event', eventError);
}
}
}
// Stop polling for task failures and rate limiting // Stop polling for task failures and rate limiting
if (errorMessage.includes('404') || errorMessage.includes('Task not found')) { if (errorMessage.includes('404') || errorMessage.includes('Task not found')) {
setError('Task not found - it may have expired or been cleaned up'); setError('Task not found - it may have expired or been cleaned up');

View File

@@ -219,9 +219,22 @@ export interface BlogSEOMetadataResponse {
success: boolean; success: boolean;
title_options: string[]; title_options: string[];
meta_descriptions: string[]; meta_descriptions: string[];
seo_title?: string;
meta_description?: string;
url_slug?: string;
blog_tags: string[];
blog_categories: string[];
social_hashtags: string[];
open_graph: Record<string, any>; open_graph: Record<string, any>;
twitter_card: Record<string, any>; twitter_card: Record<string, any>;
schema: Record<string, any>; json_ld_schema?: Record<string, any>;
schema?: Record<string, any>; // Legacy field name
canonical_url?: string;
reading_time?: number;
focus_keyword?: string;
generated_at?: string;
optimization_score?: number;
error?: string;
} }
export interface BlogPublishResponse { export interface BlogPublishResponse {
@@ -241,6 +254,26 @@ export interface TaskStatusResponse {
}>; }>;
result?: BlogResearchResponse; result?: BlogResearchResponse;
error?: string; error?: string;
// Subscription error details (set by backend when subscription limit is exceeded)
error_status?: number; // HTTP status code (429 for usage limit, 402 for subscription expired)
error_data?: {
error?: string;
message?: string;
provider?: string;
usage_info?: {
provider?: string;
current_calls?: number;
limit?: number;
type?: string;
breakdown?: {
gemini?: number;
openai?: number;
anthropic?: number;
mistral?: number;
};
};
[key: string]: any; // Allow additional fields
};
} }
export const blogWriterApi = { export const blogWriterApi = {

View File

@@ -0,0 +1,198 @@
/**
* Wix Token Utilities
* Functions for validating and refreshing Wix OAuth tokens
*/
import { apiClient } from '../api/client';
interface WixTokens {
accessToken?: {
value: string;
expiresAt?: string;
};
refreshToken?: {
value: string;
};
access_token?: string;
refresh_token?: string;
expires_in?: number;
}
interface TokenValidationResult {
valid: boolean;
accessToken: string | null;
needsRefresh: boolean;
needsReconnect: boolean;
}
/**
* Get Wix tokens from sessionStorage
*/
export function getWixTokens(): WixTokens | null {
try {
const tokensRaw = sessionStorage.getItem('wix_tokens');
if (!tokensRaw) return null;
return JSON.parse(tokensRaw);
} catch (error) {
console.error('Error parsing Wix tokens:', error);
return null;
}
}
/**
* Extract access token from token structure
*/
export function extractAccessToken(tokens: WixTokens | null): string | null {
if (!tokens) return null;
return tokens.accessToken?.value || tokens.access_token || null;
}
/**
* Extract refresh token from token structure
*/
export function extractRefreshToken(tokens: WixTokens | null): string | null {
if (!tokens) return null;
return tokens.refreshToken?.value || tokens.refresh_token || null;
}
/**
* Refresh Wix access token using refresh token
*/
export async function refreshWixToken(refreshToken: string): Promise<WixTokens | null> {
try {
const response = await apiClient.post('/api/wix/refresh-token', {
refresh_token: refreshToken
});
if (response.data.success) {
// Create new token structure matching Wix SDK format
const newTokens: WixTokens = {
accessToken: {
value: response.data.access_token
},
refreshToken: {
value: response.data.refresh_token || refreshToken // Keep old refresh token if new one not provided
},
access_token: response.data.access_token,
refresh_token: response.data.refresh_token || refreshToken
};
// Update sessionStorage
try {
sessionStorage.setItem('wix_tokens', JSON.stringify(newTokens));
sessionStorage.setItem('wix_connected', 'true');
} catch (e) {
console.error('Error saving refreshed tokens:', e);
}
return newTokens;
}
return null;
} catch (error: any) {
console.error('Error refreshing Wix token:', error);
return null;
}
}
/**
* Check if token is expired based on expiresAt timestamp
*/
function isTokenExpired(tokens: WixTokens): boolean {
if (tokens.accessToken?.expiresAt) {
try {
const expiresAt = new Date(tokens.accessToken.expiresAt);
return expiresAt < new Date();
} catch (e) {
// If we can't parse, assume not expired (will validate during publish)
return false;
}
}
// If no expiration info, we can't tell - assume valid for now
// Real validation happens during actual API call
return false;
}
/**
* Validate and refresh Wix tokens proactively
* Returns access token if valid, or null if needs reconnection
*
* Strategy:
* 1. Check if tokens exist
* 2. Check if token is expired (if expiration info available)
* 3. If expired, attempt refresh
* 4. If refresh fails or no refresh token, needs reconnection
* 5. Real validation happens during actual publish (we catch 401/403 errors)
*/
export async function validateAndRefreshWixTokens(): Promise<TokenValidationResult> {
const tokens = getWixTokens();
if (!tokens) {
return {
valid: false,
accessToken: null,
needsRefresh: false,
needsReconnect: true
};
}
const accessToken = extractAccessToken(tokens);
const refreshToken = extractRefreshToken(tokens);
if (!accessToken) {
return {
valid: false,
accessToken: null,
needsRefresh: false,
needsReconnect: true
};
}
// Check if token is expired (if we have expiration info)
const expired = isTokenExpired(tokens);
if (!expired) {
// Token appears valid (not expired or no expiration info)
// We'll do real validation during publish
return {
valid: true,
accessToken: accessToken,
needsRefresh: false,
needsReconnect: false
};
}
// Token is expired, try to refresh
if (!refreshToken) {
return {
valid: false,
accessToken: null,
needsRefresh: false,
needsReconnect: true
};
}
// Attempt to refresh token
const refreshedTokens = await refreshWixToken(refreshToken);
if (refreshedTokens) {
const newAccessToken = extractAccessToken(refreshedTokens);
if (newAccessToken) {
return {
valid: true,
accessToken: newAccessToken,
needsRefresh: true,
needsReconnect: false
};
}
}
// Refresh failed, needs reconnection
return {
valid: false,
accessToken: null,
needsRefresh: false,
needsReconnect: true
};
}