Subscription dashboard improvements, AI text generation limit, and other fixes.
This commit is contained in:
@@ -5,10 +5,11 @@ Main router for blog writing operations including research, outline generation,
|
||||
content creation, SEO analysis, and publishing.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from typing import Any, Dict, List
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from typing import Any, Dict, List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from loguru import logger
|
||||
from middleware.auth_middleware import get_current_user
|
||||
|
||||
from models.blog_models import (
|
||||
BlogResearchRequest,
|
||||
@@ -64,10 +65,21 @@ class SEOApplyRecommendationsRequest(BaseModel):
|
||||
|
||||
|
||||
@router.post("/seo/apply-recommendations")
|
||||
async def apply_seo_recommendations(request: SEOApplyRecommendationsRequest) -> Dict[str, Any]:
|
||||
async def apply_seo_recommendations(
|
||||
request: SEOApplyRecommendationsRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Apply actionable SEO recommendations and return updated content."""
|
||||
try:
|
||||
result = await recommendation_applier.apply_recommendations(request.dict())
|
||||
# Extract Clerk user ID (required)
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get('id', ''))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
result = await recommendation_applier.apply_recommendations(request.dict(), user_id=user_id)
|
||||
if not result.get("success"):
|
||||
raise HTTPException(status_code=500, detail=result.get("error", "Failed to apply recommendations"))
|
||||
return result
|
||||
@@ -87,13 +99,24 @@ async def health() -> Dict[str, Any]:
|
||||
|
||||
# Research Endpoints
|
||||
@router.post("/research/start")
|
||||
async def start_research(request: BlogResearchRequest) -> Dict[str, Any]:
|
||||
async def start_research(
|
||||
request: BlogResearchRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Start a research operation and return a task ID for polling."""
|
||||
try:
|
||||
# TODO: Get user_id from authentication context
|
||||
user_id = "anonymous" # This should come from auth middleware
|
||||
# Extract Clerk user ID (required)
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get('id', ''))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
task_id = await task_manager.start_research_task(request, user_id)
|
||||
return {"task_id": task_id, "status": "started"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start research: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -107,6 +130,50 @@ async def get_research_status(task_id: str) -> Dict[str, Any]:
|
||||
if status is None:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
# If task failed with subscription error, return HTTP error so frontend interceptor can catch it
|
||||
if status.get('status') == 'failed' and status.get('error_status') in [429, 402]:
|
||||
error_data = status.get('error_data', {}) or {}
|
||||
error_status = status.get('error_status', 429)
|
||||
|
||||
if not isinstance(error_data, dict):
|
||||
logger.warning(f"Research task {task_id} error_data not dict: {error_data}")
|
||||
error_data = {'error': str(error_data)}
|
||||
|
||||
# Determine provider and usage info
|
||||
stored_error_message = status.get('error', error_data.get('error'))
|
||||
provider = error_data.get('provider', 'unknown')
|
||||
usage_info = error_data.get('usage_info')
|
||||
|
||||
if not usage_info:
|
||||
usage_info = {
|
||||
'provider': provider,
|
||||
'message': stored_error_message,
|
||||
'error_type': error_data.get('error_type', 'unknown')
|
||||
}
|
||||
# Include any known fields from error_data
|
||||
for key in ['current_tokens', 'requested_tokens', 'limit', 'current_calls']:
|
||||
if key in error_data:
|
||||
usage_info[key] = error_data[key]
|
||||
|
||||
# Build error message for detail
|
||||
error_msg = error_data.get('message', stored_error_message or 'Subscription limit exceeded')
|
||||
|
||||
# Log the subscription error with all context
|
||||
logger.warning(f"Research task {task_id} failed with subscription error {error_status}: {error_msg}")
|
||||
logger.warning(f" Provider: {provider}, Usage Info: {usage_info}")
|
||||
|
||||
# Use JSONResponse to ensure detail is returned as-is, not wrapped in an array
|
||||
from fastapi.responses import JSONResponse
|
||||
return JSONResponse(
|
||||
status_code=error_status,
|
||||
content={
|
||||
'error': error_data.get('error', stored_error_message or 'Subscription limit exceeded'),
|
||||
'message': error_msg,
|
||||
'provider': provider,
|
||||
'usage_info': usage_info
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Research status request for {task_id}: {status['status']} with {len(status.get('progress_messages', []))} progress messages")
|
||||
return status
|
||||
except HTTPException:
|
||||
@@ -310,20 +377,46 @@ async def hallucination_check(request: HallucinationCheckRequest) -> Hallucinati
|
||||
|
||||
# SEO Endpoints
|
||||
@router.post("/seo/analyze", response_model=BlogSEOAnalyzeResponse)
|
||||
async def seo_analyze(request: BlogSEOAnalyzeRequest) -> BlogSEOAnalyzeResponse:
|
||||
async def seo_analyze(
|
||||
request: BlogSEOAnalyzeRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> BlogSEOAnalyzeResponse:
|
||||
"""Analyze content for SEO optimization opportunities."""
|
||||
try:
|
||||
return await service.seo_analyze(request)
|
||||
# Extract Clerk user ID (required)
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get('id', ''))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
return await service.seo_analyze(request, user_id=user_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to perform SEO analysis: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/seo/metadata", response_model=BlogSEOMetadataResponse)
|
||||
async def seo_metadata(request: BlogSEOMetadataRequest) -> BlogSEOMetadataResponse:
|
||||
async def seo_metadata(
|
||||
request: BlogSEOMetadataRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> BlogSEOMetadataResponse:
|
||||
"""Generate SEO metadata for the blog post."""
|
||||
try:
|
||||
return await service.seo_metadata(request)
|
||||
# Extract Clerk user ID (required)
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get('id', ''))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
return await service.seo_metadata(request, user_id=user_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate SEO metadata: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -10,6 +10,7 @@ import asyncio
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from models.blog_models import (
|
||||
@@ -85,6 +86,10 @@ class TaskManager:
|
||||
response["result"] = task["result"]
|
||||
elif task["status"] == "failed":
|
||||
response["error"] = task["error"]
|
||||
if "error_status" in task:
|
||||
response["error_status"] = task["error_status"]
|
||||
if "error_data" in task:
|
||||
response["error_data"] = task["error_data"]
|
||||
|
||||
return response
|
||||
|
||||
@@ -109,14 +114,17 @@ class TaskManager:
|
||||
|
||||
logger.info(f"Progress update for task {task_id}: {message}")
|
||||
|
||||
async def start_research_task(self, request: BlogResearchRequest, user_id: str = "anonymous") -> str:
|
||||
async def start_research_task(self, request: BlogResearchRequest, user_id: str) -> str:
|
||||
"""Start a research operation and return a task ID."""
|
||||
if self.use_database:
|
||||
return await self.db_manager.start_research_task(request, user_id)
|
||||
else:
|
||||
task_id = self.create_task("research")
|
||||
# Store user_id in task for subscription checks
|
||||
if task_id in self.task_storage:
|
||||
self.task_storage[task_id]["user_id"] = user_id
|
||||
# Start the research operation in the background
|
||||
asyncio.create_task(self._run_research_task(task_id, request))
|
||||
asyncio.create_task(self._run_research_task(task_id, request, user_id))
|
||||
return task_id
|
||||
|
||||
def start_outline_task(self, request: BlogOutlineRequest) -> str:
|
||||
@@ -144,7 +152,7 @@ class TaskManager:
|
||||
asyncio.create_task(self._run_medium_generation_task(task_id, request))
|
||||
return task_id
|
||||
|
||||
async def _run_research_task(self, task_id: str, request: BlogResearchRequest):
|
||||
async def _run_research_task(self, task_id: str, request: BlogResearchRequest, user_id: str):
|
||||
"""Background task to run research and update status with progress messages."""
|
||||
try:
|
||||
# Update status to running
|
||||
@@ -157,8 +165,8 @@ class TaskManager:
|
||||
# Check cache first
|
||||
await self.update_progress(task_id, "📋 Checking cache for existing research...")
|
||||
|
||||
# Run the actual research with progress updates
|
||||
result = await self.service.research_with_progress(request, task_id)
|
||||
# Run the actual research with progress updates (pass user_id for subscription checks)
|
||||
result = await self.service.research_with_progress(request, task_id, user_id)
|
||||
|
||||
# Check if research failed gracefully
|
||||
if not result.success:
|
||||
@@ -171,6 +179,16 @@ class TaskManager:
|
||||
self.task_storage[task_id]["status"] = "completed"
|
||||
self.task_storage[task_id]["result"] = result.dict()
|
||||
|
||||
except HTTPException as http_error:
|
||||
# Handle HTTPException (e.g., 429 subscription limit) - preserve error details for frontend
|
||||
error_detail = http_error.detail
|
||||
error_message = error_detail.get('message', str(error_detail)) if isinstance(error_detail, dict) else str(error_detail)
|
||||
await self.update_progress(task_id, f"❌ {error_message}")
|
||||
self.task_storage[task_id]["status"] = "failed"
|
||||
self.task_storage[task_id]["error"] = error_message
|
||||
# Store HTTP error details for frontend modal
|
||||
self.task_storage[task_id]["error_status"] = http_error.status_code
|
||||
self.task_storage[task_id]["error_data"] = error_detail if isinstance(error_detail, dict) else {"error": str(error_detail)}
|
||||
except Exception as e:
|
||||
await self.update_progress(task_id, f"❌ Research failed with error: {str(e)}")
|
||||
# Update status to failed
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query, Body
|
||||
from typing import Dict, Any
|
||||
from typing import Dict, Any, Optional
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -64,6 +64,15 @@ async def activate_strategy_with_monitoring(
|
||||
if not monitoring_success:
|
||||
logger.warning(f"Failed to save monitoring data for strategy {strategy_id}")
|
||||
|
||||
# Trigger scheduler interval adjustment (scheduler will check more frequently now)
|
||||
try:
|
||||
from services.scheduler import get_scheduler
|
||||
scheduler = get_scheduler()
|
||||
await scheduler.trigger_interval_adjustment()
|
||||
logger.info(f"Triggered scheduler interval adjustment after strategy {strategy_id} activation")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not trigger scheduler interval adjustment: {e}")
|
||||
|
||||
logger.info(f"Successfully activated strategy {strategy_id} with monitoring")
|
||||
return {
|
||||
"success": True,
|
||||
@@ -396,6 +405,150 @@ async def get_monitoring_tasks(
|
||||
logger.error(f"Error retrieving monitoring tasks: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.get("/user/{user_id}/monitoring-tasks")
|
||||
async def get_user_monitoring_tasks(
|
||||
user_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
status: Optional[str] = Query(None, description="Filter by task status"),
|
||||
limit: int = Query(50, description="Maximum number of tasks to return"),
|
||||
offset: int = Query(0, description="Number of tasks to skip")
|
||||
):
|
||||
"""
|
||||
Get all monitoring tasks for a specific user with their execution status.
|
||||
|
||||
Uses the scheduler's task loader to get tasks filtered by user_id for proper user isolation.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Getting monitoring tasks for user {user_id}")
|
||||
|
||||
# Use scheduler task loader for user-specific tasks
|
||||
from services.scheduler.utils.task_loader import load_due_monitoring_tasks
|
||||
|
||||
# Load all tasks for user (not just due tasks - we want all user tasks)
|
||||
# Join with strategy to filter by user
|
||||
tasks_query = db.query(MonitoringTask).join(
|
||||
EnhancedContentStrategy,
|
||||
MonitoringTask.strategy_id == EnhancedContentStrategy.id
|
||||
).filter(
|
||||
EnhancedContentStrategy.user_id == user_id
|
||||
)
|
||||
|
||||
# Apply status filter if provided
|
||||
if status:
|
||||
tasks_query = tasks_query.filter(MonitoringTask.status == status)
|
||||
|
||||
# Get tasks with pagination
|
||||
tasks = tasks_query.order_by(desc(MonitoringTask.created_at)).offset(offset).limit(limit).all()
|
||||
|
||||
tasks_data = []
|
||||
for task in tasks:
|
||||
# Get latest execution log
|
||||
latest_log = db.query(TaskExecutionLog).filter(
|
||||
TaskExecutionLog.task_id == task.id
|
||||
).order_by(desc(TaskExecutionLog.execution_date)).first()
|
||||
|
||||
# Get strategy info
|
||||
strategy = db.query(EnhancedContentStrategy).filter(
|
||||
EnhancedContentStrategy.id == task.strategy_id
|
||||
).first()
|
||||
|
||||
task_data = {
|
||||
"id": task.id,
|
||||
"strategy_id": task.strategy_id,
|
||||
"strategy_name": strategy.name if strategy else None,
|
||||
"title": task.task_title,
|
||||
"description": task.task_description,
|
||||
"assignee": task.assignee,
|
||||
"frequency": task.frequency,
|
||||
"metric": task.metric,
|
||||
"measurementMethod": task.measurement_method,
|
||||
"successCriteria": task.success_criteria,
|
||||
"alertThreshold": task.alert_threshold,
|
||||
"status": task.status,
|
||||
"lastExecuted": latest_log.execution_date.isoformat() if latest_log else None,
|
||||
"nextExecution": task.next_execution.isoformat() if task.next_execution else None,
|
||||
"executionCount": db.query(TaskExecutionLog).filter(
|
||||
TaskExecutionLog.task_id == task.id
|
||||
).count(),
|
||||
"created_at": task.created_at.isoformat() if task.created_at else None
|
||||
}
|
||||
tasks_data.append(task_data)
|
||||
|
||||
# Get total count for pagination
|
||||
total_count = db.query(MonitoringTask).join(
|
||||
EnhancedContentStrategy,
|
||||
MonitoringTask.strategy_id == EnhancedContentStrategy.id
|
||||
).filter(
|
||||
EnhancedContentStrategy.user_id == user_id
|
||||
)
|
||||
if status:
|
||||
total_count = total_count.filter(MonitoringTask.status == status)
|
||||
total_count = total_count.count()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": tasks_data,
|
||||
"pagination": {
|
||||
"total": total_count,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"has_more": (offset + len(tasks_data)) < total_count
|
||||
},
|
||||
"message": f"Retrieved {len(tasks_data)} monitoring tasks for user {user_id}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving user monitoring tasks: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to retrieve monitoring tasks: {str(e)}")
|
||||
|
||||
@router.get("/user/{user_id}/execution-logs")
|
||||
async def get_user_execution_logs(
|
||||
user_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
status: Optional[str] = Query(None, description="Filter by execution status"),
|
||||
limit: int = Query(50, description="Maximum number of logs to return"),
|
||||
offset: int = Query(0, description="Number of logs to skip")
|
||||
):
|
||||
"""
|
||||
Get execution logs for a specific user.
|
||||
|
||||
Provides user isolation by filtering execution logs by user_id.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Getting execution logs for user {user_id}")
|
||||
|
||||
monitoring_service = MonitoringDataService(db)
|
||||
logs_data = monitoring_service.get_user_execution_logs(
|
||||
user_id=user_id,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
status_filter=status
|
||||
)
|
||||
|
||||
# Get total count for pagination
|
||||
count_query = db.query(TaskExecutionLog).filter(
|
||||
TaskExecutionLog.user_id == user_id
|
||||
)
|
||||
if status:
|
||||
count_query = count_query.filter(TaskExecutionLog.status == status)
|
||||
total_count = count_query.count()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": logs_data,
|
||||
"pagination": {
|
||||
"total": total_count,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"has_more": (offset + len(logs_data)) < total_count
|
||||
},
|
||||
"message": f"Retrieved {len(logs_data)} execution logs for user {user_id}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving execution logs for user {user_id}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to retrieve execution logs: {str(e)}")
|
||||
|
||||
@router.get("/{strategy_id}/data-freshness")
|
||||
async def get_data_freshness(
|
||||
strategy_id: int,
|
||||
|
||||
@@ -3,13 +3,18 @@ from __future__ import annotations
|
||||
import base64
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from services.llm_providers.main_image_generation import generate_image
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from utils.logger_utils import get_service_logger
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from services.database import get_db
|
||||
from services.subscription import UsageTrackingService, PricingService
|
||||
from models.subscription_models import APIProvider, UsageSummary
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/images", tags=["images"])
|
||||
@@ -39,9 +44,23 @@ class ImageGenerateResponse(BaseModel):
|
||||
|
||||
|
||||
@router.post("/generate", response_model=ImageGenerateResponse)
|
||||
def generate(req: ImageGenerateRequest) -> ImageGenerateResponse:
|
||||
def generate(
|
||||
req: ImageGenerateRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> ImageGenerateResponse:
|
||||
"""Generate image with subscription checking."""
|
||||
try:
|
||||
# Extract Clerk user ID (required)
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get('id', ''))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
# Validation is now handled inside generate_image function
|
||||
last_error: Optional[Exception] = None
|
||||
result = None
|
||||
for attempt in range(2): # simple single retry
|
||||
try:
|
||||
result = generate_image(
|
||||
@@ -56,8 +75,79 @@ def generate(req: ImageGenerateRequest) -> ImageGenerateResponse:
|
||||
"steps": req.steps,
|
||||
"seed": req.seed,
|
||||
},
|
||||
user_id=user_id, # Pass user_id for validation inside generate_image
|
||||
)
|
||||
image_b64 = base64.b64encode(result.image_bytes).decode("utf-8")
|
||||
|
||||
# TRACK USAGE after successful image generation
|
||||
if result:
|
||||
logger.info(f"[images.generate] ✅ Image generation successful, tracking usage for user {user_id}")
|
||||
try:
|
||||
db_track = next(get_db())
|
||||
try:
|
||||
# Get or create usage summary
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
logger.debug(f"[images.generate] Looking for usage summary: user_id={user_id}, period={current_period}")
|
||||
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
logger.info(f"[images.generate] Creating new usage summary for user {user_id}, period {current_period}")
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush() # Ensure summary is persisted before updating
|
||||
|
||||
# Get "before" state for unified log
|
||||
current_calls_before = getattr(summary, "stability_calls", 0) or 0
|
||||
|
||||
# Update provider-specific counters (stability for image generation)
|
||||
# Note: All image generation goes through STABILITY provider enum regardless of actual provider
|
||||
new_calls = current_calls_before + 1
|
||||
setattr(summary, "stability_calls", new_calls)
|
||||
logger.debug(f"[images.generate] Updated stability_calls: {current_calls_before} -> {new_calls}")
|
||||
|
||||
# Update totals
|
||||
old_total_calls = summary.total_calls or 0
|
||||
summary.total_calls = old_total_calls + 1
|
||||
logger.debug(f"[images.generate] Updated totals: calls {old_total_calls} -> {summary.total_calls}")
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
call_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"[images.generate] ✅ Successfully tracked usage: user {user_id} -> stability -> {new_calls} calls")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
|
||||
print(f"""
|
||||
[SUBSCRIPTION] Image Generation
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: stability
|
||||
├─ Actual Provider: {result.provider}
|
||||
├─ Model: {result.model or 'default'}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {call_limit if call_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""")
|
||||
except Exception as track_error:
|
||||
logger.error(f"[images.generate] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
# Non-blocking: log error but don't fail the request
|
||||
logger.error(f"[images.generate] ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
|
||||
return ImageGenerateResponse(
|
||||
image_base64=image_b64,
|
||||
width=result.width,
|
||||
@@ -106,7 +196,10 @@ class ImagePromptSuggestResponse(BaseModel):
|
||||
|
||||
|
||||
@router.post("/suggest-prompts", response_model=ImagePromptSuggestResponse)
|
||||
def suggest_prompts(req: ImagePromptSuggestRequest) -> ImagePromptSuggestResponse:
|
||||
def suggest_prompts(
|
||||
req: ImagePromptSuggestRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> ImagePromptSuggestResponse:
|
||||
try:
|
||||
provider = (req.provider or ("gemini" if (os.getenv("GPT_PROVIDER") or "").lower().startswith("gemini") else "huggingface")).lower()
|
||||
section = req.section or {}
|
||||
@@ -203,7 +296,15 @@ def suggest_prompts(req: ImagePromptSuggestRequest) -> ImagePromptSuggestRespons
|
||||
If including on-image text, return it in overlay_text (short: <= 8 words).
|
||||
"""
|
||||
|
||||
raw = llm_text_gen(prompt=prompt, system_prompt=system, json_struct=schema)
|
||||
# Get user_id for llm_text_gen subscription check (required)
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id_for_llm = str(current_user.get('id', ''))
|
||||
if not user_id_for_llm:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
raw = llm_text_gen(prompt=prompt, system_prompt=system, json_struct=schema, user_id=user_id_for_llm)
|
||||
data = raw if isinstance(raw, dict) else {}
|
||||
suggestions = data.get("suggestions") or []
|
||||
# basic fallback if provider returns string
|
||||
|
||||
@@ -94,6 +94,7 @@ async def get_subscription_plans(
|
||||
"description": plan.description,
|
||||
"features": plan.features or [],
|
||||
"limits": {
|
||||
"ai_text_generation_calls": getattr(plan, 'ai_text_generation_calls_limit', None) or 0,
|
||||
"gemini_calls": plan.gemini_calls_limit,
|
||||
"openai_calls": plan.openai_calls_limit,
|
||||
"anthropic_calls": plan.anthropic_calls_limit,
|
||||
@@ -162,6 +163,7 @@ async def get_user_subscription(
|
||||
},
|
||||
"status": "free",
|
||||
"limits": {
|
||||
"ai_text_generation_calls": getattr(free_plan, 'ai_text_generation_calls_limit', None) or 0,
|
||||
"gemini_calls": free_plan.gemini_calls_limit,
|
||||
"openai_calls": free_plan.openai_calls_limit,
|
||||
"anthropic_calls": free_plan.anthropic_calls_limit,
|
||||
@@ -200,6 +202,7 @@ async def get_user_subscription(
|
||||
"is_free": False
|
||||
},
|
||||
"limits": {
|
||||
"ai_text_generation_calls": getattr(subscription.plan, 'ai_text_generation_calls_limit', None) or 0,
|
||||
"gemini_calls": subscription.plan.gemini_calls_limit,
|
||||
"openai_calls": subscription.plan.openai_calls_limit,
|
||||
"anthropic_calls": subscription.plan.anthropic_calls_limit,
|
||||
@@ -252,6 +255,7 @@ async def get_subscription_status(
|
||||
"tier": "free",
|
||||
"can_use_api": True,
|
||||
"limits": {
|
||||
"ai_text_generation_calls": getattr(free_plan, 'ai_text_generation_calls_limit', None) or 0,
|
||||
"gemini_calls": free_plan.gemini_calls_limit,
|
||||
"openai_calls": free_plan.openai_calls_limit,
|
||||
"anthropic_calls": free_plan.anthropic_calls_limit,
|
||||
@@ -309,6 +313,7 @@ async def get_subscription_status(
|
||||
"tier": subscription.plan.tier.value,
|
||||
"can_use_api": True,
|
||||
"limits": {
|
||||
"ai_text_generation_calls": getattr(subscription.plan, 'ai_text_generation_calls_limit', None) or 0,
|
||||
"gemini_calls": subscription.plan.gemini_calls_limit,
|
||||
"openai_calls": subscription.plan.openai_calls_limit,
|
||||
"anthropic_calls": subscription.plan.anthropic_calls_limit,
|
||||
@@ -331,9 +336,14 @@ async def get_subscription_status(
|
||||
async def subscribe_to_plan(
|
||||
user_id: str,
|
||||
subscription_data: dict,
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Create or update a user's subscription."""
|
||||
"""Create or update a user's subscription (renewal)."""
|
||||
|
||||
# Verify user can only subscribe/renew their own subscription
|
||||
if current_user.get('id') != user_id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
try:
|
||||
plan_id = subscription_data.get('plan_id')
|
||||
@@ -388,12 +398,75 @@ async def subscribe_to_plan(
|
||||
|
||||
db.commit()
|
||||
|
||||
# Get current usage BEFORE reset for logging
|
||||
current_period = datetime.utcnow().strftime("%Y-%m")
|
||||
usage_before = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
# Log renewal request details
|
||||
logger.info("=" * 80)
|
||||
logger.info(f"[SUBSCRIPTION RENEWAL] 🔄 Processing renewal request")
|
||||
logger.info(f" ├─ User: {user_id}")
|
||||
logger.info(f" ├─ Plan: {plan.name} (ID: {plan_id}, Tier: {plan.tier.value})")
|
||||
logger.info(f" ├─ Billing Cycle: {billing_cycle}")
|
||||
logger.info(f" ├─ Period Start: {now.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
logger.info(f" └─ Period End: {subscription.current_period_end.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
if usage_before:
|
||||
logger.info(f" 📊 Current Usage BEFORE Reset (Period: {current_period}):")
|
||||
logger.info(f" ├─ Gemini: {usage_before.gemini_tokens or 0} tokens / {usage_before.gemini_calls or 0} calls")
|
||||
logger.info(f" ├─ Mistral/HF: {usage_before.mistral_tokens or 0} tokens / {usage_before.mistral_calls or 0} calls")
|
||||
logger.info(f" ├─ OpenAI: {usage_before.openai_tokens or 0} tokens / {usage_before.openai_calls or 0} calls")
|
||||
logger.info(f" ├─ Stability (Images): {usage_before.stability_calls or 0} calls")
|
||||
logger.info(f" ├─ Total Tokens: {usage_before.total_tokens or 0}")
|
||||
logger.info(f" ├─ Total Calls: {usage_before.total_calls or 0}")
|
||||
logger.info(f" └─ Usage Status: {usage_before.usage_status.value}")
|
||||
else:
|
||||
logger.info(f" 📊 No usage summary found for period {current_period} (will be created on reset)")
|
||||
|
||||
# Clear subscription limits cache to force refresh on next check
|
||||
try:
|
||||
from services.subscription import PricingService
|
||||
# Clear cache for this specific user (class-level cache shared across all instances)
|
||||
cleared_count = PricingService.clear_user_cache(user_id)
|
||||
logger.info(f" 🗑️ Cleared {cleared_count} subscription cache entries for user {user_id}")
|
||||
except Exception as cache_err:
|
||||
logger.error(f" ❌ Failed to clear cache after subscribe: {cache_err}")
|
||||
|
||||
# Reset usage status for current billing period so new plan takes effect immediately
|
||||
reset_result = None
|
||||
try:
|
||||
usage_service = UsageTrackingService(db)
|
||||
await usage_service.reset_current_billing_period(user_id)
|
||||
reset_result = await usage_service.reset_current_billing_period(user_id)
|
||||
|
||||
# Re-query usage summary from DB after reset to get fresh data
|
||||
usage_after = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if reset_result.get('reset'):
|
||||
logger.info(f" ✅ Usage counters RESET successfully")
|
||||
if usage_after:
|
||||
logger.info(f" 📊 New Usage AFTER Reset:")
|
||||
logger.info(f" ├─ Gemini: {usage_after.gemini_tokens or 0} tokens / {usage_after.gemini_calls or 0} calls")
|
||||
logger.info(f" ├─ Mistral/HF: {usage_after.mistral_tokens or 0} tokens / {usage_after.mistral_calls or 0} calls")
|
||||
logger.info(f" ├─ OpenAI: {usage_after.openai_tokens or 0} tokens / {usage_after.openai_calls or 0} calls")
|
||||
logger.info(f" ├─ Stability (Images): {usage_after.stability_calls or 0} calls")
|
||||
logger.info(f" ├─ Total Tokens: {usage_after.total_tokens or 0}")
|
||||
logger.info(f" ├─ Total Calls: {usage_after.total_calls or 0}")
|
||||
logger.info(f" └─ Usage Status: {usage_after.usage_status.value}")
|
||||
else:
|
||||
logger.warning(f" ⚠️ Usage summary not found after reset - may need to be created on next API call")
|
||||
else:
|
||||
logger.warning(f" ⚠️ Reset returned: {reset_result.get('reason', 'unknown')}")
|
||||
except Exception as reset_err:
|
||||
logger.error(f"Failed to reset usage after subscribe: {reset_err}")
|
||||
logger.error(f" ❌ Failed to reset usage after subscribe: {reset_err}", exc_info=True)
|
||||
|
||||
logger.info(f" ✅ Renewal completed: User {user_id} → {plan.name} ({billing_cycle})")
|
||||
logger.info("=" * 80)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
@@ -404,7 +477,20 @@ async def subscribe_to_plan(
|
||||
"billing_cycle": billing_cycle,
|
||||
"current_period_start": subscription.current_period_start.isoformat(),
|
||||
"current_period_end": subscription.current_period_end.isoformat(),
|
||||
"status": subscription.status.value
|
||||
"status": subscription.status.value,
|
||||
"limits": {
|
||||
"ai_text_generation_calls": getattr(plan, 'ai_text_generation_calls_limit', None) or 0,
|
||||
"gemini_calls": plan.gemini_calls_limit,
|
||||
"openai_calls": plan.openai_calls_limit,
|
||||
"anthropic_calls": plan.anthropic_calls_limit,
|
||||
"mistral_calls": plan.mistral_calls_limit,
|
||||
"tavily_calls": plan.tavily_calls_limit,
|
||||
"serper_calls": plan.serper_calls_limit,
|
||||
"metaphor_calls": plan.metaphor_calls_limit,
|
||||
"firecrawl_calls": plan.firecrawl_calls_limit,
|
||||
"stability_calls": plan.stability_calls_limit,
|
||||
"monthly_cost": plan.monthly_cost_limit
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -477,6 +477,39 @@ async def test_publish_to_wix(request: WixPublishRequest) -> Dict[str, Any]:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/refresh-token")
|
||||
async def refresh_wix_token(request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Refresh Wix access token using refresh token
|
||||
|
||||
Args:
|
||||
request: Dict containing refresh_token
|
||||
|
||||
Returns:
|
||||
New token information with access_token, refresh_token, expires_in
|
||||
"""
|
||||
try:
|
||||
refresh_token = request.get("refresh_token")
|
||||
if not refresh_token:
|
||||
raise HTTPException(status_code=400, detail="Missing refresh_token")
|
||||
|
||||
# Refresh the token
|
||||
new_tokens = wix_service.refresh_access_token(refresh_token)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"access_token": new_tokens.get("access_token"),
|
||||
"refresh_token": new_tokens.get("refresh_token"),
|
||||
"expires_in": new_tokens.get("expires_in"),
|
||||
"token_type": new_tokens.get("token_type", "Bearer")
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to refresh Wix token: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to refresh token: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/test/publish/real")
|
||||
async def test_publish_real(payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
|
||||
@@ -298,6 +298,11 @@ async def startup_event():
|
||||
try:
|
||||
# Initialize database
|
||||
init_database()
|
||||
|
||||
# Start task scheduler
|
||||
from services.scheduler import get_scheduler
|
||||
await get_scheduler().start()
|
||||
|
||||
logger.info("ALwrity backend started successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during startup: {e}")
|
||||
@@ -307,6 +312,10 @@ async def startup_event():
|
||||
async def shutdown_event():
|
||||
"""Cleanup on shutdown."""
|
||||
try:
|
||||
# Stop task scheduler
|
||||
from services.scheduler import get_scheduler
|
||||
await get_scheduler().stop()
|
||||
|
||||
# Close database connections
|
||||
close_database()
|
||||
logger.info("ALwrity backend shutdown successfully")
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
-- Migration: Add user_id column to task_execution_logs for user isolation
|
||||
-- Date: 2025-01-XX
|
||||
-- Purpose: Enable user isolation tracking in scheduler task execution logs
|
||||
|
||||
-- Add user_id column (nullable for backward compatibility with existing records)
|
||||
ALTER TABLE task_execution_logs
|
||||
ADD COLUMN user_id INTEGER NULL;
|
||||
|
||||
-- Create index for efficient user filtering and queries
|
||||
CREATE INDEX IF NOT EXISTS idx_task_execution_logs_user_id
|
||||
ON task_execution_logs(user_id);
|
||||
|
||||
-- Create composite index for common query patterns (user_id + status + execution_date)
|
||||
CREATE INDEX IF NOT EXISTS idx_task_execution_logs_user_status_date
|
||||
ON task_execution_logs(user_id, status, execution_date);
|
||||
|
||||
-- Note: Backfilling existing records would require joining with monitoring_tasks
|
||||
-- and enhanced_content_strategies tables. This can be done in a separate migration
|
||||
-- or during a maintenance window. For now, existing records will have user_id = NULL.
|
||||
|
||||
@@ -65,7 +65,7 @@ class LinkedInPostRequest(BaseModel):
|
||||
persona_override: Optional[Dict[str, Any]] = Field(default=None, description="Session-only persona overrides to apply without saving")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"topic": "AI in healthcare transformation",
|
||||
"industry": "Healthcare",
|
||||
@@ -102,7 +102,7 @@ class LinkedInArticleRequest(BaseModel):
|
||||
persona_override: Optional[Dict[str, Any]] = Field(default=None, description="Session-only persona overrides to apply without saving")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"topic": "Digital transformation in manufacturing",
|
||||
"industry": "Manufacturing",
|
||||
@@ -135,7 +135,7 @@ class LinkedInCarouselRequest(BaseModel):
|
||||
include_citations: bool = Field(default=True, description="Whether to include inline citations")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"topic": "Future of remote work",
|
||||
"industry": "Technology",
|
||||
@@ -167,7 +167,7 @@ class LinkedInVideoScriptRequest(BaseModel):
|
||||
include_citations: bool = Field(default=True, description="Whether to include inline citations")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"topic": "Cybersecurity best practices",
|
||||
"industry": "Technology",
|
||||
@@ -197,7 +197,7 @@ class LinkedInCommentResponseRequest(BaseModel):
|
||||
grounding_level: GroundingLevel = Field(default=GroundingLevel.BASIC, description="Level of content grounding")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"original_comment": "Great insights on AI implementation!",
|
||||
"post_context": "Post about AI transformation in healthcare",
|
||||
@@ -353,7 +353,7 @@ class LinkedInPostResponse(BaseModel):
|
||||
grounding_status: Optional[Dict[str, Any]] = Field(None, description="Grounding operation status")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"success": True,
|
||||
"data": {
|
||||
|
||||
@@ -48,8 +48,9 @@ class TaskExecutionLog(Base):
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
task_id = Column(Integer, ForeignKey("monitoring_tasks.id"), nullable=False)
|
||||
user_id = Column(Integer, nullable=True) # User ID for user isolation (nullable for backward compatibility)
|
||||
execution_date = Column(DateTime, default=datetime.utcnow)
|
||||
status = Column(String(50), nullable=False) # 'success', 'failed', 'skipped'
|
||||
status = Column(String(50), nullable=False) # 'success', 'failed', 'skipped', 'running'
|
||||
result_data = Column(JSON, nullable=True)
|
||||
error_message = Column(Text, nullable=True)
|
||||
execution_time_ms = Column(Integer, nullable=True)
|
||||
|
||||
@@ -50,16 +50,22 @@ class SubscriptionPlan(Base):
|
||||
price_monthly = Column(Float, nullable=False, default=0.0)
|
||||
price_yearly = Column(Float, nullable=False, default=0.0)
|
||||
|
||||
# API Call Limits
|
||||
gemini_calls_limit = Column(Integer, default=0) # 0 = unlimited
|
||||
openai_calls_limit = Column(Integer, default=0)
|
||||
anthropic_calls_limit = Column(Integer, default=0)
|
||||
mistral_calls_limit = Column(Integer, default=0)
|
||||
# Unified AI Text Generation Call Limit (applies to all LLM providers: gemini, openai, anthropic, mistral)
|
||||
# Note: This column may not exist in older databases - use getattr() when accessing
|
||||
ai_text_generation_calls_limit = Column(Integer, default=0, nullable=True) # 0 = unlimited, None if column doesn't exist
|
||||
|
||||
# Legacy per-provider limits (kept for backwards compatibility and analytics)
|
||||
gemini_calls_limit = Column(Integer, default=0) # 0 = unlimited (deprecated, use ai_text_generation_calls_limit)
|
||||
openai_calls_limit = Column(Integer, default=0) # (deprecated, use ai_text_generation_calls_limit)
|
||||
anthropic_calls_limit = Column(Integer, default=0) # (deprecated, use ai_text_generation_calls_limit)
|
||||
mistral_calls_limit = Column(Integer, default=0) # (deprecated, use ai_text_generation_calls_limit)
|
||||
|
||||
# Other API Call Limits (non-LLM)
|
||||
tavily_calls_limit = Column(Integer, default=0)
|
||||
serper_calls_limit = Column(Integer, default=0)
|
||||
metaphor_calls_limit = Column(Integer, default=0)
|
||||
firecrawl_calls_limit = Column(Integer, default=0)
|
||||
stability_calls_limit = Column(Integer, default=0)
|
||||
stability_calls_limit = Column(Integer, default=0) # Image generation
|
||||
|
||||
# Token Limits (for LLM providers)
|
||||
gemini_tokens_limit = Column(Integer, default=0)
|
||||
|
||||
@@ -63,6 +63,9 @@ pytest-asyncio>=0.21.0
|
||||
pydantic>=2.5.2,<3.0.0
|
||||
typing-extensions>=4.8.0
|
||||
|
||||
# Task scheduling
|
||||
apscheduler>=3.10.0
|
||||
|
||||
# Optional dependencies (for enhanced features)
|
||||
redis>=5.0.0
|
||||
schedule>=1.2.0
|
||||
146
backend/scripts/add_ai_text_generation_limit_column.py
Normal file
146
backend/scripts/add_ai_text_generation_limit_column.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
Migration Script: Add ai_text_generation_calls_limit column to subscription_plans table.
|
||||
|
||||
This adds the unified AI text generation limit column that applies to all LLM providers
|
||||
(gemini, openai, anthropic, mistral) instead of per-provider limits.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Add the backend directory to Python path
|
||||
backend_dir = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(backend_dir))
|
||||
|
||||
from sqlalchemy import create_engine, text, inspect
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from loguru import logger
|
||||
|
||||
from models.subscription_models import SubscriptionPlan, SubscriptionTier
|
||||
from services.database import DATABASE_URL
|
||||
|
||||
def add_ai_text_generation_limit_column():
|
||||
"""Add ai_text_generation_calls_limit column to subscription_plans table."""
|
||||
|
||||
try:
|
||||
engine = create_engine(DATABASE_URL, echo=False)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
db = SessionLocal()
|
||||
|
||||
try:
|
||||
# Check if column already exists
|
||||
inspector = inspect(engine)
|
||||
columns = [col['name'] for col in inspector.get_columns('subscription_plans')]
|
||||
|
||||
if 'ai_text_generation_calls_limit' in columns:
|
||||
logger.info("✅ Column 'ai_text_generation_calls_limit' already exists. Skipping migration.")
|
||||
return True
|
||||
|
||||
logger.info("📋 Adding 'ai_text_generation_calls_limit' column to subscription_plans table...")
|
||||
|
||||
# Add the column (SQLite compatible)
|
||||
alter_query = text("""
|
||||
ALTER TABLE subscription_plans
|
||||
ADD COLUMN ai_text_generation_calls_limit INTEGER DEFAULT 0
|
||||
""")
|
||||
|
||||
db.execute(alter_query)
|
||||
db.commit()
|
||||
|
||||
logger.info("✅ Column added successfully!")
|
||||
|
||||
# Update existing plans with unified limits based on their current limits
|
||||
logger.info("\n🔄 Updating existing subscription plans with unified limits...")
|
||||
|
||||
plans = db.query(SubscriptionPlan).all()
|
||||
updated_count = 0
|
||||
|
||||
for plan in plans:
|
||||
# Use the first non-zero LLM provider limit as the unified limit
|
||||
# Or use gemini_calls_limit as default
|
||||
unified_limit = (
|
||||
plan.ai_text_generation_calls_limit or
|
||||
plan.gemini_calls_limit or
|
||||
plan.openai_calls_limit or
|
||||
plan.anthropic_calls_limit or
|
||||
plan.mistral_calls_limit or
|
||||
0
|
||||
)
|
||||
|
||||
# For Basic plan, ensure it's set to 10 (from our recent update)
|
||||
if plan.tier == SubscriptionTier.BASIC:
|
||||
unified_limit = 10
|
||||
|
||||
if plan.ai_text_generation_calls_limit != unified_limit:
|
||||
plan.ai_text_generation_calls_limit = unified_limit
|
||||
plan.updated_at = datetime.now(timezone.utc)
|
||||
updated_count += 1
|
||||
|
||||
logger.info(f" ✅ Updated {plan.name} ({plan.tier.value}): ai_text_generation_calls_limit = {unified_limit}")
|
||||
else:
|
||||
logger.info(f" ℹ️ {plan.name} ({plan.tier.value}): already set to {unified_limit}")
|
||||
|
||||
if updated_count > 0:
|
||||
db.commit()
|
||||
logger.info(f"\n✅ Updated {updated_count} subscription plan(s)")
|
||||
else:
|
||||
logger.info("\nℹ️ No plans needed updating")
|
||||
|
||||
# Display summary
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("MIGRATION SUMMARY")
|
||||
logger.info("="*60)
|
||||
|
||||
all_plans = db.query(SubscriptionPlan).all()
|
||||
for plan in all_plans:
|
||||
logger.info(f"\n{plan.name} ({plan.tier.value}):")
|
||||
logger.info(f" Unified AI Text Gen Limit: {plan.ai_text_generation_calls_limit if plan.ai_text_generation_calls_limit else 'Not set'}")
|
||||
logger.info(f" Legacy Limits: gemini={plan.gemini_calls_limit}, mistral={plan.mistral_calls_limit}")
|
||||
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("✅ Migration completed successfully!")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"❌ Error during migration: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to connect to database: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("🚀 Starting ai_text_generation_calls_limit column migration...")
|
||||
logger.info("="*60)
|
||||
logger.info("This will add the unified AI text generation limit column")
|
||||
logger.info("and update existing plans with appropriate values.")
|
||||
logger.info("="*60)
|
||||
|
||||
try:
|
||||
success = add_ai_text_generation_limit_column()
|
||||
|
||||
if success:
|
||||
logger.info("\n✅ Script completed successfully!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
logger.error("\n❌ Script failed!")
|
||||
sys.exit(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n⚠️ Script cancelled by user")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"\n❌ Unexpected error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
210
backend/scripts/cap_basic_plan_usage.py
Normal file
210
backend/scripts/cap_basic_plan_usage.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""
|
||||
Standalone script to cap usage counters at new Basic plan limits.
|
||||
|
||||
This preserves historical usage data but caps it at the new limits so users
|
||||
can continue making new calls within their limits.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Add the backend directory to Python path
|
||||
backend_dir = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(backend_dir))
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from loguru import logger
|
||||
|
||||
from models.subscription_models import SubscriptionPlan, SubscriptionTier, UserSubscription, UsageSummary, UsageStatus
|
||||
from services.database import DATABASE_URL
|
||||
from services.subscription import PricingService
|
||||
|
||||
def cap_basic_plan_usage():
|
||||
"""Cap usage counters at new Basic plan limits."""
|
||||
|
||||
try:
|
||||
engine = create_engine(DATABASE_URL, echo=False)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
db = SessionLocal()
|
||||
|
||||
try:
|
||||
# Find Basic plan
|
||||
basic_plan = db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.tier == SubscriptionTier.BASIC
|
||||
).first()
|
||||
|
||||
if not basic_plan:
|
||||
logger.error("❌ Basic plan not found in database!")
|
||||
return False
|
||||
|
||||
# New limits
|
||||
new_call_limit = basic_plan.gemini_calls_limit # Should be 10
|
||||
new_token_limit = basic_plan.gemini_tokens_limit # Should be 2000
|
||||
new_image_limit = basic_plan.stability_calls_limit # Should be 5
|
||||
|
||||
logger.info(f"📋 Basic Plan Limits:")
|
||||
logger.info(f" Calls: {new_call_limit}")
|
||||
logger.info(f" Tokens: {new_token_limit}")
|
||||
logger.info(f" Images: {new_image_limit}")
|
||||
|
||||
# Get all Basic plan users
|
||||
user_subscriptions = db.query(UserSubscription).filter(
|
||||
UserSubscription.plan_id == basic_plan.id,
|
||||
UserSubscription.is_active == True
|
||||
).all()
|
||||
|
||||
logger.info(f"\n👥 Found {len(user_subscriptions)} Basic plan user(s)")
|
||||
|
||||
pricing_service = PricingService(db)
|
||||
capped_count = 0
|
||||
|
||||
for sub in user_subscriptions:
|
||||
try:
|
||||
# Get current billing period for this user
|
||||
current_period = pricing_service.get_current_billing_period(sub.user_id) or datetime.now(timezone.utc).strftime("%Y-%m")
|
||||
|
||||
# Find usage summary for current period
|
||||
usage_summary = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == sub.user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if usage_summary:
|
||||
# Store old values for logging
|
||||
old_gemini = usage_summary.gemini_calls or 0
|
||||
old_mistral = usage_summary.mistral_calls or 0
|
||||
old_openai = usage_summary.openai_calls or 0
|
||||
old_anthropic = usage_summary.anthropic_calls or 0
|
||||
old_tokens = max(
|
||||
usage_summary.gemini_tokens or 0,
|
||||
usage_summary.openai_tokens or 0,
|
||||
usage_summary.anthropic_tokens or 0,
|
||||
usage_summary.mistral_tokens or 0
|
||||
)
|
||||
old_images = usage_summary.stability_calls or 0
|
||||
|
||||
# Check if capping is needed
|
||||
needs_cap = (
|
||||
old_gemini > new_call_limit or
|
||||
old_mistral > new_call_limit or
|
||||
old_openai > new_call_limit or
|
||||
old_anthropic > new_call_limit or
|
||||
old_images > new_image_limit or
|
||||
old_tokens > new_token_limit
|
||||
)
|
||||
|
||||
if needs_cap:
|
||||
# Cap LLM provider counters at new limits
|
||||
usage_summary.gemini_calls = min(old_gemini, new_call_limit)
|
||||
usage_summary.mistral_calls = min(old_mistral, new_call_limit)
|
||||
usage_summary.openai_calls = min(old_openai, new_call_limit)
|
||||
usage_summary.anthropic_calls = min(old_anthropic, new_call_limit)
|
||||
|
||||
# Cap token counters at new limits
|
||||
usage_summary.gemini_tokens = min(usage_summary.gemini_tokens or 0, new_token_limit)
|
||||
usage_summary.openai_tokens = min(usage_summary.openai_tokens or 0, new_token_limit)
|
||||
usage_summary.anthropic_tokens = min(usage_summary.anthropic_tokens or 0, new_token_limit)
|
||||
usage_summary.mistral_tokens = min(usage_summary.mistral_tokens or 0, new_token_limit)
|
||||
|
||||
# Cap image counter at new limit
|
||||
usage_summary.stability_calls = min(old_images, new_image_limit)
|
||||
|
||||
# Recalculate totals based on capped values
|
||||
total_capped_calls = (
|
||||
usage_summary.gemini_calls +
|
||||
usage_summary.mistral_calls +
|
||||
usage_summary.openai_calls +
|
||||
usage_summary.anthropic_calls +
|
||||
usage_summary.stability_calls
|
||||
)
|
||||
total_capped_tokens = (
|
||||
usage_summary.gemini_tokens +
|
||||
usage_summary.mistral_tokens +
|
||||
usage_summary.openai_tokens +
|
||||
usage_summary.anthropic_tokens
|
||||
)
|
||||
|
||||
usage_summary.total_calls = total_capped_calls
|
||||
usage_summary.total_tokens = total_capped_tokens
|
||||
|
||||
# Reset status to active to allow new calls
|
||||
usage_summary.usage_status = UsageStatus.ACTIVE
|
||||
usage_summary.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
db.commit()
|
||||
capped_count += 1
|
||||
|
||||
logger.info(f"\n✅ Capped usage for user {sub.user_id} (period {current_period}):")
|
||||
logger.info(f" Gemini Calls: {old_gemini} → {usage_summary.gemini_calls} (limit: {new_call_limit})")
|
||||
logger.info(f" Mistral Calls: {old_mistral} → {usage_summary.mistral_calls} (limit: {new_call_limit})")
|
||||
logger.info(f" OpenAI Calls: {old_openai} → {usage_summary.openai_calls} (limit: {new_call_limit})")
|
||||
logger.info(f" Anthropic Calls: {old_anthropic} → {usage_summary.anthropic_calls} (limit: {new_call_limit})")
|
||||
logger.info(f" Tokens: {old_tokens} → {max(usage_summary.gemini_tokens, usage_summary.mistral_tokens)} (limit: {new_token_limit})")
|
||||
logger.info(f" Images: {old_images} → {usage_summary.stability_calls} (limit: {new_image_limit})")
|
||||
else:
|
||||
logger.info(f" ℹ️ User {sub.user_id} usage is within limits - no capping needed")
|
||||
else:
|
||||
logger.info(f" ℹ️ No usage summary found for user {sub.user_id} (period {current_period})")
|
||||
|
||||
except Exception as cap_error:
|
||||
logger.error(f" ❌ Error capping usage for user {sub.user_id}: {cap_error}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
db.rollback()
|
||||
|
||||
if capped_count > 0:
|
||||
logger.info(f"\n✅ Successfully capped usage for {capped_count} user(s)")
|
||||
logger.info(" Historical usage preserved, but capped at new limits")
|
||||
logger.info(" Users can now make new calls within their limits")
|
||||
else:
|
||||
logger.info("\nℹ️ No usage counters needed capping")
|
||||
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("CAPPING COMPLETE")
|
||||
logger.info("="*60)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"❌ Error capping usage: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to connect to database: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("🚀 Starting Basic plan usage capping...")
|
||||
logger.info("="*60)
|
||||
logger.info("This will cap usage counters at new Basic plan limits")
|
||||
logger.info("while preserving historical usage data.")
|
||||
logger.info("="*60)
|
||||
|
||||
try:
|
||||
success = cap_basic_plan_usage()
|
||||
|
||||
if success:
|
||||
logger.info("\n✅ Script completed successfully!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
logger.error("\n❌ Script failed!")
|
||||
sys.exit(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n⚠️ Script cancelled by user")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"\n❌ Unexpected error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
168
backend/scripts/reset_basic_plan_usage.py
Normal file
168
backend/scripts/reset_basic_plan_usage.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
Quick script to reset usage counters for Basic plan users.
|
||||
|
||||
This fixes the issue where plan limits were updated but old usage data remained.
|
||||
Resets all usage counters (calls, tokens, images) to 0 for the current billing period.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Add the backend directory to Python path
|
||||
backend_dir = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(backend_dir))
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from loguru import logger
|
||||
|
||||
from models.subscription_models import SubscriptionPlan, SubscriptionTier, UserSubscription, UsageSummary, UsageStatus
|
||||
from services.database import DATABASE_URL
|
||||
from services.subscription import PricingService
|
||||
|
||||
def reset_basic_plan_usage():
|
||||
"""Reset usage counters for all Basic plan users."""
|
||||
|
||||
try:
|
||||
engine = create_engine(DATABASE_URL, echo=False)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
db = SessionLocal()
|
||||
|
||||
try:
|
||||
# Find Basic plan
|
||||
basic_plan = db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.tier == SubscriptionTier.BASIC
|
||||
).first()
|
||||
|
||||
if not basic_plan:
|
||||
logger.error("❌ Basic plan not found in database!")
|
||||
return False
|
||||
|
||||
# Get all Basic plan users
|
||||
user_subscriptions = db.query(UserSubscription).filter(
|
||||
UserSubscription.plan_id == basic_plan.id,
|
||||
UserSubscription.is_active == True
|
||||
).all()
|
||||
|
||||
logger.info(f"Found {len(user_subscriptions)} Basic plan user(s)")
|
||||
|
||||
pricing_service = PricingService(db)
|
||||
reset_count = 0
|
||||
|
||||
for sub in user_subscriptions:
|
||||
try:
|
||||
# Get current billing period for this user
|
||||
current_period = pricing_service.get_current_billing_period(sub.user_id) or datetime.now(timezone.utc).strftime("%Y-%m")
|
||||
|
||||
# Find usage summary for current period
|
||||
usage_summary = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == sub.user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if usage_summary:
|
||||
# Store old values for logging
|
||||
old_gemini = usage_summary.gemini_calls or 0
|
||||
old_mistral = usage_summary.mistral_calls or 0
|
||||
old_tokens = (usage_summary.mistral_tokens or 0) + (usage_summary.gemini_tokens or 0)
|
||||
old_images = usage_summary.stability_calls or 0
|
||||
old_total_calls = usage_summary.total_calls or 0
|
||||
old_total_tokens = usage_summary.total_tokens or 0
|
||||
|
||||
# Reset all LLM provider counters
|
||||
usage_summary.gemini_calls = 0
|
||||
usage_summary.openai_calls = 0
|
||||
usage_summary.anthropic_calls = 0
|
||||
usage_summary.mistral_calls = 0
|
||||
|
||||
# Reset all token counters
|
||||
usage_summary.gemini_tokens = 0
|
||||
usage_summary.openai_tokens = 0
|
||||
usage_summary.anthropic_tokens = 0
|
||||
usage_summary.mistral_tokens = 0
|
||||
|
||||
# Reset image counter
|
||||
usage_summary.stability_calls = 0
|
||||
|
||||
# Reset totals
|
||||
usage_summary.total_calls = 0
|
||||
usage_summary.total_tokens = 0
|
||||
usage_summary.total_cost = 0.0
|
||||
|
||||
# Reset status to active
|
||||
usage_summary.usage_status = UsageStatus.ACTIVE
|
||||
usage_summary.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
db.commit()
|
||||
reset_count += 1
|
||||
|
||||
logger.info(f"\n✅ Reset usage for user {sub.user_id} (period {current_period}):")
|
||||
logger.info(f" Calls: {old_gemini + old_mistral} (gemini: {old_gemini}, mistral: {old_mistral}) → 0")
|
||||
logger.info(f" Tokens: {old_tokens} → 0")
|
||||
logger.info(f" Images: {old_images} → 0")
|
||||
logger.info(f" Total Calls: {old_total_calls} → 0")
|
||||
logger.info(f" Total Tokens: {old_total_tokens} → 0")
|
||||
else:
|
||||
logger.info(f" ℹ️ No usage summary found for user {sub.user_id} (period {current_period}) - nothing to reset")
|
||||
|
||||
except Exception as reset_error:
|
||||
logger.error(f" ❌ Error resetting usage for user {sub.user_id}: {reset_error}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
db.rollback()
|
||||
|
||||
if reset_count > 0:
|
||||
logger.info(f"\n✅ Successfully reset usage counters for {reset_count} user(s)")
|
||||
else:
|
||||
logger.info("\nℹ️ No usage counters to reset")
|
||||
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("RESET COMPLETE")
|
||||
logger.info("="*60)
|
||||
logger.info("\n💡 Usage counters have been reset. Users can now use their new limits.")
|
||||
logger.info(" Next API call will start counting from 0.")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"❌ Error resetting usage: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to connect to database: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("🚀 Starting Basic plan usage counter reset...")
|
||||
logger.info("="*60)
|
||||
logger.info("This will reset all usage counters (calls, tokens, images) to 0")
|
||||
logger.info("for all Basic plan users in their current billing period.")
|
||||
logger.info("="*60)
|
||||
|
||||
try:
|
||||
success = reset_basic_plan_usage()
|
||||
|
||||
if success:
|
||||
logger.info("\n✅ Script completed successfully!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
logger.error("\n❌ Script failed!")
|
||||
sys.exit(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n⚠️ Script cancelled by user")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"\n❌ Unexpected error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
279
backend/scripts/update_basic_plan_limits.py
Normal file
279
backend/scripts/update_basic_plan_limits.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""
|
||||
Script to update Basic plan subscription limits for testing rate limits and renewal flows.
|
||||
|
||||
Updates:
|
||||
- LLM Calls (all providers): 10 calls (was 500-1000)
|
||||
- LLM Tokens (all providers): 2000 tokens (was 200k-1M)
|
||||
- Images: 5 images (was 50)
|
||||
|
||||
This script updates the SubscriptionPlan table, which automatically applies to all users
|
||||
who have a Basic plan subscription via the plan_id foreign key.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Add the backend directory to Python path
|
||||
backend_dir = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(backend_dir))
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from loguru import logger
|
||||
|
||||
from models.subscription_models import SubscriptionPlan, SubscriptionTier, UserSubscription, UsageStatus
|
||||
from services.database import DATABASE_URL
|
||||
|
||||
def update_basic_plan_limits():
|
||||
"""Update Basic plan limits for testing rate limits and renewal."""
|
||||
|
||||
try:
|
||||
engine = create_engine(DATABASE_URL, echo=False)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
db = SessionLocal()
|
||||
|
||||
try:
|
||||
# Find Basic plan
|
||||
basic_plan = db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.tier == SubscriptionTier.BASIC
|
||||
).first()
|
||||
|
||||
if not basic_plan:
|
||||
logger.error("❌ Basic plan not found in database!")
|
||||
return False
|
||||
|
||||
# Store old values for logging
|
||||
old_limits = {
|
||||
'gemini_calls': basic_plan.gemini_calls_limit,
|
||||
'mistral_calls': basic_plan.mistral_calls_limit,
|
||||
'gemini_tokens': basic_plan.gemini_tokens_limit,
|
||||
'mistral_tokens': basic_plan.mistral_tokens_limit,
|
||||
'stability_calls': basic_plan.stability_calls_limit,
|
||||
}
|
||||
|
||||
logger.info(f"📋 Current Basic plan limits:")
|
||||
logger.info(f" Gemini Calls: {old_limits['gemini_calls']}")
|
||||
logger.info(f" Mistral Calls: {old_limits['mistral_calls']}")
|
||||
logger.info(f" Gemini Tokens: {old_limits['gemini_tokens']}")
|
||||
logger.info(f" Mistral Tokens: {old_limits['mistral_tokens']}")
|
||||
logger.info(f" Images (Stability): {old_limits['stability_calls']}")
|
||||
|
||||
# Update unified AI text generation limit to 10
|
||||
basic_plan.ai_text_generation_calls_limit = 10
|
||||
|
||||
# Legacy per-provider limits (kept for backwards compatibility, but not used for enforcement)
|
||||
basic_plan.gemini_calls_limit = 1000
|
||||
basic_plan.openai_calls_limit = 500
|
||||
basic_plan.anthropic_calls_limit = 200
|
||||
basic_plan.mistral_calls_limit = 500
|
||||
|
||||
# Update all LLM provider token limits to 2000
|
||||
basic_plan.gemini_tokens_limit = 2000
|
||||
basic_plan.openai_tokens_limit = 2000
|
||||
basic_plan.anthropic_tokens_limit = 2000
|
||||
basic_plan.mistral_tokens_limit = 2000
|
||||
|
||||
# Update image generation limit to 5
|
||||
basic_plan.stability_calls_limit = 5
|
||||
|
||||
# Update timestamp
|
||||
basic_plan.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
logger.info("\n📝 New Basic plan limits:")
|
||||
logger.info(f" LLM Calls (all providers): 10")
|
||||
logger.info(f" LLM Tokens (all providers): 2000")
|
||||
logger.info(f" Images: 5")
|
||||
|
||||
# Count and get affected users
|
||||
user_subscriptions = db.query(UserSubscription).filter(
|
||||
UserSubscription.plan_id == basic_plan.id,
|
||||
UserSubscription.is_active == True
|
||||
).all()
|
||||
|
||||
affected_users = len(user_subscriptions)
|
||||
|
||||
logger.info(f"\n👥 Users affected: {affected_users}")
|
||||
|
||||
if affected_users > 0:
|
||||
logger.info("\n📋 Affected user IDs:")
|
||||
for sub in user_subscriptions:
|
||||
logger.info(f" - {sub.user_id}")
|
||||
else:
|
||||
logger.info(" (No active Basic plan subscriptions found)")
|
||||
|
||||
# Commit plan limit changes first
|
||||
db.commit()
|
||||
logger.info("\n✅ Basic plan limits updated successfully!")
|
||||
|
||||
# Cap usage at new limits for all affected users (preserve historical data, but cap enforcement)
|
||||
logger.info("\n🔄 Capping usage counters at new limits for Basic plan users...")
|
||||
logger.info(" (Historical usage preserved, but capped to allow new calls within limits)")
|
||||
from models.subscription_models import UsageSummary
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing_service = PricingService(db)
|
||||
capped_count = 0
|
||||
|
||||
# New limits - use unified AI text generation limit if available
|
||||
new_call_limit = getattr(basic_plan, 'ai_text_generation_calls_limit', None) or basic_plan.gemini_calls_limit
|
||||
new_token_limit = basic_plan.gemini_tokens_limit # 2000
|
||||
new_image_limit = basic_plan.stability_calls_limit # 5
|
||||
|
||||
for sub in user_subscriptions:
|
||||
try:
|
||||
# Get current billing period for this user
|
||||
current_period = pricing_service.get_current_billing_period(sub.user_id) or datetime.now(timezone.utc).strftime("%Y-%m")
|
||||
|
||||
# Find usage summary for current period
|
||||
usage_summary = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == sub.user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if usage_summary:
|
||||
# Store old values for logging
|
||||
old_gemini = usage_summary.gemini_calls or 0
|
||||
old_mistral = usage_summary.mistral_calls or 0
|
||||
old_openai = usage_summary.openai_calls or 0
|
||||
old_anthropic = usage_summary.anthropic_calls or 0
|
||||
old_tokens = max(
|
||||
usage_summary.gemini_tokens or 0,
|
||||
usage_summary.openai_tokens or 0,
|
||||
usage_summary.anthropic_tokens or 0,
|
||||
usage_summary.mistral_tokens or 0
|
||||
)
|
||||
old_images = usage_summary.stability_calls or 0
|
||||
|
||||
# Cap LLM provider counters at new limits (don't reset, just cap)
|
||||
# This allows historical data to remain but prevents blocking from old usage
|
||||
usage_summary.gemini_calls = min(old_gemini, new_call_limit)
|
||||
usage_summary.mistral_calls = min(old_mistral, new_call_limit)
|
||||
usage_summary.openai_calls = min(old_openai, new_call_limit)
|
||||
usage_summary.anthropic_calls = min(old_anthropic, new_call_limit)
|
||||
|
||||
# Cap token counters at new limits
|
||||
usage_summary.gemini_tokens = min(usage_summary.gemini_tokens or 0, new_token_limit)
|
||||
usage_summary.openai_tokens = min(usage_summary.openai_tokens or 0, new_token_limit)
|
||||
usage_summary.anthropic_tokens = min(usage_summary.anthropic_tokens or 0, new_token_limit)
|
||||
usage_summary.mistral_tokens = min(usage_summary.mistral_tokens or 0, new_token_limit)
|
||||
|
||||
# Cap image counter at new limit
|
||||
usage_summary.stability_calls = min(old_images, new_image_limit)
|
||||
|
||||
# Update totals based on capped values (approximate)
|
||||
# Recalculate total_calls and total_tokens based on capped provider values
|
||||
total_capped_calls = (
|
||||
usage_summary.gemini_calls +
|
||||
usage_summary.mistral_calls +
|
||||
usage_summary.openai_calls +
|
||||
usage_summary.anthropic_calls +
|
||||
usage_summary.stability_calls
|
||||
)
|
||||
total_capped_tokens = (
|
||||
usage_summary.gemini_tokens +
|
||||
usage_summary.mistral_tokens +
|
||||
usage_summary.openai_tokens +
|
||||
usage_summary.anthropic_tokens
|
||||
)
|
||||
|
||||
usage_summary.total_calls = total_capped_calls
|
||||
usage_summary.total_tokens = total_capped_tokens
|
||||
|
||||
# Reset status to active to allow new calls
|
||||
usage_summary.usage_status = UsageStatus.ACTIVE
|
||||
usage_summary.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
db.commit()
|
||||
capped_count += 1
|
||||
|
||||
logger.info(f" ✅ Capped usage for user {sub.user_id}:")
|
||||
logger.info(f" Gemini Calls: {old_gemini} → {usage_summary.gemini_calls} (limit: {new_call_limit})")
|
||||
logger.info(f" Mistral Calls: {old_mistral} → {usage_summary.mistral_calls} (limit: {new_call_limit})")
|
||||
logger.info(f" Tokens: {old_tokens} → {max(usage_summary.gemini_tokens, usage_summary.mistral_tokens)} (limit: {new_token_limit})")
|
||||
logger.info(f" Images: {old_images} → {usage_summary.stability_calls} (limit: {new_image_limit})")
|
||||
else:
|
||||
logger.info(f" ℹ️ No usage summary found for user {sub.user_id} (period {current_period})")
|
||||
|
||||
except Exception as cap_error:
|
||||
logger.error(f" ❌ Error capping usage for user {sub.user_id}: {cap_error}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
db.rollback()
|
||||
|
||||
if capped_count > 0:
|
||||
logger.info(f"\n✅ Capped usage counters for {capped_count} user(s)")
|
||||
logger.info(" Historical usage preserved, but capped at new limits to allow new calls")
|
||||
else:
|
||||
logger.info("\nℹ️ No usage counters to cap")
|
||||
|
||||
# Note about cache clearing
|
||||
logger.info("\n🔄 Cache Information:")
|
||||
logger.info(" The subscription limits cache is per-instance and will refresh on next request.")
|
||||
logger.info(" No manual cache clearing needed - limits will be read from database on next check.")
|
||||
|
||||
# Display final summary
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("BASIC PLAN UPDATE SUMMARY")
|
||||
logger.info("="*60)
|
||||
logger.info(f"\nPlan: {basic_plan.name} ({basic_plan.tier.value})")
|
||||
logger.info(f"Price: ${basic_plan.price_monthly}/mo, ${basic_plan.price_yearly}/yr")
|
||||
logger.info(f"\nUpdated Limits:")
|
||||
logger.info(f" LLM Calls (gemini/openai/anthropic/mistral): {basic_plan.gemini_calls_limit}")
|
||||
logger.info(f" LLM Tokens (gemini/openai/anthropic/mistral): {basic_plan.gemini_tokens_limit}")
|
||||
logger.info(f" Images (stability): {basic_plan.stability_calls_limit}")
|
||||
logger.info(f"\nUsers Affected: {affected_users}")
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("\n💡 Note: These limits apply immediately to all Basic plan users.")
|
||||
logger.info(" Historical usage has been preserved but capped at new limits.")
|
||||
logger.info(" Users can continue making new calls up to the new limits.")
|
||||
logger.info(" Users will hit rate limits faster for testing purposes.")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"❌ Error updating Basic plan: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to connect to database: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("🚀 Starting Basic plan limits update...")
|
||||
logger.info("="*60)
|
||||
logger.info("This will update Basic plan limits for testing rate limits:")
|
||||
logger.info(" - LLM Calls: 10 (all providers)")
|
||||
logger.info(" - LLM Tokens: 2000 (all providers)")
|
||||
logger.info(" - Images: 5")
|
||||
logger.info("="*60)
|
||||
|
||||
# Ask for confirmation in non-interactive mode, proceed directly
|
||||
# In interactive mode, you can add: input("\nPress Enter to continue or Ctrl+C to cancel...")
|
||||
|
||||
try:
|
||||
success = update_basic_plan_limits()
|
||||
|
||||
if success:
|
||||
logger.info("\n✅ Script completed successfully!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
logger.error("\n❌ Script failed!")
|
||||
sys.exit(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n⚠️ Script cancelled by user")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"\n❌ Unexpected error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
@@ -295,3 +295,55 @@ class ActiveStrategyService:
|
||||
'cached_users': list(self._memory_cache.keys()),
|
||||
'last_updates': {k: v.isoformat() for k, v in self._last_cache_update.items()}
|
||||
}
|
||||
|
||||
def count_active_strategies_with_tasks(self) -> int:
|
||||
"""
|
||||
Count how many active strategies have monitoring tasks.
|
||||
|
||||
This is used for intelligent scheduling - if there are no active strategies
|
||||
with tasks, the scheduler can check less frequently.
|
||||
|
||||
Returns:
|
||||
Number of active strategies that have at least one active monitoring task
|
||||
"""
|
||||
try:
|
||||
if not self.db_session:
|
||||
logger.warning("Database session not available")
|
||||
return 0
|
||||
|
||||
from sqlalchemy import func, and_
|
||||
from models.monitoring_models import MonitoringTask
|
||||
|
||||
# Count distinct strategies that:
|
||||
# 1. Have activation status = 'active'
|
||||
# 2. Have at least one active monitoring task
|
||||
count = self.db_session.query(
|
||||
func.count(func.distinct(EnhancedContentStrategy.id))
|
||||
).join(
|
||||
StrategyActivationStatus,
|
||||
EnhancedContentStrategy.id == StrategyActivationStatus.strategy_id
|
||||
).join(
|
||||
MonitoringTask,
|
||||
EnhancedContentStrategy.id == MonitoringTask.strategy_id
|
||||
).filter(
|
||||
and_(
|
||||
StrategyActivationStatus.status == 'active',
|
||||
MonitoringTask.status == 'active'
|
||||
)
|
||||
).scalar()
|
||||
|
||||
return count or 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting active strategies with tasks: {e}")
|
||||
# On error, assume there are active strategies (safer to check more frequently)
|
||||
return 1
|
||||
|
||||
def has_active_strategies_with_tasks(self) -> bool:
|
||||
"""
|
||||
Check if there are any active strategies with monitoring tasks.
|
||||
|
||||
Returns:
|
||||
True if there are active strategies with tasks, False otherwise
|
||||
"""
|
||||
return self.count_active_strategies_with_tasks() > 0
|
||||
@@ -96,13 +96,13 @@ class BlogWriterService:
|
||||
self.blog_rewriter = BlogRewriter(self.task_manager)
|
||||
|
||||
# Research Methods
|
||||
async def research(self, request: BlogResearchRequest) -> BlogResearchResponse:
|
||||
async def research(self, request: BlogResearchRequest, user_id: str) -> BlogResearchResponse:
|
||||
"""Conduct comprehensive research using Google Search grounding."""
|
||||
return await self.research_service.research(request)
|
||||
return await self.research_service.research(request, user_id)
|
||||
|
||||
async def research_with_progress(self, request: BlogResearchRequest, task_id: str) -> BlogResearchResponse:
|
||||
async def research_with_progress(self, request: BlogResearchRequest, task_id: str, user_id: str) -> BlogResearchResponse:
|
||||
"""Conduct research with real-time progress updates."""
|
||||
return await self.research_service.research_with_progress(request, task_id)
|
||||
return await self.research_service.research_with_progress(request, task_id, user_id)
|
||||
|
||||
# Outline Methods
|
||||
async def generate_outline(self, request: BlogOutlineRequest) -> BlogOutlineResponse:
|
||||
@@ -204,11 +204,14 @@ class BlogWriterService:
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def seo_analyze(self, request: BlogSEOAnalyzeRequest) -> BlogSEOAnalyzeResponse:
|
||||
async def seo_analyze(self, request: BlogSEOAnalyzeRequest, user_id: str = None) -> BlogSEOAnalyzeResponse:
|
||||
"""Analyze content for SEO optimization using comprehensive blog-specific analyzer."""
|
||||
try:
|
||||
from services.blog_writer.seo.blog_content_seo_analyzer import BlogContentSEOAnalyzer
|
||||
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
|
||||
content = request.content or ""
|
||||
target_keywords = request.keywords or []
|
||||
|
||||
@@ -231,7 +234,7 @@ class BlogWriterService:
|
||||
|
||||
# Use our comprehensive SEO analyzer
|
||||
analyzer = BlogContentSEOAnalyzer()
|
||||
analysis_results = await analyzer.analyze_blog_content(content, research_data)
|
||||
analysis_results = await analyzer.analyze_blog_content(content, research_data, user_id=user_id)
|
||||
|
||||
# Convert results to response format
|
||||
recommendations = analysis_results.get('actionable_recommendations', [])
|
||||
@@ -267,11 +270,14 @@ class BlogWriterService:
|
||||
recommendations=[f"SEO analysis failed: {str(e)}"]
|
||||
)
|
||||
|
||||
async def seo_metadata(self, request: BlogSEOMetadataRequest) -> BlogSEOMetadataResponse:
|
||||
async def seo_metadata(self, request: BlogSEOMetadataRequest, user_id: str = None) -> BlogSEOMetadataResponse:
|
||||
"""Generate comprehensive SEO metadata for content."""
|
||||
try:
|
||||
from services.blog_writer.seo.blog_seo_metadata_generator import BlogSEOMetadataGenerator
|
||||
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
|
||||
# Initialize metadata generator
|
||||
metadata_generator = BlogSEOMetadataGenerator()
|
||||
|
||||
@@ -285,7 +291,8 @@ class BlogWriterService:
|
||||
blog_title=request.title or "Untitled Blog Post",
|
||||
research_data=request.research_data or {},
|
||||
outline=outline,
|
||||
seo_analysis=seo_analysis
|
||||
seo_analysis=seo_analysis,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Convert to BlogSEOMetadataResponse format
|
||||
|
||||
@@ -163,13 +163,18 @@ class BlogWriterLogger:
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""Log error with full context."""
|
||||
# Safely format error message to avoid KeyError on format strings in error messages
|
||||
error_str = str(error)
|
||||
# Replace any curly braces that might be in the error message to avoid format string issues
|
||||
safe_error_str = error_str.replace('{', '{{').replace('}', '}}')
|
||||
|
||||
logger.error(
|
||||
f"Error in {operation}: {str(error)}",
|
||||
f"Error in {operation}: {safe_error_str}",
|
||||
extra={
|
||||
"event_type": "error",
|
||||
"operation": operation,
|
||||
"error_type": type(error).__name__,
|
||||
"error_message": str(error),
|
||||
"error_message": error_str, # Keep original in extra, but use safe version in format string
|
||||
"context": context or {}
|
||||
},
|
||||
exc_info=True
|
||||
|
||||
@@ -11,7 +11,7 @@ from loguru import logger
|
||||
class CompetitorAnalyzer:
|
||||
"""Analyzes competitors and market intelligence from research content."""
|
||||
|
||||
def analyze(self, content: str) -> Dict[str, Any]:
|
||||
def analyze(self, content: str, user_id: str = None) -> Dict[str, Any]:
|
||||
"""Parse comprehensive competitor analysis from the research content using AI."""
|
||||
competitor_prompt = f"""
|
||||
Analyze the following research content and extract competitor insights:
|
||||
@@ -57,7 +57,8 @@ class CompetitorAnalyzer:
|
||||
|
||||
competitor_analysis = llm_text_gen(
|
||||
prompt=competitor_prompt,
|
||||
json_struct=competitor_schema
|
||||
json_struct=competitor_schema,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if isinstance(competitor_analysis, dict) and 'error' not in competitor_analysis:
|
||||
|
||||
@@ -11,7 +11,7 @@ from loguru import logger
|
||||
class ContentAngleGenerator:
|
||||
"""Generates strategic content angles from research content."""
|
||||
|
||||
def generate(self, content: str, topic: str, industry: str) -> List[str]:
|
||||
def generate(self, content: str, topic: str, industry: str, user_id: str = None) -> List[str]:
|
||||
"""Parse strategic content angles from the research content using AI."""
|
||||
angles_prompt = f"""
|
||||
Analyze the following research content and create strategic content angles for: {topic} in {industry}
|
||||
@@ -65,7 +65,8 @@ class ContentAngleGenerator:
|
||||
|
||||
angles_result = llm_text_gen(
|
||||
prompt=angles_prompt,
|
||||
json_struct=angles_schema
|
||||
json_struct=angles_schema,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if isinstance(angles_result, dict) and 'content_angles' in angles_result:
|
||||
|
||||
@@ -11,7 +11,7 @@ from loguru import logger
|
||||
class KeywordAnalyzer:
|
||||
"""Analyzes keywords from research content using AI-powered extraction."""
|
||||
|
||||
def analyze(self, content: str, original_keywords: List[str]) -> Dict[str, Any]:
|
||||
def analyze(self, content: str, original_keywords: List[str], user_id: str = None) -> Dict[str, Any]:
|
||||
"""Parse comprehensive keyword analysis from the research content using AI."""
|
||||
# Use AI to extract and analyze keywords from the rich research content
|
||||
keyword_prompt = f"""
|
||||
@@ -64,7 +64,8 @@ class KeywordAnalyzer:
|
||||
|
||||
keyword_analysis = llm_text_gen(
|
||||
prompt=keyword_prompt,
|
||||
json_struct=keyword_schema
|
||||
json_struct=keyword_schema,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if isinstance(keyword_analysis, dict) and 'error' not in keyword_analysis:
|
||||
|
||||
@@ -4,7 +4,8 @@ Research Service - Core research functionality for AI Blog Writer.
|
||||
Handles Google Search grounding, caching, and research orchestration.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
|
||||
from models.blog_models import (
|
||||
@@ -17,6 +18,7 @@ from models.blog_models import (
|
||||
Citation,
|
||||
)
|
||||
from services.blog_writer.logger_config import blog_writer_logger, log_function_call
|
||||
from fastapi import HTTPException
|
||||
|
||||
from .keyword_analyzer import KeywordAnalyzer
|
||||
from .competitor_analyzer import CompetitorAnalyzer
|
||||
@@ -34,7 +36,7 @@ class ResearchService:
|
||||
self.data_filter = ResearchDataFilter()
|
||||
|
||||
@log_function_call("research_operation")
|
||||
async def research(self, request: BlogResearchRequest) -> BlogResearchResponse:
|
||||
async def research(self, request: BlogResearchRequest, user_id: str) -> BlogResearchResponse:
|
||||
"""
|
||||
Stage 1: Research & Strategy (AI Orchestration)
|
||||
Uses ONLY Gemini's native Google Search grounding - ONE API call for everything.
|
||||
@@ -71,6 +73,10 @@ class ResearchService:
|
||||
blog_writer_logger.log_operation_end("research", 0, success=True, cache_hit=True)
|
||||
return BlogResearchResponse(**cached_result)
|
||||
|
||||
# User ID validation (validation logic is now in Google Grounding provider)
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for research operation. Please provide Clerk user ID.")
|
||||
|
||||
# Cache miss - proceed with API call
|
||||
logger.info(f"Cache miss - making API call for keywords: {request.keywords}")
|
||||
blog_writer_logger.log_operation_start("gemini_api_call", api_name="gemini_grounded", operation="research")
|
||||
@@ -96,12 +102,15 @@ class ResearchService:
|
||||
"""
|
||||
|
||||
# Single Gemini call with native Google Search grounding - no fallbacks
|
||||
# Validation is handled inside generate_grounded_content when validate_subsequent_operations=True
|
||||
import time
|
||||
api_start_time = time.time()
|
||||
gemini_result = await gemini.generate_grounded_content(
|
||||
prompt=research_prompt,
|
||||
content_type="research",
|
||||
max_tokens=2000
|
||||
max_tokens=2000,
|
||||
user_id=user_id,
|
||||
validate_subsequent_operations=True # Validates Google Grounding + 3 LLM calls
|
||||
)
|
||||
api_duration_ms = (time.time() - api_start_time) * 1000
|
||||
|
||||
@@ -126,9 +135,9 @@ class ResearchService:
|
||||
|
||||
# Parse the comprehensive response for different analysis components
|
||||
content = gemini_result.get("content", "")
|
||||
keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords)
|
||||
competitor_analysis = self.competitor_analyzer.analyze(content)
|
||||
suggested_angles = self.content_angle_generator.generate(content, topic, industry)
|
||||
keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords, user_id=user_id)
|
||||
competitor_analysis = self.competitor_analyzer.analyze(content, user_id=user_id)
|
||||
suggested_angles = self.content_angle_generator.generate(content, topic, industry, user_id=user_id)
|
||||
|
||||
logger.info(f"Research completed successfully with {len(sources)} sources and {len(search_queries)} search queries")
|
||||
|
||||
@@ -179,6 +188,9 @@ class ResearchService:
|
||||
|
||||
return filtered_response
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTPException (subscription errors) - let task manager handle it
|
||||
raise
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.error(f"Research failed: {error_message}")
|
||||
@@ -244,7 +256,7 @@ class ResearchService:
|
||||
)
|
||||
|
||||
@log_function_call("research_with_progress")
|
||||
async def research_with_progress(self, request: BlogResearchRequest, task_id: str) -> BlogResearchResponse:
|
||||
async def research_with_progress(self, request: BlogResearchRequest, task_id: str, user_id: str) -> BlogResearchResponse:
|
||||
"""
|
||||
Research method with progress updates for real-time feedback.
|
||||
"""
|
||||
@@ -281,6 +293,11 @@ class ResearchService:
|
||||
logger.info(f"Returning cached research result for keywords: {request.keywords}")
|
||||
return BlogResearchResponse(**cached_result)
|
||||
|
||||
# User ID validation (validation logic is now in Google Grounding provider)
|
||||
if not user_id:
|
||||
await task_manager.update_progress(task_id, "❌ Error: User ID is required for research operation")
|
||||
raise ValueError("user_id is required for research operation. Please provide Clerk user ID.")
|
||||
|
||||
# Cache miss - proceed with API call
|
||||
await task_manager.update_progress(task_id, "🌐 Cache miss - connecting to Google Search grounding...")
|
||||
logger.info(f"Cache miss - making API call for keywords: {request.keywords}")
|
||||
@@ -307,11 +324,20 @@ class ResearchService:
|
||||
|
||||
await task_manager.update_progress(task_id, "🤖 Making AI request to Gemini with Google Search grounding...")
|
||||
# Single Gemini call with native Google Search grounding - no fallbacks
|
||||
gemini_result = await gemini.generate_grounded_content(
|
||||
prompt=research_prompt,
|
||||
content_type="research",
|
||||
max_tokens=2000
|
||||
)
|
||||
# Validation is handled inside generate_grounded_content when validate_subsequent_operations=True
|
||||
try:
|
||||
gemini_result = await gemini.generate_grounded_content(
|
||||
prompt=research_prompt,
|
||||
content_type="research",
|
||||
max_tokens=2000,
|
||||
user_id=user_id,
|
||||
validate_subsequent_operations=True # Validates Google Grounding + 3 LLM calls
|
||||
)
|
||||
except HTTPException as http_error:
|
||||
# Re-raise HTTPException so it can be properly handled by task manager
|
||||
logger.error(f"Subscription limit exceeded for research: {http_error.detail}")
|
||||
await task_manager.update_progress(task_id, f"❌ Subscription limit exceeded: {http_error.detail.get('message', str(http_error.detail)) if isinstance(http_error.detail, dict) else str(http_error.detail)}")
|
||||
raise # Re-raise HTTPException to preserve status code and error details
|
||||
|
||||
await task_manager.update_progress(task_id, "📊 Processing research results and extracting insights...")
|
||||
# Extract sources from grounding metadata
|
||||
@@ -327,9 +353,9 @@ class ResearchService:
|
||||
await task_manager.update_progress(task_id, "🔍 Analyzing keywords and content angles...")
|
||||
# Parse the comprehensive response for different analysis components
|
||||
content = gemini_result.get("content", "")
|
||||
keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords)
|
||||
competitor_analysis = self.competitor_analyzer.analyze(content)
|
||||
suggested_angles = self.content_angle_generator.generate(content, topic, industry)
|
||||
keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords, user_id=user_id)
|
||||
competitor_analysis = self.competitor_analyzer.analyze(content, user_id=user_id)
|
||||
suggested_angles = self.content_angle_generator.generate(content, topic, industry, user_id=user_id)
|
||||
|
||||
await task_manager.update_progress(task_id, "💾 Caching results for future use...")
|
||||
logger.info(f"Research completed successfully with {len(sources)} sources and {len(search_queries)} search queries")
|
||||
@@ -373,6 +399,9 @@ class ResearchService:
|
||||
|
||||
return filtered_response
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTPException (subscription errors) - let task manager handle it
|
||||
raise
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.error(f"Research failed: {error_message}")
|
||||
|
||||
@@ -34,17 +34,21 @@ class BlogContentSEOAnalyzer:
|
||||
|
||||
logger.info("BlogContentSEOAnalyzer initialized")
|
||||
|
||||
async def analyze_blog_content(self, blog_content: str, research_data: Dict[str, Any], blog_title: Optional[str] = None) -> Dict[str, Any]:
|
||||
async def analyze_blog_content(self, blog_content: str, research_data: Dict[str, Any], blog_title: Optional[str] = None, user_id: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Main analysis method with parallel processing
|
||||
|
||||
Args:
|
||||
blog_content: The blog content to analyze
|
||||
research_data: Research data containing keywords and other insights
|
||||
blog_title: Optional blog title
|
||||
user_id: Clerk user ID for subscription checking (required)
|
||||
|
||||
Returns:
|
||||
Comprehensive SEO analysis results
|
||||
"""
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
try:
|
||||
logger.info("Starting blog content SEO analysis")
|
||||
|
||||
@@ -58,7 +62,7 @@ class BlogContentSEOAnalyzer:
|
||||
|
||||
# Phase 2: Single AI analysis for structured insights
|
||||
logger.info("Running AI analysis")
|
||||
ai_insights = await self._run_ai_analysis(blog_content, keywords_data, non_ai_results)
|
||||
ai_insights = await self._run_ai_analysis(blog_content, keywords_data, non_ai_results, user_id=user_id)
|
||||
|
||||
# Phase 3: Compile and format results
|
||||
logger.info("Compiling results")
|
||||
@@ -599,8 +603,10 @@ class BlogContentSEOAnalyzer:
|
||||
|
||||
return recommendations
|
||||
|
||||
async def _run_ai_analysis(self, blog_content: str, keywords_data: Dict[str, Any], non_ai_results: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def _run_ai_analysis(self, blog_content: str, keywords_data: Dict[str, Any], non_ai_results: Dict[str, Any], user_id: str = None) -> Dict[str, Any]:
|
||||
"""Run single AI analysis for structured insights (provider-agnostic)"""
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
try:
|
||||
# Prepare context for AI analysis
|
||||
context = {
|
||||
@@ -658,7 +664,8 @@ class BlogContentSEOAnalyzer:
|
||||
ai_response = llm_text_gen(
|
||||
prompt=prompt,
|
||||
json_struct=schema,
|
||||
system_prompt=None
|
||||
system_prompt=None,
|
||||
user_id=user_id # Pass user_id for subscription checking
|
||||
)
|
||||
|
||||
return ai_response
|
||||
|
||||
@@ -28,7 +28,8 @@ class BlogSEOMetadataGenerator:
|
||||
blog_title: str,
|
||||
research_data: Dict[str, Any],
|
||||
outline: Optional[List[Dict[str, Any]]] = None,
|
||||
seo_analysis: Optional[Dict[str, Any]] = None
|
||||
seo_analysis: Optional[Dict[str, Any]] = None,
|
||||
user_id: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate comprehensive SEO metadata using maximum 2 AI calls
|
||||
@@ -39,10 +40,13 @@ class BlogSEOMetadataGenerator:
|
||||
research_data: Research data containing keywords and insights
|
||||
outline: Outline structure with sections and headings
|
||||
seo_analysis: SEO analysis results from previous phase
|
||||
user_id: Clerk user ID for subscription checking (required)
|
||||
|
||||
Returns:
|
||||
Comprehensive metadata including all SEO elements
|
||||
"""
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
try:
|
||||
logger.info("Starting comprehensive SEO metadata generation")
|
||||
|
||||
@@ -53,13 +57,13 @@ class BlogSEOMetadataGenerator:
|
||||
# Call 1: Generate core SEO metadata (parallel with Call 2)
|
||||
logger.info("Generating core SEO metadata")
|
||||
core_metadata_task = self._generate_core_metadata(
|
||||
blog_content, blog_title, keywords_data, outline, seo_analysis
|
||||
blog_content, blog_title, keywords_data, outline, seo_analysis, user_id=user_id
|
||||
)
|
||||
|
||||
# Call 2: Generate social media and structured data (parallel with Call 1)
|
||||
logger.info("Generating social media and structured data")
|
||||
social_metadata_task = self._generate_social_metadata(
|
||||
blog_content, blog_title, keywords_data, outline, seo_analysis
|
||||
blog_content, blog_title, keywords_data, outline, seo_analysis, user_id=user_id
|
||||
)
|
||||
|
||||
# Wait for both calls to complete
|
||||
@@ -114,9 +118,12 @@ class BlogSEOMetadataGenerator:
|
||||
blog_title: str,
|
||||
keywords_data: Dict[str, Any],
|
||||
outline: Optional[List[Dict[str, Any]]] = None,
|
||||
seo_analysis: Optional[Dict[str, Any]] = None
|
||||
seo_analysis: Optional[Dict[str, Any]] = None,
|
||||
user_id: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate core SEO metadata (Call 1)"""
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
try:
|
||||
# Create comprehensive prompt for core metadata
|
||||
prompt = self._create_core_metadata_prompt(
|
||||
@@ -170,7 +177,8 @@ class BlogSEOMetadataGenerator:
|
||||
ai_response_raw = llm_text_gen(
|
||||
prompt=prompt,
|
||||
json_struct=schema,
|
||||
system_prompt=None
|
||||
system_prompt=None,
|
||||
user_id=user_id # Pass user_id for subscription checking
|
||||
)
|
||||
|
||||
# Handle response: llm_text_gen may return dict (from structured JSON) or str (needs parsing)
|
||||
@@ -215,9 +223,12 @@ class BlogSEOMetadataGenerator:
|
||||
blog_title: str,
|
||||
keywords_data: Dict[str, Any],
|
||||
outline: Optional[List[Dict[str, Any]]] = None,
|
||||
seo_analysis: Optional[Dict[str, Any]] = None
|
||||
seo_analysis: Optional[Dict[str, Any]] = None,
|
||||
user_id: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate social media and structured data (Call 2)"""
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
try:
|
||||
# Create comprehensive prompt for social metadata
|
||||
prompt = self._create_social_metadata_prompt(
|
||||
@@ -274,7 +285,8 @@ class BlogSEOMetadataGenerator:
|
||||
ai_response_raw = llm_text_gen(
|
||||
prompt=prompt,
|
||||
json_struct=schema,
|
||||
system_prompt=None
|
||||
system_prompt=None,
|
||||
user_id=user_id # Pass user_id for subscription checking
|
||||
)
|
||||
|
||||
# Handle response: llm_text_gen may return dict (from structured JSON) or str (needs parsing)
|
||||
|
||||
@@ -20,8 +20,11 @@ class BlogSEORecommendationApplier:
|
||||
def __init__(self):
|
||||
logger.debug("Initialized BlogSEORecommendationApplier")
|
||||
|
||||
async def apply_recommendations(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def apply_recommendations(self, payload: Dict[str, Any], user_id: str = None) -> Dict[str, Any]:
|
||||
"""Apply recommendations and return updated content."""
|
||||
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
|
||||
title = payload.get("title", "Untitled Blog")
|
||||
sections: List[Dict[str, Any]] = payload.get("sections", [])
|
||||
@@ -88,6 +91,7 @@ class BlogSEORecommendationApplier:
|
||||
prompt,
|
||||
None,
|
||||
schema,
|
||||
user_id, # Pass user_id for subscription checking
|
||||
)
|
||||
|
||||
if not result or result.get("error"):
|
||||
|
||||
@@ -56,7 +56,9 @@ class GeminiGroundedProvider:
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2048,
|
||||
urls: Optional[List[str]] = None,
|
||||
mode: str = "polished"
|
||||
mode: str = "polished",
|
||||
user_id: Optional[str] = None,
|
||||
validate_subsequent_operations: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate grounded content using native Google Search grounding.
|
||||
@@ -66,12 +68,49 @@ class GeminiGroundedProvider:
|
||||
content_type: Type of content to generate
|
||||
temperature: Creativity level (0.0-1.0)
|
||||
max_tokens: Maximum tokens in response
|
||||
urls: Optional list of URLs for URL Context tool
|
||||
mode: Content mode ("draft" or "polished")
|
||||
user_id: User ID for subscription checking (required if validate_subsequent_operations=True)
|
||||
validate_subsequent_operations: If True, validates Google Grounding + 3 LLM calls for research workflow
|
||||
|
||||
Returns:
|
||||
Dictionary containing generated content and grounding metadata
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Generating grounded content for {content_type} using native Google Search")
|
||||
# PRE-FLIGHT VALIDATION: If this is part of a research workflow, validate ALL operations
|
||||
# MUST happen BEFORE any API calls - return immediately if validation fails
|
||||
if validate_subsequent_operations:
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required when validate_subsequent_operations=True")
|
||||
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_research_operations
|
||||
from fastapi import HTTPException
|
||||
import os
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
gpt_provider = os.getenv("GPT_PROVIDER", "google")
|
||||
|
||||
# Validate ALL research operations before making ANY API calls
|
||||
# This prevents wasteful external API calls if subsequent LLM calls would fail
|
||||
# Raises HTTPException immediately if validation fails - frontend gets immediate response
|
||||
validate_research_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id,
|
||||
gpt_provider=gpt_provider
|
||||
)
|
||||
except HTTPException as http_ex:
|
||||
# Re-raise immediately - don't proceed with API call
|
||||
logger.error(f"[Gemini Grounded] ❌ Pre-flight validation failed - blocking API call")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
logger.info(f"[Gemini Grounded] ✅ Pre-flight validation passed - proceeding with API call")
|
||||
logger.info(f"[Gemini Grounded] Generating grounded content for {content_type} using native Google Search")
|
||||
|
||||
# Build the grounded prompt
|
||||
grounded_prompt = self._build_grounded_prompt(prompt, content_type)
|
||||
|
||||
@@ -40,7 +40,38 @@ def _get_provider(provider_name: str):
|
||||
raise ValueError(f"Unknown image provider: {provider_name}")
|
||||
|
||||
|
||||
def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None) -> ImageGenerationResult:
|
||||
def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> ImageGenerationResult:
|
||||
"""Generate image with pre-flight validation.
|
||||
|
||||
Args:
|
||||
prompt: Image generation prompt
|
||||
options: Image generation options (provider, model, width, height, etc.)
|
||||
user_id: User ID for subscription checking (optional, but required for validation)
|
||||
"""
|
||||
# PRE-FLIGHT VALIDATION: Validate image generation before API call
|
||||
# MUST happen BEFORE any API calls - return immediately if validation fails
|
||||
if user_id:
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_generation_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
# Raises HTTPException immediately if validation fails - frontend gets immediate response
|
||||
validate_image_generation_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id
|
||||
)
|
||||
except HTTPException as http_ex:
|
||||
# Re-raise immediately - don't proceed with API call
|
||||
logger.error(f"[Image Generation] ❌ Pre-flight validation failed - blocking API call")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
logger.info(f"[Image Generation] ✅ Pre-flight validation passed - proceeding with image generation")
|
||||
opts = options or {}
|
||||
provider_name = _select_provider(opts.get("provider"))
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ migrated from the legacy lib/gpt_providers/text_generation/main_text_generation.
|
||||
import os
|
||||
import json
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
from ..onboarding.api_key_manager import APIKeyManager
|
||||
|
||||
@@ -14,7 +15,7 @@ from .gemini_provider import gemini_text_response, gemini_structured_json_respon
|
||||
from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response
|
||||
|
||||
|
||||
def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: Optional[Dict[str, Any]] = None) -> str:
|
||||
def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: Optional[Dict[str, Any]] = None, user_id: str = None) -> str:
|
||||
"""
|
||||
Generate text using Language Model (LLM) based on the provided prompt.
|
||||
|
||||
@@ -22,9 +23,13 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
prompt (str): The prompt to generate text from.
|
||||
system_prompt (str, optional): Custom system prompt to use instead of the default one.
|
||||
json_struct (dict, optional): JSON schema structure for structured responses.
|
||||
user_id (str): Clerk user ID for subscription checking (required).
|
||||
|
||||
Returns:
|
||||
str: Generated text based on the prompt.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If subscription limits are exceeded or user_id is missing.
|
||||
"""
|
||||
try:
|
||||
logger.info("[llm_text_gen] Starting text generation")
|
||||
@@ -93,6 +98,75 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
|
||||
logger.debug(f"[llm_text_gen] Using provider: {gpt_provider}, model: {model}")
|
||||
|
||||
# Map provider name to APIProvider enum (define at function scope for usage tracking)
|
||||
from models.subscription_models import APIProvider
|
||||
provider_enum = None
|
||||
# Store actual provider name for logging (e.g., "huggingface", "gemini")
|
||||
actual_provider_name = None
|
||||
if gpt_provider == "google":
|
||||
provider_enum = APIProvider.GEMINI
|
||||
actual_provider_name = "gemini" # Use "gemini" for consistency in logs
|
||||
elif gpt_provider == "huggingface":
|
||||
provider_enum = APIProvider.MISTRAL # HuggingFace maps to Mistral enum for usage tracking
|
||||
actual_provider_name = "huggingface" # Keep actual provider name for logs
|
||||
|
||||
if not provider_enum:
|
||||
raise RuntimeError(f"Unknown provider {gpt_provider} for subscription checking")
|
||||
|
||||
# SUBSCRIPTION CHECK - Required and strict enforcement
|
||||
if not user_id:
|
||||
raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
|
||||
try:
|
||||
from services.database import get_db
|
||||
from services.subscription import UsageTrackingService, PricingService
|
||||
from models.subscription_models import UsageSummary
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
|
||||
usage_service = UsageTrackingService(db)
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
# Estimate tokens from prompt (input tokens)
|
||||
# Note: We estimate output tokens conservatively (assume response is similar length to prompt)
|
||||
# This prevents underestimating total token usage
|
||||
input_tokens = int(len(prompt.split()) * 1.3)
|
||||
# Conservative estimate: assume output tokens ≈ input tokens * 1.0 (can be up to max_tokens)
|
||||
estimated_output_tokens = min(input_tokens, max_tokens) if max_tokens else int(input_tokens * 0.8)
|
||||
estimated_total_tokens = input_tokens + estimated_output_tokens
|
||||
|
||||
# Check limits using sync method from pricing service (strict enforcement)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=provider_enum,
|
||||
tokens_requested=estimated_total_tokens,
|
||||
actual_provider_name=actual_provider_name # Pass actual provider name for correct error messages
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
logger.warning(f"[llm_text_gen] Subscription limit exceeded for user {user_id}: {message}")
|
||||
raise RuntimeError(f"Subscription limit exceeded: {message}")
|
||||
|
||||
# Get current usage for limit checking only
|
||||
current_period = pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
usage = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
# No separate log here - we'll create unified log after API call and usage tracking
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
except RuntimeError:
|
||||
# Re-raise subscription limit errors
|
||||
raise
|
||||
except Exception as sub_error:
|
||||
# STRICT: Fail on subscription check errors
|
||||
logger.error(f"[llm_text_gen] Subscription check failed for user {user_id}: {sub_error}")
|
||||
raise RuntimeError(f"Subscription check failed: {str(sub_error)}")
|
||||
|
||||
# Construct the system prompt if not provided
|
||||
if system_prompt is None:
|
||||
system_instructions = f"""You are a highly skilled content writer with a knack for creating engaging and informative content.
|
||||
@@ -117,10 +191,12 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
system_instructions = system_prompt
|
||||
|
||||
# Generate response based on provider
|
||||
response_text = None
|
||||
actual_provider_used = gpt_provider
|
||||
try:
|
||||
if gpt_provider == "google":
|
||||
if json_struct:
|
||||
return gemini_structured_json_response(
|
||||
response_text = gemini_structured_json_response(
|
||||
prompt=prompt,
|
||||
schema=json_struct,
|
||||
temperature=temperature,
|
||||
@@ -130,7 +206,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
else:
|
||||
return gemini_text_response(
|
||||
response_text = gemini_text_response(
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
@@ -140,7 +216,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
)
|
||||
elif gpt_provider == "huggingface":
|
||||
if json_struct:
|
||||
return huggingface_structured_json_response(
|
||||
response_text = huggingface_structured_json_response(
|
||||
prompt=prompt,
|
||||
schema=json_struct,
|
||||
model=model,
|
||||
@@ -149,7 +225,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
else:
|
||||
return huggingface_text_response(
|
||||
response_text = huggingface_text_response(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
@@ -160,6 +236,107 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
else:
|
||||
logger.error(f"[llm_text_gen] Unknown provider: {gpt_provider}")
|
||||
raise RuntimeError("Unknown LLM provider. Supported providers: google, huggingface")
|
||||
|
||||
# TRACK USAGE after successful API call
|
||||
if response_text:
|
||||
logger.info(f"[llm_text_gen] ✅ API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
|
||||
try:
|
||||
db_track = next(get_db())
|
||||
try:
|
||||
# Estimate tokens from prompt and response
|
||||
tokens_input = estimated_tokens # Already calculated above
|
||||
tokens_output = int(len(str(response_text).split()) * 1.3) # Estimate output tokens
|
||||
tokens_total = tokens_input + tokens_output
|
||||
|
||||
logger.debug(f"[llm_text_gen] Token estimates: input={tokens_input}, output={tokens_output}, total={tokens_total}")
|
||||
|
||||
# Get or create usage summary
|
||||
from models.subscription_models import UsageSummary
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
logger.debug(f"[llm_text_gen] Looking for usage summary: user_id={user_id}, period={current_period}")
|
||||
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
logger.info(f"[llm_text_gen] Creating new usage summary for user {user_id}, period {current_period}")
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush() # Ensure summary is persisted before updating
|
||||
|
||||
# Get "before" state for unified log
|
||||
provider_name = provider_enum.value
|
||||
current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0
|
||||
|
||||
# Update provider-specific counters (sync operation)
|
||||
new_calls = current_calls_before + 1
|
||||
setattr(summary, f"{provider_name}_calls", new_calls)
|
||||
logger.debug(f"[llm_text_gen] Updated {provider_name}_calls: {current_calls_before} -> {new_calls}")
|
||||
|
||||
# Update token usage for LLM providers
|
||||
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0
|
||||
new_tokens = current_tokens_before + tokens_total
|
||||
setattr(summary, f"{provider_name}_tokens", new_tokens)
|
||||
logger.debug(f"[llm_text_gen] Updated {provider_name}_tokens: {current_tokens_before} -> {new_tokens}")
|
||||
else:
|
||||
current_tokens_before = 0
|
||||
new_tokens = 0
|
||||
|
||||
# Update totals
|
||||
old_total_calls = summary.total_calls or 0
|
||||
old_total_tokens = summary.total_tokens or 0
|
||||
summary.total_calls = old_total_calls + 1
|
||||
summary.total_tokens = old_total_tokens + tokens_total
|
||||
logger.debug(f"[llm_text_gen] Updated totals: calls {old_total_calls} -> {summary.total_calls}, tokens {old_total_tokens} -> {summary.total_tokens}")
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0
|
||||
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) if limits else 0
|
||||
|
||||
# Get image stats for unified log
|
||||
current_images_before = getattr(summary, "stability_calls", 0) or 0
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"[llm_text_gen] ✅ Successfully tracked usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
|
||||
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")
|
||||
# Include image stats in the log
|
||||
print(f"""
|
||||
[SUBSCRIPTION] LLM Text Generation
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: {actual_provider_name}
|
||||
├─ Model: {model}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {call_limit if call_limit > 0 else '∞'}
|
||||
├─ Tokens: {current_tokens_before} → {new_tokens} / {token_limit if token_limit > 0 else '∞'}
|
||||
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""")
|
||||
except Exception as track_error:
|
||||
logger.error(f"[llm_text_gen] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
# Non-blocking: log error but don't fail the request
|
||||
logger.error(f"[llm_text_gen] ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
|
||||
return response_text
|
||||
except Exception as provider_error:
|
||||
logger.error(f"[llm_text_gen] Provider {gpt_provider} failed: {str(provider_error)}")
|
||||
|
||||
@@ -171,9 +348,21 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
fallback_provider = fallback_providers[0] # Only try the first available
|
||||
try:
|
||||
logger.info(f"[llm_text_gen] Trying SINGLE fallback provider: {fallback_provider}")
|
||||
actual_provider_used = fallback_provider
|
||||
|
||||
# Update provider enum for fallback
|
||||
if fallback_provider == "google":
|
||||
provider_enum = APIProvider.GEMINI
|
||||
actual_provider_name = "gemini"
|
||||
fallback_model = "gemini-2.0-flash-lite"
|
||||
elif fallback_provider == "huggingface":
|
||||
provider_enum = APIProvider.MISTRAL
|
||||
actual_provider_name = "huggingface"
|
||||
fallback_model = "openai/gpt-oss-120b:groq"
|
||||
|
||||
if fallback_provider == "google":
|
||||
if json_struct:
|
||||
return gemini_structured_json_response(
|
||||
response_text = gemini_structured_json_response(
|
||||
prompt=prompt,
|
||||
schema=json_struct,
|
||||
temperature=temperature,
|
||||
@@ -183,7 +372,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
else:
|
||||
return gemini_text_response(
|
||||
response_text = gemini_text_response(
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
@@ -193,7 +382,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
)
|
||||
elif fallback_provider == "huggingface":
|
||||
if json_struct:
|
||||
return huggingface_structured_json_response(
|
||||
response_text = huggingface_structured_json_response(
|
||||
prompt=prompt,
|
||||
schema=json_struct,
|
||||
model="openai/gpt-oss-120b:groq",
|
||||
@@ -202,7 +391,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
else:
|
||||
return huggingface_text_response(
|
||||
response_text = huggingface_text_response(
|
||||
prompt=prompt,
|
||||
model="openai/gpt-oss-120b:groq",
|
||||
temperature=temperature,
|
||||
@@ -210,6 +399,96 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
top_p=top_p,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
|
||||
# TRACK USAGE after successful fallback call
|
||||
if response_text:
|
||||
logger.info(f"[llm_text_gen] ✅ Fallback API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
|
||||
try:
|
||||
db_track = next(get_db())
|
||||
try:
|
||||
# Estimate tokens from prompt and response
|
||||
tokens_input = estimated_tokens
|
||||
tokens_output = int(len(str(response_text).split()) * 1.3)
|
||||
tokens_total = tokens_input + tokens_output
|
||||
|
||||
# Get or create usage summary
|
||||
from models.subscription_models import UsageSummary
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush() # Ensure summary is persisted before updating
|
||||
|
||||
# Get "before" state for unified log
|
||||
provider_name = provider_enum.value
|
||||
current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0
|
||||
|
||||
# Update provider-specific counters (sync operation)
|
||||
new_calls = current_calls_before + 1
|
||||
setattr(summary, f"{provider_name}_calls", new_calls)
|
||||
|
||||
# Update token usage for LLM providers
|
||||
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0
|
||||
new_tokens = current_tokens_before + tokens_total
|
||||
setattr(summary, f"{provider_name}_tokens", new_tokens)
|
||||
else:
|
||||
current_tokens_before = 0
|
||||
new_tokens = 0
|
||||
|
||||
# Update totals
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.total_tokens = (summary.total_tokens or 0) + tokens_total
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0
|
||||
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) if limits else 0
|
||||
|
||||
# Get image stats for unified log
|
||||
current_images_before = getattr(summary, "stability_calls", 0) or 0
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"[llm_text_gen] ✅ Successfully tracked fallback usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG for fallback
|
||||
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")
|
||||
# Include image stats in the log
|
||||
print(f"""
|
||||
[SUBSCRIPTION] LLM Text Generation (Fallback)
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: {actual_provider_name}
|
||||
├─ Model: {fallback_model}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {call_limit if call_limit > 0 else '∞'}
|
||||
├─ Tokens: {current_tokens_before} → {new_tokens} / {token_limit if token_limit > 0 else '∞'}
|
||||
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""")
|
||||
except Exception as track_error:
|
||||
logger.error(f"[llm_text_gen] ❌ Error tracking fallback usage (non-blocking): {track_error}", exc_info=True)
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"[llm_text_gen] ❌ Failed to track fallback usage: {usage_error}", exc_info=True)
|
||||
|
||||
return response_text
|
||||
except Exception as fallback_error:
|
||||
logger.error(f"[llm_text_gen] Fallback provider {fallback_provider} also failed: {str(fallback_error)}")
|
||||
|
||||
|
||||
@@ -55,6 +55,14 @@ class MonitoringDataService:
|
||||
alert_threshold=task_data.get('alertThreshold', ''),
|
||||
status='active'
|
||||
)
|
||||
|
||||
# Initialize next_execution based on frequency
|
||||
from services.scheduler.utils.frequency_calculator import calculate_next_execution
|
||||
task.next_execution = calculate_next_execution(
|
||||
frequency=task.frequency,
|
||||
base_time=datetime.utcnow()
|
||||
)
|
||||
|
||||
self.db.add(task)
|
||||
|
||||
# Save activation status
|
||||
@@ -357,3 +365,80 @@ class MonitoringDataService:
|
||||
logger.error(f"Error updating performance metrics for strategy {strategy_id}: {e}")
|
||||
self.db.rollback()
|
||||
return False
|
||||
|
||||
def get_user_execution_logs(
|
||||
self,
|
||||
user_id: int,
|
||||
limit: Optional[int] = 50,
|
||||
offset: Optional[int] = 0,
|
||||
status_filter: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get execution logs for a specific user.
|
||||
|
||||
Args:
|
||||
user_id: User ID to filter execution logs
|
||||
limit: Maximum number of logs to return
|
||||
offset: Number of logs to skip (for pagination)
|
||||
status_filter: Optional status filter ('success', 'failed', 'running', 'skipped')
|
||||
|
||||
Returns:
|
||||
List of execution log dictionaries with task details
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Getting execution logs for user {user_id}")
|
||||
|
||||
# Build query for execution logs filtered by user_id
|
||||
query = self.db.query(TaskExecutionLog).filter(
|
||||
TaskExecutionLog.user_id == user_id
|
||||
)
|
||||
|
||||
# Apply status filter if provided
|
||||
if status_filter:
|
||||
query = query.filter(TaskExecutionLog.status == status_filter)
|
||||
|
||||
# Order by execution date (most recent first)
|
||||
query = query.order_by(desc(TaskExecutionLog.execution_date))
|
||||
|
||||
# Apply pagination
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
if offset:
|
||||
query = query.offset(offset)
|
||||
|
||||
logs = query.all()
|
||||
|
||||
# Convert to dictionaries with task details
|
||||
logs_data = []
|
||||
for log in logs:
|
||||
# Get task details if available
|
||||
task = self.db.query(MonitoringTask).filter(
|
||||
MonitoringTask.id == log.task_id
|
||||
).first()
|
||||
|
||||
log_data = {
|
||||
"id": log.id,
|
||||
"task_id": log.task_id,
|
||||
"user_id": log.user_id,
|
||||
"execution_date": log.execution_date.isoformat() if log.execution_date else None,
|
||||
"status": log.status,
|
||||
"result_data": log.result_data,
|
||||
"error_message": log.error_message,
|
||||
"execution_time_ms": log.execution_time_ms,
|
||||
"created_at": log.created_at.isoformat() if log.created_at else None,
|
||||
"task": {
|
||||
"title": task.task_title if task else None,
|
||||
"description": task.task_description if task else None,
|
||||
"assignee": task.assignee if task else None,
|
||||
"frequency": task.frequency if task else None,
|
||||
"strategy_id": task.strategy_id if task else None
|
||||
} if task else None
|
||||
}
|
||||
logs_data.append(log_data)
|
||||
|
||||
logger.info(f"Retrieved {len(logs_data)} execution logs for user {user_id}")
|
||||
return logs_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting execution logs for user {user_id}: {e}")
|
||||
return []
|
||||
|
||||
59
backend/services/scheduler/__init__.py
Normal file
59
backend/services/scheduler/__init__.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Task Scheduler Package
|
||||
Modular, pluggable scheduler for ALwrity tasks.
|
||||
"""
|
||||
|
||||
from .core.scheduler import TaskScheduler
|
||||
from .core.executor_interface import TaskExecutor, TaskExecutionResult
|
||||
from .core.exception_handler import (
|
||||
SchedulerExceptionHandler, SchedulerException, SchedulerErrorType, SchedulerErrorSeverity,
|
||||
TaskExecutionError, DatabaseError, TaskLoaderError, SchedulerConfigError
|
||||
)
|
||||
from .executors.monitoring_task_executor import MonitoringTaskExecutor
|
||||
from .utils.task_loader import load_due_monitoring_tasks
|
||||
|
||||
# Global scheduler instance (initialized on first access)
|
||||
_scheduler_instance: TaskScheduler = None
|
||||
|
||||
|
||||
def get_scheduler() -> TaskScheduler:
|
||||
"""
|
||||
Get global scheduler instance (singleton pattern).
|
||||
|
||||
Returns:
|
||||
TaskScheduler instance
|
||||
"""
|
||||
global _scheduler_instance
|
||||
if _scheduler_instance is None:
|
||||
_scheduler_instance = TaskScheduler(
|
||||
check_interval_minutes=15,
|
||||
max_concurrent_executions=10
|
||||
)
|
||||
|
||||
# Register monitoring task executor
|
||||
monitoring_executor = MonitoringTaskExecutor()
|
||||
_scheduler_instance.register_executor(
|
||||
'monitoring_task',
|
||||
monitoring_executor,
|
||||
load_due_monitoring_tasks
|
||||
)
|
||||
|
||||
return _scheduler_instance
|
||||
|
||||
|
||||
__all__ = [
|
||||
'TaskScheduler',
|
||||
'TaskExecutor',
|
||||
'TaskExecutionResult',
|
||||
'MonitoringTaskExecutor',
|
||||
'get_scheduler',
|
||||
# Exception handling
|
||||
'SchedulerExceptionHandler',
|
||||
'SchedulerException',
|
||||
'SchedulerErrorType',
|
||||
'SchedulerErrorSeverity',
|
||||
'TaskExecutionError',
|
||||
'DatabaseError',
|
||||
'TaskLoaderError',
|
||||
'SchedulerConfigError'
|
||||
]
|
||||
4
backend/services/scheduler/core/__init__.py
Normal file
4
backend/services/scheduler/core/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
Core scheduler components.
|
||||
"""
|
||||
|
||||
395
backend/services/scheduler/core/exception_handler.py
Normal file
395
backend/services/scheduler/core/exception_handler.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
Comprehensive Exception Handling and Logging for Task Scheduler
|
||||
Provides robust error handling, logging, and monitoring for the scheduler system.
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional, Union
|
||||
from enum import Enum
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("scheduler_exception_handler")
|
||||
|
||||
|
||||
class SchedulerErrorType(Enum):
|
||||
"""Error types for scheduler system."""
|
||||
DATABASE_ERROR = "database_error"
|
||||
TASK_EXECUTION_ERROR = "task_execution_error"
|
||||
TASK_LOADER_ERROR = "task_loader_error"
|
||||
SCHEDULER_CONFIG_ERROR = "scheduler_config_error"
|
||||
RETRY_ERROR = "retry_error"
|
||||
CONCURRENCY_ERROR = "concurrency_error"
|
||||
TIMEOUT_ERROR = "timeout_error"
|
||||
|
||||
|
||||
class SchedulerErrorSeverity(Enum):
|
||||
"""Severity levels for scheduler errors."""
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
class SchedulerException(Exception):
|
||||
"""Base exception for scheduler system errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
error_type: SchedulerErrorType,
|
||||
severity: SchedulerErrorSeverity = SchedulerErrorSeverity.MEDIUM,
|
||||
user_id: Optional[int] = None,
|
||||
task_id: Optional[int] = None,
|
||||
task_type: Optional[str] = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
original_error: Optional[Exception] = None
|
||||
):
|
||||
self.message = message
|
||||
self.error_type = error_type
|
||||
self.severity = severity
|
||||
self.user_id = user_id
|
||||
self.task_id = task_id
|
||||
self.task_type = task_type
|
||||
self.context = context or {}
|
||||
self.original_error = original_error
|
||||
self.timestamp = datetime.utcnow()
|
||||
|
||||
# Capture stack trace if original error provided
|
||||
self.stack_trace = None
|
||||
if self.original_error:
|
||||
try:
|
||||
exc_type, exc_value, exc_traceback = sys.exc_info()
|
||||
if exc_traceback:
|
||||
self.stack_trace = ''.join(traceback.format_exception(
|
||||
exc_type, exc_value, exc_traceback
|
||||
))
|
||||
else:
|
||||
self.stack_trace = traceback.format_exception(
|
||||
type(self.original_error),
|
||||
self.original_error,
|
||||
self.original_error.__traceback__
|
||||
)
|
||||
except Exception:
|
||||
self.stack_trace = str(self.original_error)
|
||||
|
||||
super().__init__(message)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert exception to dictionary for logging/storage."""
|
||||
return {
|
||||
"message": self.message,
|
||||
"error_type": self.error_type.value,
|
||||
"severity": self.severity.value,
|
||||
"user_id": self.user_id,
|
||||
"task_id": self.task_id,
|
||||
"task_type": self.task_type,
|
||||
"context": self.context,
|
||||
"timestamp": self.timestamp.isoformat() if isinstance(self.timestamp, datetime) else self.timestamp,
|
||||
"original_error": str(self.original_error) if self.original_error else None,
|
||||
"stack_trace": self.stack_trace
|
||||
}
|
||||
|
||||
def __str__(self):
|
||||
return f"[{self.error_type.value}] {self.message}"
|
||||
|
||||
|
||||
class DatabaseError(SchedulerException):
|
||||
"""Exception raised for database-related errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
user_id: Optional[int] = None,
|
||||
task_id: Optional[int] = None,
|
||||
context: Dict[str, Any] = None,
|
||||
original_error: Exception = None
|
||||
):
|
||||
super().__init__(
|
||||
message=message,
|
||||
error_type=SchedulerErrorType.DATABASE_ERROR,
|
||||
severity=SchedulerErrorSeverity.CRITICAL,
|
||||
user_id=user_id,
|
||||
task_id=task_id,
|
||||
context=context or {},
|
||||
original_error=original_error
|
||||
)
|
||||
|
||||
|
||||
class TaskExecutionError(SchedulerException):
|
||||
"""Exception raised for task execution failures."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
user_id: Optional[int] = None,
|
||||
task_id: Optional[int] = None,
|
||||
task_type: Optional[str] = None,
|
||||
retry_count: int = 0,
|
||||
execution_time_ms: Optional[int] = None,
|
||||
context: Dict[str, Any] = None,
|
||||
original_error: Exception = None
|
||||
):
|
||||
context = context or {}
|
||||
context.update({
|
||||
"retry_count": retry_count,
|
||||
"execution_time_ms": execution_time_ms
|
||||
})
|
||||
|
||||
super().__init__(
|
||||
message=message,
|
||||
error_type=SchedulerErrorType.TASK_EXECUTION_ERROR,
|
||||
severity=SchedulerErrorSeverity.HIGH,
|
||||
user_id=user_id,
|
||||
task_id=task_id,
|
||||
task_type=task_type,
|
||||
context=context,
|
||||
original_error=original_error
|
||||
)
|
||||
|
||||
|
||||
class TaskLoaderError(SchedulerException):
|
||||
"""Exception raised for task loading failures."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
task_type: Optional[str] = None,
|
||||
user_id: Optional[int] = None,
|
||||
context: Dict[str, Any] = None,
|
||||
original_error: Exception = None
|
||||
):
|
||||
super().__init__(
|
||||
message=message,
|
||||
error_type=SchedulerErrorType.TASK_LOADER_ERROR,
|
||||
severity=SchedulerErrorSeverity.HIGH,
|
||||
user_id=user_id,
|
||||
task_type=task_type,
|
||||
context=context or {},
|
||||
original_error=original_error
|
||||
)
|
||||
|
||||
|
||||
class SchedulerConfigError(SchedulerException):
|
||||
"""Exception raised for scheduler configuration errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
context: Dict[str, Any] = None,
|
||||
original_error: Exception = None
|
||||
):
|
||||
super().__init__(
|
||||
message=message,
|
||||
error_type=SchedulerErrorType.SCHEDULER_CONFIG_ERROR,
|
||||
severity=SchedulerErrorSeverity.CRITICAL,
|
||||
context=context or {},
|
||||
original_error=original_error
|
||||
)
|
||||
|
||||
|
||||
class SchedulerExceptionHandler:
|
||||
"""Comprehensive exception handler for the scheduler system."""
|
||||
|
||||
def __init__(self, db: Session = None):
|
||||
self.db = db
|
||||
self.logger = logger
|
||||
|
||||
def handle_exception(
|
||||
self,
|
||||
error: Union[Exception, SchedulerException],
|
||||
context: Dict[str, Any] = None,
|
||||
log_level: str = "error"
|
||||
) -> Dict[str, Any]:
|
||||
"""Handle and log scheduler exceptions."""
|
||||
|
||||
context = context or {}
|
||||
|
||||
# Convert regular exceptions to SchedulerException
|
||||
if not isinstance(error, SchedulerException):
|
||||
error = SchedulerException(
|
||||
message=str(error),
|
||||
error_type=self._classify_error(error),
|
||||
severity=self._determine_severity(error),
|
||||
context=context,
|
||||
original_error=error
|
||||
)
|
||||
|
||||
# Log the error
|
||||
error_data = error.to_dict()
|
||||
error_data.update(context)
|
||||
|
||||
log_message = f"Scheduler Error: {error.message}"
|
||||
|
||||
if log_level == "critical" or error.severity == SchedulerErrorSeverity.CRITICAL:
|
||||
self.logger.critical(log_message, extra={"error_data": error_data})
|
||||
elif log_level == "error" or error.severity == SchedulerErrorSeverity.HIGH:
|
||||
self.logger.error(log_message, extra={"error_data": error_data})
|
||||
elif log_level == "warning" or error.severity == SchedulerErrorSeverity.MEDIUM:
|
||||
self.logger.warning(log_message, extra={"error_data": error_data})
|
||||
else:
|
||||
self.logger.info(log_message, extra={"error_data": error_data})
|
||||
|
||||
# Store critical errors in database for alerting
|
||||
if error.severity in [SchedulerErrorSeverity.HIGH, SchedulerErrorSeverity.CRITICAL]:
|
||||
self._store_error_alert(error)
|
||||
|
||||
# Return formatted error response
|
||||
return self._format_error_response(error)
|
||||
|
||||
def _classify_error(self, error: Exception) -> SchedulerErrorType:
|
||||
"""Classify an exception into a scheduler error type."""
|
||||
|
||||
error_str = str(error).lower()
|
||||
error_type_name = type(error).__name__.lower()
|
||||
|
||||
# Database errors
|
||||
if isinstance(error, (SQLAlchemyError, OperationalError, IntegrityError)):
|
||||
return SchedulerErrorType.DATABASE_ERROR
|
||||
if "database" in error_str or "sql" in error_type_name or "connection" in error_str:
|
||||
return SchedulerErrorType.DATABASE_ERROR
|
||||
|
||||
# Timeout errors
|
||||
if "timeout" in error_str or "timed out" in error_str:
|
||||
return SchedulerErrorType.TIMEOUT_ERROR
|
||||
|
||||
# Concurrency errors
|
||||
if "concurrent" in error_str or "race" in error_str or "lock" in error_str:
|
||||
return SchedulerErrorType.CONCURRENCY_ERROR
|
||||
|
||||
# Task execution errors
|
||||
if "task" in error_str and "execut" in error_str:
|
||||
return SchedulerErrorType.TASK_EXECUTION_ERROR
|
||||
|
||||
# Task loader errors
|
||||
if "load" in error_str and "task" in error_str:
|
||||
return SchedulerErrorType.TASK_LOADER_ERROR
|
||||
|
||||
# Retry errors
|
||||
if "retry" in error_str:
|
||||
return SchedulerErrorType.RETRY_ERROR
|
||||
|
||||
# Config errors
|
||||
if "config" in error_str or "scheduler" in error_str and "init" in error_str:
|
||||
return SchedulerErrorType.SCHEDULER_CONFIG_ERROR
|
||||
|
||||
# Default to task execution error for unknown errors
|
||||
return SchedulerErrorType.TASK_EXECUTION_ERROR
|
||||
|
||||
def _determine_severity(self, error: Exception) -> SchedulerErrorSeverity:
|
||||
"""Determine the severity of an error."""
|
||||
|
||||
error_str = str(error).lower()
|
||||
error_type = type(error)
|
||||
|
||||
# Critical errors
|
||||
if isinstance(error, (SQLAlchemyError, OperationalError, ConnectionError)):
|
||||
return SchedulerErrorSeverity.CRITICAL
|
||||
if "database" in error_str or "connection" in error_str:
|
||||
return SchedulerErrorSeverity.CRITICAL
|
||||
|
||||
# High severity errors
|
||||
if "timeout" in error_str or "concurrent" in error_str:
|
||||
return SchedulerErrorSeverity.HIGH
|
||||
if isinstance(error, (KeyError, AttributeError)) and "config" in error_str:
|
||||
return SchedulerErrorSeverity.HIGH
|
||||
|
||||
# Medium severity errors
|
||||
if "task" in error_str or "execution" in error_str:
|
||||
return SchedulerErrorSeverity.MEDIUM
|
||||
|
||||
# Default to low
|
||||
return SchedulerErrorSeverity.LOW
|
||||
|
||||
def _store_error_alert(self, error: SchedulerException):
|
||||
"""Store critical errors in database for alerting."""
|
||||
|
||||
if not self.db:
|
||||
return
|
||||
|
||||
try:
|
||||
# Import here to avoid circular dependencies
|
||||
from models.monitoring_models import TaskExecutionLog
|
||||
|
||||
# Store as failed execution log if we have task_id (even without user_id for system errors)
|
||||
if error.task_id:
|
||||
try:
|
||||
execution_log = TaskExecutionLog(
|
||||
task_id=error.task_id,
|
||||
user_id=error.user_id, # Can be None for system-level errors
|
||||
execution_date=error.timestamp,
|
||||
status='failed',
|
||||
error_message=error.message,
|
||||
result_data={
|
||||
"error_type": error.error_type.value,
|
||||
"severity": error.severity.value,
|
||||
"context": error.context,
|
||||
"stack_trace": error.stack_trace,
|
||||
"task_type": error.task_type
|
||||
}
|
||||
)
|
||||
self.db.add(execution_log)
|
||||
self.db.commit()
|
||||
self.logger.info(f"Stored error alert in execution log for task {error.task_id}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to store error in execution log: {e}")
|
||||
self.db.rollback()
|
||||
# Note: For errors without task_id, we rely on structured logging only
|
||||
# Future: Could create a separate scheduler_error_logs table for system-level errors
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to store error alert: {e}")
|
||||
|
||||
def _format_error_response(self, error: SchedulerException) -> Dict[str, Any]:
|
||||
"""Format error for API response or logging."""
|
||||
|
||||
response = {
|
||||
"success": False,
|
||||
"error": {
|
||||
"type": error.error_type.value,
|
||||
"message": error.message,
|
||||
"severity": error.severity.value,
|
||||
"timestamp": error.timestamp.isoformat() if isinstance(error.timestamp, datetime) else str(error.timestamp),
|
||||
"user_id": error.user_id,
|
||||
"task_id": error.task_id,
|
||||
"task_type": error.task_type
|
||||
}
|
||||
}
|
||||
|
||||
# Add context for debugging (non-sensitive info only)
|
||||
if error.context:
|
||||
safe_context = {
|
||||
k: v for k, v in error.context.items()
|
||||
if k not in ["password", "token", "key", "secret", "credential"]
|
||||
}
|
||||
response["error"]["context"] = safe_context
|
||||
|
||||
# Add user-friendly message based on error type
|
||||
user_messages = {
|
||||
SchedulerErrorType.DATABASE_ERROR:
|
||||
"A database error occurred while processing the task. Please try again later.",
|
||||
SchedulerErrorType.TASK_EXECUTION_ERROR:
|
||||
"The task failed to execute. Please check the task configuration and try again.",
|
||||
SchedulerErrorType.TASK_LOADER_ERROR:
|
||||
"Failed to load tasks. The scheduler may be experiencing issues.",
|
||||
SchedulerErrorType.SCHEDULER_CONFIG_ERROR:
|
||||
"The scheduler configuration is invalid. Contact support.",
|
||||
SchedulerErrorType.RETRY_ERROR:
|
||||
"Task retry failed. The task will be rescheduled.",
|
||||
SchedulerErrorType.CONCURRENCY_ERROR:
|
||||
"A concurrency issue occurred. The task will be retried.",
|
||||
SchedulerErrorType.TIMEOUT_ERROR:
|
||||
"The task execution timed out. The task will be retried."
|
||||
}
|
||||
|
||||
response["error"]["user_message"] = user_messages.get(
|
||||
error.error_type,
|
||||
"An error occurred while processing the task."
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
75
backend/services/scheduler/core/executor_interface.py
Normal file
75
backend/services/scheduler/core/executor_interface.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
Task Executor Interface
|
||||
Abstract base class for all task executors.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskExecutionResult:
|
||||
"""Result of task execution."""
|
||||
success: bool
|
||||
error_message: Optional[str] = None
|
||||
result_data: Optional[Dict[str, Any]] = None
|
||||
execution_time_ms: Optional[int] = None
|
||||
retryable: bool = True
|
||||
retry_delay: int = 300 # seconds
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
'success': self.success,
|
||||
'error_message': self.error_message,
|
||||
'result_data': self.result_data,
|
||||
'execution_time_ms': self.execution_time_ms,
|
||||
'retryable': self.retryable,
|
||||
'retry_delay': self.retry_delay
|
||||
}
|
||||
|
||||
|
||||
class TaskExecutor(ABC):
|
||||
"""
|
||||
Abstract base class for task executors.
|
||||
|
||||
Each task type must implement this interface to be schedulable.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def execute_task(self, task: Any, db: Session) -> TaskExecutionResult:
|
||||
"""
|
||||
Execute a task.
|
||||
|
||||
Args:
|
||||
task: Task instance from database
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
TaskExecutionResult with execution details
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def calculate_next_execution(
|
||||
self,
|
||||
task: Any,
|
||||
frequency: str,
|
||||
last_execution: Optional[datetime] = None
|
||||
) -> datetime:
|
||||
"""
|
||||
Calculate next execution time based on frequency.
|
||||
|
||||
Args:
|
||||
task: Task instance
|
||||
frequency: Task frequency (e.g., 'Daily', 'Weekly')
|
||||
last_execution: Last execution datetime
|
||||
|
||||
Returns:
|
||||
Next execution datetime
|
||||
"""
|
||||
pass
|
||||
|
||||
628
backend/services/scheduler/core/scheduler.py
Normal file
628
backend/services/scheduler/core/scheduler.py
Normal file
@@ -0,0 +1,628 @@
|
||||
"""
|
||||
Core Task Scheduler Service
|
||||
Pluggable task scheduler that can work with any task model.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List, Callable
|
||||
from datetime import datetime
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .executor_interface import TaskExecutor, TaskExecutionResult
|
||||
from .task_registry import TaskRegistry
|
||||
from .exception_handler import (
|
||||
SchedulerExceptionHandler, SchedulerException, TaskExecutionError, DatabaseError,
|
||||
TaskLoaderError, SchedulerConfigError
|
||||
)
|
||||
from services.database import get_db_session
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("task_scheduler")
|
||||
|
||||
|
||||
class TaskScheduler:
|
||||
"""
|
||||
Pluggable task scheduler that can work with any task model.
|
||||
|
||||
Features:
|
||||
- Async task execution
|
||||
- Plugin-based executor system
|
||||
- Database-backed task persistence
|
||||
- Configurable check intervals
|
||||
- Automatic retry logic
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
check_interval_minutes: int = 15,
|
||||
max_concurrent_executions: int = 10,
|
||||
enable_retries: bool = True,
|
||||
max_retries: int = 3
|
||||
):
|
||||
"""
|
||||
Initialize the task scheduler.
|
||||
|
||||
Args:
|
||||
check_interval_minutes: How often to check for due tasks
|
||||
max_concurrent_executions: Maximum concurrent task executions
|
||||
enable_retries: Whether to retry failed tasks
|
||||
max_retries: Maximum retry attempts
|
||||
"""
|
||||
self.check_interval_minutes = check_interval_minutes
|
||||
self.max_concurrent_executions = max_concurrent_executions
|
||||
self.enable_retries = enable_retries
|
||||
self.max_retries = max_retries
|
||||
|
||||
# Initialize APScheduler
|
||||
self.scheduler = AsyncIOScheduler(
|
||||
timezone='UTC',
|
||||
job_defaults={
|
||||
'coalesce': True,
|
||||
'max_instances': 1,
|
||||
'misfire_grace_time': 300 # 5 minutes grace period
|
||||
}
|
||||
)
|
||||
|
||||
# Task executor registry
|
||||
self.registry = TaskRegistry()
|
||||
|
||||
# Track running executions
|
||||
self.active_executions: Dict[str, asyncio.Task] = {}
|
||||
|
||||
# Exception handler for robust error handling
|
||||
self.exception_handler = SchedulerExceptionHandler()
|
||||
|
||||
# Intelligent scheduling configuration
|
||||
self.min_check_interval_minutes = 15 # Check every 15min when active strategies exist
|
||||
self.max_check_interval_minutes = 60 # Check every 60min when no active strategies
|
||||
self.current_check_interval_minutes = check_interval_minutes # Current interval
|
||||
|
||||
# Statistics
|
||||
self.stats = {
|
||||
'total_checks': 0,
|
||||
'tasks_found': 0,
|
||||
'tasks_executed': 0,
|
||||
'tasks_failed': 0,
|
||||
'tasks_skipped': 0,
|
||||
'last_check': None,
|
||||
'per_user_stats': {}, # Track metrics per user for user isolation
|
||||
'active_strategies_count': 0, # Track active strategies with tasks
|
||||
'last_interval_adjustment': None # Track when interval was last adjusted
|
||||
}
|
||||
|
||||
self._running = False
|
||||
|
||||
def _get_trigger_for_interval(self, interval_minutes: int):
|
||||
"""
|
||||
Get the appropriate trigger for the given interval.
|
||||
|
||||
For intervals >= 60 minutes, use IntervalTrigger.
|
||||
For intervals < 60 minutes, use CronTrigger.
|
||||
|
||||
Args:
|
||||
interval_minutes: Interval in minutes
|
||||
|
||||
Returns:
|
||||
Appropriate APScheduler trigger
|
||||
"""
|
||||
if interval_minutes >= 60:
|
||||
# Use IntervalTrigger for intervals >= 60 minutes
|
||||
return IntervalTrigger(minutes=interval_minutes)
|
||||
else:
|
||||
# Use CronTrigger for intervals < 60 minutes (valid range: 0-59)
|
||||
return CronTrigger(minute=f'*/{interval_minutes}')
|
||||
|
||||
def register_executor(
|
||||
self,
|
||||
task_type: str,
|
||||
executor: TaskExecutor,
|
||||
task_loader: Callable[[Session], List[Any]]
|
||||
):
|
||||
"""
|
||||
Register a task executor for a specific task type.
|
||||
|
||||
Args:
|
||||
task_type: Unique identifier for task type (e.g., 'monitoring_task')
|
||||
executor: TaskExecutor instance that handles execution
|
||||
task_loader: Function that loads due tasks from database
|
||||
"""
|
||||
self.registry.register(task_type, executor, task_loader)
|
||||
logger.info(f"Registered executor for task type: {task_type}")
|
||||
|
||||
async def start(self):
|
||||
"""Start the scheduler with intelligent interval adjustment."""
|
||||
if self._running:
|
||||
logger.warning("Scheduler is already running")
|
||||
return
|
||||
|
||||
try:
|
||||
# Determine initial check interval based on active strategies
|
||||
initial_interval = await self._determine_optimal_interval()
|
||||
self.current_check_interval_minutes = initial_interval
|
||||
|
||||
# Add periodic job to check for due tasks
|
||||
self.scheduler.add_job(
|
||||
self._check_and_execute_due_tasks,
|
||||
trigger=self._get_trigger_for_interval(initial_interval),
|
||||
id='check_due_tasks',
|
||||
replace_existing=True
|
||||
)
|
||||
|
||||
self.scheduler.start()
|
||||
self._running = True
|
||||
|
||||
logger.info(
|
||||
f"Task scheduler started | "
|
||||
f"check_interval={initial_interval}min | "
|
||||
f"registered_types={self.registry.get_registered_types()}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start scheduler: {e}")
|
||||
raise
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the scheduler gracefully."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
try:
|
||||
# Cancel all active executions
|
||||
for task_id, execution_task in self.active_executions.items():
|
||||
execution_task.cancel()
|
||||
|
||||
# Wait for active executions to complete (with timeout)
|
||||
if self.active_executions:
|
||||
await asyncio.wait(
|
||||
self.active_executions.values(),
|
||||
timeout=30
|
||||
)
|
||||
|
||||
# Shutdown scheduler
|
||||
self.scheduler.shutdown(wait=True)
|
||||
self._running = False
|
||||
|
||||
logger.info("Task scheduler stopped gracefully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping scheduler: {e}")
|
||||
raise
|
||||
|
||||
async def _check_and_execute_due_tasks(self):
|
||||
"""
|
||||
Main scheduler loop: check for due tasks and execute them.
|
||||
This runs periodically with intelligent interval adjustment based on active strategies.
|
||||
"""
|
||||
self.stats['total_checks'] += 1
|
||||
self.stats['last_check'] = datetime.utcnow().isoformat()
|
||||
|
||||
logger.debug("Checking for due tasks...")
|
||||
|
||||
db = None
|
||||
try:
|
||||
db = get_db_session()
|
||||
if db is None:
|
||||
logger.error("Failed to get database session")
|
||||
return
|
||||
|
||||
# Check for active strategies and adjust interval intelligently
|
||||
await self._adjust_check_interval_if_needed(db)
|
||||
|
||||
# Check each registered task type
|
||||
for task_type in self.registry.get_registered_types():
|
||||
await self._process_task_type(task_type, db)
|
||||
|
||||
except Exception as e:
|
||||
error = DatabaseError(
|
||||
message=f"Error checking for due tasks: {str(e)}",
|
||||
original_error=e
|
||||
)
|
||||
self.exception_handler.handle_exception(error)
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
async def _determine_optimal_interval(self) -> int:
|
||||
"""
|
||||
Determine optimal check interval based on active strategies.
|
||||
|
||||
Returns:
|
||||
Optimal check interval in minutes
|
||||
"""
|
||||
db = None
|
||||
try:
|
||||
db = get_db_session()
|
||||
if db:
|
||||
from services.active_strategy_service import ActiveStrategyService
|
||||
active_strategy_service = ActiveStrategyService(db_session=db)
|
||||
active_count = active_strategy_service.count_active_strategies_with_tasks()
|
||||
self.stats['active_strategies_count'] = active_count
|
||||
|
||||
if active_count > 0:
|
||||
logger.info(f"Found {active_count} active strategies with tasks - using {self.min_check_interval_minutes}min interval")
|
||||
return self.min_check_interval_minutes
|
||||
else:
|
||||
logger.info(f"No active strategies with tasks - using {self.max_check_interval_minutes}min interval")
|
||||
return self.max_check_interval_minutes
|
||||
except Exception as e:
|
||||
logger.warning(f"Error determining optimal interval: {e}, using default {self.min_check_interval_minutes}min")
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
# Default to shorter interval on error (safer)
|
||||
return self.min_check_interval_minutes
|
||||
|
||||
async def _adjust_check_interval_if_needed(self, db: Session):
|
||||
"""
|
||||
Intelligently adjust check interval based on active strategies.
|
||||
|
||||
If there are active strategies with tasks, check more frequently.
|
||||
If there are no active strategies, check less frequently.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
"""
|
||||
try:
|
||||
from services.active_strategy_service import ActiveStrategyService
|
||||
|
||||
active_strategy_service = ActiveStrategyService(db_session=db)
|
||||
active_count = active_strategy_service.count_active_strategies_with_tasks()
|
||||
self.stats['active_strategies_count'] = active_count
|
||||
|
||||
# Determine optimal interval
|
||||
if active_count > 0:
|
||||
optimal_interval = self.min_check_interval_minutes
|
||||
else:
|
||||
optimal_interval = self.max_check_interval_minutes
|
||||
|
||||
# Only reschedule if interval needs to change
|
||||
if optimal_interval != self.current_check_interval_minutes:
|
||||
logger.info(
|
||||
f"Adjusting scheduler interval: {self.current_check_interval_minutes}min → {optimal_interval}min | "
|
||||
f"active_strategies={active_count}"
|
||||
)
|
||||
|
||||
# Reschedule the job with new interval
|
||||
self.scheduler.modify_job(
|
||||
'check_due_tasks',
|
||||
trigger=self._get_trigger_for_interval(optimal_interval)
|
||||
)
|
||||
|
||||
self.current_check_interval_minutes = optimal_interval
|
||||
self.stats['last_interval_adjustment'] = datetime.utcnow().isoformat()
|
||||
|
||||
logger.info(f"Scheduler interval adjusted to {optimal_interval}min")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error adjusting check interval: {e}")
|
||||
|
||||
async def trigger_interval_adjustment(self):
|
||||
"""
|
||||
Trigger immediate interval adjustment check.
|
||||
|
||||
This should be called when a strategy is activated or deactivated
|
||||
to immediately adjust the scheduler interval based on current active strategies.
|
||||
"""
|
||||
if not self._running:
|
||||
logger.debug("Scheduler not running, skipping interval adjustment")
|
||||
return
|
||||
|
||||
try:
|
||||
db = get_db_session()
|
||||
if db:
|
||||
await self._adjust_check_interval_if_needed(db)
|
||||
else:
|
||||
logger.warning("Could not get database session for interval adjustment")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error triggering interval adjustment: {e}")
|
||||
|
||||
async def _process_task_type(self, task_type: str, db: Session):
|
||||
"""Process due tasks for a specific task type."""
|
||||
try:
|
||||
# Get task loader for this type
|
||||
try:
|
||||
task_loader = self.registry.get_task_loader(task_type)
|
||||
except Exception as e:
|
||||
error = TaskLoaderError(
|
||||
message=f"Failed to get task loader for type {task_type}: {str(e)}",
|
||||
task_type=task_type,
|
||||
original_error=e
|
||||
)
|
||||
self.exception_handler.handle_exception(error)
|
||||
return
|
||||
|
||||
# Load due tasks (with error handling)
|
||||
try:
|
||||
due_tasks = task_loader(db)
|
||||
except Exception as e:
|
||||
error = TaskLoaderError(
|
||||
message=f"Failed to load due tasks for type {task_type}: {str(e)}",
|
||||
task_type=task_type,
|
||||
original_error=e
|
||||
)
|
||||
self.exception_handler.handle_exception(error)
|
||||
return
|
||||
|
||||
if not due_tasks:
|
||||
return
|
||||
|
||||
self.stats['tasks_found'] += len(due_tasks)
|
||||
logger.info(f"Found {len(due_tasks)} due tasks for type: {task_type}")
|
||||
|
||||
# Execute tasks (with concurrency limit)
|
||||
execution_tasks = []
|
||||
for task in due_tasks:
|
||||
if len(self.active_executions) >= self.max_concurrent_executions:
|
||||
logger.warning(
|
||||
f"Max concurrent executions reached ({self.max_concurrent_executions}), "
|
||||
f"skipping {len(due_tasks) - len(execution_tasks)} tasks"
|
||||
)
|
||||
break
|
||||
|
||||
# Execute task asynchronously
|
||||
# Note: Each task gets its own database session to prevent concurrent access issues
|
||||
execution_task = asyncio.create_task(
|
||||
self._execute_task_async(task_type, task)
|
||||
)
|
||||
|
||||
task_id = f"{task_type}_{getattr(task, 'id', id(task))}"
|
||||
self.active_executions[task_id] = execution_task
|
||||
|
||||
execution_tasks.append(execution_task)
|
||||
|
||||
# Wait for executions to complete (with timeout per task)
|
||||
if execution_tasks:
|
||||
await asyncio.wait(execution_tasks, timeout=300)
|
||||
|
||||
except Exception as e:
|
||||
error = TaskLoaderError(
|
||||
message=f"Error processing task type {task_type}: {str(e)}",
|
||||
task_type=task_type,
|
||||
original_error=e
|
||||
)
|
||||
self.exception_handler.handle_exception(error)
|
||||
|
||||
async def _execute_task_async(self, task_type: str, task: Any):
|
||||
"""
|
||||
Execute a single task asynchronously with user isolation.
|
||||
|
||||
Each task gets its own database session to prevent concurrent access issues,
|
||||
as SQLAlchemy sessions are not async-safe or concurrent-safe.
|
||||
|
||||
User context is extracted and tracked for user isolation.
|
||||
|
||||
Args:
|
||||
task_type: Type of task
|
||||
task: Task instance from database (detached from original session)
|
||||
"""
|
||||
task_id = f"{task_type}_{getattr(task, 'id', id(task))}"
|
||||
db = None
|
||||
user_id = None
|
||||
|
||||
try:
|
||||
# Extract user context if available (for user isolation tracking)
|
||||
try:
|
||||
if hasattr(task, 'strategy') and task.strategy:
|
||||
user_id = getattr(task.strategy, 'user_id', None)
|
||||
elif hasattr(task, 'strategy_id') and task.strategy_id:
|
||||
# Will query user_id after we have db session
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not extract user_id before execution for task {task_id}: {e}")
|
||||
|
||||
logger.info(f"Executing task: {task_id} | user_id: {user_id}")
|
||||
|
||||
# Create a new database session for this async task
|
||||
# SQLAlchemy sessions are not async-safe and cannot be shared across concurrent tasks
|
||||
db = get_db_session()
|
||||
if db is None:
|
||||
error = DatabaseError(
|
||||
message=f"Failed to get database session for task {task_id}",
|
||||
user_id=user_id,
|
||||
task_id=getattr(task, 'id', None),
|
||||
task_type=task_type
|
||||
)
|
||||
self.exception_handler.handle_exception(error, log_level="error")
|
||||
self.stats['tasks_failed'] += 1
|
||||
self._update_user_stats(user_id, success=False)
|
||||
return
|
||||
|
||||
# Set database session for exception handler
|
||||
self.exception_handler.db = db
|
||||
|
||||
# Merge the detached task object into this session
|
||||
# The task object was loaded in a different session and is now detached
|
||||
from sqlalchemy.orm import object_session
|
||||
if object_session(task) is None:
|
||||
# Task is detached, need to merge it into this session
|
||||
task = db.merge(task)
|
||||
|
||||
# Extract user_id after merge if not already available
|
||||
if user_id is None and hasattr(task, 'strategy'):
|
||||
try:
|
||||
if task.strategy:
|
||||
user_id = getattr(task.strategy, 'user_id', None)
|
||||
elif hasattr(task, 'strategy_id'):
|
||||
# Query strategy if relationship not loaded
|
||||
from models.enhanced_strategy_models import EnhancedContentStrategy
|
||||
strategy = db.query(EnhancedContentStrategy).filter(
|
||||
EnhancedContentStrategy.id == task.strategy_id
|
||||
).first()
|
||||
if strategy:
|
||||
user_id = strategy.user_id
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not extract user_id after merge for task {task_id}: {e}")
|
||||
|
||||
# Get executor for this task type
|
||||
try:
|
||||
executor = self.registry.get_executor(task_type)
|
||||
except Exception as e:
|
||||
from .exception_handler import SchedulerConfigError
|
||||
error = SchedulerConfigError(
|
||||
message=f"Failed to get executor for task type {task_type}: {str(e)}",
|
||||
user_id=user_id,
|
||||
context={
|
||||
"task_id": getattr(task, 'id', None),
|
||||
"task_type": task_type
|
||||
},
|
||||
original_error=e
|
||||
)
|
||||
self.exception_handler.handle_exception(error)
|
||||
self.stats['tasks_failed'] += 1
|
||||
self._update_user_stats(user_id, success=False)
|
||||
return
|
||||
|
||||
# Execute task with its own session (with error handling)
|
||||
try:
|
||||
result = await executor.execute_task(task, db)
|
||||
|
||||
# Handle result and update statistics
|
||||
if result.success:
|
||||
self.stats['tasks_executed'] += 1
|
||||
self._update_user_stats(user_id, success=True)
|
||||
logger.info(f"Task executed successfully: {task_id} | user_id: {user_id}")
|
||||
else:
|
||||
self.stats['tasks_failed'] += 1
|
||||
self._update_user_stats(user_id, success=False)
|
||||
|
||||
# Create structured error for failed execution
|
||||
error = TaskExecutionError(
|
||||
message=result.error_message or "Task execution failed",
|
||||
user_id=user_id,
|
||||
task_id=getattr(task, 'id', None),
|
||||
task_type=task_type,
|
||||
execution_time_ms=result.execution_time_ms,
|
||||
context={"result_data": result.result_data}
|
||||
)
|
||||
self.exception_handler.handle_exception(error, log_level="warning")
|
||||
|
||||
# Retry logic if enabled
|
||||
if self.enable_retries and result.retryable:
|
||||
await self._schedule_retry(task, result.retry_delay)
|
||||
|
||||
except SchedulerException as e:
|
||||
# Re-raise scheduler exceptions (they're already handled)
|
||||
raise
|
||||
except Exception as e:
|
||||
# Wrap unexpected exceptions
|
||||
error = TaskExecutionError(
|
||||
message=f"Unexpected error during task execution: {str(e)}",
|
||||
user_id=user_id,
|
||||
task_id=getattr(task, 'id', None),
|
||||
task_type=task_type,
|
||||
original_error=e
|
||||
)
|
||||
self.exception_handler.handle_exception(error)
|
||||
self.stats['tasks_failed'] += 1
|
||||
self._update_user_stats(user_id, success=False)
|
||||
|
||||
except SchedulerException as e:
|
||||
# Handle scheduler exceptions
|
||||
self.exception_handler.handle_exception(e)
|
||||
self.stats['tasks_failed'] += 1
|
||||
self._update_user_stats(user_id, success=False)
|
||||
except Exception as e:
|
||||
# Handle any other unexpected errors
|
||||
error = TaskExecutionError(
|
||||
message=f"Unexpected error in task execution wrapper: {str(e)}",
|
||||
user_id=user_id,
|
||||
task_id=getattr(task, 'id', None),
|
||||
task_type=task_type,
|
||||
original_error=e
|
||||
)
|
||||
self.exception_handler.handle_exception(error)
|
||||
self.stats['tasks_failed'] += 1
|
||||
self._update_user_stats(user_id, success=False)
|
||||
finally:
|
||||
# Clean up database session
|
||||
if db:
|
||||
try:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing database session for task {task_id}: {e}")
|
||||
|
||||
# Remove from active executions
|
||||
if task_id in self.active_executions:
|
||||
del self.active_executions[task_id]
|
||||
|
||||
def _update_user_stats(self, user_id: Optional[int], success: bool):
|
||||
"""
|
||||
Update per-user statistics for user isolation tracking.
|
||||
|
||||
Args:
|
||||
user_id: User ID (None if user context not available)
|
||||
success: Whether task execution was successful
|
||||
"""
|
||||
if user_id is None:
|
||||
return
|
||||
|
||||
if user_id not in self.stats['per_user_stats']:
|
||||
self.stats['per_user_stats'][user_id] = {
|
||||
'executed': 0,
|
||||
'failed': 0,
|
||||
'success_rate': 0.0
|
||||
}
|
||||
|
||||
user_stats = self.stats['per_user_stats'][user_id]
|
||||
if success:
|
||||
user_stats['executed'] += 1
|
||||
else:
|
||||
user_stats['failed'] += 1
|
||||
|
||||
# Calculate success rate
|
||||
total = user_stats['executed'] + user_stats['failed']
|
||||
if total > 0:
|
||||
user_stats['success_rate'] = (user_stats['executed'] / total) * 100.0
|
||||
|
||||
async def _schedule_retry(self, task: Any, delay_seconds: int):
|
||||
"""Schedule a retry for a failed task."""
|
||||
# This would update the task's next_execution time
|
||||
# For now, just log - could be enhanced to update next_execution
|
||||
logger.debug(f"Scheduling retry for task in {delay_seconds}s")
|
||||
|
||||
def get_stats(self, user_id: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Get scheduler statistics with optional user filtering.
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID to filter statistics for specific user
|
||||
|
||||
Returns:
|
||||
Dictionary with scheduler statistics
|
||||
"""
|
||||
base_stats = {
|
||||
**{k: v for k, v in self.stats.items() if k not in ['per_user_stats']},
|
||||
'active_executions': len(self.active_executions),
|
||||
'registered_types': self.registry.get_registered_types(),
|
||||
'running': self._running,
|
||||
'check_interval_minutes': self.current_check_interval_minutes,
|
||||
'min_check_interval_minutes': self.min_check_interval_minutes,
|
||||
'max_check_interval_minutes': self.max_check_interval_minutes,
|
||||
'intelligent_scheduling': True
|
||||
}
|
||||
|
||||
# Include per-user stats (all users or filtered)
|
||||
if user_id is not None:
|
||||
if user_id in self.stats['per_user_stats']:
|
||||
base_stats['user_stats'] = self.stats['per_user_stats'][user_id]
|
||||
else:
|
||||
base_stats['user_stats'] = {
|
||||
'executed': 0,
|
||||
'failed': 0,
|
||||
'success_rate': 0.0
|
||||
}
|
||||
else:
|
||||
# Include all per-user stats (for admin/debugging)
|
||||
base_stats['per_user_stats'] = self.stats['per_user_stats']
|
||||
|
||||
return base_stats
|
||||
|
||||
def is_running(self) -> bool:
|
||||
"""Check if scheduler is running."""
|
||||
return self._running
|
||||
|
||||
59
backend/services/scheduler/core/task_registry.py
Normal file
59
backend/services/scheduler/core/task_registry.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Task Registry
|
||||
Manages registration of task executors and loaders.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Callable, List, Any
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .executor_interface import TaskExecutor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskRegistry:
|
||||
"""Registry for task executors and loaders."""
|
||||
|
||||
def __init__(self):
|
||||
self.executors: Dict[str, TaskExecutor] = {}
|
||||
self.task_loaders: Dict[str, Callable[[Session], List[Any]]] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
task_type: str,
|
||||
executor: TaskExecutor,
|
||||
task_loader: Callable[[Session], List[Any]]
|
||||
):
|
||||
"""
|
||||
Register a task executor and loader.
|
||||
|
||||
Args:
|
||||
task_type: Unique identifier for task type
|
||||
executor: TaskExecutor instance
|
||||
task_loader: Function that loads due tasks from database
|
||||
"""
|
||||
if task_type in self.executors:
|
||||
logger.warning(f"Overwriting existing executor for task type: {task_type}")
|
||||
|
||||
self.executors[task_type] = executor
|
||||
self.task_loaders[task_type] = task_loader
|
||||
|
||||
logger.info(f"Registered task type: {task_type}")
|
||||
|
||||
def get_executor(self, task_type: str) -> TaskExecutor:
|
||||
"""Get executor for task type."""
|
||||
if task_type not in self.executors:
|
||||
raise ValueError(f"No executor registered for task type: {task_type}")
|
||||
return self.executors[task_type]
|
||||
|
||||
def get_task_loader(self, task_type: str) -> Callable[[Session], List[Any]]:
|
||||
"""Get task loader for task type."""
|
||||
if task_type not in self.task_loaders:
|
||||
raise ValueError(f"No task loader registered for task type: {task_type}")
|
||||
return self.task_loaders[task_type]
|
||||
|
||||
def get_registered_types(self) -> List[str]:
|
||||
"""Get list of registered task types."""
|
||||
return list(self.executors.keys())
|
||||
|
||||
4
backend/services/scheduler/executors/__init__.py
Normal file
4
backend/services/scheduler/executors/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
Task executor implementations.
|
||||
"""
|
||||
|
||||
266
backend/services/scheduler/executors/monitoring_task_executor.py
Normal file
266
backend/services/scheduler/executors/monitoring_task_executor.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""
|
||||
Monitoring Task Executor
|
||||
Handles execution of content strategy monitoring tasks.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..core.executor_interface import TaskExecutor, TaskExecutionResult
|
||||
from ..core.exception_handler import TaskExecutionError, DatabaseError, SchedulerExceptionHandler
|
||||
from ..utils.frequency_calculator import calculate_next_execution
|
||||
from models.monitoring_models import MonitoringTask, TaskExecutionLog
|
||||
from models.enhanced_strategy_models import EnhancedContentStrategy
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("monitoring_task_executor")
|
||||
|
||||
|
||||
class MonitoringTaskExecutor(TaskExecutor):
|
||||
"""
|
||||
Executor for content strategy monitoring tasks.
|
||||
|
||||
Handles:
|
||||
- ALwrity tasks (automated execution)
|
||||
- Human tasks (notifications/queuing)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logger
|
||||
self.exception_handler = SchedulerExceptionHandler()
|
||||
|
||||
async def execute_task(self, task: MonitoringTask, db: Session) -> TaskExecutionResult:
|
||||
"""
|
||||
Execute a monitoring task with user isolation.
|
||||
|
||||
Args:
|
||||
task: MonitoringTask instance (with strategy relationship loaded)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
TaskExecutionResult
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Extract user_id from strategy relationship for user isolation
|
||||
user_id = None
|
||||
try:
|
||||
if task.strategy and hasattr(task.strategy, 'user_id'):
|
||||
user_id = task.strategy.user_id
|
||||
elif task.strategy_id:
|
||||
# Fallback: query strategy if relationship not loaded
|
||||
strategy = db.query(EnhancedContentStrategy).filter(
|
||||
EnhancedContentStrategy.id == task.strategy_id
|
||||
).first()
|
||||
if strategy:
|
||||
user_id = strategy.user_id
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Could not extract user_id for task {task.id}: {e}")
|
||||
|
||||
try:
|
||||
self.logger.info(
|
||||
f"Executing monitoring task: {task.id} | "
|
||||
f"user_id: {user_id} | "
|
||||
f"assignee: {task.assignee} | "
|
||||
f"frequency: {task.frequency}"
|
||||
)
|
||||
|
||||
# Create execution log with user_id for user isolation tracking
|
||||
execution_log = TaskExecutionLog(
|
||||
task_id=task.id,
|
||||
user_id=user_id,
|
||||
execution_date=datetime.utcnow(),
|
||||
status='running'
|
||||
)
|
||||
db.add(execution_log)
|
||||
db.flush()
|
||||
|
||||
# Execute based on assignee
|
||||
if task.assignee == 'ALwrity':
|
||||
result = await self._execute_alwrity_task(task, db)
|
||||
else:
|
||||
result = await self._execute_human_task(task, db)
|
||||
|
||||
# Update execution log
|
||||
execution_time_ms = int((time.time() - start_time) * 1000)
|
||||
execution_log.status = 'success' if result.success else 'failed'
|
||||
execution_log.result_data = result.result_data
|
||||
execution_log.error_message = result.error_message
|
||||
execution_log.execution_time_ms = execution_time_ms
|
||||
|
||||
# Update task
|
||||
task.last_executed = datetime.utcnow()
|
||||
task.next_execution = self.calculate_next_execution(
|
||||
task,
|
||||
task.frequency,
|
||||
task.last_executed
|
||||
)
|
||||
|
||||
if result.success:
|
||||
task.status = 'completed'
|
||||
else:
|
||||
task.status = 'failed'
|
||||
|
||||
db.commit()
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
execution_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Set database session for exception handler
|
||||
self.exception_handler.db = db
|
||||
|
||||
# Create structured error
|
||||
error = TaskExecutionError(
|
||||
message=f"Error executing monitoring task {task.id}: {str(e)}",
|
||||
user_id=user_id,
|
||||
task_id=task.id,
|
||||
task_type="monitoring_task",
|
||||
execution_time_ms=execution_time_ms,
|
||||
context={
|
||||
"assignee": task.assignee,
|
||||
"frequency": task.frequency,
|
||||
"component": task.component_name
|
||||
},
|
||||
original_error=e
|
||||
)
|
||||
|
||||
# Handle exception with structured logging
|
||||
self.exception_handler.handle_exception(error)
|
||||
|
||||
# Update execution log with error (include user_id for isolation)
|
||||
try:
|
||||
execution_log = TaskExecutionLog(
|
||||
task_id=task.id,
|
||||
user_id=user_id,
|
||||
execution_date=datetime.utcnow(),
|
||||
status='failed',
|
||||
error_message=str(e),
|
||||
execution_time_ms=execution_time_ms,
|
||||
result_data={
|
||||
"error_type": error.error_type.value,
|
||||
"severity": error.severity.value,
|
||||
"context": error.context
|
||||
}
|
||||
)
|
||||
db.add(execution_log)
|
||||
|
||||
task.status = 'failed'
|
||||
task.last_executed = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
except Exception as commit_error:
|
||||
db_error = DatabaseError(
|
||||
message=f"Error saving execution log: {str(commit_error)}",
|
||||
user_id=user_id,
|
||||
task_id=task.id,
|
||||
original_error=commit_error
|
||||
)
|
||||
self.exception_handler.handle_exception(db_error)
|
||||
db.rollback()
|
||||
|
||||
return TaskExecutionResult(
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
execution_time_ms=execution_time_ms,
|
||||
retryable=True,
|
||||
retry_delay=300
|
||||
)
|
||||
|
||||
async def _execute_alwrity_task(self, task: MonitoringTask, db: Session) -> TaskExecutionResult:
|
||||
"""
|
||||
Execute an ALwrity (automated) monitoring task.
|
||||
|
||||
This is where the actual monitoring logic would go.
|
||||
For now, we'll implement a placeholder that can be extended.
|
||||
"""
|
||||
try:
|
||||
self.logger.info(f"Executing ALwrity task: {task.task_title}")
|
||||
|
||||
# TODO: Implement actual monitoring logic based on:
|
||||
# - task.metric
|
||||
# - task.measurement_method
|
||||
# - task.success_criteria
|
||||
# - task.alert_threshold
|
||||
|
||||
# Placeholder: Simulate task execution
|
||||
result_data = {
|
||||
'metric_value': 0,
|
||||
'status': 'measured',
|
||||
'message': f"Task {task.task_title} executed successfully",
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
return TaskExecutionResult(
|
||||
success=True,
|
||||
result_data=result_data
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in ALwrity task execution: {e}")
|
||||
return TaskExecutionResult(
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
retryable=True
|
||||
)
|
||||
|
||||
async def _execute_human_task(self, task: MonitoringTask, db: Session) -> TaskExecutionResult:
|
||||
"""
|
||||
Execute a Human monitoring task (notification/queuing).
|
||||
|
||||
For human tasks, we don't execute the task directly,
|
||||
but rather queue it for human review or send notifications.
|
||||
"""
|
||||
try:
|
||||
self.logger.info(f"Queuing human task: {task.task_title}")
|
||||
|
||||
# TODO: Implement notification/queuing system:
|
||||
# - Send email notification
|
||||
# - Add to user's task queue
|
||||
# - Create in-app notification
|
||||
|
||||
result_data = {
|
||||
'status': 'queued',
|
||||
'message': f"Task {task.task_title} queued for human review",
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
return TaskExecutionResult(
|
||||
success=True,
|
||||
result_data=result_data
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error queuing human task: {e}")
|
||||
return TaskExecutionResult(
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
retryable=True
|
||||
)
|
||||
|
||||
def calculate_next_execution(
|
||||
self,
|
||||
task: MonitoringTask,
|
||||
frequency: str,
|
||||
last_execution: Optional[datetime] = None
|
||||
) -> datetime:
|
||||
"""
|
||||
Calculate next execution time based on frequency.
|
||||
|
||||
Args:
|
||||
task: MonitoringTask instance
|
||||
frequency: Frequency string (Daily, Weekly, Monthly, Quarterly)
|
||||
last_execution: Last execution datetime (defaults to now)
|
||||
|
||||
Returns:
|
||||
Next execution datetime
|
||||
"""
|
||||
return calculate_next_execution(
|
||||
frequency=frequency,
|
||||
base_time=last_execution or datetime.utcnow()
|
||||
)
|
||||
|
||||
4
backend/services/scheduler/utils/__init__.py
Normal file
4
backend/services/scheduler/utils/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
Scheduler utilities.
|
||||
"""
|
||||
|
||||
33
backend/services/scheduler/utils/frequency_calculator.py
Normal file
33
backend/services/scheduler/utils/frequency_calculator.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""
|
||||
Frequency Calculator Utility
|
||||
Calculates next execution time based on frequency string.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def calculate_next_execution(frequency: str, base_time: Optional[datetime] = None) -> datetime:
|
||||
"""
|
||||
Calculate next execution time based on frequency.
|
||||
|
||||
Args:
|
||||
frequency: Frequency string ('Daily', 'Weekly', 'Monthly', 'Quarterly')
|
||||
base_time: Base time to calculate from (defaults to now if None)
|
||||
|
||||
Returns:
|
||||
Next execution datetime
|
||||
"""
|
||||
if base_time is None:
|
||||
base_time = datetime.utcnow()
|
||||
|
||||
frequency_map = {
|
||||
'Daily': timedelta(days=1),
|
||||
'Weekly': timedelta(weeks=1),
|
||||
'Monthly': timedelta(days=30),
|
||||
'Quarterly': timedelta(days=90)
|
||||
}
|
||||
|
||||
delta = frequency_map.get(frequency, timedelta(days=1))
|
||||
return base_time + delta
|
||||
|
||||
60
backend/services/scheduler/utils/task_loader.py
Normal file
60
backend/services/scheduler/utils/task_loader.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Task Loader Utilities
|
||||
Functions to load due tasks from database.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy import and_, or_
|
||||
|
||||
from models.monitoring_models import MonitoringTask
|
||||
from models.enhanced_strategy_models import EnhancedContentStrategy
|
||||
|
||||
|
||||
def load_due_monitoring_tasks(
|
||||
db: Session,
|
||||
user_id: Optional[int] = None
|
||||
) -> List[MonitoringTask]:
|
||||
"""
|
||||
Load all monitoring tasks that are due for execution.
|
||||
|
||||
Criteria:
|
||||
- status == 'active'
|
||||
- next_execution <= now (or is None for first execution)
|
||||
- Optional: user_id filter for specific user (for future admin features)
|
||||
|
||||
Note: Strategy relationship is eagerly loaded to ensure user_id is accessible
|
||||
during task execution for user isolation.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: Optional user ID to filter tasks (if None, loads all users' tasks)
|
||||
|
||||
Returns:
|
||||
List of due MonitoringTask instances with strategy relationship loaded
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
|
||||
# Join with strategy to ensure relationship is loaded and support user filtering
|
||||
query = db.query(MonitoringTask).join(
|
||||
EnhancedContentStrategy,
|
||||
MonitoringTask.strategy_id == EnhancedContentStrategy.id
|
||||
).options(
|
||||
joinedload(MonitoringTask.strategy) # Eagerly load strategy relationship
|
||||
).filter(
|
||||
and_(
|
||||
MonitoringTask.status == 'active',
|
||||
or_(
|
||||
MonitoringTask.next_execution <= now,
|
||||
MonitoringTask.next_execution.is_(None)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Apply user filter if provided
|
||||
if user_id is not None:
|
||||
query = query.filter(EnhancedContentStrategy.user_id == user_id)
|
||||
|
||||
return query.all()
|
||||
|
||||
189
backend/services/subscription/preflight_validator.py
Normal file
189
backend/services/subscription/preflight_validator.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
Pre-flight Validation Utility for Multi-Operation Workflows
|
||||
|
||||
Provides transparent validation for operations that involve multiple API calls.
|
||||
Services can use this to validate entire workflows before making any external API calls.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from services.subscription.pricing_service import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
|
||||
|
||||
def validate_research_operations(
|
||||
pricing_service: PricingService,
|
||||
user_id: str,
|
||||
gpt_provider: str = "google"
|
||||
) -> None:
|
||||
"""
|
||||
Validate all operations for a research workflow before making ANY API calls.
|
||||
|
||||
This prevents wasteful external API calls (like Google Grounding) if subsequent
|
||||
LLM calls would fail due to token or call limits.
|
||||
|
||||
Args:
|
||||
pricing_service: PricingService instance
|
||||
user_id: User ID for subscription checking
|
||||
gpt_provider: GPT provider from env var (defaults to "google")
|
||||
|
||||
Returns:
|
||||
(can_proceed, error_message, error_details)
|
||||
If can_proceed is False, raises HTTPException with 429 status
|
||||
"""
|
||||
try:
|
||||
# Determine actual provider for LLM calls based on GPT_PROVIDER env var
|
||||
gpt_provider_lower = gpt_provider.lower()
|
||||
if gpt_provider_lower == "huggingface":
|
||||
llm_provider_enum = APIProvider.MISTRAL # Maps to HuggingFace
|
||||
llm_provider_name = "huggingface"
|
||||
else:
|
||||
llm_provider_enum = APIProvider.GEMINI
|
||||
llm_provider_name = "gemini"
|
||||
|
||||
# Estimate tokens for each operation in research workflow
|
||||
# Google Grounding call: ~2000 tokens (input + output)
|
||||
# Keyword analyzer: ~1000 tokens (input: 3000 chars research, output: structured JSON)
|
||||
# Competitor analyzer: ~1000 tokens (input: 3000 chars research, output: structured JSON)
|
||||
# Content angle generator: ~1000 tokens (input: 3000 chars research, output: list of angles)
|
||||
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': APIProvider.GEMINI, # Google Grounding uses Gemini
|
||||
'tokens_requested': 2000,
|
||||
'actual_provider_name': 'gemini',
|
||||
'operation_type': 'google_grounding'
|
||||
},
|
||||
{
|
||||
'provider': llm_provider_enum,
|
||||
'tokens_requested': 1000,
|
||||
'actual_provider_name': llm_provider_name,
|
||||
'operation_type': 'keyword_analysis'
|
||||
},
|
||||
{
|
||||
'provider': llm_provider_enum,
|
||||
'tokens_requested': 1000,
|
||||
'actual_provider_name': llm_provider_name,
|
||||
'operation_type': 'competitor_analysis'
|
||||
},
|
||||
{
|
||||
'provider': llm_provider_enum,
|
||||
'tokens_requested': 1000,
|
||||
'actual_provider_name': llm_provider_name,
|
||||
'operation_type': 'content_angle_generation'
|
||||
}
|
||||
]
|
||||
|
||||
logger.info(f"[Pre-flight Validator] 🚀 Starting Research Workflow Validation")
|
||||
logger.info(f" ├─ User: {user_id}")
|
||||
logger.info(f" ├─ LLM Provider: {llm_provider_name} (GPT_PROVIDER={gpt_provider})")
|
||||
logger.info(f" └─ Operations to validate: {len(operations_to_validate)}")
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
usage_info = error_details.get('usage_info', {}) if error_details else {}
|
||||
provider = usage_info.get('provider', llm_provider_name) if usage_info else llm_provider_name
|
||||
operation_type = usage_info.get('operation_type', 'unknown')
|
||||
|
||||
logger.error(f"[Pre-flight Validator] ❌ RESEARCH WORKFLOW BLOCKED")
|
||||
logger.error(f" ├─ User: {user_id}")
|
||||
logger.error(f" ├─ Blocked at: {operation_type}")
|
||||
logger.error(f" ├─ Provider: {provider}")
|
||||
logger.error(f" └─ Reason: {message}")
|
||||
|
||||
# Raise HTTPException immediately - frontend gets immediate response, no API calls made
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': provider,
|
||||
'usage_info': usage_info if usage_info else error_details
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[Pre-flight Validator] ✅ RESEARCH WORKFLOW APPROVED")
|
||||
logger.info(f" ├─ User: {user_id}")
|
||||
logger.info(f" └─ All {len(operations_to_validate)} operations validated - proceeding with API calls")
|
||||
# Validation passed - no return needed (function raises HTTPException if validation fails)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Validator] Error validating research operations: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
'error': f"Failed to validate operations: {str(e)}",
|
||||
'message': f"Failed to validate operations: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def validate_image_generation_operations(
|
||||
pricing_service: PricingService,
|
||||
user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Validate image generation operation before making API calls.
|
||||
|
||||
Args:
|
||||
pricing_service: PricingService instance
|
||||
user_id: User ID for subscription checking
|
||||
|
||||
Returns:
|
||||
(can_proceed, error_message, error_details)
|
||||
If can_proceed is False, raises HTTPException with 429 status
|
||||
"""
|
||||
try:
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': APIProvider.STABILITY,
|
||||
'tokens_requested': 0,
|
||||
'actual_provider_name': 'stability',
|
||||
'operation_type': 'image_generation'
|
||||
}
|
||||
]
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
logger.error(f"[Pre-flight Validator] Image generation blocked for user {user_id}: {message}")
|
||||
|
||||
usage_info = error_details.get('usage_info', {}) if error_details else {}
|
||||
provider = usage_info.get('provider', 'stability') if usage_info else 'stability'
|
||||
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': provider,
|
||||
'usage_info': usage_info if usage_info else error_details
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[Pre-flight Validator] ✅ Image generation validated for user {user_id}")
|
||||
# Validation passed - no return needed (function raises HTTPException if validation fails)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Validator] Error validating image generation: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
'error': f"Failed to validate image generation: {str(e)}",
|
||||
'message': f"Failed to validate image generation: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
@@ -3,10 +3,11 @@ Pricing Service for API Usage Tracking
|
||||
Manages API pricing, cost calculation, and subscription limits.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List, Tuple
|
||||
from typing import Dict, Any, Optional, List, Tuple, Union
|
||||
from decimal import Decimal, ROUND_HALF_UP
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
from loguru import logger
|
||||
|
||||
from models.subscription_models import (
|
||||
@@ -17,13 +18,17 @@ from models.subscription_models import (
|
||||
class PricingService:
|
||||
"""Service for managing API pricing and cost calculations."""
|
||||
|
||||
# Class-level cache shared across all instances (critical for cache invalidation on subscription renewal)
|
||||
# key: f"{user_id}:{provider}", value: { 'result': (bool, str, dict), 'expires_at': datetime }
|
||||
_limits_cache: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self._pricing_cache = {}
|
||||
self._plans_cache = {}
|
||||
# Lightweight in-process cache for limit checks
|
||||
# key: f"{user_id}:{provider}", value: { 'result': (bool, str, dict), 'expires_at': datetime }
|
||||
self._limits_cache: Dict[str, Dict[str, Any]] = {}
|
||||
# Cache for schema feature detection (ai_text_generation_calls_limit column)
|
||||
self._ai_text_gen_col_checked: bool = False
|
||||
self._ai_text_gen_col_available: bool = False
|
||||
|
||||
# ------------------- Billing period helpers -------------------
|
||||
def _compute_next_period_end(self, start: datetime, cycle: str) -> datetime:
|
||||
@@ -68,6 +73,15 @@ class PricingService:
|
||||
self._ensure_subscription_current(subscription)
|
||||
# Continue to use YYYY-MM for summaries
|
||||
return datetime.now().strftime("%Y-%m")
|
||||
|
||||
@classmethod
|
||||
def clear_user_cache(cls, user_id: str) -> int:
|
||||
"""Clear all cached limit checks for a specific user. Returns number of entries cleared."""
|
||||
keys_to_remove = [key for key in cls._limits_cache.keys() if key.startswith(f"{user_id}:")]
|
||||
for key in keys_to_remove:
|
||||
del cls._limits_cache[key]
|
||||
logger.info(f"Cleared {len(keys_to_remove)} cache entries for user {user_id}")
|
||||
return len(keys_to_remove)
|
||||
|
||||
def initialize_default_pricing(self):
|
||||
"""Initialize default pricing for all API providers."""
|
||||
@@ -292,7 +306,8 @@ class PricingService:
|
||||
"tier": SubscriptionTier.BASIC,
|
||||
"price_monthly": 29.0,
|
||||
"price_yearly": 290.0,
|
||||
"gemini_calls_limit": 1000,
|
||||
"ai_text_generation_calls_limit": 10, # Unified limit for all LLM providers
|
||||
"gemini_calls_limit": 1000, # Legacy, kept for backwards compatibility (not used for enforcement)
|
||||
"openai_calls_limit": 500,
|
||||
"anthropic_calls_limit": 200,
|
||||
"mistral_calls_limit": 500,
|
||||
@@ -300,11 +315,11 @@ class PricingService:
|
||||
"serper_calls_limit": 200,
|
||||
"metaphor_calls_limit": 100,
|
||||
"firecrawl_calls_limit": 100,
|
||||
"stability_calls_limit": 50,
|
||||
"gemini_tokens_limit": 1000000,
|
||||
"openai_tokens_limit": 500000,
|
||||
"anthropic_tokens_limit": 200000,
|
||||
"mistral_tokens_limit": 500000,
|
||||
"stability_calls_limit": 5,
|
||||
"gemini_tokens_limit": 2000,
|
||||
"openai_tokens_limit": 2000,
|
||||
"anthropic_tokens_limit": 2000,
|
||||
"mistral_tokens_limit": 2000,
|
||||
"monthly_cost_limit": 50.0,
|
||||
"features": ["full_content_generation", "advanced_research", "basic_analytics"],
|
||||
"description": "Great for individuals and small teams"
|
||||
@@ -426,21 +441,60 @@ class PricingService:
|
||||
self._ensure_subscription_current(subscription)
|
||||
return self._plan_to_limits_dict(subscription.plan)
|
||||
|
||||
def _ensure_ai_text_gen_column_detection(self) -> None:
|
||||
"""Detect at runtime whether ai_text_generation_calls_limit column exists and cache the result."""
|
||||
if self._ai_text_gen_col_checked:
|
||||
return
|
||||
try:
|
||||
# Try to query the column - if it exists, this will work
|
||||
self.db.execute(text('SELECT ai_text_generation_calls_limit FROM subscription_plans LIMIT 0'))
|
||||
self._ai_text_gen_col_available = True
|
||||
except Exception:
|
||||
self._ai_text_gen_col_available = False
|
||||
finally:
|
||||
self._ai_text_gen_col_checked = True
|
||||
|
||||
def _plan_to_limits_dict(self, plan: SubscriptionPlan) -> Dict[str, Any]:
|
||||
"""Convert subscription plan to limits dictionary."""
|
||||
# Detect if unified AI text generation limit column exists
|
||||
self._ensure_ai_text_gen_column_detection()
|
||||
|
||||
# Use unified AI text generation limit if column exists and is set
|
||||
ai_text_gen_limit = None
|
||||
if self._ai_text_gen_col_available:
|
||||
try:
|
||||
ai_text_gen_limit = getattr(plan, 'ai_text_generation_calls_limit', None)
|
||||
# If 0, treat as not set (unlimited for Enterprise or use fallback)
|
||||
if ai_text_gen_limit == 0:
|
||||
ai_text_gen_limit = None
|
||||
except (AttributeError, Exception):
|
||||
# Column exists but access failed - use fallback
|
||||
ai_text_gen_limit = None
|
||||
|
||||
return {
|
||||
'plan_name': plan.name,
|
||||
'tier': plan.tier.value,
|
||||
'limits': {
|
||||
# Unified AI text generation limit (applies to all LLM providers)
|
||||
# If not set, fall back to first non-zero legacy limit for backwards compatibility
|
||||
'ai_text_generation_calls': ai_text_gen_limit if ai_text_gen_limit is not None else (
|
||||
plan.gemini_calls_limit if plan.gemini_calls_limit > 0 else
|
||||
plan.openai_calls_limit if plan.openai_calls_limit > 0 else
|
||||
plan.anthropic_calls_limit if plan.anthropic_calls_limit > 0 else
|
||||
plan.mistral_calls_limit if plan.mistral_calls_limit > 0 else 0
|
||||
),
|
||||
# Legacy per-provider limits (for backwards compatibility and analytics)
|
||||
'gemini_calls': plan.gemini_calls_limit,
|
||||
'openai_calls': plan.openai_calls_limit,
|
||||
'anthropic_calls': plan.anthropic_calls_limit,
|
||||
'mistral_calls': plan.mistral_calls_limit,
|
||||
# Other API limits
|
||||
'tavily_calls': plan.tavily_calls_limit,
|
||||
'serper_calls': plan.serper_calls_limit,
|
||||
'metaphor_calls': plan.metaphor_calls_limit,
|
||||
'firecrawl_calls': plan.firecrawl_calls_limit,
|
||||
'stability_calls': plan.stability_calls_limit,
|
||||
# Token limits
|
||||
'gemini_tokens': plan.gemini_tokens_limit,
|
||||
'openai_tokens': plan.openai_tokens_limit,
|
||||
'anthropic_tokens': plan.anthropic_tokens_limit,
|
||||
@@ -451,101 +505,293 @@ class PricingService:
|
||||
}
|
||||
|
||||
def check_usage_limits(self, user_id: str, provider: APIProvider,
|
||||
tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]:
|
||||
"""Check if user can make an API call within their limits."""
|
||||
# Short TTL cache to reduce DB reads under sustained traffic
|
||||
cache_key = f"{user_id}:{provider.value}"
|
||||
now = datetime.utcnow()
|
||||
cached = self._limits_cache.get(cache_key)
|
||||
if cached and cached.get('expires_at') and cached['expires_at'] > now:
|
||||
return tuple(cached['result']) # type: ignore
|
||||
|
||||
# Get user limits
|
||||
limits = self.get_user_limits(user_id)
|
||||
if not limits:
|
||||
return False, "No subscription plan found", {}
|
||||
tokens_requested: int = 0, actual_provider_name: Optional[str] = None) -> Tuple[bool, str, Dict[str, Any]]:
|
||||
"""Check if user can make an API call within their limits.
|
||||
|
||||
# Get current usage for this billing period
|
||||
current_period = self.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not usage:
|
||||
# First usage this period, create summary
|
||||
usage = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
self.db.add(usage)
|
||||
self.db.commit()
|
||||
|
||||
# Check call limits
|
||||
provider_name = provider.value
|
||||
current_calls = getattr(usage, f"{provider_name}_calls", 0)
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
|
||||
|
||||
if call_limit > 0 and current_calls >= call_limit:
|
||||
result = (False, f"API call limit reached for {provider_name}", {
|
||||
'current_calls': current_calls,
|
||||
'limit': call_limit,
|
||||
'usage_percentage': 100.0
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
|
||||
# Check token limits for LLM providers
|
||||
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
current_tokens = getattr(usage, f"{provider_name}_tokens", 0)
|
||||
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0)
|
||||
Args:
|
||||
user_id: User ID
|
||||
provider: APIProvider enum (may be MISTRAL for HuggingFace)
|
||||
tokens_requested: Estimated tokens for the request
|
||||
actual_provider_name: Optional actual provider name (e.g., "huggingface" when provider is MISTRAL)
|
||||
"""
|
||||
try:
|
||||
# Use actual_provider_name if provided, otherwise use enum value
|
||||
# This fixes cases where HuggingFace maps to MISTRAL enum but should show as "huggingface" in errors
|
||||
display_provider_name = actual_provider_name or provider.value
|
||||
|
||||
if token_limit > 0 and (current_tokens + tokens_requested) > token_limit:
|
||||
result = (False, f"Token limit would be exceeded for {provider_name}", {
|
||||
'current_tokens': current_tokens,
|
||||
'requested_tokens': tokens_requested,
|
||||
'limit': token_limit,
|
||||
'usage_percentage': ((current_tokens + tokens_requested) / token_limit) * 100
|
||||
logger.debug(f"[Subscription Check] Starting limit check for user {user_id}, provider {display_provider_name}, tokens {tokens_requested}")
|
||||
|
||||
# Short TTL cache to reduce DB reads under sustained traffic
|
||||
cache_key = f"{user_id}:{provider.value}"
|
||||
now = datetime.utcnow()
|
||||
cached = self._limits_cache.get(cache_key)
|
||||
if cached and cached.get('expires_at') and cached['expires_at'] > now:
|
||||
logger.debug(f"[Subscription Check] Using cached result for {user_id}:{provider.value}")
|
||||
return tuple(cached['result']) # type: ignore
|
||||
|
||||
# Get user subscription first to check expiration
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id,
|
||||
UserSubscription.is_active == True
|
||||
).first()
|
||||
|
||||
if subscription:
|
||||
logger.debug(f"[Subscription Check] Found subscription for user {user_id}: plan_id={subscription.plan_id}, period_end={subscription.current_period_end}")
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] No active subscription found for user {user_id}")
|
||||
|
||||
# Check subscription expiration (STRICT: deny if expired)
|
||||
if subscription:
|
||||
if subscription.current_period_end < now:
|
||||
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}: period_end={subscription.current_period_end}, now={now}")
|
||||
# Subscription expired - check if auto_renew is enabled
|
||||
if not getattr(subscription, 'auto_renew', False):
|
||||
# Expired and no auto-renew - deny access
|
||||
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}, auto_renew=False, denying access")
|
||||
result = (False, "Subscription expired. Please renew your subscription to continue using the service.", {
|
||||
'expired': True,
|
||||
'period_end': subscription.current_period_end.isoformat()
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
else:
|
||||
# Try to auto-renew
|
||||
if not self._ensure_subscription_current(subscription):
|
||||
# Auto-renew failed - deny access
|
||||
result = (False, "Subscription expired and auto-renewal failed. Please renew manually.", {
|
||||
'expired': True,
|
||||
'auto_renew_failed': True
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
|
||||
# Get user limits with error handling (STRICT: fail on errors)
|
||||
try:
|
||||
limits = self.get_user_limits(user_id)
|
||||
if limits:
|
||||
logger.debug(f"[Subscription Check] Retrieved limits for user {user_id}: plan={limits.get('plan_name')}, tier={limits.get('tier')}")
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] No limits found for user {user_id}, checking free tier")
|
||||
except Exception as e:
|
||||
logger.error(f"[Subscription Check] Error getting user limits for {user_id}: {e}", exc_info=True)
|
||||
# STRICT: Fail closed - deny request if we can't check limits
|
||||
return False, f"Failed to retrieve subscription limits: {str(e)}", {}
|
||||
|
||||
if not limits:
|
||||
# No subscription found - check for free tier
|
||||
free_plan = self.db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.tier == SubscriptionTier.FREE,
|
||||
SubscriptionPlan.is_active == True
|
||||
).first()
|
||||
if free_plan:
|
||||
logger.info(f"[Subscription Check] Assigning free tier to user {user_id}")
|
||||
limits = self._plan_to_limits_dict(free_plan)
|
||||
else:
|
||||
# No subscription and no free tier - deny access
|
||||
logger.warning(f"[Subscription Check] No subscription or free tier found for user {user_id}, denying access")
|
||||
return False, "No subscription plan found. Please subscribe to a plan.", {}
|
||||
|
||||
# Get current usage for this billing period with error handling
|
||||
try:
|
||||
current_period = self.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not usage:
|
||||
# First usage this period, create summary
|
||||
try:
|
||||
usage = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
self.db.add(usage)
|
||||
self.db.commit()
|
||||
except Exception as create_error:
|
||||
logger.error(f"Error creating usage summary: {create_error}")
|
||||
self.db.rollback()
|
||||
# STRICT: Fail closed on DB error
|
||||
return False, f"Failed to create usage summary: {str(create_error)}", {}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting usage summary for {user_id}: {e}")
|
||||
self.db.rollback()
|
||||
# STRICT: Fail closed on DB error
|
||||
return False, f"Failed to retrieve usage summary: {str(e)}", {}
|
||||
|
||||
# Check call limits with error handling
|
||||
# NOTE: call_limit = 0 means UNLIMITED (Enterprise plans)
|
||||
try:
|
||||
# Use display_provider_name for error messages, but provider.value for DB queries
|
||||
provider_name = provider.value # For DB field names (e.g., "mistral_calls", "mistral_tokens")
|
||||
|
||||
# For LLM text generation providers, check against unified total_calls limit
|
||||
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
|
||||
is_llm_provider = provider_name in llm_providers
|
||||
|
||||
if is_llm_provider:
|
||||
# Use unified AI text generation limit (total_calls across all LLM providers)
|
||||
ai_text_gen_limit = limits['limits'].get('ai_text_generation_calls', 0) or 0
|
||||
|
||||
# If unified limit not set, fall back to provider-specific limit for backwards compatibility
|
||||
if ai_text_gen_limit == 0:
|
||||
ai_text_gen_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
|
||||
|
||||
# Calculate total LLM provider calls (sum of gemini + openai + anthropic + mistral)
|
||||
current_total_llm_calls = (
|
||||
(usage.gemini_calls or 0) +
|
||||
(usage.openai_calls or 0) +
|
||||
(usage.anthropic_calls or 0) +
|
||||
(usage.mistral_calls or 0)
|
||||
)
|
||||
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if ai_text_gen_limit > 0 and current_total_llm_calls >= ai_text_gen_limit:
|
||||
logger.error(f"[Subscription Check] AI text generation call limit exceeded for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit} (provider: {display_provider_name})")
|
||||
result = (False, f"AI text generation call limit reached. Used {current_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls this billing period.", {
|
||||
'current_calls': current_total_llm_calls,
|
||||
'limit': ai_text_gen_limit,
|
||||
'usage_percentage': (current_total_llm_calls / ai_text_gen_limit) * 100 if ai_text_gen_limit > 0 else 0,
|
||||
'provider': display_provider_name, # Use display name for consistency
|
||||
'usage_info': {
|
||||
'provider': display_provider_name, # Use display name for user-facing info
|
||||
'current_calls': current_total_llm_calls,
|
||||
'limit': ai_text_gen_limit,
|
||||
'type': 'ai_text_generation',
|
||||
'breakdown': {
|
||||
'gemini': usage.gemini_calls or 0,
|
||||
'openai': usage.openai_calls or 0,
|
||||
'anthropic': usage.anthropic_calls or 0,
|
||||
'mistral': usage.mistral_calls or 0 # DB field name (not display name)
|
||||
}
|
||||
}
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] AI text generation limit check passed for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit if ai_text_gen_limit > 0 else 'unlimited'} (provider: {display_provider_name})")
|
||||
else:
|
||||
# For non-LLM providers, check provider-specific limit
|
||||
current_calls = getattr(usage, f"{provider_name}_calls", 0) or 0
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
|
||||
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if call_limit > 0 and current_calls >= call_limit:
|
||||
logger.error(f"[Subscription Check] Call limit exceeded for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit}")
|
||||
result = (False, f"API call limit reached for {display_provider_name}. Used {current_calls} of {call_limit} calls this billing period.", {
|
||||
'current_calls': current_calls,
|
||||
'limit': call_limit,
|
||||
'usage_percentage': 100.0,
|
||||
'provider': display_provider_name # Use display name for consistency
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] Call limit check passed for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit if call_limit > 0 else 'unlimited'}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking call limits: {e}")
|
||||
# Continue to next check
|
||||
|
||||
# Check token limits for LLM providers with error handling
|
||||
# NOTE: token_limit = 0 means UNLIMITED (Enterprise plans)
|
||||
try:
|
||||
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
current_tokens = getattr(usage, f"{provider_name}_tokens", 0) or 0
|
||||
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
|
||||
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if token_limit > 0 and (current_tokens + tokens_requested) > token_limit:
|
||||
result = (False, f"Token limit would be exceeded for {display_provider_name}. Current: {current_tokens}, Requested: {tokens_requested}, Limit: {token_limit}", {
|
||||
'current_tokens': current_tokens,
|
||||
'requested_tokens': tokens_requested,
|
||||
'limit': token_limit,
|
||||
'usage_percentage': ((current_tokens + tokens_requested) / token_limit) * 100,
|
||||
'provider': display_provider_name, # Use display name in error details
|
||||
'usage_info': {
|
||||
'provider': display_provider_name,
|
||||
'current_tokens': current_tokens,
|
||||
'requested_tokens': tokens_requested,
|
||||
'limit': token_limit,
|
||||
'type': 'tokens'
|
||||
}
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking token limits: {e}")
|
||||
# Continue to next check
|
||||
|
||||
# Check cost limits with error handling
|
||||
# NOTE: cost_limit = 0 means UNLIMITED (Enterprise plans)
|
||||
try:
|
||||
cost_limit = limits['limits'].get('monthly_cost', 0) or 0
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if cost_limit > 0 and usage.total_cost >= cost_limit:
|
||||
result = (False, f"Monthly cost limit reached. Current cost: ${usage.total_cost:.2f}, Limit: ${cost_limit:.2f}", {
|
||||
'current_cost': usage.total_cost,
|
||||
'limit': cost_limit,
|
||||
'usage_percentage': 100.0
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking cost limits: {e}")
|
||||
# Continue to success case
|
||||
|
||||
# Calculate usage percentages for warnings
|
||||
try:
|
||||
# Determine which call variables to use based on provider type
|
||||
if is_llm_provider:
|
||||
# Use unified LLM call tracking
|
||||
current_call_count = current_total_llm_calls
|
||||
call_limit_value = ai_text_gen_limit
|
||||
else:
|
||||
# Use provider-specific call tracking
|
||||
current_call_count = current_calls
|
||||
call_limit_value = call_limit
|
||||
|
||||
call_usage_pct = (current_call_count / max(call_limit_value, 1)) * 100 if call_limit_value > 0 else 0
|
||||
cost_usage_pct = (usage.total_cost / max(cost_limit, 1)) * 100 if cost_limit > 0 else 0
|
||||
result = (True, "Within limits", {
|
||||
'current_calls': current_call_count,
|
||||
'call_limit': call_limit_value,
|
||||
'call_usage_percentage': call_usage_pct,
|
||||
'current_cost': usage.total_cost,
|
||||
'cost_limit': cost_limit,
|
||||
'cost_usage_percentage': cost_usage_pct
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating usage percentages: {e}")
|
||||
# Return basic success
|
||||
return True, "Within limits", {}
|
||||
|
||||
# Check cost limits
|
||||
cost_limit = limits['limits'].get('monthly_cost', 0)
|
||||
if cost_limit > 0 and usage.total_cost >= cost_limit:
|
||||
result = (False, "Monthly cost limit reached", {
|
||||
'current_cost': usage.total_cost,
|
||||
'limit': cost_limit,
|
||||
'usage_percentage': 100.0
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
|
||||
# Calculate usage percentages for warnings
|
||||
call_usage_pct = (current_calls / max(call_limit, 1)) * 100 if call_limit > 0 else 0
|
||||
cost_usage_pct = (usage.total_cost / max(cost_limit, 1)) * 100 if cost_limit > 0 else 0
|
||||
result = (True, "Within limits", {
|
||||
'current_calls': current_calls,
|
||||
'call_limit': call_limit,
|
||||
'call_usage_percentage': call_usage_pct,
|
||||
'current_cost': usage.total_cost,
|
||||
'cost_limit': cost_limit,
|
||||
'cost_usage_percentage': cost_usage_pct
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in check_usage_limits for {user_id}: {e}")
|
||||
# STRICT: Fail closed - deny requests if subscription system fails
|
||||
return False, f"Subscription check error: {str(e)}", {}
|
||||
|
||||
def estimate_tokens(self, text: str, provider: APIProvider) -> int:
|
||||
"""Estimate token count for text based on provider."""
|
||||
@@ -581,6 +827,236 @@ class PricingService:
|
||||
if not pricing:
|
||||
return None
|
||||
|
||||
def check_comprehensive_limits(
|
||||
self,
|
||||
user_id: str,
|
||||
operations: List[Dict[str, Any]]
|
||||
) -> Tuple[bool, Optional[str], Optional[Dict[str, Any]]]:
|
||||
"""
|
||||
Comprehensive pre-flight validation that checks ALL limits before making ANY API calls.
|
||||
|
||||
This prevents wasteful API calls by validating that ALL subsequent operations will succeed
|
||||
before making the first external API call.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
operations: List of operations to validate, each with:
|
||||
- 'provider': APIProvider enum
|
||||
- 'tokens_requested': int (estimated tokens for LLM calls, 0 for non-LLM)
|
||||
- 'actual_provider_name': Optional[str] (e.g., "huggingface" when provider is MISTRAL)
|
||||
- 'operation_type': str (e.g., "google_grounding", "llm_call", "image_generation")
|
||||
|
||||
Returns:
|
||||
(can_proceed, error_message, error_details)
|
||||
If can_proceed is False, error_message explains which limit would be exceeded
|
||||
"""
|
||||
try:
|
||||
logger.info(f"[Pre-flight Check] 🔍 Starting comprehensive validation for user {user_id}")
|
||||
logger.info(f"[Pre-flight Check] 📋 Validating {len(operations)} operation(s) before making any API calls")
|
||||
|
||||
# Get current usage and limits once
|
||||
current_period = self.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not usage:
|
||||
# First usage this period, create summary
|
||||
try:
|
||||
usage = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
self.db.add(usage)
|
||||
self.db.commit()
|
||||
except Exception as create_error:
|
||||
logger.error(f"Error creating usage summary: {create_error}")
|
||||
self.db.rollback()
|
||||
return False, f"Failed to create usage summary: {str(create_error)}", {}
|
||||
|
||||
# Get user limits
|
||||
limits_dict = self.get_user_limits(user_id)
|
||||
if not limits_dict:
|
||||
# No subscription found - check for free tier
|
||||
free_plan = self.db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.tier == SubscriptionTier.FREE,
|
||||
SubscriptionPlan.is_active == True
|
||||
).first()
|
||||
if free_plan:
|
||||
limits_dict = self._plan_to_limits_dict(free_plan)
|
||||
else:
|
||||
return False, "No subscription plan found. Please subscribe to a plan.", {}
|
||||
|
||||
limits = limits_dict.get('limits', {})
|
||||
|
||||
# Track cumulative usage across all operations
|
||||
total_llm_calls = (
|
||||
(usage.gemini_calls or 0) +
|
||||
(usage.openai_calls or 0) +
|
||||
(usage.anthropic_calls or 0) +
|
||||
(usage.mistral_calls or 0)
|
||||
)
|
||||
total_llm_tokens = {}
|
||||
total_images = usage.stability_calls or 0
|
||||
|
||||
# Log current usage summary
|
||||
logger.info(f"[Pre-flight Check] 📊 Current Usage Summary:")
|
||||
logger.info(f" └─ Total LLM Calls: {total_llm_calls}")
|
||||
logger.info(f" └─ Gemini Tokens: {usage.gemini_tokens or 0}, Mistral/HF Tokens: {usage.mistral_tokens or 0}")
|
||||
logger.info(f" └─ Image Calls: {total_images}")
|
||||
|
||||
# Validate each operation
|
||||
for op_idx, operation in enumerate(operations):
|
||||
provider = operation.get('provider')
|
||||
provider_name = provider.value if hasattr(provider, 'value') else str(provider)
|
||||
tokens_requested = operation.get('tokens_requested', 0)
|
||||
actual_provider_name = operation.get('actual_provider_name')
|
||||
operation_type = operation.get('operation_type', 'unknown')
|
||||
|
||||
display_provider_name = actual_provider_name or provider_name
|
||||
|
||||
logger.info(f"[Pre-flight Check] ✅ Operation {op_idx + 1}/{len(operations)}: {operation_type}")
|
||||
logger.info(f" ├─ Provider: {display_provider_name} (enum: {provider_name})")
|
||||
logger.info(f" └─ Estimated Tokens: {tokens_requested}")
|
||||
|
||||
# Check if this is an LLM provider
|
||||
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
|
||||
is_llm_provider = provider_name in llm_providers
|
||||
|
||||
# Check unified AI text generation limit for LLM providers
|
||||
if is_llm_provider:
|
||||
ai_text_gen_limit = limits.get('ai_text_generation_calls', 0) or 0
|
||||
if ai_text_gen_limit == 0:
|
||||
# Fallback to provider-specific limit
|
||||
ai_text_gen_limit = limits.get(f"{provider_name}_calls", 0) or 0
|
||||
|
||||
# Count this operation as an LLM call
|
||||
projected_total_llm_calls = total_llm_calls + 1
|
||||
|
||||
if ai_text_gen_limit > 0 and projected_total_llm_calls > ai_text_gen_limit:
|
||||
error_info = {
|
||||
'current_calls': total_llm_calls,
|
||||
'limit': ai_text_gen_limit,
|
||||
'provider': display_provider_name,
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
return False, f"AI text generation call limit would be exceeded. Would use {projected_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls.", {
|
||||
'error_type': 'call_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
|
||||
# Check token limits for this provider
|
||||
# Use cumulative projected tokens from previous operations, or current from DB if first operation
|
||||
provider_tokens_key = f"{provider_name}_tokens"
|
||||
if provider_tokens_key in total_llm_tokens:
|
||||
# Use cumulative projected tokens from previous operations
|
||||
current_provider_tokens = total_llm_tokens[provider_tokens_key]
|
||||
logger.info(f" └─ Using cumulative projected tokens: {current_provider_tokens}")
|
||||
else:
|
||||
# First operation for this provider - get current from database
|
||||
current_provider_tokens = getattr(usage, provider_tokens_key, 0) or 0
|
||||
total_llm_tokens[provider_tokens_key] = current_provider_tokens
|
||||
logger.info(f" └─ Current tokens from DB: {current_provider_tokens}")
|
||||
|
||||
token_limit = limits.get(provider_tokens_key, 0) or 0
|
||||
|
||||
if token_limit > 0 and tokens_requested > 0:
|
||||
projected_tokens = current_provider_tokens + tokens_requested
|
||||
logger.info(f" └─ Token Check: {current_provider_tokens} (current) + {tokens_requested} (requested) = {projected_tokens} (total) / {token_limit} (limit)")
|
||||
|
||||
if projected_tokens > token_limit:
|
||||
usage_percentage = (projected_tokens / token_limit) * 100 if token_limit > 0 else 0
|
||||
error_info = {
|
||||
'current_tokens': current_provider_tokens,
|
||||
'requested_tokens': tokens_requested,
|
||||
'limit': token_limit,
|
||||
'provider': display_provider_name,
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
error_msg = (
|
||||
f"Token limit exceeded for {display_provider_name} "
|
||||
f"({operation_type}). "
|
||||
f"Current: {current_provider_tokens}/{token_limit}, "
|
||||
f"Requested: {tokens_requested}, "
|
||||
f"Would exceed by: {projected_tokens - token_limit} tokens "
|
||||
f"({usage_percentage:.1f}% of limit)"
|
||||
)
|
||||
logger.error(f"[Pre-flight Check] ❌ BLOCKED: {error_msg}")
|
||||
return False, error_msg, {
|
||||
'error_type': 'token_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
else:
|
||||
logger.info(f" └─ ✅ Token limit check passed: {projected_tokens} <= {token_limit}")
|
||||
|
||||
# Update cumulative counts for next operation
|
||||
total_llm_calls = projected_total_llm_calls
|
||||
total_llm_tokens[provider_tokens_key] += tokens_requested
|
||||
logger.info(f" └─ Updated cumulative tokens for {display_provider_name}: {total_llm_tokens[provider_tokens_key]}")
|
||||
|
||||
# Check image generation limits
|
||||
elif provider == APIProvider.STABILITY:
|
||||
image_limit = limits.get('stability_calls', 0) or 0
|
||||
projected_images = total_images + 1
|
||||
|
||||
if image_limit > 0 and projected_images > image_limit:
|
||||
error_info = {
|
||||
'current_images': total_images,
|
||||
'limit': image_limit,
|
||||
'provider': 'stability',
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
return False, f"Image generation limit would be exceeded. Would use {projected_images} of {image_limit} images this billing period.", {
|
||||
'error_type': 'image_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
|
||||
total_images = projected_images
|
||||
|
||||
# Check other provider-specific limits
|
||||
else:
|
||||
provider_calls_key = f"{provider_name}_calls"
|
||||
current_provider_calls = getattr(usage, provider_calls_key, 0) or 0
|
||||
call_limit = limits.get(provider_calls_key, 0) or 0
|
||||
|
||||
if call_limit > 0:
|
||||
projected_calls = current_provider_calls + 1
|
||||
if projected_calls > call_limit:
|
||||
error_info = {
|
||||
'current_calls': current_provider_calls,
|
||||
'limit': call_limit,
|
||||
'provider': display_provider_name,
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
return False, f"API call limit would be exceeded for {display_provider_name}. Would use {projected_calls} of {call_limit} calls this billing period.", {
|
||||
'error_type': 'call_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
|
||||
# All checks passed
|
||||
logger.info(f"[Pre-flight Check] ✅ All {len(operations)} operation(s) validated successfully")
|
||||
logger.info(f"[Pre-flight Check] ✅ User {user_id} is cleared to proceed with API calls")
|
||||
return True, None, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Check] Error during comprehensive limit check: {e}", exc_info=True)
|
||||
return False, f"Failed to validate limits: {str(e)}", {}
|
||||
|
||||
def get_pricing_for_provider_model(self, provider: APIProvider, model_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get pricing configuration for a specific provider and model."""
|
||||
pricing = self.db.query(APIProviderPricing).filter(
|
||||
APIProviderPricing.provider == provider,
|
||||
APIProviderPricing.model_name == model_name
|
||||
).first()
|
||||
|
||||
if not pricing:
|
||||
return None
|
||||
|
||||
return {
|
||||
'provider': pricing.provider.value,
|
||||
'model_name': pricing.model_name,
|
||||
|
||||
@@ -502,7 +502,7 @@ class UsageTrackingService:
|
||||
return result
|
||||
|
||||
async def reset_current_billing_period(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Reset usage status for the current billing period (after plan change)."""
|
||||
"""Reset usage status and counters for the current billing period (after plan renewal/change)."""
|
||||
try:
|
||||
billing_period = datetime.now().strftime("%Y-%m")
|
||||
summary = self.db.query(UsageSummary).filter(
|
||||
@@ -514,11 +514,52 @@ class UsageTrackingService:
|
||||
# Nothing to reset
|
||||
return {"reset": False, "reason": "no_summary"}
|
||||
|
||||
# Clear LIMIT_REACHED so the user can resume; keep counters intact
|
||||
# CRITICAL: Reset ALL usage counters to 0 so user gets fresh limits with new/renewed plan
|
||||
# Clear LIMIT_REACHED status
|
||||
summary.usage_status = UsageStatus.ACTIVE
|
||||
|
||||
# Reset all LLM provider call counters
|
||||
summary.gemini_calls = 0
|
||||
summary.openai_calls = 0
|
||||
summary.anthropic_calls = 0
|
||||
summary.mistral_calls = 0
|
||||
|
||||
# Reset all LLM provider token counters
|
||||
summary.gemini_tokens = 0
|
||||
summary.openai_tokens = 0
|
||||
summary.anthropic_tokens = 0
|
||||
summary.mistral_tokens = 0
|
||||
|
||||
# Reset search/research provider counters
|
||||
summary.tavily_calls = 0
|
||||
summary.serper_calls = 0
|
||||
summary.metaphor_calls = 0
|
||||
summary.firecrawl_calls = 0
|
||||
|
||||
# Reset image generation counters
|
||||
summary.stability_calls = 0
|
||||
|
||||
# Reset cost counters
|
||||
summary.gemini_cost = 0.0
|
||||
summary.openai_cost = 0.0
|
||||
summary.anthropic_cost = 0.0
|
||||
summary.mistral_cost = 0.0
|
||||
summary.tavily_cost = 0.0
|
||||
summary.serper_cost = 0.0
|
||||
summary.metaphor_cost = 0.0
|
||||
summary.firecrawl_cost = 0.0
|
||||
summary.stability_cost = 0.0
|
||||
|
||||
# Reset totals
|
||||
summary.total_calls = 0
|
||||
summary.total_tokens = 0
|
||||
summary.total_cost = 0.0
|
||||
|
||||
summary.updated_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
return {"reset": True}
|
||||
|
||||
logger.info(f"Reset usage counters for user {user_id} in billing period {billing_period} after renewal")
|
||||
return {"reset": True, "counters_reset": True}
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"Error resetting usage status: {e}")
|
||||
|
||||
@@ -58,19 +58,25 @@ const InitialRouteHandler: React.FC = () => {
|
||||
error: null,
|
||||
});
|
||||
|
||||
// Check subscription on mount
|
||||
// Check subscription on mount (non-blocking - don't wait for it to route)
|
||||
useEffect(() => {
|
||||
checkSubscription().catch((err) => {
|
||||
console.error('Error checking subscription:', err);
|
||||
|
||||
// Check if it's a connection error - handle it locally
|
||||
if (err instanceof Error && (err.name === 'NetworkError' || err.name === 'ConnectionError')) {
|
||||
setConnectionError({
|
||||
hasError: true,
|
||||
error: err,
|
||||
});
|
||||
}
|
||||
});
|
||||
// Delay subscription check slightly to allow auth token getter to be installed first
|
||||
const timeoutId = setTimeout(() => {
|
||||
checkSubscription().catch((err) => {
|
||||
console.error('Error checking subscription (non-blocking):', err);
|
||||
|
||||
// Check if it's a connection error - handle it locally
|
||||
if (err instanceof Error && (err.name === 'NetworkError' || err.name === 'ConnectionError')) {
|
||||
setConnectionError({
|
||||
hasError: true,
|
||||
error: err,
|
||||
});
|
||||
}
|
||||
// Don't block routing on subscription check errors - allow graceful degradation
|
||||
});
|
||||
}, 100); // Small delay to ensure TokenInstaller has run
|
||||
|
||||
return () => clearTimeout(timeoutId);
|
||||
}, []); // Remove checkSubscription dependency to prevent loop
|
||||
|
||||
// Initialize onboarding only after subscription is confirmed
|
||||
@@ -125,9 +131,10 @@ const InitialRouteHandler: React.FC = () => {
|
||||
);
|
||||
}
|
||||
|
||||
// Loading state - ensure we wait for onboarding init after subscription is confirmed
|
||||
const waitingForOnboardingInit = !!subscription && subscription.active && !subscriptionLoading && (loading || !data);
|
||||
if (subscriptionLoading || loading || waitingForOnboardingInit) {
|
||||
// Loading state - only wait for onboarding init, not subscription check
|
||||
// Subscription check is non-blocking and happens in background
|
||||
const waitingForOnboardingInit = loading || !data;
|
||||
if (loading || waitingForOnboardingInit) {
|
||||
return (
|
||||
<Box
|
||||
display="flex"
|
||||
@@ -167,29 +174,79 @@ const InitialRouteHandler: React.FC = () => {
|
||||
);
|
||||
}
|
||||
|
||||
if (!subscription) {
|
||||
return null; // Should not happen, but just in case
|
||||
// Decision tree for SIGNED-IN users:
|
||||
// Priority: Subscription → Onboarding → Dashboard (as per user flow: Landing → Subscription → Onboarding → Dashboard)
|
||||
|
||||
// 1. If subscription is still loading, show loading state
|
||||
if (subscriptionLoading) {
|
||||
return (
|
||||
<Box
|
||||
display="flex"
|
||||
flexDirection="column"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
minHeight="100vh"
|
||||
gap={2}
|
||||
>
|
||||
<CircularProgress size={60} />
|
||||
<Typography variant="h6" color="textSecondary">
|
||||
Checking subscription...
|
||||
</Typography>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
// Decision tree for SIGNED-IN users:
|
||||
// Priority: Subscription → Onboarding → Dashboard
|
||||
|
||||
// Check if user is new (no subscription record at all)
|
||||
// 2. No subscription data yet - handle gracefully
|
||||
// If onboarding is complete, allow access to dashboard (user already went through flow)
|
||||
// If onboarding not complete, check if subscription check is still loading or failed
|
||||
if (!subscription) {
|
||||
if (isOnboardingComplete) {
|
||||
console.log('InitialRouteHandler: Onboarding complete but no subscription data → Dashboard (allow access)');
|
||||
return <Navigate to="/dashboard" replace />;
|
||||
}
|
||||
|
||||
// Onboarding not complete and no subscription data
|
||||
// If subscription check is still loading, show loading state
|
||||
if (subscriptionLoading) {
|
||||
return (
|
||||
<Box
|
||||
display="flex"
|
||||
flexDirection="column"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
minHeight="100vh"
|
||||
gap={2}
|
||||
>
|
||||
<CircularProgress size={60} />
|
||||
<Typography variant="h6" color="textSecondary">
|
||||
Checking subscription...
|
||||
</Typography>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
// Subscription check completed but returned null/undefined
|
||||
// This likely means no subscription - redirect to pricing
|
||||
console.log('InitialRouteHandler: No subscription data after check → Pricing page');
|
||||
return <Navigate to="/pricing" replace />;
|
||||
}
|
||||
|
||||
// 3. Check subscription status first
|
||||
const isNewUser = !subscription || subscription.plan === 'none';
|
||||
|
||||
// 1. No active subscription? → Must subscribe first (even if onboarding is complete)
|
||||
// No active subscription → Must subscribe first
|
||||
if (isNewUser || !subscription.active) {
|
||||
console.log('InitialRouteHandler: No active subscription → Pricing page');
|
||||
return <Navigate to="/pricing" replace />;
|
||||
}
|
||||
|
||||
// 2. Has active subscription, check onboarding status
|
||||
// 4. Has active subscription, check onboarding status
|
||||
if (!isOnboardingComplete) {
|
||||
console.log('InitialRouteHandler: Subscription active but onboarding incomplete → Onboarding');
|
||||
return <Navigate to="/onboarding" replace />;
|
||||
}
|
||||
|
||||
// 3. Has subscription AND completed onboarding → Dashboard
|
||||
// 5. Has subscription AND completed onboarding → Dashboard
|
||||
console.log('InitialRouteHandler: All set (subscription + onboarding) → Dashboard');
|
||||
return <Navigate to="/dashboard" replace />;
|
||||
};
|
||||
|
||||
@@ -7,6 +7,24 @@ export const setGlobalSubscriptionErrorHandler = (handler: (error: any) => boole
|
||||
globalSubscriptionErrorHandler = handler;
|
||||
};
|
||||
|
||||
// Export a function to trigger subscription error handler from outside axios interceptors
|
||||
export const triggerSubscriptionError = (error: any) => {
|
||||
const status = error?.response?.status;
|
||||
console.log('triggerSubscriptionError: Received error', {
|
||||
hasHandler: !!globalSubscriptionErrorHandler,
|
||||
status,
|
||||
dataKeys: error?.response?.data ? Object.keys(error.response.data) : null
|
||||
});
|
||||
|
||||
if (globalSubscriptionErrorHandler) {
|
||||
console.log('triggerSubscriptionError: Calling global subscription error handler');
|
||||
return globalSubscriptionErrorHandler(error);
|
||||
}
|
||||
|
||||
console.warn('triggerSubscriptionError: No global subscription error handler registered');
|
||||
return false;
|
||||
};
|
||||
|
||||
// Optional token getter installed from within the app after Clerk is available
|
||||
let authTokenGetter: (() => Promise<string | null>) | null = null;
|
||||
|
||||
@@ -64,13 +82,27 @@ apiClient.interceptors.request.use(
|
||||
async (config) => {
|
||||
console.log(`Making ${config.method?.toUpperCase()} request to ${config.url}`);
|
||||
try {
|
||||
const token = authTokenGetter ? await authTokenGetter() : null;
|
||||
if (!authTokenGetter) {
|
||||
console.warn(`[apiClient] ⚠️ authTokenGetter not set for ${config.url} - request may fail authentication`);
|
||||
console.warn(`[apiClient] This usually means TokenInstaller hasn't run yet. Request will likely fail with 401.`);
|
||||
} else {
|
||||
try {
|
||||
const token = await authTokenGetter();
|
||||
if (token) {
|
||||
config.headers = config.headers || {};
|
||||
(config.headers as any)['Authorization'] = `Bearer ${token}`;
|
||||
console.log(`[apiClient] ✅ Added auth token to request: ${config.url}`);
|
||||
} else {
|
||||
console.warn(`[apiClient] ⚠️ authTokenGetter returned null for ${config.url} - user may not be signed in`);
|
||||
console.warn(`[apiClient] User ID from localStorage: ${localStorage.getItem('user_id') || 'none'}`);
|
||||
}
|
||||
} catch (tokenError) {
|
||||
console.error(`[apiClient] ❌ Error getting auth token for ${config.url}:`, tokenError);
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
// non-fatal
|
||||
console.error(`[apiClient] ❌ Unexpected error in request interceptor for ${config.url}:`, e);
|
||||
// non-fatal - let the request proceed, backend will return 401 if needed
|
||||
}
|
||||
return config;
|
||||
},
|
||||
@@ -138,13 +170,17 @@ apiClient.interceptors.response.use(
|
||||
console.error('Token refresh failed:', retryError);
|
||||
}
|
||||
|
||||
// If retry failed and not in onboarding, redirect
|
||||
const isOnboardingRoute = window.location.pathname.includes('/onboarding') ||
|
||||
window.location.pathname === '/';
|
||||
if (!isOnboardingRoute) {
|
||||
// If retry failed, don't redirect during app initialization (root route)
|
||||
// Only redirect if we're on a protected route and definitely authenticated
|
||||
const isOnboardingRoute = window.location.pathname.includes('/onboarding');
|
||||
const isRootRoute = window.location.pathname === '/';
|
||||
|
||||
// Don't redirect from root route during app initialization - allow InitialRouteHandler to work
|
||||
if (!isRootRoute && !isOnboardingRoute) {
|
||||
// Only redirect if we're definitely not just initializing
|
||||
try { window.location.assign('/'); } catch {}
|
||||
} else {
|
||||
console.warn('401 Unauthorized - token refresh failed');
|
||||
console.warn('401 Unauthorized - token refresh failed (during initialization, not redirecting)');
|
||||
}
|
||||
}
|
||||
|
||||
@@ -204,12 +240,14 @@ aiApiClient.interceptors.response.use(
|
||||
console.error('Token refresh failed:', retryError);
|
||||
}
|
||||
|
||||
const isOnboardingRoute = window.location.pathname.includes('/onboarding') ||
|
||||
window.location.pathname === '/';
|
||||
if (!isOnboardingRoute) {
|
||||
const isOnboardingRoute = window.location.pathname.includes('/onboarding');
|
||||
const isRootRoute = window.location.pathname === '/';
|
||||
|
||||
// Don't redirect from root route during app initialization
|
||||
if (!isRootRoute && !isOnboardingRoute) {
|
||||
try { window.location.assign('/'); } catch {}
|
||||
} else {
|
||||
console.warn('401 Unauthorized - token refresh failed');
|
||||
console.warn('401 Unauthorized - token refresh failed (during initialization, not redirecting)');
|
||||
}
|
||||
}
|
||||
|
||||
@@ -254,13 +292,15 @@ longRunningApiClient.interceptors.response.use(
|
||||
},
|
||||
(error) => {
|
||||
if (error?.response?.status === 401) {
|
||||
// Only redirect on 401 if we're not in onboarding flow
|
||||
const isOnboardingRoute = window.location.pathname.includes('/onboarding') ||
|
||||
window.location.pathname === '/';
|
||||
if (!isOnboardingRoute) {
|
||||
// Only redirect on 401 if we're not in onboarding flow or root route
|
||||
const isOnboardingRoute = window.location.pathname.includes('/onboarding');
|
||||
const isRootRoute = window.location.pathname === '/';
|
||||
|
||||
// Don't redirect from root route during app initialization
|
||||
if (!isRootRoute && !isOnboardingRoute) {
|
||||
try { window.location.assign('/'); } catch {}
|
||||
} else {
|
||||
console.warn('401 Unauthorized during onboarding - token may need refresh');
|
||||
console.warn('401 Unauthorized during initialization - token may need refresh (not redirecting)');
|
||||
}
|
||||
}
|
||||
// Check if it's a subscription-related error and handle it globally
|
||||
@@ -304,13 +344,15 @@ pollingApiClient.interceptors.response.use(
|
||||
},
|
||||
(error) => {
|
||||
if (error?.response?.status === 401) {
|
||||
// Only redirect on 401 if we're not in onboarding flow
|
||||
const isOnboardingRoute = window.location.pathname.includes('/onboarding') ||
|
||||
window.location.pathname === '/';
|
||||
if (!isOnboardingRoute) {
|
||||
// Only redirect on 401 if we're not in onboarding flow or root route
|
||||
const isOnboardingRoute = window.location.pathname.includes('/onboarding');
|
||||
const isRootRoute = window.location.pathname === '/';
|
||||
|
||||
// Don't redirect from root route during app initialization
|
||||
if (!isRootRoute && !isOnboardingRoute) {
|
||||
try { window.location.assign('/'); } catch {}
|
||||
} else {
|
||||
console.warn('401 Unauthorized during onboarding - token may need refresh');
|
||||
console.warn('401 Unauthorized during initialization - token may need refresh (not redirecting)');
|
||||
}
|
||||
}
|
||||
// Check if it's a subscription-related error and handle it globally
|
||||
|
||||
@@ -66,6 +66,7 @@ export const BlogWriter: React.FC = () => {
|
||||
contentConfirmed,
|
||||
flowAnalysisCompleted,
|
||||
flowAnalysisResults,
|
||||
sectionImages,
|
||||
setOutline,
|
||||
setTitleOptions,
|
||||
setSelectedTitle,
|
||||
@@ -78,6 +79,7 @@ export const BlogWriter: React.FC = () => {
|
||||
setContentConfirmed,
|
||||
setFlowAnalysisCompleted,
|
||||
setFlowAnalysisResults,
|
||||
setSectionImages,
|
||||
handleResearchComplete,
|
||||
handleOutlineComplete,
|
||||
handleOutlineError,
|
||||
@@ -670,6 +672,8 @@ export const BlogWriter: React.FC = () => {
|
||||
flowAnalysisResults={flowAnalysisResults}
|
||||
outlineGenRef={outlineGenRef}
|
||||
blogWriterApi={blogWriterApi}
|
||||
sectionImages={sectionImages}
|
||||
setSectionImages={setSectionImages}
|
||||
contentConfirmed={contentConfirmed}
|
||||
seoAnalysis={seoAnalysis}
|
||||
seoMetadata={seoMetadata}
|
||||
|
||||
@@ -31,6 +31,8 @@ interface PhaseContentProps {
|
||||
seoMetadata: any;
|
||||
onTitleSelect: any;
|
||||
onCustomTitle: any;
|
||||
sectionImages?: Record<string, string>;
|
||||
setSectionImages?: (images: Record<string, string> | ((prev: Record<string, string>) => Record<string, string>)) => void;
|
||||
}
|
||||
|
||||
export const PhaseContent: React.FC<PhaseContentProps> = ({
|
||||
@@ -58,7 +60,9 @@ export const PhaseContent: React.FC<PhaseContentProps> = ({
|
||||
seoAnalysis,
|
||||
seoMetadata,
|
||||
onTitleSelect,
|
||||
onCustomTitle
|
||||
onCustomTitle,
|
||||
sectionImages,
|
||||
setSectionImages
|
||||
}) => {
|
||||
return (
|
||||
<div style={{ display: 'flex', flex: 1, overflow: 'hidden' }}>
|
||||
@@ -100,6 +104,8 @@ export const PhaseContent: React.FC<PhaseContentProps> = ({
|
||||
optimizationResults={optimizationResults}
|
||||
researchCoverage={researchCoverage}
|
||||
onRefine={(op: any, id: any, payload: any) => blogWriterApi.refineOutline({ outline, operation: op, section_id: id, payload }).then((res: any) => setOutline(res.outline))}
|
||||
sectionImages={sectionImages}
|
||||
setSectionImages={setSectionImages}
|
||||
/>
|
||||
</>
|
||||
) : (
|
||||
@@ -126,6 +132,7 @@ export const PhaseContent: React.FC<PhaseContentProps> = ({
|
||||
onSave={handleContentSave}
|
||||
continuityRefresh={continuityRefresh || undefined}
|
||||
flowAnalysisResults={flowAnalysisResults}
|
||||
sectionImages={sectionImages}
|
||||
/>
|
||||
) : (
|
||||
<div style={{ padding: '20px', textAlign: 'center' }}>
|
||||
@@ -151,6 +158,7 @@ export const PhaseContent: React.FC<PhaseContentProps> = ({
|
||||
onSave={handleContentSave}
|
||||
continuityRefresh={continuityRefresh || undefined}
|
||||
flowAnalysisResults={flowAnalysisResults}
|
||||
sectionImages={sectionImages}
|
||||
/>
|
||||
) : (
|
||||
<div style={{ padding: '20px', textAlign: 'center' }}>
|
||||
|
||||
@@ -0,0 +1,168 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import {
|
||||
Dialog,
|
||||
DialogTitle,
|
||||
DialogContent,
|
||||
DialogActions,
|
||||
Button,
|
||||
Typography,
|
||||
Box,
|
||||
CircularProgress,
|
||||
Alert
|
||||
} from '@mui/material';
|
||||
import { usePlatformConnections } from '../../../components/OnboardingWizard/common/usePlatformConnections';
|
||||
|
||||
interface WixConnectModalProps {
|
||||
isOpen: boolean;
|
||||
onClose: () => void;
|
||||
onConnectionSuccess?: () => void;
|
||||
}
|
||||
|
||||
export const WixConnectModal: React.FC<WixConnectModalProps> = ({
|
||||
isOpen,
|
||||
onClose,
|
||||
onConnectionSuccess
|
||||
}) => {
|
||||
const { handleConnect, isLoading } = usePlatformConnections();
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [isConnecting, setIsConnecting] = useState(false);
|
||||
|
||||
// Handle OAuth success via postMessage (same pattern as onboarding)
|
||||
useEffect(() => {
|
||||
if (!isOpen) return;
|
||||
|
||||
const handler = (event: MessageEvent) => {
|
||||
const trusted = [window.location.origin, 'https://littery-sonny-unscrutinisingly.ngrok-free.dev'];
|
||||
if (!trusted.includes(event.origin)) return;
|
||||
if (!event.data || typeof event.data !== 'object') return;
|
||||
|
||||
if (event.data.type === 'WIX_OAUTH_SUCCESS') {
|
||||
console.log('Wix OAuth success in modal');
|
||||
setIsConnecting(false);
|
||||
setError(null);
|
||||
// Close modal and notify parent
|
||||
if (onConnectionSuccess) {
|
||||
onConnectionSuccess();
|
||||
}
|
||||
onClose();
|
||||
}
|
||||
|
||||
if (event.data.type === 'WIX_OAUTH_ERROR') {
|
||||
console.error('Wix OAuth error in modal:', event.data.error);
|
||||
setIsConnecting(false);
|
||||
setError(event.data.error || 'Wix connection failed. Please try again.');
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener('message', handler);
|
||||
return () => window.removeEventListener('message', handler);
|
||||
}, [isOpen, onClose, onConnectionSuccess]);
|
||||
|
||||
// Also check for URL param (fallback for same-tab redirect)
|
||||
useEffect(() => {
|
||||
if (!isOpen) return;
|
||||
|
||||
const params = new URLSearchParams(window.location.search);
|
||||
if (params.get('wix_connected') === 'true') {
|
||||
console.log('Wix connected via URL param in modal');
|
||||
setIsConnecting(false);
|
||||
setError(null);
|
||||
if (onConnectionSuccess) {
|
||||
onConnectionSuccess();
|
||||
}
|
||||
onClose();
|
||||
// Clean URL
|
||||
const clean = window.location.pathname + window.location.hash;
|
||||
window.history.replaceState({}, document.title, clean || '/');
|
||||
}
|
||||
}, [isOpen, onClose, onConnectionSuccess]);
|
||||
|
||||
const handleConnectClick = async () => {
|
||||
try {
|
||||
setIsConnecting(true);
|
||||
setError(null);
|
||||
await handleConnect('wix');
|
||||
// OAuth will redirect, so we don't need to do anything else here
|
||||
// The postMessage handler or URL param handler will close the modal
|
||||
} catch (err: any) {
|
||||
console.error('Error connecting to Wix:', err);
|
||||
setIsConnecting(false);
|
||||
setError(err?.message || 'Failed to start Wix connection. Please try again.');
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
open={isOpen}
|
||||
onClose={onClose}
|
||||
maxWidth="sm"
|
||||
fullWidth
|
||||
PaperProps={{
|
||||
sx: {
|
||||
borderRadius: 2,
|
||||
boxShadow: '0 4px 20px rgba(0,0,0,0.15)'
|
||||
}
|
||||
}}
|
||||
>
|
||||
<DialogTitle sx={{ pb: 1 }}>
|
||||
<Typography variant="h6" sx={{ fontWeight: 600, color: '#1e293b' }}>
|
||||
Connect Your Wix Account
|
||||
</Typography>
|
||||
</DialogTitle>
|
||||
|
||||
<DialogContent>
|
||||
<Box sx={{ py: 1 }}>
|
||||
<Typography variant="body2" color="text.secondary" paragraph>
|
||||
Connect your Wix account to publish blog posts directly to your website.
|
||||
</Typography>
|
||||
|
||||
{error && (
|
||||
<Alert severity="error" sx={{ mb: 2 }}>
|
||||
{error}
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
{isConnecting && (
|
||||
<Box sx={{ display: 'flex', alignItems: 'center', gap: 2, py: 2 }}>
|
||||
<CircularProgress size={20} />
|
||||
<Typography variant="body2" color="text.secondary">
|
||||
Opening Wix authorization page...
|
||||
</Typography>
|
||||
</Box>
|
||||
)}
|
||||
|
||||
<Box sx={{ mt: 2, p: 2, bgcolor: '#f8fafc', borderRadius: 1 }}>
|
||||
<Typography variant="caption" color="text.secondary">
|
||||
<strong>What happens next:</strong>
|
||||
</Typography>
|
||||
<Typography variant="caption" component="div" color="text.secondary" sx={{ mt: 1, display: 'block' }}>
|
||||
<ol style={{ margin: '8px 0 0 20px', padding: 0 }}>
|
||||
<li>You'll be redirected to Wix to authorize ALwrity</li>
|
||||
<li>Grant permissions for blog creation and publishing</li>
|
||||
<li>You'll be redirected back to ALwrity</li>
|
||||
<li>Your blog post will be published automatically</li>
|
||||
</ol>
|
||||
</Typography>
|
||||
</Box>
|
||||
</Box>
|
||||
</DialogContent>
|
||||
|
||||
<DialogActions sx={{ px: 3, pb: 2 }}>
|
||||
<Button onClick={onClose} disabled={isConnecting}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="contained"
|
||||
onClick={handleConnectClick}
|
||||
disabled={isConnecting || isLoading}
|
||||
startIcon={isConnecting ? <CircularProgress size={16} /> : undefined}
|
||||
>
|
||||
{isConnecting ? 'Connecting...' : 'Connect to Wix'}
|
||||
</Button>
|
||||
</DialogActions>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
|
||||
export default WixConnectModal;
|
||||
|
||||
@@ -12,6 +12,8 @@ interface Props {
|
||||
groundingInsights?: GroundingInsights | null;
|
||||
optimizationResults?: OptimizationResults | null;
|
||||
researchCoverage?: ResearchCoverage | null;
|
||||
sectionImages?: Record<string, string>;
|
||||
setSectionImages?: (images: Record<string, string> | ((prev: Record<string, string>) => Record<string, string>)) => void;
|
||||
}
|
||||
|
||||
const EnhancedOutlineEditor: React.FC<Props> = ({
|
||||
@@ -21,14 +23,15 @@ const EnhancedOutlineEditor: React.FC<Props> = ({
|
||||
sourceMappingStats,
|
||||
groundingInsights,
|
||||
optimizationResults,
|
||||
researchCoverage
|
||||
researchCoverage,
|
||||
sectionImages = {},
|
||||
setSectionImages
|
||||
}) => {
|
||||
const [editingSection, setEditingSection] = useState<string | null>(null);
|
||||
const [expandedSections, setExpandedSections] = useState<Set<string>>(new Set());
|
||||
const [hoveredSection, setHoveredSection] = useState<string | null>(null);
|
||||
const [showAddSection, setShowAddSection] = useState(false);
|
||||
const [imageModalState, setImageModalState] = useState<{ open: boolean; sectionId?: string }>(() => ({ open: false }));
|
||||
const [sectionImages, setSectionImages] = useState<Record<string, string>>({});
|
||||
const [newSectionData, setNewSectionData] = useState({
|
||||
heading: '',
|
||||
subheadings: '',
|
||||
@@ -117,8 +120,8 @@ const EnhancedOutlineEditor: React.FC<Props> = ({
|
||||
};
|
||||
})()}
|
||||
onImageGenerated={(imageBase64, sectionId) => {
|
||||
if (sectionId) {
|
||||
setSectionImages(prev => ({ ...prev, [sectionId]: imageBase64 }));
|
||||
if (sectionId && setSectionImages) {
|
||||
setSectionImages((prev: Record<string, string>) => ({ ...prev, [sectionId]: imageBase64 }));
|
||||
}
|
||||
}}
|
||||
/>
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { useCopilotAction } from '@copilotkit/react-core';
|
||||
import { blogWriterApi, BlogSEOMetadataResponse } from '../../services/blogWriterApi';
|
||||
import { BlogSEOMetadataResponse } from '../../services/blogWriterApi';
|
||||
import { apiClient } from '../../api/client';
|
||||
import { wordpressAPI, WordPressSite, WordPressPublishRequest } from '../../api/wordpress';
|
||||
import { validateAndRefreshWixTokens } from '../../utils/wixTokenUtils';
|
||||
import WixConnectModal from './BlogWriterUtils/WixConnectModal';
|
||||
|
||||
interface PublisherProps {
|
||||
buildFullMarkdown: () => string;
|
||||
@@ -26,10 +29,15 @@ export const Publisher: React.FC<PublisherProps> = ({
|
||||
}) => {
|
||||
const [wixConnectionStatus, setWixConnectionStatus] = useState<WixConnectionStatus | null>(null);
|
||||
const [checkingWixStatus, setCheckingWixStatus] = useState(false);
|
||||
const [wordpressSites, setWordpressSites] = useState<WordPressSite[]>([]);
|
||||
const [checkingWordPressStatus, setCheckingWordPressStatus] = useState(false);
|
||||
const [showWixConnectModal, setShowWixConnectModal] = useState(false);
|
||||
const [pendingWixPublish, setPendingWixPublish] = useState<(() => Promise<any>) | null>(null);
|
||||
|
||||
// Check Wix connection status on component mount
|
||||
// Check platform connection statuses on component mount
|
||||
useEffect(() => {
|
||||
checkWixConnectionStatus();
|
||||
checkWordPressConnectionStatus();
|
||||
}, []);
|
||||
|
||||
const checkWixConnectionStatus = async () => {
|
||||
@@ -48,6 +56,137 @@ export const Publisher: React.FC<PublisherProps> = ({
|
||||
setCheckingWixStatus(false);
|
||||
}
|
||||
};
|
||||
|
||||
const checkWordPressConnectionStatus = async () => {
|
||||
setCheckingWordPressStatus(true);
|
||||
try {
|
||||
const status = await wordpressAPI.getStatus();
|
||||
setWordpressSites(status.sites || []);
|
||||
} catch (error) {
|
||||
console.error('Failed to check WordPress connection status:', error);
|
||||
setWordpressSites([]);
|
||||
} finally {
|
||||
setCheckingWordPressStatus(false);
|
||||
}
|
||||
};
|
||||
|
||||
// Helper function to publish to Wix
|
||||
const publishToWix = async (md: string, metadata: BlogSEOMetadataResponse | null, accessToken?: string): Promise<any> => {
|
||||
// Get access token if not provided
|
||||
if (!accessToken) {
|
||||
const tokenResult = await validateAndRefreshWixTokens();
|
||||
if (!tokenResult.accessToken) {
|
||||
return {
|
||||
success: false,
|
||||
message: 'Wix tokens not available. Please connect your Wix account.',
|
||||
action_required: 'connect_wix'
|
||||
};
|
||||
}
|
||||
accessToken = tokenResult.accessToken;
|
||||
}
|
||||
|
||||
// Extract title from SEO metadata or markdown
|
||||
const title = metadata?.seo_title || (() => {
|
||||
const titleMatch = md.match(/^#\s+(.+)$/m);
|
||||
return titleMatch ? titleMatch[1] : 'Blog Post from ALwrity';
|
||||
})();
|
||||
|
||||
// Extract cover image URL, skip if base64 (Wix needs HTTP URL)
|
||||
let coverImageUrl: string | undefined = undefined;
|
||||
if (metadata?.open_graph?.image) {
|
||||
const imageUrl = metadata.open_graph.image;
|
||||
// Skip base64 images - Wix import_image needs HTTP/HTTPS URL
|
||||
if (typeof imageUrl === 'string' && (imageUrl.startsWith('http://') || imageUrl.startsWith('https://'))) {
|
||||
coverImageUrl = imageUrl;
|
||||
} else {
|
||||
console.warn('Skipping cover image - Wix requires HTTP/HTTPS URL, received:', imageUrl?.substring(0, 50));
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
// Publish using same endpoint as WixTestPage
|
||||
// Note: Wix requires category/tag IDs (UUIDs), not names
|
||||
// For now, skip categories/tags until we implement ID lookup/creation
|
||||
const response = await apiClient.post('/api/wix/test/publish/real', {
|
||||
title: title,
|
||||
content: md, // Use markdown, backend converts it
|
||||
cover_image_url: coverImageUrl,
|
||||
// TODO: Lookup/create category IDs from metadata?.blog_categories
|
||||
// TODO: Lookup/create tag IDs from metadata?.blog_tags
|
||||
category_ids: undefined,
|
||||
tag_ids: undefined,
|
||||
publish: true,
|
||||
access_token: accessToken,
|
||||
member_id: undefined // Let backend derive from token
|
||||
});
|
||||
|
||||
if (response.data.success) {
|
||||
return {
|
||||
success: true,
|
||||
url: response.data.url,
|
||||
post_id: response.data.post_id,
|
||||
message: 'Blog post published successfully to Wix!'
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
success: false,
|
||||
message: response.data.error || 'Failed to publish to Wix'
|
||||
};
|
||||
}
|
||||
} catch (error: any) {
|
||||
// If auth error, token may be invalid - try refreshing or reconnect
|
||||
if (error.response?.status === 401 || error.response?.status === 403) {
|
||||
// Try to refresh one more time
|
||||
const tokenResult = await validateAndRefreshWixTokens();
|
||||
if (tokenResult.needsReconnect) {
|
||||
const publishFunction = async () => {
|
||||
return await publishToWix(md, metadata);
|
||||
};
|
||||
setPendingWixPublish(() => publishFunction);
|
||||
setShowWixConnectModal(true);
|
||||
return {
|
||||
success: false,
|
||||
message: 'Wix tokens expired. Please reconnect your Wix account.',
|
||||
action_required: 'reconnect_wix'
|
||||
};
|
||||
}
|
||||
// If refresh worked, retry once
|
||||
if (tokenResult.accessToken) {
|
||||
return await publishToWix(md, metadata, tokenResult.accessToken);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
success: false,
|
||||
message: `Failed to publish to Wix: ${error.response?.data?.detail || error.message}`
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
// Handle Wix connection success - retry publish
|
||||
const handleWixConnectionSuccess = async () => {
|
||||
if (pendingWixPublish) {
|
||||
const publishFn = pendingWixPublish;
|
||||
setPendingWixPublish(null);
|
||||
// Small delay to ensure tokens are saved in sessionStorage
|
||||
setTimeout(async () => {
|
||||
try {
|
||||
// Retry the publish - this will be executed and return result
|
||||
// Note: The result won't show in CopilotKit UI since we're outside the action handler
|
||||
// But the publish will succeed and user will see their blog on Wix
|
||||
const result = await publishFn();
|
||||
console.log('Wix publish after connection:', result);
|
||||
// Optionally show a success notification
|
||||
if (result.success) {
|
||||
// Publish succeeded - user's blog is now on Wix
|
||||
console.log('Blog published to Wix successfully after connection');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error retrying publish after connection:', error);
|
||||
}
|
||||
}, 500);
|
||||
}
|
||||
};
|
||||
// Enhanced publish action with Wix support
|
||||
useCopilotActionTyped({
|
||||
name: 'publishToPlatform',
|
||||
@@ -61,58 +200,101 @@ export const Publisher: React.FC<PublisherProps> = ({
|
||||
const html = convertMarkdownToHTML(md);
|
||||
|
||||
if (platform === 'wix') {
|
||||
// Check Wix connection status first
|
||||
if (!wixConnectionStatus?.connected) {
|
||||
return {
|
||||
success: false,
|
||||
message: 'Wix account not connected. Please connect your Wix account first using the Wix Test Page.',
|
||||
// Proactively validate and refresh tokens
|
||||
const tokenResult = await validateAndRefreshWixTokens();
|
||||
|
||||
if (tokenResult.needsReconnect || !tokenResult.accessToken) {
|
||||
// Store the publish function to retry after connection
|
||||
const publishFunction = async () => {
|
||||
return await publishToWix(md, seoMetadata);
|
||||
};
|
||||
setPendingWixPublish(() => publishFunction);
|
||||
setShowWixConnectModal(true);
|
||||
return {
|
||||
success: false,
|
||||
message: 'Wix account not connected. Please connect your Wix account to publish.',
|
||||
action_required: 'connect_wix'
|
||||
};
|
||||
}
|
||||
|
||||
if (!wixConnectionStatus?.has_permissions) {
|
||||
|
||||
// We have a valid access token, proceed with publishing
|
||||
return await publishToWix(md, seoMetadata, tokenResult.accessToken);
|
||||
} else if (platform === 'wordpress') {
|
||||
// WordPress publishing
|
||||
if (!seoMetadata) {
|
||||
return {
|
||||
success: false,
|
||||
message: 'Insufficient Wix permissions. Please reconnect your Wix account.',
|
||||
action_required: 'reconnect_wix'
|
||||
message: 'Generate SEO metadata first. Use the "Next: Generate SEO Metadata" suggestion to create metadata before publishing.'
|
||||
};
|
||||
}
|
||||
|
||||
// Extract title from markdown (first heading or use default)
|
||||
const titleMatch = md.match(/^#\s+(.+)$/m);
|
||||
const title = titleMatch ? titleMatch[1] : 'Blog Post from ALwrity';
|
||||
|
||||
|
||||
// Check if user has connected WordPress sites
|
||||
if (wordpressSites.length === 0) {
|
||||
return {
|
||||
success: false,
|
||||
message: 'No WordPress sites connected. Please connect a WordPress site first. Go to Settings > Integrations to add your WordPress site.',
|
||||
action_required: 'connect_wordpress'
|
||||
};
|
||||
}
|
||||
|
||||
// Find first active site, or use first site if none are active
|
||||
const activeSite = wordpressSites.find(site => site.is_active) || wordpressSites[0];
|
||||
if (!activeSite) {
|
||||
return {
|
||||
success: false,
|
||||
message: 'No active WordPress sites found. Please activate a WordPress site connection.',
|
||||
action_required: 'activate_wordpress'
|
||||
};
|
||||
}
|
||||
|
||||
// Extract title from SEO metadata or markdown
|
||||
const title = seoMetadata.seo_title || (() => {
|
||||
const titleMatch = md.match(/^#\s+(.+)$/m);
|
||||
return titleMatch ? titleMatch[1] : 'Blog Post from ALwrity';
|
||||
})();
|
||||
|
||||
// Extract excerpt from SEO metadata
|
||||
const excerpt = seoMetadata.meta_description || '';
|
||||
|
||||
// Build WordPress publish request
|
||||
const publishRequest: WordPressPublishRequest = {
|
||||
site_id: activeSite.id,
|
||||
title: title,
|
||||
content: html,
|
||||
excerpt: excerpt,
|
||||
status: 'publish',
|
||||
meta_description: seoMetadata.meta_description || excerpt,
|
||||
tags: seoMetadata.blog_tags || [],
|
||||
categories: seoMetadata.blog_categories || []
|
||||
};
|
||||
|
||||
try {
|
||||
const response = await apiClient.post('/api/wix/publish', {
|
||||
title: title,
|
||||
content: md,
|
||||
publish: true
|
||||
});
|
||||
const result = await wordpressAPI.publishContent(publishRequest);
|
||||
|
||||
if (response.data.success) {
|
||||
return {
|
||||
success: true,
|
||||
url: response.data.url,
|
||||
post_id: response.data.post_id,
|
||||
message: 'Blog post published successfully to Wix!'
|
||||
if (result.success) {
|
||||
return {
|
||||
success: true,
|
||||
url: result.post_url || `${activeSite.site_url}/?p=${result.post_id}`,
|
||||
post_id: result.post_id,
|
||||
message: `Blog post published successfully to WordPress site "${activeSite.site_name}"!`
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
success: false,
|
||||
message: response.data.error || 'Failed to publish to Wix'
|
||||
return {
|
||||
success: false,
|
||||
message: result.error || 'Failed to publish to WordPress'
|
||||
};
|
||||
}
|
||||
} catch (error: any) {
|
||||
return {
|
||||
success: false,
|
||||
message: `Failed to publish to Wix: ${error.response?.data?.detail || error.message}`
|
||||
return {
|
||||
success: false,
|
||||
message: `Failed to publish to WordPress: ${error.response?.data?.detail || error.message || 'Unknown error'}`
|
||||
};
|
||||
}
|
||||
} else {
|
||||
// WordPress or other platforms
|
||||
if (!seoMetadata) return { success: false, message: 'Generate SEO metadata first' };
|
||||
const res = await blogWriterApi.publish({ platform, html, metadata: seoMetadata, schedule_time });
|
||||
return { success: true, url: res.url };
|
||||
return {
|
||||
success: false,
|
||||
message: `Unsupported platform: ${platform}. Supported platforms are 'wix' and 'wordpress'.`
|
||||
};
|
||||
}
|
||||
},
|
||||
render: ({ status, result }: any) => {
|
||||
@@ -153,6 +335,13 @@ export const Publisher: React.FC<PublisherProps> = ({
|
||||
</a>
|
||||
</div>
|
||||
)}
|
||||
{(result?.action_required === 'connect_wordpress' || result?.action_required === 'activate_wordpress') && (
|
||||
<div style={{ marginTop: 8 }}>
|
||||
<a href="/settings/integrations" target="_blank" rel="noopener noreferrer">
|
||||
Manage WordPress Connections
|
||||
</a>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -161,7 +350,18 @@ export const Publisher: React.FC<PublisherProps> = ({
|
||||
}
|
||||
});
|
||||
|
||||
return null; // This component only provides the copilot action
|
||||
return (
|
||||
<>
|
||||
<WixConnectModal
|
||||
isOpen={showWixConnectModal}
|
||||
onClose={() => {
|
||||
setShowWixConnectModal(false);
|
||||
setPendingWixPublish(null);
|
||||
}}
|
||||
onConnectionSuccess={handleWixConnectionSuccess}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default Publisher;
|
||||
|
||||
@@ -145,11 +145,7 @@ export const useSuggestions = ({
|
||||
priority: 'high'
|
||||
});
|
||||
items.push({
|
||||
title: 'Content Analysis',
|
||||
message: 'Analyze the flow and quality of my blog content to get improvement suggestions'
|
||||
});
|
||||
items.push({
|
||||
title: 'Content Analysis',
|
||||
title: '📊 Content Analysis',
|
||||
message: 'Analyze the flow and quality of my blog content to get improvement suggestions'
|
||||
});
|
||||
} else if (seoAnalysis && !seoRecommendationsApplied) {
|
||||
@@ -160,7 +156,7 @@ export const useSuggestions = ({
|
||||
priority: 'high'
|
||||
});
|
||||
items.push({
|
||||
title: 'Content Analysis',
|
||||
title: '📊 Content Analysis',
|
||||
message: 'Run analyzeContentQuality to review narrative flow and get final improvement suggestions before publishing.'
|
||||
});
|
||||
items.push({
|
||||
@@ -175,33 +171,21 @@ export const useSuggestions = ({
|
||||
message: 'SEO recommendations are applied. Execute generateSEOMetadata immediately so we can prepare titles, descriptions, and schema without further prompts.',
|
||||
priority: 'high'
|
||||
});
|
||||
} else {
|
||||
items.push({
|
||||
title: 'Next: Publish',
|
||||
message: 'The blog is SEO-optimized. Use publishToPlatform with your preferred destination (wix|wordpress) right away—no additional confirmation needed.',
|
||||
priority: 'high'
|
||||
title: '📊 Content Analysis',
|
||||
message: 'Run analyzeContentQuality to validate flow, consistency, and progression before publishing.'
|
||||
});
|
||||
}
|
||||
|
||||
items.push({
|
||||
title: 'Content Analysis',
|
||||
message: 'Run analyzeContentQuality to validate flow, consistency, and progression before publishing.'
|
||||
});
|
||||
items.push({
|
||||
title: 'Publish',
|
||||
message: seoMetadata
|
||||
? 'Publish my blog to your preferred platform using publishToPlatform.'
|
||||
: 'Generate SEO metadata first, then publish your blog.'
|
||||
});
|
||||
|
||||
if (seoMetadata) {
|
||||
} else {
|
||||
// SEO metadata is ready - show publishing options
|
||||
items.push({
|
||||
title: '🚀 Publish to Wix',
|
||||
message: 'Publish my blog to Wix using publishToPlatform with platform "wix".'
|
||||
message: 'Publish my blog to Wix using publishToPlatform with platform "wix".',
|
||||
priority: 'high'
|
||||
});
|
||||
items.push({
|
||||
title: '🌐 Publish to WordPress',
|
||||
message: 'Publish my blog to WordPress using publishToPlatform with platform "wordpress".'
|
||||
message: 'Publish my blog to WordPress using publishToPlatform with platform "wordpress".',
|
||||
priority: 'high'
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@ interface BlogEditorProps {
|
||||
onSave?: (content: any) => void;
|
||||
continuityRefresh?: number;
|
||||
flowAnalysisResults?: any;
|
||||
sectionImages?: Record<string, string>;
|
||||
}
|
||||
|
||||
const BlogEditor: React.FC<BlogEditorProps> = ({
|
||||
@@ -43,7 +44,8 @@ const BlogEditor: React.FC<BlogEditorProps> = ({
|
||||
onContentUpdate,
|
||||
onSave,
|
||||
continuityRefresh,
|
||||
flowAnalysisResults
|
||||
flowAnalysisResults,
|
||||
sectionImages = {}
|
||||
}) => {
|
||||
const [blogTitle, setBlogTitle] = useState(initialTitle || 'Your Amazing Blog Title');
|
||||
const [sections, setSections] = useState<any[]>([]);
|
||||
@@ -143,17 +145,25 @@ const BlogEditor: React.FC<BlogEditorProps> = ({
|
||||
<Divider sx={{ mt: 3, opacity: 0.3 }} />
|
||||
</div>
|
||||
<div>
|
||||
{sections.map((section) => (
|
||||
<BlogSection
|
||||
key={section.id}
|
||||
{...section}
|
||||
onContentUpdate={onContentUpdate}
|
||||
expandedSections={expandedSections}
|
||||
toggleSectionExpansion={toggleSectionExpansion}
|
||||
refreshToken={continuityRefresh}
|
||||
flowAnalysisResults={flowAnalysisResults}
|
||||
/>
|
||||
))}
|
||||
{sections.map((section, index) => {
|
||||
// Robust image mapping: prefer outline index id (order is consistent across phases)
|
||||
const imageIdByIndex = outline[index]?.id;
|
||||
const outlineSection = outline.find(s => (s.id === section.id) || (s.heading === section.title));
|
||||
const imageId = imageIdByIndex || outlineSection?.id || section.id;
|
||||
const sectionImage = sectionImages?.[imageId] || null;
|
||||
return (
|
||||
<BlogSection
|
||||
key={section.id}
|
||||
{...section}
|
||||
onContentUpdate={onContentUpdate}
|
||||
expandedSections={expandedSections}
|
||||
toggleSectionExpansion={toggleSectionExpansion}
|
||||
refreshToken={continuityRefresh}
|
||||
flowAnalysisResults={flowAnalysisResults}
|
||||
sectionImage={sectionImage}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</Paper>
|
||||
</div>
|
||||
|
||||
@@ -40,6 +40,7 @@ interface BlogSectionProps {
|
||||
toggleSectionExpansion: (sectionId: any) => void;
|
||||
refreshToken?: number;
|
||||
flowAnalysisResults?: any;
|
||||
sectionImage?: string;
|
||||
}
|
||||
|
||||
const BlogSection: React.FC<BlogSectionProps> = ({
|
||||
@@ -53,7 +54,8 @@ const BlogSection: React.FC<BlogSectionProps> = ({
|
||||
expandedSections,
|
||||
toggleSectionExpansion,
|
||||
refreshToken,
|
||||
flowAnalysisResults
|
||||
flowAnalysisResults,
|
||||
sectionImage
|
||||
}) => {
|
||||
const [isEditing, setIsEditing] = useState(false);
|
||||
const [sectionTitle, setSectionTitle] = useState(title);
|
||||
@@ -181,6 +183,31 @@ const BlogSection: React.FC<BlogSectionProps> = ({
|
||||
)}
|
||||
|
||||
</div>
|
||||
|
||||
{/* Section Image Display */}
|
||||
{sectionImage && (
|
||||
<div style={{ marginBottom: '16px', marginTop: '8px' }}>
|
||||
<div style={{
|
||||
border: '1px solid #e0e0e0',
|
||||
borderRadius: '8px',
|
||||
overflow: 'hidden',
|
||||
maxWidth: '100%',
|
||||
backgroundColor: '#fff'
|
||||
}}>
|
||||
<img
|
||||
src={`data:image/png;base64,${sectionImage}`}
|
||||
alt={`Cover image for ${sectionTitle}`}
|
||||
style={{
|
||||
width: '100%',
|
||||
height: 'auto',
|
||||
display: 'block',
|
||||
maxHeight: '400px',
|
||||
objectFit: 'contain'
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div
|
||||
className="relative"
|
||||
|
||||
@@ -119,25 +119,44 @@ const SystemStatusIndicator: React.FC<SystemStatusIndicatorProps> = ({ className
|
||||
const fetchDetailedStats = async () => {
|
||||
try {
|
||||
const response = await apiClient.get('/api/content-planning/monitoring/api-stats');
|
||||
const result = response.data;
|
||||
if (result.status === 'success') {
|
||||
setDetailedStats(result.data);
|
||||
if (result.data?.cache_performance) {
|
||||
setCachePerf(result.data.cache_performance);
|
||||
const result = response?.data;
|
||||
|
||||
// Validate response structure
|
||||
if (!result || result.status !== 'success' || !result.data) {
|
||||
console.warn('Invalid response structure from api-stats endpoint:', result);
|
||||
setChartData([]);
|
||||
return;
|
||||
}
|
||||
|
||||
const data = result.data;
|
||||
setDetailedStats(data);
|
||||
|
||||
if (data?.cache_performance) {
|
||||
setCachePerf(data.cache_performance);
|
||||
}
|
||||
|
||||
// Generate chart data
|
||||
const chartData = result.data.top_endpoints.slice(0, 5).map((endpoint: any, index: number) => ({
|
||||
name: endpoint.endpoint.split(' ')[1].split('/').pop() || 'API',
|
||||
requests: endpoint.count,
|
||||
avgTime: endpoint.avg_time,
|
||||
errors: endpoint.errors,
|
||||
hitRate: endpoint.cache_hit_rate
|
||||
// Generate chart data - safely handle missing top_endpoints
|
||||
if (data?.top_endpoints && Array.isArray(data.top_endpoints) && data.top_endpoints.length > 0) {
|
||||
try {
|
||||
const chartData = data.top_endpoints.slice(0, 5).map((endpoint: any) => ({
|
||||
name: endpoint?.endpoint?.split(' ')[1]?.split('/').pop() || 'API',
|
||||
requests: endpoint?.count || 0,
|
||||
avgTime: endpoint?.avg_time || 0,
|
||||
errors: endpoint?.errors || 0,
|
||||
hitRate: endpoint?.cache_hit_rate || 0
|
||||
}));
|
||||
setChartData(chartData);
|
||||
} catch (mapError) {
|
||||
console.error('Error mapping chart data:', mapError);
|
||||
setChartData([]);
|
||||
}
|
||||
} else {
|
||||
// If top_endpoints is missing or not an array, set empty chart data
|
||||
setChartData([]);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Error fetching detailed stats:', err);
|
||||
setChartData([]);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -353,7 +372,7 @@ const SystemStatusIndicator: React.FC<SystemStatusIndicatorProps> = ({ className
|
||||
)}
|
||||
|
||||
{/* Recent Errors Section */}
|
||||
{detailedStats?.recent_errors && detailedStats.recent_errors.length > 0 && (
|
||||
{detailedStats?.recent_errors && Array.isArray(detailedStats.recent_errors) && detailedStats.recent_errors.length > 0 && (
|
||||
<motion.div
|
||||
initial={{ opacity: 0, y: 20 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
@@ -395,6 +414,8 @@ const SystemStatusIndicator: React.FC<SystemStatusIndicatorProps> = ({ className
|
||||
>
|
||||
Close
|
||||
</Button>
|
||||
<Tooltip title={loading ? "Refreshing data..." : "Refresh monitoring data"}>
|
||||
<span>
|
||||
<Button
|
||||
onClick={fetchDetailedStats}
|
||||
variant="contained"
|
||||
@@ -403,6 +424,8 @@ const SystemStatusIndicator: React.FC<SystemStatusIndicatorProps> = ({ className
|
||||
>
|
||||
Refresh Data
|
||||
</Button>
|
||||
</span>
|
||||
</Tooltip>
|
||||
</DialogActions>
|
||||
</Dialog>
|
||||
</>
|
||||
|
||||
@@ -56,18 +56,10 @@ export interface PromptSuggestion {
|
||||
}
|
||||
|
||||
export async function fetchPromptSuggestions(payload: any): Promise<PromptSuggestion[]> {
|
||||
const res = await fetch('/api/images/suggest-prompts', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
credentials: 'include',
|
||||
body: JSON.stringify(payload)
|
||||
});
|
||||
if (!res.ok) {
|
||||
const text = await res.text();
|
||||
throw new Error(text || 'Failed to fetch prompt suggestions');
|
||||
}
|
||||
const data = await res.json();
|
||||
return data.suggestions || [];
|
||||
// Use apiClient directly (same pattern as SEO analysis in SEOAnalysisModal.tsx)
|
||||
// The apiClient interceptor will handle auth token injection automatically
|
||||
const response = await apiClient.post('/api/images/suggest-prompts', payload);
|
||||
return response.data.suggestions || [];
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ import {
|
||||
Modal,
|
||||
Fade,
|
||||
Backdrop,
|
||||
Snackbar,
|
||||
} from '@mui/material';
|
||||
import {
|
||||
Check as CheckIcon,
|
||||
@@ -35,6 +36,7 @@ import {
|
||||
Star as StarIcon,
|
||||
WorkspacePremium as PremiumIcon,
|
||||
Info as InfoIcon,
|
||||
Warning,
|
||||
Psychology,
|
||||
Search,
|
||||
FactCheck,
|
||||
@@ -83,6 +85,7 @@ const PricingPage: React.FC = () => {
|
||||
const [subscribing, setSubscribing] = useState(false);
|
||||
const [paymentModalOpen, setPaymentModalOpen] = useState(false);
|
||||
const [showSignInPrompt, setShowSignInPrompt] = useState(false);
|
||||
const [successSnackbar, setSuccessSnackbar] = useState({ open: false, message: '', countdown: 3 });
|
||||
const [knowMoreModal, setKnowMoreModal] = useState<{ open: boolean; title: string; content: React.ReactNode }>({
|
||||
open: false,
|
||||
title: '',
|
||||
@@ -172,27 +175,70 @@ const PricingPage: React.FC = () => {
|
||||
setSubscribing(true);
|
||||
const userId = localStorage.getItem('user_id') || 'anonymous';
|
||||
|
||||
await apiClient.post(`/api/subscription/subscribe/${userId}`, {
|
||||
const response = await apiClient.post(`/api/subscription/subscribe/${userId}`, {
|
||||
plan_id: selectedPlan,
|
||||
billing_cycle: yearlyBilling ? 'yearly' : 'monthly'
|
||||
});
|
||||
|
||||
// Refresh subscription status
|
||||
console.log('Subscription renewed successfully:', response.data);
|
||||
|
||||
// Refresh subscription status immediately
|
||||
window.dispatchEvent(new CustomEvent('subscription-updated'));
|
||||
|
||||
// Also trigger user authenticated event to refresh subscription context
|
||||
window.dispatchEvent(new CustomEvent('user-authenticated'));
|
||||
|
||||
setPaymentModalOpen(false);
|
||||
|
||||
// After subscription, check if onboarding is complete
|
||||
// If not complete, redirect to onboarding; otherwise to dashboard
|
||||
const onboardingComplete = localStorage.getItem('onboarding_complete') === 'true';
|
||||
if (onboardingComplete) {
|
||||
navigate('/dashboard');
|
||||
} else {
|
||||
navigate('/onboarding');
|
||||
}
|
||||
// Get plan name for success message
|
||||
const planName = plans.find(p => p.id === selectedPlan)?.name || 'subscription';
|
||||
|
||||
// Show success message with countdown
|
||||
setSuccessSnackbar({
|
||||
open: true,
|
||||
message: `🎉 ${planName} plan activated! Your usage limits have been reset. Returning to your work in 3 seconds...`,
|
||||
countdown: 3
|
||||
});
|
||||
|
||||
// Countdown timer
|
||||
let countdown = 3;
|
||||
const countdownInterval = setInterval(() => {
|
||||
countdown -= 1;
|
||||
if (countdown > 0) {
|
||||
setSuccessSnackbar(prev => ({
|
||||
...prev,
|
||||
message: `🎉 ${planName} plan activated! Your usage limits have been reset. Returning to your work in ${countdown} second${countdown !== 1 ? 's' : ''}...`,
|
||||
countdown
|
||||
}));
|
||||
} else {
|
||||
clearInterval(countdownInterval);
|
||||
}
|
||||
}, 1000);
|
||||
|
||||
// Auto-redirect after 3 seconds
|
||||
setTimeout(() => {
|
||||
clearInterval(countdownInterval);
|
||||
|
||||
// After subscription, check if onboarding is complete
|
||||
// If not complete, redirect to onboarding; otherwise to dashboard
|
||||
const onboardingComplete = localStorage.getItem('onboarding_complete') === 'true';
|
||||
if (onboardingComplete) {
|
||||
// Try to go back to where the user was (e.g., blog writer)
|
||||
// If no history, go to dashboard
|
||||
const referrer = sessionStorage.getItem('subscription_referrer');
|
||||
if (referrer && referrer !== '/pricing') {
|
||||
navigate(referrer);
|
||||
} else {
|
||||
navigate('/dashboard');
|
||||
}
|
||||
} else {
|
||||
navigate('/onboarding');
|
||||
}
|
||||
}, 3000);
|
||||
} catch (err) {
|
||||
console.error('Error subscribing:', err);
|
||||
setError('Failed to process subscription');
|
||||
setSuccessSnackbar({ open: false, message: '', countdown: 0 });
|
||||
} finally {
|
||||
setSubscribing(false);
|
||||
}
|
||||
@@ -900,32 +946,71 @@ const PricingPage: React.FC = () => {
|
||||
top: '50%',
|
||||
left: '50%',
|
||||
transform: 'translate(-50%, -50%)',
|
||||
width: 400,
|
||||
width: 450,
|
||||
bgcolor: 'background.paper',
|
||||
border: '2px solid #000',
|
||||
boxShadow: 24,
|
||||
p: 4,
|
||||
borderRadius: 2,
|
||||
}}>
|
||||
<Typography variant="h6" component="h2" gutterBottom>
|
||||
<Typography variant="h6" component="h2" gutterBottom sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>
|
||||
<Warning sx={{ color: 'warning.main' }} />
|
||||
Alpha Testing Subscription
|
||||
</Typography>
|
||||
<Typography variant="body1" sx={{ mb: 3 }}>
|
||||
Thank you for participating in our alpha testing! For the Basic plan, we're crediting $29 to your account.
|
||||
</Typography>
|
||||
<Typography variant="body2" color="text.secondary" sx={{ mb: 3 }}>
|
||||
In production, this would integrate with Stripe/Paddle for real payment processing.
|
||||
|
||||
{/* Alpha Testing Notice */}
|
||||
<Alert severity="warning" sx={{ mb: 2 }}>
|
||||
<Typography variant="body2" sx={{ fontWeight: 600, mb: 0.5 }}>
|
||||
⚠️ Alpha Testing Mode - No Payment Required
|
||||
</Typography>
|
||||
<Typography variant="caption" sx={{ display: 'block' }}>
|
||||
Payment integration is coming soon. For now, subscriptions are activated without charge.
|
||||
</Typography>
|
||||
</Alert>
|
||||
|
||||
<Typography variant="body1" sx={{ mb: 2 }}>
|
||||
Thank you for participating in our alpha testing! We're crediting the Basic plan ($29 value) to your account.
|
||||
</Typography>
|
||||
|
||||
{/* TODO: Payment Integration Notice */}
|
||||
<Box sx={{
|
||||
p: 2,
|
||||
mb: 3,
|
||||
bgcolor: 'info.lighter',
|
||||
borderRadius: 1,
|
||||
border: '1px solid',
|
||||
borderColor: 'info.light'
|
||||
}}>
|
||||
<Typography variant="body2" color="info.dark">
|
||||
<strong>Coming in Production:</strong>
|
||||
</Typography>
|
||||
<Typography variant="caption" color="info.dark" sx={{ display: 'block', mt: 0.5 }}>
|
||||
• Secure Stripe/PayPal payment processing<br />
|
||||
• Automatic renewal management<br />
|
||||
• Payment verification & receipts<br />
|
||||
• Upgrade/downgrade options
|
||||
</Typography>
|
||||
</Box>
|
||||
|
||||
{/* Note: Current behavior allows renewal without payment verification */}
|
||||
{/* This is intentional for alpha testing but will be secured in production */}
|
||||
|
||||
<Box sx={{ display: 'flex', justifyContent: 'flex-end', gap: 2 }}>
|
||||
<Button onClick={() => setPaymentModalOpen(false)}>
|
||||
<Button onClick={() => setPaymentModalOpen(false)} variant="outlined">
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="contained"
|
||||
onClick={handlePaymentConfirm}
|
||||
disabled={subscribing}
|
||||
sx={{
|
||||
background: 'linear-gradient(135deg, #667eea 0%, #764ba2 100%)',
|
||||
'&:hover': {
|
||||
background: 'linear-gradient(135deg, #5a6fd8 0%, #6a4190 100%)',
|
||||
}
|
||||
}}
|
||||
>
|
||||
{subscribing ? <CircularProgress size={20} /> : 'Confirm Subscription'}
|
||||
{subscribing ? <CircularProgress size={20} sx={{ color: 'white' }} /> : 'Confirm Subscription'}
|
||||
</Button>
|
||||
</Box>
|
||||
</Box>
|
||||
@@ -981,6 +1066,37 @@ const PricingPage: React.FC = () => {
|
||||
</Button>
|
||||
</DialogActions>
|
||||
</Dialog>
|
||||
|
||||
{/* Success Snackbar */}
|
||||
<Snackbar
|
||||
open={successSnackbar.open}
|
||||
autoHideDuration={3000}
|
||||
onClose={() => setSuccessSnackbar({ open: false, message: '', countdown: 0 })}
|
||||
anchorOrigin={{ vertical: 'top', horizontal: 'center' }}
|
||||
sx={{
|
||||
top: { xs: 16, sm: 24 },
|
||||
'& .MuiSnackbarContent-root': {
|
||||
minWidth: { xs: '90vw', sm: '500px' }
|
||||
}
|
||||
}}
|
||||
>
|
||||
<Alert
|
||||
severity="success"
|
||||
variant="filled"
|
||||
onClose={() => setSuccessSnackbar({ open: false, message: '', countdown: 0 })}
|
||||
sx={{
|
||||
width: '100%',
|
||||
fontSize: '1rem',
|
||||
alignItems: 'center',
|
||||
boxShadow: '0 8px 24px rgba(76, 175, 80, 0.4)',
|
||||
'& .MuiAlert-icon': {
|
||||
fontSize: '2rem'
|
||||
}
|
||||
}}
|
||||
>
|
||||
{successSnackbar.message}
|
||||
</Alert>
|
||||
</Snackbar>
|
||||
</Container>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -39,6 +39,25 @@ const SubscriptionExpiredModal: React.FC<SubscriptionExpiredModalProps> = ({
|
||||
subscriptionData,
|
||||
errorData
|
||||
}) => {
|
||||
// Debug logging to verify modal state
|
||||
React.useEffect(() => {
|
||||
if (open) {
|
||||
console.log('SubscriptionExpiredModal: Modal opened', {
|
||||
open,
|
||||
errorData,
|
||||
hasUsageInfo: !!errorData?.usage_info
|
||||
});
|
||||
}
|
||||
}, [open, errorData]);
|
||||
|
||||
const handleDialogClose = (_event: object, reason?: string) => {
|
||||
if (reason === 'backdropClick') {
|
||||
console.log('SubscriptionExpiredModal: Ignoring backdrop click close');
|
||||
return;
|
||||
}
|
||||
onClose();
|
||||
};
|
||||
|
||||
const handleRenewClick = () => {
|
||||
onRenewSubscription();
|
||||
onClose();
|
||||
@@ -47,16 +66,21 @@ const SubscriptionExpiredModal: React.FC<SubscriptionExpiredModalProps> = ({
|
||||
return (
|
||||
<Dialog
|
||||
open={open}
|
||||
onClose={onClose}
|
||||
onClose={handleDialogClose}
|
||||
maxWidth="sm"
|
||||
fullWidth
|
||||
disableEscapeKeyDown
|
||||
PaperProps={{
|
||||
sx: {
|
||||
borderRadius: 3,
|
||||
background: 'linear-gradient(135deg, #fff 0%, #f8fafc 100%)',
|
||||
boxShadow: '0 25px 50px -12px rgba(0, 0, 0, 0.25)',
|
||||
zIndex: 9999, // Ensure modal appears above everything
|
||||
}
|
||||
}}
|
||||
sx={{
|
||||
zIndex: 9999, // Ensure modal backdrop appears above everything
|
||||
}}
|
||||
>
|
||||
<DialogTitle sx={{ textAlign: 'center', pb: 1 }}>
|
||||
<Box sx={{ display: 'flex', alignItems: 'center', justifyContent: 'center', gap: 2 }}>
|
||||
@@ -93,56 +117,156 @@ const SubscriptionExpiredModal: React.FC<SubscriptionExpiredModalProps> = ({
|
||||
borderRadius: 2
|
||||
}}
|
||||
>
|
||||
<Typography variant="body1" sx={{ mb: 2, color: 'text.secondary' }}>
|
||||
{/* Main error message */}
|
||||
<Typography variant="body1" sx={{ mb: 2, color: 'text.secondary', lineHeight: 1.6 }}>
|
||||
{errorData?.message || (errorData?.usage_info
|
||||
? 'You\'ve reached your monthly usage limit for this plan. Upgrade your plan to get higher limits.'
|
||||
: 'To continue using Alwrity and access all features, you need to renew your subscription.'
|
||||
)}
|
||||
</Typography>
|
||||
|
||||
{/* Detailed usage information */}
|
||||
{errorData?.usage_info && (
|
||||
<Box sx={{ mb: 2, p: 2, background: 'rgba(255,255,255,0.7)', borderRadius: 1 }}>
|
||||
<Typography variant="body2" sx={{ fontWeight: 600, mb: 1, color: 'text.primary' }}>
|
||||
<Box sx={{ mb: 2, p: 2.5, background: 'rgba(255,255,255,0.9)', borderRadius: 2, border: '1px solid #e2e8f0' }}>
|
||||
<Typography variant="subtitle2" sx={{ fontWeight: 700, mb: 2, color: 'text.primary', display: 'flex', alignItems: 'center', gap: 1 }}>
|
||||
<Warning sx={{ fontSize: 18, color: 'warning.main' }} />
|
||||
Usage Information:
|
||||
</Typography>
|
||||
{errorData.usage_info.call_usage_percentage && (
|
||||
<Typography variant="body2" sx={{ color: 'text.secondary' }}>
|
||||
You've used {errorData.usage_info.call_usage_percentage.toFixed(1)}% of your monthly limit
|
||||
</Typography>
|
||||
|
||||
{/* Provider and operation type */}
|
||||
<Box sx={{ display: 'flex', gap: 2, mb: 2, flexWrap: 'wrap' }}>
|
||||
{errorData.provider && (
|
||||
<Box sx={{
|
||||
flex: '1 1 auto',
|
||||
px: 2,
|
||||
py: 1.5,
|
||||
background: 'linear-gradient(135deg, #e0e7ff 0%, #c7d2fe 100%)',
|
||||
borderRadius: 1.5,
|
||||
border: '1px solid #a5b4fc'
|
||||
}}>
|
||||
<Typography variant="caption" sx={{ color: '#4338ca', fontWeight: 600, display: 'block', mb: 0.5 }}>
|
||||
Provider:
|
||||
</Typography>
|
||||
<Typography variant="body2" sx={{ color: '#312e81', fontWeight: 700 }}>
|
||||
{errorData.provider}
|
||||
</Typography>
|
||||
</Box>
|
||||
)}
|
||||
|
||||
{errorData.usage_info.operation_type && (
|
||||
<Box sx={{
|
||||
flex: '1 1 auto',
|
||||
px: 2,
|
||||
py: 1.5,
|
||||
background: 'linear-gradient(135deg, #fef3c7 0%, #fde68a 100%)',
|
||||
borderRadius: 1.5,
|
||||
border: '1px solid #fbbf24'
|
||||
}}>
|
||||
<Typography variant="caption" sx={{ color: '#92400e', fontWeight: 600, display: 'block', mb: 0.5 }}>
|
||||
Operation:
|
||||
</Typography>
|
||||
<Typography variant="body2" sx={{ color: '#78350f', fontWeight: 700, textTransform: 'capitalize' }}>
|
||||
{errorData.usage_info.operation_type.replace(/_/g, ' ')}
|
||||
</Typography>
|
||||
</Box>
|
||||
)}
|
||||
</Box>
|
||||
|
||||
{/* Token usage details (if available) */}
|
||||
{(errorData.usage_info.current_tokens !== undefined || errorData.usage_info.current_calls !== undefined) && (
|
||||
<Box sx={{
|
||||
p: 2,
|
||||
background: 'linear-gradient(135deg, #fee2e2 0%, #fecaca 100%)',
|
||||
borderRadius: 1.5,
|
||||
border: '1px solid #f87171',
|
||||
mb: 2
|
||||
}}>
|
||||
{errorData.usage_info.current_tokens !== undefined && (
|
||||
<>
|
||||
<Typography variant="body2" sx={{ color: '#7f1d1d', fontWeight: 600, mb: 1 }}>
|
||||
Token Usage:
|
||||
</Typography>
|
||||
<Box sx={{ display: 'flex', alignItems: 'baseline', gap: 1, mb: 0.5 }}>
|
||||
<Typography variant="h6" sx={{ color: '#991b1b', fontWeight: 700 }}>
|
||||
{errorData.usage_info.current_tokens?.toLocaleString() || 0}
|
||||
</Typography>
|
||||
<Typography variant="body2" sx={{ color: '#7f1d1d' }}>
|
||||
/ {errorData.usage_info.limit?.toLocaleString() || 0}
|
||||
</Typography>
|
||||
<Typography variant="caption" sx={{ color: '#7f1d1d', ml: 'auto' }}>
|
||||
({((errorData.usage_info.current_tokens / errorData.usage_info.limit) * 100).toFixed(1)}% used)
|
||||
</Typography>
|
||||
</Box>
|
||||
|
||||
{errorData.usage_info.requested_tokens && (
|
||||
<Typography variant="caption" sx={{ color: '#7f1d1d', display: 'block', mt: 1 }}>
|
||||
Requested: {errorData.usage_info.requested_tokens.toLocaleString()} tokens
|
||||
{errorData.usage_info.current_tokens + errorData.usage_info.requested_tokens > errorData.usage_info.limit && (
|
||||
<span style={{ fontWeight: 700, marginLeft: 4 }}>
|
||||
(Would exceed by: {((errorData.usage_info.current_tokens + errorData.usage_info.requested_tokens) - errorData.usage_info.limit).toLocaleString()} tokens)
|
||||
</span>
|
||||
)}
|
||||
</Typography>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
|
||||
{errorData.usage_info.current_calls !== undefined && (
|
||||
<>
|
||||
<Typography variant="body2" sx={{ color: '#7f1d1d', fontWeight: 600, mb: 1, mt: errorData.usage_info.current_tokens !== undefined ? 2 : 0 }}>
|
||||
API Call Usage:
|
||||
</Typography>
|
||||
<Box sx={{ display: 'flex', alignItems: 'baseline', gap: 1 }}>
|
||||
<Typography variant="h6" sx={{ color: '#991b1b', fontWeight: 700 }}>
|
||||
{errorData.usage_info.current_calls?.toLocaleString() || 0}
|
||||
</Typography>
|
||||
<Typography variant="body2" sx={{ color: '#7f1d1d' }}>
|
||||
/ {errorData.usage_info.call_limit?.toLocaleString() || 0}
|
||||
</Typography>
|
||||
<Typography variant="caption" sx={{ color: '#7f1d1d', ml: 'auto' }}>
|
||||
({((errorData.usage_info.current_calls / errorData.usage_info.call_limit) * 100).toFixed(1)}% used)
|
||||
</Typography>
|
||||
</Box>
|
||||
</>
|
||||
)}
|
||||
</Box>
|
||||
)}
|
||||
{errorData.provider && (
|
||||
<Typography variant="body2" sx={{ color: 'text.secondary' }}>
|
||||
Provider: {errorData.provider}
|
||||
</Typography>
|
||||
|
||||
{/* Error type badge */}
|
||||
{errorData.usage_info.error_type && (
|
||||
<Box sx={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<Box sx={{
|
||||
px: 2,
|
||||
py: 0.5,
|
||||
background: '#dc2626',
|
||||
borderRadius: 1,
|
||||
display: 'inline-block'
|
||||
}}>
|
||||
<Typography variant="caption" sx={{ color: 'white', fontWeight: 700, textTransform: 'uppercase', letterSpacing: 0.5 }}>
|
||||
{errorData.usage_info.error_type.replace(/_/g, ' ')}
|
||||
</Typography>
|
||||
</Box>
|
||||
</Box>
|
||||
)}
|
||||
</Box>
|
||||
)}
|
||||
|
||||
{/* Current plan information */}
|
||||
{subscriptionData && (
|
||||
<Box sx={{ display: 'flex', justifyContent: 'center', gap: 2, flexWrap: 'wrap' }}>
|
||||
{subscriptionData.plan && (
|
||||
<Box sx={{
|
||||
px: 2,
|
||||
py: 1,
|
||||
background: 'rgba(255,255,255,0.7)',
|
||||
borderRadius: 1,
|
||||
border: '1px solid #e2e8f0'
|
||||
px: 3,
|
||||
py: 1.5,
|
||||
background: 'rgba(255,255,255,0.9)',
|
||||
borderRadius: 1.5,
|
||||
border: '2px solid #e2e8f0'
|
||||
}}>
|
||||
<Typography variant="caption" sx={{ color: 'text.secondary', fontWeight: 500 }}>
|
||||
Current Plan: {subscriptionData.plan}
|
||||
<Typography variant="caption" sx={{ color: 'text.secondary', fontWeight: 600, display: 'block', mb: 0.5 }}>
|
||||
Current Plan:
|
||||
</Typography>
|
||||
</Box>
|
||||
)}
|
||||
{subscriptionData.tier && subscriptionData.tier !== subscriptionData.plan && (
|
||||
<Box sx={{
|
||||
px: 2,
|
||||
py: 1,
|
||||
background: 'rgba(255,255,255,0.7)',
|
||||
borderRadius: 1,
|
||||
border: '1px solid #e2e8f0'
|
||||
}}>
|
||||
<Typography variant="caption" sx={{ color: 'text.secondary', fontWeight: 500 }}>
|
||||
Tier: {subscriptionData.tier}
|
||||
<Typography variant="body2" sx={{ color: 'text.primary', fontWeight: 700, textTransform: 'capitalize' }}>
|
||||
{subscriptionData.plan}
|
||||
</Typography>
|
||||
</Box>
|
||||
)}
|
||||
|
||||
@@ -105,12 +105,13 @@ const DashboardHeader: React.FC<DashboardHeaderProps> = ({
|
||||
/* Enhanced Start Button with Phase 1 Improvements */
|
||||
<Box sx={{ position: 'relative', display: 'inline-flex' }}>
|
||||
<Tooltip title={tooltipMessage} arrow placement="bottom">
|
||||
<Button
|
||||
variant="contained"
|
||||
size={isFirstVisit ? "medium" : "small"}
|
||||
startIcon={<PlayArrow />}
|
||||
onClick={workflowControls.onStartWorkflow}
|
||||
disabled={workflowControls.isLoading}
|
||||
<span>
|
||||
<Button
|
||||
variant="contained"
|
||||
size={isFirstVisit ? "medium" : "small"}
|
||||
startIcon={<PlayArrow />}
|
||||
onClick={workflowControls.onStartWorkflow}
|
||||
disabled={workflowControls.isLoading}
|
||||
sx={{
|
||||
position: 'relative',
|
||||
overflow: 'hidden',
|
||||
@@ -180,8 +181,9 @@ const DashboardHeader: React.FC<DashboardHeaderProps> = ({
|
||||
},
|
||||
}}
|
||||
>
|
||||
{isFirstVisit ? '🚀 Start Journey' : 'Start'}
|
||||
</Button>
|
||||
{isFirstVisit ? '🚀 Start Journey' : 'Start'}
|
||||
</Button>
|
||||
</span>
|
||||
</Tooltip>
|
||||
<Box
|
||||
sx={{
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, { createContext, useContext, useState, useEffect, ReactNode, useCallback } from 'react';
|
||||
import React, { createContext, useContext, useState, useEffect, ReactNode, useCallback, useRef } from 'react';
|
||||
import { apiClient, setGlobalSubscriptionErrorHandler } from '../api/client';
|
||||
import SubscriptionExpiredModal from '../components/SubscriptionExpiredModal';
|
||||
|
||||
@@ -60,6 +60,8 @@ export const SubscriptionProvider: React.FC<SubscriptionProviderProps> = ({ chil
|
||||
// New: Grace window after plan changes to avoid noisy UX
|
||||
const [graceUntil, setGraceUntil] = useState<number>(0);
|
||||
const [planSignature, setPlanSignature] = useState<string>("");
|
||||
// Flag to track if current modal is a usage limit modal (should never be auto-closed)
|
||||
const [isUsageLimitModal, setIsUsageLimitModal] = useState<boolean>(false);
|
||||
|
||||
const checkSubscription = useCallback(async () => {
|
||||
// Throttle subscription checks to prevent excessive API calls
|
||||
@@ -86,6 +88,10 @@ export const SubscriptionProvider: React.FC<SubscriptionProviderProps> = ({ chil
|
||||
return;
|
||||
}
|
||||
|
||||
// Wait a moment to ensure auth token getter is installed
|
||||
// This prevents 401 errors during app initialization
|
||||
await new Promise(resolve => setTimeout(resolve, 200));
|
||||
|
||||
console.log('SubscriptionContext: Checking subscription for user:', userId);
|
||||
const response = await apiClient.get(`/api/subscription/status/${userId}`);
|
||||
const subscriptionData = response.data.data;
|
||||
@@ -101,29 +107,42 @@ export const SubscriptionProvider: React.FC<SubscriptionProviderProps> = ({ chil
|
||||
setPlanSignature(newSignature);
|
||||
setGraceUntil(Date.now() + 5 * 60 * 1000);
|
||||
// Close any existing modal as plan just changed
|
||||
if (showModal) {
|
||||
// BUT: Don't close usage limit modals - they're important even after plan changes
|
||||
if (showModal && !isUsageLimitModal) {
|
||||
console.log('SubscriptionContext: Plan changed, closing non-usage-limit modal');
|
||||
setShowModal(false);
|
||||
setModalErrorData(null);
|
||||
} else if (showModal && isUsageLimitModal) {
|
||||
console.log('SubscriptionContext: Plan changed but usage limit modal is open, keeping it open');
|
||||
}
|
||||
}
|
||||
} catch (_e) {}
|
||||
|
||||
// If we have a valid subscription and the modal is open, close it
|
||||
// BUT: NEVER close usage limit modals - user needs to see they hit a limit even with active subscription
|
||||
if (subscriptionData && subscriptionData.active && showModal) {
|
||||
console.log('SubscriptionContext: Valid subscription detected, closing modal');
|
||||
setShowModal(false);
|
||||
setModalErrorData(null);
|
||||
setLastModalShowTime(0); // Reset the cooldown timer
|
||||
}
|
||||
|
||||
// Also check if this is a usage limit error that should be suppressed
|
||||
if (subscriptionData && subscriptionData.active && modalErrorData) {
|
||||
const now = Date.now();
|
||||
const timeSinceLastModal = now - lastModalShowTime;
|
||||
|
||||
// If it's been less than 10 minutes since modal was shown for usage limits, keep it closed
|
||||
if (timeSinceLastModal < 600000 && modalErrorData.usage_info) {
|
||||
console.log('SubscriptionContext: Recent usage limit modal, keeping it closed');
|
||||
// Check if this is a usage limit modal (using flag or checking error data)
|
||||
const hasUsageInfo = modalErrorData?.usage_info ||
|
||||
(modalErrorData?.current_tokens !== undefined) ||
|
||||
(modalErrorData?.current_calls !== undefined) ||
|
||||
(modalErrorData?.limit !== undefined) ||
|
||||
(modalErrorData?.requested_tokens !== undefined);
|
||||
|
||||
const isUsageLimit = isUsageLimitModal || hasUsageInfo;
|
||||
|
||||
if (isUsageLimit) {
|
||||
console.log('SubscriptionContext: Usage limit modal detected - KEEPING OPEN (never auto-close usage limit modals)', {
|
||||
isUsageLimitModal,
|
||||
hasUsageInfo,
|
||||
modalErrorDataKeys: modalErrorData ? Object.keys(modalErrorData) : []
|
||||
});
|
||||
// Do NOT close - usage limit modals should stay open until user dismisses them
|
||||
} else {
|
||||
console.log('SubscriptionContext: Non-usage-limit modal detected, closing since subscription is active');
|
||||
setShowModal(false);
|
||||
setModalErrorData(null);
|
||||
setIsUsageLimitModal(false);
|
||||
setLastModalShowTime(0); // Reset the cooldown timer
|
||||
}
|
||||
}
|
||||
|
||||
@@ -156,7 +175,7 @@ export const SubscriptionProvider: React.FC<SubscriptionProviderProps> = ({ chil
|
||||
setLastModalShowTime(now);
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
} catch (err: any) {
|
||||
console.error('Error checking subscription:', err);
|
||||
|
||||
// Check if it's a connection error that should be handled at the app level
|
||||
@@ -165,6 +184,16 @@ export const SubscriptionProvider: React.FC<SubscriptionProviderProps> = ({ chil
|
||||
throw err;
|
||||
}
|
||||
|
||||
// Handle 401 errors gracefully during initialization - don't block routing
|
||||
// 401 might happen if auth token getter isn't ready yet
|
||||
if (err?.response?.status === 401) {
|
||||
console.warn('Subscription check failed with 401 - auth may not be ready yet, will retry later');
|
||||
setError(null); // Don't set error for 401 during init
|
||||
setLoading(false);
|
||||
// Don't throw - allow routing to proceed, subscription check will retry later
|
||||
return;
|
||||
}
|
||||
|
||||
setError(err instanceof Error ? err.message : 'Failed to check subscription');
|
||||
|
||||
// Don't default to free tier on error - preserve existing subscription or leave null
|
||||
@@ -173,21 +202,30 @@ export const SubscriptionProvider: React.FC<SubscriptionProviderProps> = ({ chil
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, [lastCheckTime, planSignature, showModal, modalErrorData, lastModalShowTime, graceUntil]);
|
||||
}, [lastCheckTime, planSignature, showModal, modalErrorData, lastModalShowTime, graceUntil, isUsageLimitModal]);
|
||||
|
||||
const refreshSubscription = useCallback(async () => {
|
||||
await checkSubscription();
|
||||
}, [checkSubscription]);
|
||||
|
||||
const showExpiredModal = useCallback(() => {
|
||||
setIsUsageLimitModal(false);
|
||||
setShowModal(true);
|
||||
}, []);
|
||||
|
||||
const hideExpiredModal = useCallback(() => {
|
||||
console.log('SubscriptionExpiredModal: User manually closed modal');
|
||||
setShowModal(false);
|
||||
setIsUsageLimitModal(false); // Reset flag when user closes modal
|
||||
setModalErrorData(null);
|
||||
}, []);
|
||||
|
||||
const handleRenewSubscription = useCallback(() => {
|
||||
// Save current location so we can return after renewal
|
||||
const currentPath = window.location.pathname;
|
||||
sessionStorage.setItem('subscription_referrer', currentPath);
|
||||
|
||||
console.log('SubscriptionContext: Navigating to pricing page, saved referrer:', currentPath);
|
||||
window.location.href = '/pricing';
|
||||
}, []);
|
||||
|
||||
@@ -203,42 +241,131 @@ export const SubscriptionProvider: React.FC<SubscriptionProviderProps> = ({ chil
|
||||
|
||||
const now = Date.now();
|
||||
|
||||
// If we have subscription data and it's active, always suppress modal for usage limits
|
||||
if (subscription && subscription.active) {
|
||||
console.log('SubscriptionContext: Active subscription; suppressing usage-limit modal');
|
||||
return true; // Do not show modal for active plan usage limits
|
||||
// Check if this is a usage limit error (status 429) vs subscription expired (402)
|
||||
let errorData = error.response?.data || {};
|
||||
|
||||
// DEBUG: Log the raw error data structure
|
||||
console.log('SubscriptionContext: Raw error data', {
|
||||
type: typeof errorData,
|
||||
isArray: Array.isArray(errorData),
|
||||
data: errorData,
|
||||
stringified: JSON.stringify(errorData)
|
||||
});
|
||||
|
||||
// If errorData is an array, extract the first element (common FastAPI response format)
|
||||
if (Array.isArray(errorData)) {
|
||||
console.log('SubscriptionContext: errorData is array, extracting first element');
|
||||
errorData = errorData[0] || {};
|
||||
}
|
||||
|
||||
// If we don't have subscription data yet, defer the decision
|
||||
if (!subscription) {
|
||||
console.log('SubscriptionContext: No subscription data yet, deferring modal decision');
|
||||
setDeferredError(error);
|
||||
return true; // Handle the error but don't show modal yet
|
||||
}
|
||||
|
||||
// If subscription is not active, show modal immediately
|
||||
if (!subscription.active) {
|
||||
console.log('SubscriptionContext: Inactive subscription, showing modal immediately');
|
||||
const errorData = error.response?.data || {};
|
||||
setModalErrorData({
|
||||
provider: errorData.provider,
|
||||
usage_info: errorData.usage_info,
|
||||
message: errorData.message || errorData.error
|
||||
|
||||
// Check for usage_info in various possible locations
|
||||
const usageInfo = errorData.usage_info ||
|
||||
(errorData.current_calls !== undefined ? errorData : null) ||
|
||||
null;
|
||||
|
||||
// Usage limit error: 429 status with usage info OR 429 status without explicit expiration
|
||||
const isUsageLimitError = status === 429 && (usageInfo || errorData.provider || errorData.message);
|
||||
const isSubscriptionExpired = status === 402 || (status === 429 && !isUsageLimitError);
|
||||
|
||||
console.log('SubscriptionContext: Error analysis', {
|
||||
status,
|
||||
isUsageLimitError,
|
||||
isSubscriptionExpired,
|
||||
hasUsageInfo: !!usageInfo,
|
||||
errorDataType: typeof errorData,
|
||||
errorDataKeys: typeof errorData === 'object' && !Array.isArray(errorData) ? Object.keys(errorData) : 'not-an-object',
|
||||
errorData: errorData
|
||||
});
|
||||
|
||||
// For usage limit errors (429 with usage_info), always show modal - even for active subscriptions
|
||||
// Ignore grace window and cooldown for usage limit errors (user needs to know immediately)
|
||||
if (isUsageLimitError) {
|
||||
const modalData = {
|
||||
provider: errorData.provider || usageInfo?.provider || 'unknown',
|
||||
usage_info: usageInfo || errorData,
|
||||
message: errorData.message || errorData.error || 'You have reached your usage limit.'
|
||||
};
|
||||
|
||||
console.log('SubscriptionContext: Usage limit exceeded, showing modal (ignoring grace window/cooldown)', {
|
||||
modalData,
|
||||
errorData: Object.keys(errorData),
|
||||
usageInfo: usageInfo ? Object.keys(usageInfo) : null
|
||||
});
|
||||
|
||||
// Set flag to mark this as a usage limit modal (should never be auto-closed)
|
||||
setIsUsageLimitModal(true);
|
||||
setModalErrorData(modalData);
|
||||
setShowModal(true);
|
||||
setLastModalShowTime(now);
|
||||
|
||||
console.log('SubscriptionContext: Modal state updated - showModal should be true, isUsageLimitModal = true');
|
||||
return true;
|
||||
}
|
||||
|
||||
// For subscription expired errors, handle based on subscription status
|
||||
if (isSubscriptionExpired) {
|
||||
// If we have subscription data and it's active, this shouldn't happen but suppress anyway
|
||||
if (subscription && subscription.active) {
|
||||
console.log('SubscriptionContext: Active subscription but got expired error, suppressing modal');
|
||||
return true;
|
||||
}
|
||||
|
||||
// If we don't have subscription data yet, defer the decision
|
||||
if (!subscription) {
|
||||
console.log('SubscriptionContext: No subscription data yet, deferring modal decision');
|
||||
setDeferredError(error);
|
||||
return true; // Handle the error but don't show modal yet
|
||||
}
|
||||
|
||||
// If subscription is not active, show modal immediately
|
||||
if (!subscription.active) {
|
||||
console.log('SubscriptionContext: Inactive subscription, showing modal immediately');
|
||||
setIsUsageLimitModal(false);
|
||||
setModalErrorData({
|
||||
provider: errorData.provider,
|
||||
usage_info: errorData.usage_info,
|
||||
message: errorData.message || errorData.error
|
||||
});
|
||||
setShowModal(true);
|
||||
setLastModalShowTime(now);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false; // Not a subscription error
|
||||
}, [subscription]);
|
||||
|
||||
// Register the global error handler with the API client
|
||||
// Use a ref to ensure the latest handler is always used
|
||||
const handlerRef = useRef(globalSubscriptionErrorHandler);
|
||||
useEffect(() => {
|
||||
handlerRef.current = globalSubscriptionErrorHandler;
|
||||
}, [globalSubscriptionErrorHandler]);
|
||||
|
||||
useEffect(() => {
|
||||
console.log('SubscriptionContext: Registering global subscription error handler');
|
||||
setGlobalSubscriptionErrorHandler(globalSubscriptionErrorHandler);
|
||||
}, [globalSubscriptionErrorHandler]);
|
||||
setGlobalSubscriptionErrorHandler((error: any) => {
|
||||
// Always use the latest handler from ref
|
||||
return handlerRef.current(error);
|
||||
});
|
||||
|
||||
// Cleanup: Don't remove the handler on unmount - it should persist
|
||||
// This ensures errors can still be caught even during component transitions
|
||||
}, []); // Empty deps - only register once, but handler ref updates automatically
|
||||
|
||||
useEffect(() => {
|
||||
const eventHandler = (event: Event) => {
|
||||
const customEvent = event as CustomEvent;
|
||||
console.log('SubscriptionContext: Received subscription-error event fallback', customEvent.detail);
|
||||
handlerRef.current(customEvent.detail);
|
||||
};
|
||||
|
||||
window.addEventListener('subscription-error', eventHandler as EventListener);
|
||||
return () => {
|
||||
window.removeEventListener('subscription-error', eventHandler as EventListener);
|
||||
};
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
// Check subscription on mount
|
||||
|
||||
@@ -33,6 +33,9 @@ export const useBlogWriterState = () => {
|
||||
// Content confirmation state
|
||||
const [contentConfirmed, setContentConfirmed] = useState<boolean>(false);
|
||||
|
||||
// Section images state - persists images generated in outline phase to content phase
|
||||
const [sectionImages, setSectionImages] = useState<Record<string, string>>({});
|
||||
|
||||
// Cache recovery - restore most recent research on page load
|
||||
useEffect(() => {
|
||||
const cachedEntries = researchCache.getAllCachedEntries();
|
||||
@@ -211,6 +214,7 @@ export const useBlogWriterState = () => {
|
||||
contentConfirmed,
|
||||
flowAnalysisCompleted,
|
||||
flowAnalysisResults,
|
||||
sectionImages,
|
||||
|
||||
// Setters
|
||||
setResearch,
|
||||
@@ -233,6 +237,7 @@ export const useBlogWriterState = () => {
|
||||
setContentConfirmed,
|
||||
setFlowAnalysisCompleted,
|
||||
setFlowAnalysisResults,
|
||||
setSectionImages,
|
||||
|
||||
// Handlers
|
||||
handleResearchComplete,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { useState, useEffect, useCallback, useRef } from 'react';
|
||||
import { blogWriterApi, TaskStatusResponse } from '../services/blogWriterApi';
|
||||
import { triggerSubscriptionError } from '../api/client';
|
||||
|
||||
export interface UsePollingOptions {
|
||||
interval?: number; // Polling interval in milliseconds
|
||||
@@ -108,6 +109,43 @@ export function usePolling(
|
||||
console.log('❌ Task failed - stopping polling immediately');
|
||||
setError(status.error || 'Task failed');
|
||||
onError?.(status.error || 'Task failed');
|
||||
|
||||
// Check if this is a subscription error and trigger modal
|
||||
if (status.error_status === 429 || status.error_status === 402) {
|
||||
console.log('usePolling: Detected subscription error in task status', {
|
||||
error_status: status.error_status,
|
||||
error_data: status.error_data,
|
||||
error: status.error
|
||||
});
|
||||
|
||||
// Create a mock error object with the subscription error data
|
||||
const errorData = status.error_data || {};
|
||||
|
||||
// Ensure usage_info is properly nested - it might be at the top level or nested
|
||||
const usageInfo = errorData.usage_info ||
|
||||
(errorData.current_calls !== undefined ? errorData : null) ||
|
||||
errorData;
|
||||
|
||||
const mockError = {
|
||||
response: {
|
||||
status: status.error_status,
|
||||
data: {
|
||||
error: errorData.error || status.error || 'Subscription limit exceeded',
|
||||
message: errorData.message || errorData.error || status.error || 'You have reached your usage limit.',
|
||||
provider: errorData.provider || usageInfo?.provider || 'unknown',
|
||||
usage_info: usageInfo
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
console.log('usePolling: Triggering subscription error handler with:', mockError);
|
||||
const handled = triggerSubscriptionError(mockError);
|
||||
|
||||
if (!handled) {
|
||||
console.warn('usePolling: Subscription error handler did not handle the error');
|
||||
}
|
||||
}
|
||||
|
||||
stopPolling();
|
||||
return; // Exit early to prevent further processing
|
||||
}
|
||||
@@ -117,6 +155,38 @@ export function usePolling(
|
||||
const errorMessage = err instanceof Error ? err.message : 'Unknown error occurred';
|
||||
console.error('Polling error:', errorMessage);
|
||||
|
||||
// Check if this is an axios error with subscription limit status
|
||||
// This is a fallback in case the interceptor doesn't catch it
|
||||
const axiosError = err as any;
|
||||
if (axiosError?.response?.status === 429 || axiosError?.response?.status === 402) {
|
||||
console.log('usePolling: Detected subscription error in axios error response', {
|
||||
status: axiosError.response.status,
|
||||
data: axiosError.response.data
|
||||
});
|
||||
|
||||
// Trigger subscription error handler (modal will show)
|
||||
const handled = triggerSubscriptionError(axiosError);
|
||||
console.log('usePolling: triggerSubscriptionError returned', handled);
|
||||
|
||||
if (handled) {
|
||||
console.log('usePolling: Subscription error handled, stopping polling');
|
||||
const errorMsg = axiosError.response?.data?.message ||
|
||||
axiosError.response?.data?.error ||
|
||||
'Subscription limit exceeded';
|
||||
setError(errorMsg);
|
||||
onError?.(errorMsg);
|
||||
stopPolling();
|
||||
return; // Exit early - don't continue processing
|
||||
} else {
|
||||
console.warn('usePolling: Subscription error not handled by global handler, dispatching fallback event');
|
||||
try {
|
||||
window.dispatchEvent(new CustomEvent('subscription-error', { detail: axiosError }));
|
||||
} catch (eventError) {
|
||||
console.error('usePolling: Failed to dispatch subscription-error event', eventError);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop polling for task failures and rate limiting
|
||||
if (errorMessage.includes('404') || errorMessage.includes('Task not found')) {
|
||||
setError('Task not found - it may have expired or been cleaned up');
|
||||
|
||||
@@ -219,9 +219,22 @@ export interface BlogSEOMetadataResponse {
|
||||
success: boolean;
|
||||
title_options: string[];
|
||||
meta_descriptions: string[];
|
||||
seo_title?: string;
|
||||
meta_description?: string;
|
||||
url_slug?: string;
|
||||
blog_tags: string[];
|
||||
blog_categories: string[];
|
||||
social_hashtags: string[];
|
||||
open_graph: Record<string, any>;
|
||||
twitter_card: Record<string, any>;
|
||||
schema: Record<string, any>;
|
||||
json_ld_schema?: Record<string, any>;
|
||||
schema?: Record<string, any>; // Legacy field name
|
||||
canonical_url?: string;
|
||||
reading_time?: number;
|
||||
focus_keyword?: string;
|
||||
generated_at?: string;
|
||||
optimization_score?: number;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export interface BlogPublishResponse {
|
||||
@@ -241,6 +254,26 @@ export interface TaskStatusResponse {
|
||||
}>;
|
||||
result?: BlogResearchResponse;
|
||||
error?: string;
|
||||
// Subscription error details (set by backend when subscription limit is exceeded)
|
||||
error_status?: number; // HTTP status code (429 for usage limit, 402 for subscription expired)
|
||||
error_data?: {
|
||||
error?: string;
|
||||
message?: string;
|
||||
provider?: string;
|
||||
usage_info?: {
|
||||
provider?: string;
|
||||
current_calls?: number;
|
||||
limit?: number;
|
||||
type?: string;
|
||||
breakdown?: {
|
||||
gemini?: number;
|
||||
openai?: number;
|
||||
anthropic?: number;
|
||||
mistral?: number;
|
||||
};
|
||||
};
|
||||
[key: string]: any; // Allow additional fields
|
||||
};
|
||||
}
|
||||
|
||||
export const blogWriterApi = {
|
||||
|
||||
198
frontend/src/utils/wixTokenUtils.ts
Normal file
198
frontend/src/utils/wixTokenUtils.ts
Normal file
@@ -0,0 +1,198 @@
|
||||
/**
|
||||
* Wix Token Utilities
|
||||
* Functions for validating and refreshing Wix OAuth tokens
|
||||
*/
|
||||
|
||||
import { apiClient } from '../api/client';
|
||||
|
||||
interface WixTokens {
|
||||
accessToken?: {
|
||||
value: string;
|
||||
expiresAt?: string;
|
||||
};
|
||||
refreshToken?: {
|
||||
value: string;
|
||||
};
|
||||
access_token?: string;
|
||||
refresh_token?: string;
|
||||
expires_in?: number;
|
||||
}
|
||||
|
||||
interface TokenValidationResult {
|
||||
valid: boolean;
|
||||
accessToken: string | null;
|
||||
needsRefresh: boolean;
|
||||
needsReconnect: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get Wix tokens from sessionStorage
|
||||
*/
|
||||
export function getWixTokens(): WixTokens | null {
|
||||
try {
|
||||
const tokensRaw = sessionStorage.getItem('wix_tokens');
|
||||
if (!tokensRaw) return null;
|
||||
return JSON.parse(tokensRaw);
|
||||
} catch (error) {
|
||||
console.error('Error parsing Wix tokens:', error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract access token from token structure
|
||||
*/
|
||||
export function extractAccessToken(tokens: WixTokens | null): string | null {
|
||||
if (!tokens) return null;
|
||||
return tokens.accessToken?.value || tokens.access_token || null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract refresh token from token structure
|
||||
*/
|
||||
export function extractRefreshToken(tokens: WixTokens | null): string | null {
|
||||
if (!tokens) return null;
|
||||
return tokens.refreshToken?.value || tokens.refresh_token || null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh Wix access token using refresh token
|
||||
*/
|
||||
export async function refreshWixToken(refreshToken: string): Promise<WixTokens | null> {
|
||||
try {
|
||||
const response = await apiClient.post('/api/wix/refresh-token', {
|
||||
refresh_token: refreshToken
|
||||
});
|
||||
|
||||
if (response.data.success) {
|
||||
// Create new token structure matching Wix SDK format
|
||||
const newTokens: WixTokens = {
|
||||
accessToken: {
|
||||
value: response.data.access_token
|
||||
},
|
||||
refreshToken: {
|
||||
value: response.data.refresh_token || refreshToken // Keep old refresh token if new one not provided
|
||||
},
|
||||
access_token: response.data.access_token,
|
||||
refresh_token: response.data.refresh_token || refreshToken
|
||||
};
|
||||
|
||||
// Update sessionStorage
|
||||
try {
|
||||
sessionStorage.setItem('wix_tokens', JSON.stringify(newTokens));
|
||||
sessionStorage.setItem('wix_connected', 'true');
|
||||
} catch (e) {
|
||||
console.error('Error saving refreshed tokens:', e);
|
||||
}
|
||||
|
||||
return newTokens;
|
||||
}
|
||||
|
||||
return null;
|
||||
} catch (error: any) {
|
||||
console.error('Error refreshing Wix token:', error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if token is expired based on expiresAt timestamp
|
||||
*/
|
||||
function isTokenExpired(tokens: WixTokens): boolean {
|
||||
if (tokens.accessToken?.expiresAt) {
|
||||
try {
|
||||
const expiresAt = new Date(tokens.accessToken.expiresAt);
|
||||
return expiresAt < new Date();
|
||||
} catch (e) {
|
||||
// If we can't parse, assume not expired (will validate during publish)
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// If no expiration info, we can't tell - assume valid for now
|
||||
// Real validation happens during actual API call
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate and refresh Wix tokens proactively
|
||||
* Returns access token if valid, or null if needs reconnection
|
||||
*
|
||||
* Strategy:
|
||||
* 1. Check if tokens exist
|
||||
* 2. Check if token is expired (if expiration info available)
|
||||
* 3. If expired, attempt refresh
|
||||
* 4. If refresh fails or no refresh token, needs reconnection
|
||||
* 5. Real validation happens during actual publish (we catch 401/403 errors)
|
||||
*/
|
||||
export async function validateAndRefreshWixTokens(): Promise<TokenValidationResult> {
|
||||
const tokens = getWixTokens();
|
||||
|
||||
if (!tokens) {
|
||||
return {
|
||||
valid: false,
|
||||
accessToken: null,
|
||||
needsRefresh: false,
|
||||
needsReconnect: true
|
||||
};
|
||||
}
|
||||
|
||||
const accessToken = extractAccessToken(tokens);
|
||||
const refreshToken = extractRefreshToken(tokens);
|
||||
|
||||
if (!accessToken) {
|
||||
return {
|
||||
valid: false,
|
||||
accessToken: null,
|
||||
needsRefresh: false,
|
||||
needsReconnect: true
|
||||
};
|
||||
}
|
||||
|
||||
// Check if token is expired (if we have expiration info)
|
||||
const expired = isTokenExpired(tokens);
|
||||
|
||||
if (!expired) {
|
||||
// Token appears valid (not expired or no expiration info)
|
||||
// We'll do real validation during publish
|
||||
return {
|
||||
valid: true,
|
||||
accessToken: accessToken,
|
||||
needsRefresh: false,
|
||||
needsReconnect: false
|
||||
};
|
||||
}
|
||||
|
||||
// Token is expired, try to refresh
|
||||
if (!refreshToken) {
|
||||
return {
|
||||
valid: false,
|
||||
accessToken: null,
|
||||
needsRefresh: false,
|
||||
needsReconnect: true
|
||||
};
|
||||
}
|
||||
|
||||
// Attempt to refresh token
|
||||
const refreshedTokens = await refreshWixToken(refreshToken);
|
||||
|
||||
if (refreshedTokens) {
|
||||
const newAccessToken = extractAccessToken(refreshedTokens);
|
||||
if (newAccessToken) {
|
||||
return {
|
||||
valid: true,
|
||||
accessToken: newAccessToken,
|
||||
needsRefresh: true,
|
||||
needsReconnect: false
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Refresh failed, needs reconnection
|
||||
return {
|
||||
valid: false,
|
||||
accessToken: null,
|
||||
needsRefresh: false,
|
||||
needsReconnect: true
|
||||
};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user