Subscription dashboard improvements, AI text generation limit, and other fixes.
This commit is contained in:
@@ -5,10 +5,11 @@ Main router for blog writing operations including research, outline generation,
|
||||
content creation, SEO analysis, and publishing.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from typing import Any, Dict, List
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from typing import Any, Dict, List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from loguru import logger
|
||||
from middleware.auth_middleware import get_current_user
|
||||
|
||||
from models.blog_models import (
|
||||
BlogResearchRequest,
|
||||
@@ -64,10 +65,21 @@ class SEOApplyRecommendationsRequest(BaseModel):
|
||||
|
||||
|
||||
@router.post("/seo/apply-recommendations")
|
||||
async def apply_seo_recommendations(request: SEOApplyRecommendationsRequest) -> Dict[str, Any]:
|
||||
async def apply_seo_recommendations(
|
||||
request: SEOApplyRecommendationsRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Apply actionable SEO recommendations and return updated content."""
|
||||
try:
|
||||
result = await recommendation_applier.apply_recommendations(request.dict())
|
||||
# Extract Clerk user ID (required)
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get('id', ''))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
result = await recommendation_applier.apply_recommendations(request.dict(), user_id=user_id)
|
||||
if not result.get("success"):
|
||||
raise HTTPException(status_code=500, detail=result.get("error", "Failed to apply recommendations"))
|
||||
return result
|
||||
@@ -87,13 +99,24 @@ async def health() -> Dict[str, Any]:
|
||||
|
||||
# Research Endpoints
|
||||
@router.post("/research/start")
|
||||
async def start_research(request: BlogResearchRequest) -> Dict[str, Any]:
|
||||
async def start_research(
|
||||
request: BlogResearchRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Start a research operation and return a task ID for polling."""
|
||||
try:
|
||||
# TODO: Get user_id from authentication context
|
||||
user_id = "anonymous" # This should come from auth middleware
|
||||
# Extract Clerk user ID (required)
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get('id', ''))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
task_id = await task_manager.start_research_task(request, user_id)
|
||||
return {"task_id": task_id, "status": "started"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start research: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -107,6 +130,50 @@ async def get_research_status(task_id: str) -> Dict[str, Any]:
|
||||
if status is None:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
# If task failed with subscription error, return HTTP error so frontend interceptor can catch it
|
||||
if status.get('status') == 'failed' and status.get('error_status') in [429, 402]:
|
||||
error_data = status.get('error_data', {}) or {}
|
||||
error_status = status.get('error_status', 429)
|
||||
|
||||
if not isinstance(error_data, dict):
|
||||
logger.warning(f"Research task {task_id} error_data not dict: {error_data}")
|
||||
error_data = {'error': str(error_data)}
|
||||
|
||||
# Determine provider and usage info
|
||||
stored_error_message = status.get('error', error_data.get('error'))
|
||||
provider = error_data.get('provider', 'unknown')
|
||||
usage_info = error_data.get('usage_info')
|
||||
|
||||
if not usage_info:
|
||||
usage_info = {
|
||||
'provider': provider,
|
||||
'message': stored_error_message,
|
||||
'error_type': error_data.get('error_type', 'unknown')
|
||||
}
|
||||
# Include any known fields from error_data
|
||||
for key in ['current_tokens', 'requested_tokens', 'limit', 'current_calls']:
|
||||
if key in error_data:
|
||||
usage_info[key] = error_data[key]
|
||||
|
||||
# Build error message for detail
|
||||
error_msg = error_data.get('message', stored_error_message or 'Subscription limit exceeded')
|
||||
|
||||
# Log the subscription error with all context
|
||||
logger.warning(f"Research task {task_id} failed with subscription error {error_status}: {error_msg}")
|
||||
logger.warning(f" Provider: {provider}, Usage Info: {usage_info}")
|
||||
|
||||
# Use JSONResponse to ensure detail is returned as-is, not wrapped in an array
|
||||
from fastapi.responses import JSONResponse
|
||||
return JSONResponse(
|
||||
status_code=error_status,
|
||||
content={
|
||||
'error': error_data.get('error', stored_error_message or 'Subscription limit exceeded'),
|
||||
'message': error_msg,
|
||||
'provider': provider,
|
||||
'usage_info': usage_info
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Research status request for {task_id}: {status['status']} with {len(status.get('progress_messages', []))} progress messages")
|
||||
return status
|
||||
except HTTPException:
|
||||
@@ -310,20 +377,46 @@ async def hallucination_check(request: HallucinationCheckRequest) -> Hallucinati
|
||||
|
||||
# SEO Endpoints
|
||||
@router.post("/seo/analyze", response_model=BlogSEOAnalyzeResponse)
|
||||
async def seo_analyze(request: BlogSEOAnalyzeRequest) -> BlogSEOAnalyzeResponse:
|
||||
async def seo_analyze(
|
||||
request: BlogSEOAnalyzeRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> BlogSEOAnalyzeResponse:
|
||||
"""Analyze content for SEO optimization opportunities."""
|
||||
try:
|
||||
return await service.seo_analyze(request)
|
||||
# Extract Clerk user ID (required)
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get('id', ''))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
return await service.seo_analyze(request, user_id=user_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to perform SEO analysis: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/seo/metadata", response_model=BlogSEOMetadataResponse)
|
||||
async def seo_metadata(request: BlogSEOMetadataRequest) -> BlogSEOMetadataResponse:
|
||||
async def seo_metadata(
|
||||
request: BlogSEOMetadataRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> BlogSEOMetadataResponse:
|
||||
"""Generate SEO metadata for the blog post."""
|
||||
try:
|
||||
return await service.seo_metadata(request)
|
||||
# Extract Clerk user ID (required)
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get('id', ''))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
return await service.seo_metadata(request, user_id=user_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate SEO metadata: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -10,6 +10,7 @@ import asyncio
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from models.blog_models import (
|
||||
@@ -85,6 +86,10 @@ class TaskManager:
|
||||
response["result"] = task["result"]
|
||||
elif task["status"] == "failed":
|
||||
response["error"] = task["error"]
|
||||
if "error_status" in task:
|
||||
response["error_status"] = task["error_status"]
|
||||
if "error_data" in task:
|
||||
response["error_data"] = task["error_data"]
|
||||
|
||||
return response
|
||||
|
||||
@@ -109,14 +114,17 @@ class TaskManager:
|
||||
|
||||
logger.info(f"Progress update for task {task_id}: {message}")
|
||||
|
||||
async def start_research_task(self, request: BlogResearchRequest, user_id: str = "anonymous") -> str:
|
||||
async def start_research_task(self, request: BlogResearchRequest, user_id: str) -> str:
|
||||
"""Start a research operation and return a task ID."""
|
||||
if self.use_database:
|
||||
return await self.db_manager.start_research_task(request, user_id)
|
||||
else:
|
||||
task_id = self.create_task("research")
|
||||
# Store user_id in task for subscription checks
|
||||
if task_id in self.task_storage:
|
||||
self.task_storage[task_id]["user_id"] = user_id
|
||||
# Start the research operation in the background
|
||||
asyncio.create_task(self._run_research_task(task_id, request))
|
||||
asyncio.create_task(self._run_research_task(task_id, request, user_id))
|
||||
return task_id
|
||||
|
||||
def start_outline_task(self, request: BlogOutlineRequest) -> str:
|
||||
@@ -144,7 +152,7 @@ class TaskManager:
|
||||
asyncio.create_task(self._run_medium_generation_task(task_id, request))
|
||||
return task_id
|
||||
|
||||
async def _run_research_task(self, task_id: str, request: BlogResearchRequest):
|
||||
async def _run_research_task(self, task_id: str, request: BlogResearchRequest, user_id: str):
|
||||
"""Background task to run research and update status with progress messages."""
|
||||
try:
|
||||
# Update status to running
|
||||
@@ -157,8 +165,8 @@ class TaskManager:
|
||||
# Check cache first
|
||||
await self.update_progress(task_id, "📋 Checking cache for existing research...")
|
||||
|
||||
# Run the actual research with progress updates
|
||||
result = await self.service.research_with_progress(request, task_id)
|
||||
# Run the actual research with progress updates (pass user_id for subscription checks)
|
||||
result = await self.service.research_with_progress(request, task_id, user_id)
|
||||
|
||||
# Check if research failed gracefully
|
||||
if not result.success:
|
||||
@@ -171,6 +179,16 @@ class TaskManager:
|
||||
self.task_storage[task_id]["status"] = "completed"
|
||||
self.task_storage[task_id]["result"] = result.dict()
|
||||
|
||||
except HTTPException as http_error:
|
||||
# Handle HTTPException (e.g., 429 subscription limit) - preserve error details for frontend
|
||||
error_detail = http_error.detail
|
||||
error_message = error_detail.get('message', str(error_detail)) if isinstance(error_detail, dict) else str(error_detail)
|
||||
await self.update_progress(task_id, f"❌ {error_message}")
|
||||
self.task_storage[task_id]["status"] = "failed"
|
||||
self.task_storage[task_id]["error"] = error_message
|
||||
# Store HTTP error details for frontend modal
|
||||
self.task_storage[task_id]["error_status"] = http_error.status_code
|
||||
self.task_storage[task_id]["error_data"] = error_detail if isinstance(error_detail, dict) else {"error": str(error_detail)}
|
||||
except Exception as e:
|
||||
await self.update_progress(task_id, f"❌ Research failed with error: {str(e)}")
|
||||
# Update status to failed
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query, Body
|
||||
from typing import Dict, Any
|
||||
from typing import Dict, Any, Optional
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -64,6 +64,15 @@ async def activate_strategy_with_monitoring(
|
||||
if not monitoring_success:
|
||||
logger.warning(f"Failed to save monitoring data for strategy {strategy_id}")
|
||||
|
||||
# Trigger scheduler interval adjustment (scheduler will check more frequently now)
|
||||
try:
|
||||
from services.scheduler import get_scheduler
|
||||
scheduler = get_scheduler()
|
||||
await scheduler.trigger_interval_adjustment()
|
||||
logger.info(f"Triggered scheduler interval adjustment after strategy {strategy_id} activation")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not trigger scheduler interval adjustment: {e}")
|
||||
|
||||
logger.info(f"Successfully activated strategy {strategy_id} with monitoring")
|
||||
return {
|
||||
"success": True,
|
||||
@@ -396,6 +405,150 @@ async def get_monitoring_tasks(
|
||||
logger.error(f"Error retrieving monitoring tasks: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.get("/user/{user_id}/monitoring-tasks")
|
||||
async def get_user_monitoring_tasks(
|
||||
user_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
status: Optional[str] = Query(None, description="Filter by task status"),
|
||||
limit: int = Query(50, description="Maximum number of tasks to return"),
|
||||
offset: int = Query(0, description="Number of tasks to skip")
|
||||
):
|
||||
"""
|
||||
Get all monitoring tasks for a specific user with their execution status.
|
||||
|
||||
Uses the scheduler's task loader to get tasks filtered by user_id for proper user isolation.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Getting monitoring tasks for user {user_id}")
|
||||
|
||||
# Use scheduler task loader for user-specific tasks
|
||||
from services.scheduler.utils.task_loader import load_due_monitoring_tasks
|
||||
|
||||
# Load all tasks for user (not just due tasks - we want all user tasks)
|
||||
# Join with strategy to filter by user
|
||||
tasks_query = db.query(MonitoringTask).join(
|
||||
EnhancedContentStrategy,
|
||||
MonitoringTask.strategy_id == EnhancedContentStrategy.id
|
||||
).filter(
|
||||
EnhancedContentStrategy.user_id == user_id
|
||||
)
|
||||
|
||||
# Apply status filter if provided
|
||||
if status:
|
||||
tasks_query = tasks_query.filter(MonitoringTask.status == status)
|
||||
|
||||
# Get tasks with pagination
|
||||
tasks = tasks_query.order_by(desc(MonitoringTask.created_at)).offset(offset).limit(limit).all()
|
||||
|
||||
tasks_data = []
|
||||
for task in tasks:
|
||||
# Get latest execution log
|
||||
latest_log = db.query(TaskExecutionLog).filter(
|
||||
TaskExecutionLog.task_id == task.id
|
||||
).order_by(desc(TaskExecutionLog.execution_date)).first()
|
||||
|
||||
# Get strategy info
|
||||
strategy = db.query(EnhancedContentStrategy).filter(
|
||||
EnhancedContentStrategy.id == task.strategy_id
|
||||
).first()
|
||||
|
||||
task_data = {
|
||||
"id": task.id,
|
||||
"strategy_id": task.strategy_id,
|
||||
"strategy_name": strategy.name if strategy else None,
|
||||
"title": task.task_title,
|
||||
"description": task.task_description,
|
||||
"assignee": task.assignee,
|
||||
"frequency": task.frequency,
|
||||
"metric": task.metric,
|
||||
"measurementMethod": task.measurement_method,
|
||||
"successCriteria": task.success_criteria,
|
||||
"alertThreshold": task.alert_threshold,
|
||||
"status": task.status,
|
||||
"lastExecuted": latest_log.execution_date.isoformat() if latest_log else None,
|
||||
"nextExecution": task.next_execution.isoformat() if task.next_execution else None,
|
||||
"executionCount": db.query(TaskExecutionLog).filter(
|
||||
TaskExecutionLog.task_id == task.id
|
||||
).count(),
|
||||
"created_at": task.created_at.isoformat() if task.created_at else None
|
||||
}
|
||||
tasks_data.append(task_data)
|
||||
|
||||
# Get total count for pagination
|
||||
total_count = db.query(MonitoringTask).join(
|
||||
EnhancedContentStrategy,
|
||||
MonitoringTask.strategy_id == EnhancedContentStrategy.id
|
||||
).filter(
|
||||
EnhancedContentStrategy.user_id == user_id
|
||||
)
|
||||
if status:
|
||||
total_count = total_count.filter(MonitoringTask.status == status)
|
||||
total_count = total_count.count()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": tasks_data,
|
||||
"pagination": {
|
||||
"total": total_count,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"has_more": (offset + len(tasks_data)) < total_count
|
||||
},
|
||||
"message": f"Retrieved {len(tasks_data)} monitoring tasks for user {user_id}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving user monitoring tasks: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to retrieve monitoring tasks: {str(e)}")
|
||||
|
||||
@router.get("/user/{user_id}/execution-logs")
|
||||
async def get_user_execution_logs(
|
||||
user_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
status: Optional[str] = Query(None, description="Filter by execution status"),
|
||||
limit: int = Query(50, description="Maximum number of logs to return"),
|
||||
offset: int = Query(0, description="Number of logs to skip")
|
||||
):
|
||||
"""
|
||||
Get execution logs for a specific user.
|
||||
|
||||
Provides user isolation by filtering execution logs by user_id.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Getting execution logs for user {user_id}")
|
||||
|
||||
monitoring_service = MonitoringDataService(db)
|
||||
logs_data = monitoring_service.get_user_execution_logs(
|
||||
user_id=user_id,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
status_filter=status
|
||||
)
|
||||
|
||||
# Get total count for pagination
|
||||
count_query = db.query(TaskExecutionLog).filter(
|
||||
TaskExecutionLog.user_id == user_id
|
||||
)
|
||||
if status:
|
||||
count_query = count_query.filter(TaskExecutionLog.status == status)
|
||||
total_count = count_query.count()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": logs_data,
|
||||
"pagination": {
|
||||
"total": total_count,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"has_more": (offset + len(logs_data)) < total_count
|
||||
},
|
||||
"message": f"Retrieved {len(logs_data)} execution logs for user {user_id}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving execution logs for user {user_id}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to retrieve execution logs: {str(e)}")
|
||||
|
||||
@router.get("/{strategy_id}/data-freshness")
|
||||
async def get_data_freshness(
|
||||
strategy_id: int,
|
||||
|
||||
@@ -3,13 +3,18 @@ from __future__ import annotations
|
||||
import base64
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from services.llm_providers.main_image_generation import generate_image
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from utils.logger_utils import get_service_logger
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from services.database import get_db
|
||||
from services.subscription import UsageTrackingService, PricingService
|
||||
from models.subscription_models import APIProvider, UsageSummary
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/images", tags=["images"])
|
||||
@@ -39,9 +44,23 @@ class ImageGenerateResponse(BaseModel):
|
||||
|
||||
|
||||
@router.post("/generate", response_model=ImageGenerateResponse)
|
||||
def generate(req: ImageGenerateRequest) -> ImageGenerateResponse:
|
||||
def generate(
|
||||
req: ImageGenerateRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> ImageGenerateResponse:
|
||||
"""Generate image with subscription checking."""
|
||||
try:
|
||||
# Extract Clerk user ID (required)
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get('id', ''))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
# Validation is now handled inside generate_image function
|
||||
last_error: Optional[Exception] = None
|
||||
result = None
|
||||
for attempt in range(2): # simple single retry
|
||||
try:
|
||||
result = generate_image(
|
||||
@@ -56,8 +75,79 @@ def generate(req: ImageGenerateRequest) -> ImageGenerateResponse:
|
||||
"steps": req.steps,
|
||||
"seed": req.seed,
|
||||
},
|
||||
user_id=user_id, # Pass user_id for validation inside generate_image
|
||||
)
|
||||
image_b64 = base64.b64encode(result.image_bytes).decode("utf-8")
|
||||
|
||||
# TRACK USAGE after successful image generation
|
||||
if result:
|
||||
logger.info(f"[images.generate] ✅ Image generation successful, tracking usage for user {user_id}")
|
||||
try:
|
||||
db_track = next(get_db())
|
||||
try:
|
||||
# Get or create usage summary
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
logger.debug(f"[images.generate] Looking for usage summary: user_id={user_id}, period={current_period}")
|
||||
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
logger.info(f"[images.generate] Creating new usage summary for user {user_id}, period {current_period}")
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush() # Ensure summary is persisted before updating
|
||||
|
||||
# Get "before" state for unified log
|
||||
current_calls_before = getattr(summary, "stability_calls", 0) or 0
|
||||
|
||||
# Update provider-specific counters (stability for image generation)
|
||||
# Note: All image generation goes through STABILITY provider enum regardless of actual provider
|
||||
new_calls = current_calls_before + 1
|
||||
setattr(summary, "stability_calls", new_calls)
|
||||
logger.debug(f"[images.generate] Updated stability_calls: {current_calls_before} -> {new_calls}")
|
||||
|
||||
# Update totals
|
||||
old_total_calls = summary.total_calls or 0
|
||||
summary.total_calls = old_total_calls + 1
|
||||
logger.debug(f"[images.generate] Updated totals: calls {old_total_calls} -> {summary.total_calls}")
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
call_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"[images.generate] ✅ Successfully tracked usage: user {user_id} -> stability -> {new_calls} calls")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
|
||||
print(f"""
|
||||
[SUBSCRIPTION] Image Generation
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: stability
|
||||
├─ Actual Provider: {result.provider}
|
||||
├─ Model: {result.model or 'default'}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {call_limit if call_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""")
|
||||
except Exception as track_error:
|
||||
logger.error(f"[images.generate] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
# Non-blocking: log error but don't fail the request
|
||||
logger.error(f"[images.generate] ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
|
||||
return ImageGenerateResponse(
|
||||
image_base64=image_b64,
|
||||
width=result.width,
|
||||
@@ -106,7 +196,10 @@ class ImagePromptSuggestResponse(BaseModel):
|
||||
|
||||
|
||||
@router.post("/suggest-prompts", response_model=ImagePromptSuggestResponse)
|
||||
def suggest_prompts(req: ImagePromptSuggestRequest) -> ImagePromptSuggestResponse:
|
||||
def suggest_prompts(
|
||||
req: ImagePromptSuggestRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> ImagePromptSuggestResponse:
|
||||
try:
|
||||
provider = (req.provider or ("gemini" if (os.getenv("GPT_PROVIDER") or "").lower().startswith("gemini") else "huggingface")).lower()
|
||||
section = req.section or {}
|
||||
@@ -203,7 +296,15 @@ def suggest_prompts(req: ImagePromptSuggestRequest) -> ImagePromptSuggestRespons
|
||||
If including on-image text, return it in overlay_text (short: <= 8 words).
|
||||
"""
|
||||
|
||||
raw = llm_text_gen(prompt=prompt, system_prompt=system, json_struct=schema)
|
||||
# Get user_id for llm_text_gen subscription check (required)
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id_for_llm = str(current_user.get('id', ''))
|
||||
if not user_id_for_llm:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
raw = llm_text_gen(prompt=prompt, system_prompt=system, json_struct=schema, user_id=user_id_for_llm)
|
||||
data = raw if isinstance(raw, dict) else {}
|
||||
suggestions = data.get("suggestions") or []
|
||||
# basic fallback if provider returns string
|
||||
|
||||
@@ -94,6 +94,7 @@ async def get_subscription_plans(
|
||||
"description": plan.description,
|
||||
"features": plan.features or [],
|
||||
"limits": {
|
||||
"ai_text_generation_calls": getattr(plan, 'ai_text_generation_calls_limit', None) or 0,
|
||||
"gemini_calls": plan.gemini_calls_limit,
|
||||
"openai_calls": plan.openai_calls_limit,
|
||||
"anthropic_calls": plan.anthropic_calls_limit,
|
||||
@@ -162,6 +163,7 @@ async def get_user_subscription(
|
||||
},
|
||||
"status": "free",
|
||||
"limits": {
|
||||
"ai_text_generation_calls": getattr(free_plan, 'ai_text_generation_calls_limit', None) or 0,
|
||||
"gemini_calls": free_plan.gemini_calls_limit,
|
||||
"openai_calls": free_plan.openai_calls_limit,
|
||||
"anthropic_calls": free_plan.anthropic_calls_limit,
|
||||
@@ -200,6 +202,7 @@ async def get_user_subscription(
|
||||
"is_free": False
|
||||
},
|
||||
"limits": {
|
||||
"ai_text_generation_calls": getattr(subscription.plan, 'ai_text_generation_calls_limit', None) or 0,
|
||||
"gemini_calls": subscription.plan.gemini_calls_limit,
|
||||
"openai_calls": subscription.plan.openai_calls_limit,
|
||||
"anthropic_calls": subscription.plan.anthropic_calls_limit,
|
||||
@@ -252,6 +255,7 @@ async def get_subscription_status(
|
||||
"tier": "free",
|
||||
"can_use_api": True,
|
||||
"limits": {
|
||||
"ai_text_generation_calls": getattr(free_plan, 'ai_text_generation_calls_limit', None) or 0,
|
||||
"gemini_calls": free_plan.gemini_calls_limit,
|
||||
"openai_calls": free_plan.openai_calls_limit,
|
||||
"anthropic_calls": free_plan.anthropic_calls_limit,
|
||||
@@ -309,6 +313,7 @@ async def get_subscription_status(
|
||||
"tier": subscription.plan.tier.value,
|
||||
"can_use_api": True,
|
||||
"limits": {
|
||||
"ai_text_generation_calls": getattr(subscription.plan, 'ai_text_generation_calls_limit', None) or 0,
|
||||
"gemini_calls": subscription.plan.gemini_calls_limit,
|
||||
"openai_calls": subscription.plan.openai_calls_limit,
|
||||
"anthropic_calls": subscription.plan.anthropic_calls_limit,
|
||||
@@ -331,9 +336,14 @@ async def get_subscription_status(
|
||||
async def subscribe_to_plan(
|
||||
user_id: str,
|
||||
subscription_data: dict,
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Create or update a user's subscription."""
|
||||
"""Create or update a user's subscription (renewal)."""
|
||||
|
||||
# Verify user can only subscribe/renew their own subscription
|
||||
if current_user.get('id') != user_id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
try:
|
||||
plan_id = subscription_data.get('plan_id')
|
||||
@@ -388,12 +398,75 @@ async def subscribe_to_plan(
|
||||
|
||||
db.commit()
|
||||
|
||||
# Get current usage BEFORE reset for logging
|
||||
current_period = datetime.utcnow().strftime("%Y-%m")
|
||||
usage_before = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
# Log renewal request details
|
||||
logger.info("=" * 80)
|
||||
logger.info(f"[SUBSCRIPTION RENEWAL] 🔄 Processing renewal request")
|
||||
logger.info(f" ├─ User: {user_id}")
|
||||
logger.info(f" ├─ Plan: {plan.name} (ID: {plan_id}, Tier: {plan.tier.value})")
|
||||
logger.info(f" ├─ Billing Cycle: {billing_cycle}")
|
||||
logger.info(f" ├─ Period Start: {now.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
logger.info(f" └─ Period End: {subscription.current_period_end.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
if usage_before:
|
||||
logger.info(f" 📊 Current Usage BEFORE Reset (Period: {current_period}):")
|
||||
logger.info(f" ├─ Gemini: {usage_before.gemini_tokens or 0} tokens / {usage_before.gemini_calls or 0} calls")
|
||||
logger.info(f" ├─ Mistral/HF: {usage_before.mistral_tokens or 0} tokens / {usage_before.mistral_calls or 0} calls")
|
||||
logger.info(f" ├─ OpenAI: {usage_before.openai_tokens or 0} tokens / {usage_before.openai_calls or 0} calls")
|
||||
logger.info(f" ├─ Stability (Images): {usage_before.stability_calls or 0} calls")
|
||||
logger.info(f" ├─ Total Tokens: {usage_before.total_tokens or 0}")
|
||||
logger.info(f" ├─ Total Calls: {usage_before.total_calls or 0}")
|
||||
logger.info(f" └─ Usage Status: {usage_before.usage_status.value}")
|
||||
else:
|
||||
logger.info(f" 📊 No usage summary found for period {current_period} (will be created on reset)")
|
||||
|
||||
# Clear subscription limits cache to force refresh on next check
|
||||
try:
|
||||
from services.subscription import PricingService
|
||||
# Clear cache for this specific user (class-level cache shared across all instances)
|
||||
cleared_count = PricingService.clear_user_cache(user_id)
|
||||
logger.info(f" 🗑️ Cleared {cleared_count} subscription cache entries for user {user_id}")
|
||||
except Exception as cache_err:
|
||||
logger.error(f" ❌ Failed to clear cache after subscribe: {cache_err}")
|
||||
|
||||
# Reset usage status for current billing period so new plan takes effect immediately
|
||||
reset_result = None
|
||||
try:
|
||||
usage_service = UsageTrackingService(db)
|
||||
await usage_service.reset_current_billing_period(user_id)
|
||||
reset_result = await usage_service.reset_current_billing_period(user_id)
|
||||
|
||||
# Re-query usage summary from DB after reset to get fresh data
|
||||
usage_after = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if reset_result.get('reset'):
|
||||
logger.info(f" ✅ Usage counters RESET successfully")
|
||||
if usage_after:
|
||||
logger.info(f" 📊 New Usage AFTER Reset:")
|
||||
logger.info(f" ├─ Gemini: {usage_after.gemini_tokens or 0} tokens / {usage_after.gemini_calls or 0} calls")
|
||||
logger.info(f" ├─ Mistral/HF: {usage_after.mistral_tokens or 0} tokens / {usage_after.mistral_calls or 0} calls")
|
||||
logger.info(f" ├─ OpenAI: {usage_after.openai_tokens or 0} tokens / {usage_after.openai_calls or 0} calls")
|
||||
logger.info(f" ├─ Stability (Images): {usage_after.stability_calls or 0} calls")
|
||||
logger.info(f" ├─ Total Tokens: {usage_after.total_tokens or 0}")
|
||||
logger.info(f" ├─ Total Calls: {usage_after.total_calls or 0}")
|
||||
logger.info(f" └─ Usage Status: {usage_after.usage_status.value}")
|
||||
else:
|
||||
logger.warning(f" ⚠️ Usage summary not found after reset - may need to be created on next API call")
|
||||
else:
|
||||
logger.warning(f" ⚠️ Reset returned: {reset_result.get('reason', 'unknown')}")
|
||||
except Exception as reset_err:
|
||||
logger.error(f"Failed to reset usage after subscribe: {reset_err}")
|
||||
logger.error(f" ❌ Failed to reset usage after subscribe: {reset_err}", exc_info=True)
|
||||
|
||||
logger.info(f" ✅ Renewal completed: User {user_id} → {plan.name} ({billing_cycle})")
|
||||
logger.info("=" * 80)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
@@ -404,7 +477,20 @@ async def subscribe_to_plan(
|
||||
"billing_cycle": billing_cycle,
|
||||
"current_period_start": subscription.current_period_start.isoformat(),
|
||||
"current_period_end": subscription.current_period_end.isoformat(),
|
||||
"status": subscription.status.value
|
||||
"status": subscription.status.value,
|
||||
"limits": {
|
||||
"ai_text_generation_calls": getattr(plan, 'ai_text_generation_calls_limit', None) or 0,
|
||||
"gemini_calls": plan.gemini_calls_limit,
|
||||
"openai_calls": plan.openai_calls_limit,
|
||||
"anthropic_calls": plan.anthropic_calls_limit,
|
||||
"mistral_calls": plan.mistral_calls_limit,
|
||||
"tavily_calls": plan.tavily_calls_limit,
|
||||
"serper_calls": plan.serper_calls_limit,
|
||||
"metaphor_calls": plan.metaphor_calls_limit,
|
||||
"firecrawl_calls": plan.firecrawl_calls_limit,
|
||||
"stability_calls": plan.stability_calls_limit,
|
||||
"monthly_cost": plan.monthly_cost_limit
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -477,6 +477,39 @@ async def test_publish_to_wix(request: WixPublishRequest) -> Dict[str, Any]:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/refresh-token")
|
||||
async def refresh_wix_token(request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Refresh Wix access token using refresh token
|
||||
|
||||
Args:
|
||||
request: Dict containing refresh_token
|
||||
|
||||
Returns:
|
||||
New token information with access_token, refresh_token, expires_in
|
||||
"""
|
||||
try:
|
||||
refresh_token = request.get("refresh_token")
|
||||
if not refresh_token:
|
||||
raise HTTPException(status_code=400, detail="Missing refresh_token")
|
||||
|
||||
# Refresh the token
|
||||
new_tokens = wix_service.refresh_access_token(refresh_token)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"access_token": new_tokens.get("access_token"),
|
||||
"refresh_token": new_tokens.get("refresh_token"),
|
||||
"expires_in": new_tokens.get("expires_in"),
|
||||
"token_type": new_tokens.get("token_type", "Bearer")
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to refresh Wix token: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to refresh token: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/test/publish/real")
|
||||
async def test_publish_real(payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
|
||||
@@ -298,6 +298,11 @@ async def startup_event():
|
||||
try:
|
||||
# Initialize database
|
||||
init_database()
|
||||
|
||||
# Start task scheduler
|
||||
from services.scheduler import get_scheduler
|
||||
await get_scheduler().start()
|
||||
|
||||
logger.info("ALwrity backend started successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during startup: {e}")
|
||||
@@ -307,6 +312,10 @@ async def startup_event():
|
||||
async def shutdown_event():
|
||||
"""Cleanup on shutdown."""
|
||||
try:
|
||||
# Stop task scheduler
|
||||
from services.scheduler import get_scheduler
|
||||
await get_scheduler().stop()
|
||||
|
||||
# Close database connections
|
||||
close_database()
|
||||
logger.info("ALwrity backend shutdown successfully")
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
-- Migration: Add user_id column to task_execution_logs for user isolation
|
||||
-- Date: 2025-01-XX
|
||||
-- Purpose: Enable user isolation tracking in scheduler task execution logs
|
||||
|
||||
-- Add user_id column (nullable for backward compatibility with existing records)
|
||||
ALTER TABLE task_execution_logs
|
||||
ADD COLUMN user_id INTEGER NULL;
|
||||
|
||||
-- Create index for efficient user filtering and queries
|
||||
CREATE INDEX IF NOT EXISTS idx_task_execution_logs_user_id
|
||||
ON task_execution_logs(user_id);
|
||||
|
||||
-- Create composite index for common query patterns (user_id + status + execution_date)
|
||||
CREATE INDEX IF NOT EXISTS idx_task_execution_logs_user_status_date
|
||||
ON task_execution_logs(user_id, status, execution_date);
|
||||
|
||||
-- Note: Backfilling existing records would require joining with monitoring_tasks
|
||||
-- and enhanced_content_strategies tables. This can be done in a separate migration
|
||||
-- or during a maintenance window. For now, existing records will have user_id = NULL.
|
||||
|
||||
@@ -65,7 +65,7 @@ class LinkedInPostRequest(BaseModel):
|
||||
persona_override: Optional[Dict[str, Any]] = Field(default=None, description="Session-only persona overrides to apply without saving")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"topic": "AI in healthcare transformation",
|
||||
"industry": "Healthcare",
|
||||
@@ -102,7 +102,7 @@ class LinkedInArticleRequest(BaseModel):
|
||||
persona_override: Optional[Dict[str, Any]] = Field(default=None, description="Session-only persona overrides to apply without saving")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"topic": "Digital transformation in manufacturing",
|
||||
"industry": "Manufacturing",
|
||||
@@ -135,7 +135,7 @@ class LinkedInCarouselRequest(BaseModel):
|
||||
include_citations: bool = Field(default=True, description="Whether to include inline citations")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"topic": "Future of remote work",
|
||||
"industry": "Technology",
|
||||
@@ -167,7 +167,7 @@ class LinkedInVideoScriptRequest(BaseModel):
|
||||
include_citations: bool = Field(default=True, description="Whether to include inline citations")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"topic": "Cybersecurity best practices",
|
||||
"industry": "Technology",
|
||||
@@ -197,7 +197,7 @@ class LinkedInCommentResponseRequest(BaseModel):
|
||||
grounding_level: GroundingLevel = Field(default=GroundingLevel.BASIC, description="Level of content grounding")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"original_comment": "Great insights on AI implementation!",
|
||||
"post_context": "Post about AI transformation in healthcare",
|
||||
@@ -353,7 +353,7 @@ class LinkedInPostResponse(BaseModel):
|
||||
grounding_status: Optional[Dict[str, Any]] = Field(None, description="Grounding operation status")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"success": True,
|
||||
"data": {
|
||||
|
||||
@@ -48,8 +48,9 @@ class TaskExecutionLog(Base):
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
task_id = Column(Integer, ForeignKey("monitoring_tasks.id"), nullable=False)
|
||||
user_id = Column(Integer, nullable=True) # User ID for user isolation (nullable for backward compatibility)
|
||||
execution_date = Column(DateTime, default=datetime.utcnow)
|
||||
status = Column(String(50), nullable=False) # 'success', 'failed', 'skipped'
|
||||
status = Column(String(50), nullable=False) # 'success', 'failed', 'skipped', 'running'
|
||||
result_data = Column(JSON, nullable=True)
|
||||
error_message = Column(Text, nullable=True)
|
||||
execution_time_ms = Column(Integer, nullable=True)
|
||||
|
||||
@@ -50,16 +50,22 @@ class SubscriptionPlan(Base):
|
||||
price_monthly = Column(Float, nullable=False, default=0.0)
|
||||
price_yearly = Column(Float, nullable=False, default=0.0)
|
||||
|
||||
# API Call Limits
|
||||
gemini_calls_limit = Column(Integer, default=0) # 0 = unlimited
|
||||
openai_calls_limit = Column(Integer, default=0)
|
||||
anthropic_calls_limit = Column(Integer, default=0)
|
||||
mistral_calls_limit = Column(Integer, default=0)
|
||||
# Unified AI Text Generation Call Limit (applies to all LLM providers: gemini, openai, anthropic, mistral)
|
||||
# Note: This column may not exist in older databases - use getattr() when accessing
|
||||
ai_text_generation_calls_limit = Column(Integer, default=0, nullable=True) # 0 = unlimited, None if column doesn't exist
|
||||
|
||||
# Legacy per-provider limits (kept for backwards compatibility and analytics)
|
||||
gemini_calls_limit = Column(Integer, default=0) # 0 = unlimited (deprecated, use ai_text_generation_calls_limit)
|
||||
openai_calls_limit = Column(Integer, default=0) # (deprecated, use ai_text_generation_calls_limit)
|
||||
anthropic_calls_limit = Column(Integer, default=0) # (deprecated, use ai_text_generation_calls_limit)
|
||||
mistral_calls_limit = Column(Integer, default=0) # (deprecated, use ai_text_generation_calls_limit)
|
||||
|
||||
# Other API Call Limits (non-LLM)
|
||||
tavily_calls_limit = Column(Integer, default=0)
|
||||
serper_calls_limit = Column(Integer, default=0)
|
||||
metaphor_calls_limit = Column(Integer, default=0)
|
||||
firecrawl_calls_limit = Column(Integer, default=0)
|
||||
stability_calls_limit = Column(Integer, default=0)
|
||||
stability_calls_limit = Column(Integer, default=0) # Image generation
|
||||
|
||||
# Token Limits (for LLM providers)
|
||||
gemini_tokens_limit = Column(Integer, default=0)
|
||||
|
||||
@@ -63,6 +63,9 @@ pytest-asyncio>=0.21.0
|
||||
pydantic>=2.5.2,<3.0.0
|
||||
typing-extensions>=4.8.0
|
||||
|
||||
# Task scheduling
|
||||
apscheduler>=3.10.0
|
||||
|
||||
# Optional dependencies (for enhanced features)
|
||||
redis>=5.0.0
|
||||
schedule>=1.2.0
|
||||
146
backend/scripts/add_ai_text_generation_limit_column.py
Normal file
146
backend/scripts/add_ai_text_generation_limit_column.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
Migration Script: Add ai_text_generation_calls_limit column to subscription_plans table.
|
||||
|
||||
This adds the unified AI text generation limit column that applies to all LLM providers
|
||||
(gemini, openai, anthropic, mistral) instead of per-provider limits.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Add the backend directory to Python path
|
||||
backend_dir = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(backend_dir))
|
||||
|
||||
from sqlalchemy import create_engine, text, inspect
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from loguru import logger
|
||||
|
||||
from models.subscription_models import SubscriptionPlan, SubscriptionTier
|
||||
from services.database import DATABASE_URL
|
||||
|
||||
def add_ai_text_generation_limit_column():
|
||||
"""Add ai_text_generation_calls_limit column to subscription_plans table."""
|
||||
|
||||
try:
|
||||
engine = create_engine(DATABASE_URL, echo=False)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
db = SessionLocal()
|
||||
|
||||
try:
|
||||
# Check if column already exists
|
||||
inspector = inspect(engine)
|
||||
columns = [col['name'] for col in inspector.get_columns('subscription_plans')]
|
||||
|
||||
if 'ai_text_generation_calls_limit' in columns:
|
||||
logger.info("✅ Column 'ai_text_generation_calls_limit' already exists. Skipping migration.")
|
||||
return True
|
||||
|
||||
logger.info("📋 Adding 'ai_text_generation_calls_limit' column to subscription_plans table...")
|
||||
|
||||
# Add the column (SQLite compatible)
|
||||
alter_query = text("""
|
||||
ALTER TABLE subscription_plans
|
||||
ADD COLUMN ai_text_generation_calls_limit INTEGER DEFAULT 0
|
||||
""")
|
||||
|
||||
db.execute(alter_query)
|
||||
db.commit()
|
||||
|
||||
logger.info("✅ Column added successfully!")
|
||||
|
||||
# Update existing plans with unified limits based on their current limits
|
||||
logger.info("\n🔄 Updating existing subscription plans with unified limits...")
|
||||
|
||||
plans = db.query(SubscriptionPlan).all()
|
||||
updated_count = 0
|
||||
|
||||
for plan in plans:
|
||||
# Use the first non-zero LLM provider limit as the unified limit
|
||||
# Or use gemini_calls_limit as default
|
||||
unified_limit = (
|
||||
plan.ai_text_generation_calls_limit or
|
||||
plan.gemini_calls_limit or
|
||||
plan.openai_calls_limit or
|
||||
plan.anthropic_calls_limit or
|
||||
plan.mistral_calls_limit or
|
||||
0
|
||||
)
|
||||
|
||||
# For Basic plan, ensure it's set to 10 (from our recent update)
|
||||
if plan.tier == SubscriptionTier.BASIC:
|
||||
unified_limit = 10
|
||||
|
||||
if plan.ai_text_generation_calls_limit != unified_limit:
|
||||
plan.ai_text_generation_calls_limit = unified_limit
|
||||
plan.updated_at = datetime.now(timezone.utc)
|
||||
updated_count += 1
|
||||
|
||||
logger.info(f" ✅ Updated {plan.name} ({plan.tier.value}): ai_text_generation_calls_limit = {unified_limit}")
|
||||
else:
|
||||
logger.info(f" ℹ️ {plan.name} ({plan.tier.value}): already set to {unified_limit}")
|
||||
|
||||
if updated_count > 0:
|
||||
db.commit()
|
||||
logger.info(f"\n✅ Updated {updated_count} subscription plan(s)")
|
||||
else:
|
||||
logger.info("\nℹ️ No plans needed updating")
|
||||
|
||||
# Display summary
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("MIGRATION SUMMARY")
|
||||
logger.info("="*60)
|
||||
|
||||
all_plans = db.query(SubscriptionPlan).all()
|
||||
for plan in all_plans:
|
||||
logger.info(f"\n{plan.name} ({plan.tier.value}):")
|
||||
logger.info(f" Unified AI Text Gen Limit: {plan.ai_text_generation_calls_limit if plan.ai_text_generation_calls_limit else 'Not set'}")
|
||||
logger.info(f" Legacy Limits: gemini={plan.gemini_calls_limit}, mistral={plan.mistral_calls_limit}")
|
||||
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("✅ Migration completed successfully!")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"❌ Error during migration: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to connect to database: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("🚀 Starting ai_text_generation_calls_limit column migration...")
|
||||
logger.info("="*60)
|
||||
logger.info("This will add the unified AI text generation limit column")
|
||||
logger.info("and update existing plans with appropriate values.")
|
||||
logger.info("="*60)
|
||||
|
||||
try:
|
||||
success = add_ai_text_generation_limit_column()
|
||||
|
||||
if success:
|
||||
logger.info("\n✅ Script completed successfully!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
logger.error("\n❌ Script failed!")
|
||||
sys.exit(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n⚠️ Script cancelled by user")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"\n❌ Unexpected error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
210
backend/scripts/cap_basic_plan_usage.py
Normal file
210
backend/scripts/cap_basic_plan_usage.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""
|
||||
Standalone script to cap usage counters at new Basic plan limits.
|
||||
|
||||
This preserves historical usage data but caps it at the new limits so users
|
||||
can continue making new calls within their limits.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Add the backend directory to Python path
|
||||
backend_dir = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(backend_dir))
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from loguru import logger
|
||||
|
||||
from models.subscription_models import SubscriptionPlan, SubscriptionTier, UserSubscription, UsageSummary, UsageStatus
|
||||
from services.database import DATABASE_URL
|
||||
from services.subscription import PricingService
|
||||
|
||||
def cap_basic_plan_usage():
|
||||
"""Cap usage counters at new Basic plan limits."""
|
||||
|
||||
try:
|
||||
engine = create_engine(DATABASE_URL, echo=False)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
db = SessionLocal()
|
||||
|
||||
try:
|
||||
# Find Basic plan
|
||||
basic_plan = db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.tier == SubscriptionTier.BASIC
|
||||
).first()
|
||||
|
||||
if not basic_plan:
|
||||
logger.error("❌ Basic plan not found in database!")
|
||||
return False
|
||||
|
||||
# New limits
|
||||
new_call_limit = basic_plan.gemini_calls_limit # Should be 10
|
||||
new_token_limit = basic_plan.gemini_tokens_limit # Should be 2000
|
||||
new_image_limit = basic_plan.stability_calls_limit # Should be 5
|
||||
|
||||
logger.info(f"📋 Basic Plan Limits:")
|
||||
logger.info(f" Calls: {new_call_limit}")
|
||||
logger.info(f" Tokens: {new_token_limit}")
|
||||
logger.info(f" Images: {new_image_limit}")
|
||||
|
||||
# Get all Basic plan users
|
||||
user_subscriptions = db.query(UserSubscription).filter(
|
||||
UserSubscription.plan_id == basic_plan.id,
|
||||
UserSubscription.is_active == True
|
||||
).all()
|
||||
|
||||
logger.info(f"\n👥 Found {len(user_subscriptions)} Basic plan user(s)")
|
||||
|
||||
pricing_service = PricingService(db)
|
||||
capped_count = 0
|
||||
|
||||
for sub in user_subscriptions:
|
||||
try:
|
||||
# Get current billing period for this user
|
||||
current_period = pricing_service.get_current_billing_period(sub.user_id) or datetime.now(timezone.utc).strftime("%Y-%m")
|
||||
|
||||
# Find usage summary for current period
|
||||
usage_summary = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == sub.user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if usage_summary:
|
||||
# Store old values for logging
|
||||
old_gemini = usage_summary.gemini_calls or 0
|
||||
old_mistral = usage_summary.mistral_calls or 0
|
||||
old_openai = usage_summary.openai_calls or 0
|
||||
old_anthropic = usage_summary.anthropic_calls or 0
|
||||
old_tokens = max(
|
||||
usage_summary.gemini_tokens or 0,
|
||||
usage_summary.openai_tokens or 0,
|
||||
usage_summary.anthropic_tokens or 0,
|
||||
usage_summary.mistral_tokens or 0
|
||||
)
|
||||
old_images = usage_summary.stability_calls or 0
|
||||
|
||||
# Check if capping is needed
|
||||
needs_cap = (
|
||||
old_gemini > new_call_limit or
|
||||
old_mistral > new_call_limit or
|
||||
old_openai > new_call_limit or
|
||||
old_anthropic > new_call_limit or
|
||||
old_images > new_image_limit or
|
||||
old_tokens > new_token_limit
|
||||
)
|
||||
|
||||
if needs_cap:
|
||||
# Cap LLM provider counters at new limits
|
||||
usage_summary.gemini_calls = min(old_gemini, new_call_limit)
|
||||
usage_summary.mistral_calls = min(old_mistral, new_call_limit)
|
||||
usage_summary.openai_calls = min(old_openai, new_call_limit)
|
||||
usage_summary.anthropic_calls = min(old_anthropic, new_call_limit)
|
||||
|
||||
# Cap token counters at new limits
|
||||
usage_summary.gemini_tokens = min(usage_summary.gemini_tokens or 0, new_token_limit)
|
||||
usage_summary.openai_tokens = min(usage_summary.openai_tokens or 0, new_token_limit)
|
||||
usage_summary.anthropic_tokens = min(usage_summary.anthropic_tokens or 0, new_token_limit)
|
||||
usage_summary.mistral_tokens = min(usage_summary.mistral_tokens or 0, new_token_limit)
|
||||
|
||||
# Cap image counter at new limit
|
||||
usage_summary.stability_calls = min(old_images, new_image_limit)
|
||||
|
||||
# Recalculate totals based on capped values
|
||||
total_capped_calls = (
|
||||
usage_summary.gemini_calls +
|
||||
usage_summary.mistral_calls +
|
||||
usage_summary.openai_calls +
|
||||
usage_summary.anthropic_calls +
|
||||
usage_summary.stability_calls
|
||||
)
|
||||
total_capped_tokens = (
|
||||
usage_summary.gemini_tokens +
|
||||
usage_summary.mistral_tokens +
|
||||
usage_summary.openai_tokens +
|
||||
usage_summary.anthropic_tokens
|
||||
)
|
||||
|
||||
usage_summary.total_calls = total_capped_calls
|
||||
usage_summary.total_tokens = total_capped_tokens
|
||||
|
||||
# Reset status to active to allow new calls
|
||||
usage_summary.usage_status = UsageStatus.ACTIVE
|
||||
usage_summary.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
db.commit()
|
||||
capped_count += 1
|
||||
|
||||
logger.info(f"\n✅ Capped usage for user {sub.user_id} (period {current_period}):")
|
||||
logger.info(f" Gemini Calls: {old_gemini} → {usage_summary.gemini_calls} (limit: {new_call_limit})")
|
||||
logger.info(f" Mistral Calls: {old_mistral} → {usage_summary.mistral_calls} (limit: {new_call_limit})")
|
||||
logger.info(f" OpenAI Calls: {old_openai} → {usage_summary.openai_calls} (limit: {new_call_limit})")
|
||||
logger.info(f" Anthropic Calls: {old_anthropic} → {usage_summary.anthropic_calls} (limit: {new_call_limit})")
|
||||
logger.info(f" Tokens: {old_tokens} → {max(usage_summary.gemini_tokens, usage_summary.mistral_tokens)} (limit: {new_token_limit})")
|
||||
logger.info(f" Images: {old_images} → {usage_summary.stability_calls} (limit: {new_image_limit})")
|
||||
else:
|
||||
logger.info(f" ℹ️ User {sub.user_id} usage is within limits - no capping needed")
|
||||
else:
|
||||
logger.info(f" ℹ️ No usage summary found for user {sub.user_id} (period {current_period})")
|
||||
|
||||
except Exception as cap_error:
|
||||
logger.error(f" ❌ Error capping usage for user {sub.user_id}: {cap_error}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
db.rollback()
|
||||
|
||||
if capped_count > 0:
|
||||
logger.info(f"\n✅ Successfully capped usage for {capped_count} user(s)")
|
||||
logger.info(" Historical usage preserved, but capped at new limits")
|
||||
logger.info(" Users can now make new calls within their limits")
|
||||
else:
|
||||
logger.info("\nℹ️ No usage counters needed capping")
|
||||
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("CAPPING COMPLETE")
|
||||
logger.info("="*60)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"❌ Error capping usage: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to connect to database: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("🚀 Starting Basic plan usage capping...")
|
||||
logger.info("="*60)
|
||||
logger.info("This will cap usage counters at new Basic plan limits")
|
||||
logger.info("while preserving historical usage data.")
|
||||
logger.info("="*60)
|
||||
|
||||
try:
|
||||
success = cap_basic_plan_usage()
|
||||
|
||||
if success:
|
||||
logger.info("\n✅ Script completed successfully!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
logger.error("\n❌ Script failed!")
|
||||
sys.exit(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n⚠️ Script cancelled by user")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"\n❌ Unexpected error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
168
backend/scripts/reset_basic_plan_usage.py
Normal file
168
backend/scripts/reset_basic_plan_usage.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
Quick script to reset usage counters for Basic plan users.
|
||||
|
||||
This fixes the issue where plan limits were updated but old usage data remained.
|
||||
Resets all usage counters (calls, tokens, images) to 0 for the current billing period.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Add the backend directory to Python path
|
||||
backend_dir = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(backend_dir))
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from loguru import logger
|
||||
|
||||
from models.subscription_models import SubscriptionPlan, SubscriptionTier, UserSubscription, UsageSummary, UsageStatus
|
||||
from services.database import DATABASE_URL
|
||||
from services.subscription import PricingService
|
||||
|
||||
def reset_basic_plan_usage():
|
||||
"""Reset usage counters for all Basic plan users."""
|
||||
|
||||
try:
|
||||
engine = create_engine(DATABASE_URL, echo=False)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
db = SessionLocal()
|
||||
|
||||
try:
|
||||
# Find Basic plan
|
||||
basic_plan = db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.tier == SubscriptionTier.BASIC
|
||||
).first()
|
||||
|
||||
if not basic_plan:
|
||||
logger.error("❌ Basic plan not found in database!")
|
||||
return False
|
||||
|
||||
# Get all Basic plan users
|
||||
user_subscriptions = db.query(UserSubscription).filter(
|
||||
UserSubscription.plan_id == basic_plan.id,
|
||||
UserSubscription.is_active == True
|
||||
).all()
|
||||
|
||||
logger.info(f"Found {len(user_subscriptions)} Basic plan user(s)")
|
||||
|
||||
pricing_service = PricingService(db)
|
||||
reset_count = 0
|
||||
|
||||
for sub in user_subscriptions:
|
||||
try:
|
||||
# Get current billing period for this user
|
||||
current_period = pricing_service.get_current_billing_period(sub.user_id) or datetime.now(timezone.utc).strftime("%Y-%m")
|
||||
|
||||
# Find usage summary for current period
|
||||
usage_summary = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == sub.user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if usage_summary:
|
||||
# Store old values for logging
|
||||
old_gemini = usage_summary.gemini_calls or 0
|
||||
old_mistral = usage_summary.mistral_calls or 0
|
||||
old_tokens = (usage_summary.mistral_tokens or 0) + (usage_summary.gemini_tokens or 0)
|
||||
old_images = usage_summary.stability_calls or 0
|
||||
old_total_calls = usage_summary.total_calls or 0
|
||||
old_total_tokens = usage_summary.total_tokens or 0
|
||||
|
||||
# Reset all LLM provider counters
|
||||
usage_summary.gemini_calls = 0
|
||||
usage_summary.openai_calls = 0
|
||||
usage_summary.anthropic_calls = 0
|
||||
usage_summary.mistral_calls = 0
|
||||
|
||||
# Reset all token counters
|
||||
usage_summary.gemini_tokens = 0
|
||||
usage_summary.openai_tokens = 0
|
||||
usage_summary.anthropic_tokens = 0
|
||||
usage_summary.mistral_tokens = 0
|
||||
|
||||
# Reset image counter
|
||||
usage_summary.stability_calls = 0
|
||||
|
||||
# Reset totals
|
||||
usage_summary.total_calls = 0
|
||||
usage_summary.total_tokens = 0
|
||||
usage_summary.total_cost = 0.0
|
||||
|
||||
# Reset status to active
|
||||
usage_summary.usage_status = UsageStatus.ACTIVE
|
||||
usage_summary.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
db.commit()
|
||||
reset_count += 1
|
||||
|
||||
logger.info(f"\n✅ Reset usage for user {sub.user_id} (period {current_period}):")
|
||||
logger.info(f" Calls: {old_gemini + old_mistral} (gemini: {old_gemini}, mistral: {old_mistral}) → 0")
|
||||
logger.info(f" Tokens: {old_tokens} → 0")
|
||||
logger.info(f" Images: {old_images} → 0")
|
||||
logger.info(f" Total Calls: {old_total_calls} → 0")
|
||||
logger.info(f" Total Tokens: {old_total_tokens} → 0")
|
||||
else:
|
||||
logger.info(f" ℹ️ No usage summary found for user {sub.user_id} (period {current_period}) - nothing to reset")
|
||||
|
||||
except Exception as reset_error:
|
||||
logger.error(f" ❌ Error resetting usage for user {sub.user_id}: {reset_error}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
db.rollback()
|
||||
|
||||
if reset_count > 0:
|
||||
logger.info(f"\n✅ Successfully reset usage counters for {reset_count} user(s)")
|
||||
else:
|
||||
logger.info("\nℹ️ No usage counters to reset")
|
||||
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("RESET COMPLETE")
|
||||
logger.info("="*60)
|
||||
logger.info("\n💡 Usage counters have been reset. Users can now use their new limits.")
|
||||
logger.info(" Next API call will start counting from 0.")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"❌ Error resetting usage: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to connect to database: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("🚀 Starting Basic plan usage counter reset...")
|
||||
logger.info("="*60)
|
||||
logger.info("This will reset all usage counters (calls, tokens, images) to 0")
|
||||
logger.info("for all Basic plan users in their current billing period.")
|
||||
logger.info("="*60)
|
||||
|
||||
try:
|
||||
success = reset_basic_plan_usage()
|
||||
|
||||
if success:
|
||||
logger.info("\n✅ Script completed successfully!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
logger.error("\n❌ Script failed!")
|
||||
sys.exit(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n⚠️ Script cancelled by user")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"\n❌ Unexpected error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
279
backend/scripts/update_basic_plan_limits.py
Normal file
279
backend/scripts/update_basic_plan_limits.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""
|
||||
Script to update Basic plan subscription limits for testing rate limits and renewal flows.
|
||||
|
||||
Updates:
|
||||
- LLM Calls (all providers): 10 calls (was 500-1000)
|
||||
- LLM Tokens (all providers): 2000 tokens (was 200k-1M)
|
||||
- Images: 5 images (was 50)
|
||||
|
||||
This script updates the SubscriptionPlan table, which automatically applies to all users
|
||||
who have a Basic plan subscription via the plan_id foreign key.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Add the backend directory to Python path
|
||||
backend_dir = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(backend_dir))
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from loguru import logger
|
||||
|
||||
from models.subscription_models import SubscriptionPlan, SubscriptionTier, UserSubscription, UsageStatus
|
||||
from services.database import DATABASE_URL
|
||||
|
||||
def update_basic_plan_limits():
|
||||
"""Update Basic plan limits for testing rate limits and renewal."""
|
||||
|
||||
try:
|
||||
engine = create_engine(DATABASE_URL, echo=False)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
db = SessionLocal()
|
||||
|
||||
try:
|
||||
# Find Basic plan
|
||||
basic_plan = db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.tier == SubscriptionTier.BASIC
|
||||
).first()
|
||||
|
||||
if not basic_plan:
|
||||
logger.error("❌ Basic plan not found in database!")
|
||||
return False
|
||||
|
||||
# Store old values for logging
|
||||
old_limits = {
|
||||
'gemini_calls': basic_plan.gemini_calls_limit,
|
||||
'mistral_calls': basic_plan.mistral_calls_limit,
|
||||
'gemini_tokens': basic_plan.gemini_tokens_limit,
|
||||
'mistral_tokens': basic_plan.mistral_tokens_limit,
|
||||
'stability_calls': basic_plan.stability_calls_limit,
|
||||
}
|
||||
|
||||
logger.info(f"📋 Current Basic plan limits:")
|
||||
logger.info(f" Gemini Calls: {old_limits['gemini_calls']}")
|
||||
logger.info(f" Mistral Calls: {old_limits['mistral_calls']}")
|
||||
logger.info(f" Gemini Tokens: {old_limits['gemini_tokens']}")
|
||||
logger.info(f" Mistral Tokens: {old_limits['mistral_tokens']}")
|
||||
logger.info(f" Images (Stability): {old_limits['stability_calls']}")
|
||||
|
||||
# Update unified AI text generation limit to 10
|
||||
basic_plan.ai_text_generation_calls_limit = 10
|
||||
|
||||
# Legacy per-provider limits (kept for backwards compatibility, but not used for enforcement)
|
||||
basic_plan.gemini_calls_limit = 1000
|
||||
basic_plan.openai_calls_limit = 500
|
||||
basic_plan.anthropic_calls_limit = 200
|
||||
basic_plan.mistral_calls_limit = 500
|
||||
|
||||
# Update all LLM provider token limits to 2000
|
||||
basic_plan.gemini_tokens_limit = 2000
|
||||
basic_plan.openai_tokens_limit = 2000
|
||||
basic_plan.anthropic_tokens_limit = 2000
|
||||
basic_plan.mistral_tokens_limit = 2000
|
||||
|
||||
# Update image generation limit to 5
|
||||
basic_plan.stability_calls_limit = 5
|
||||
|
||||
# Update timestamp
|
||||
basic_plan.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
logger.info("\n📝 New Basic plan limits:")
|
||||
logger.info(f" LLM Calls (all providers): 10")
|
||||
logger.info(f" LLM Tokens (all providers): 2000")
|
||||
logger.info(f" Images: 5")
|
||||
|
||||
# Count and get affected users
|
||||
user_subscriptions = db.query(UserSubscription).filter(
|
||||
UserSubscription.plan_id == basic_plan.id,
|
||||
UserSubscription.is_active == True
|
||||
).all()
|
||||
|
||||
affected_users = len(user_subscriptions)
|
||||
|
||||
logger.info(f"\n👥 Users affected: {affected_users}")
|
||||
|
||||
if affected_users > 0:
|
||||
logger.info("\n📋 Affected user IDs:")
|
||||
for sub in user_subscriptions:
|
||||
logger.info(f" - {sub.user_id}")
|
||||
else:
|
||||
logger.info(" (No active Basic plan subscriptions found)")
|
||||
|
||||
# Commit plan limit changes first
|
||||
db.commit()
|
||||
logger.info("\n✅ Basic plan limits updated successfully!")
|
||||
|
||||
# Cap usage at new limits for all affected users (preserve historical data, but cap enforcement)
|
||||
logger.info("\n🔄 Capping usage counters at new limits for Basic plan users...")
|
||||
logger.info(" (Historical usage preserved, but capped to allow new calls within limits)")
|
||||
from models.subscription_models import UsageSummary
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing_service = PricingService(db)
|
||||
capped_count = 0
|
||||
|
||||
# New limits - use unified AI text generation limit if available
|
||||
new_call_limit = getattr(basic_plan, 'ai_text_generation_calls_limit', None) or basic_plan.gemini_calls_limit
|
||||
new_token_limit = basic_plan.gemini_tokens_limit # 2000
|
||||
new_image_limit = basic_plan.stability_calls_limit # 5
|
||||
|
||||
for sub in user_subscriptions:
|
||||
try:
|
||||
# Get current billing period for this user
|
||||
current_period = pricing_service.get_current_billing_period(sub.user_id) or datetime.now(timezone.utc).strftime("%Y-%m")
|
||||
|
||||
# Find usage summary for current period
|
||||
usage_summary = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == sub.user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if usage_summary:
|
||||
# Store old values for logging
|
||||
old_gemini = usage_summary.gemini_calls or 0
|
||||
old_mistral = usage_summary.mistral_calls or 0
|
||||
old_openai = usage_summary.openai_calls or 0
|
||||
old_anthropic = usage_summary.anthropic_calls or 0
|
||||
old_tokens = max(
|
||||
usage_summary.gemini_tokens or 0,
|
||||
usage_summary.openai_tokens or 0,
|
||||
usage_summary.anthropic_tokens or 0,
|
||||
usage_summary.mistral_tokens or 0
|
||||
)
|
||||
old_images = usage_summary.stability_calls or 0
|
||||
|
||||
# Cap LLM provider counters at new limits (don't reset, just cap)
|
||||
# This allows historical data to remain but prevents blocking from old usage
|
||||
usage_summary.gemini_calls = min(old_gemini, new_call_limit)
|
||||
usage_summary.mistral_calls = min(old_mistral, new_call_limit)
|
||||
usage_summary.openai_calls = min(old_openai, new_call_limit)
|
||||
usage_summary.anthropic_calls = min(old_anthropic, new_call_limit)
|
||||
|
||||
# Cap token counters at new limits
|
||||
usage_summary.gemini_tokens = min(usage_summary.gemini_tokens or 0, new_token_limit)
|
||||
usage_summary.openai_tokens = min(usage_summary.openai_tokens or 0, new_token_limit)
|
||||
usage_summary.anthropic_tokens = min(usage_summary.anthropic_tokens or 0, new_token_limit)
|
||||
usage_summary.mistral_tokens = min(usage_summary.mistral_tokens or 0, new_token_limit)
|
||||
|
||||
# Cap image counter at new limit
|
||||
usage_summary.stability_calls = min(old_images, new_image_limit)
|
||||
|
||||
# Update totals based on capped values (approximate)
|
||||
# Recalculate total_calls and total_tokens based on capped provider values
|
||||
total_capped_calls = (
|
||||
usage_summary.gemini_calls +
|
||||
usage_summary.mistral_calls +
|
||||
usage_summary.openai_calls +
|
||||
usage_summary.anthropic_calls +
|
||||
usage_summary.stability_calls
|
||||
)
|
||||
total_capped_tokens = (
|
||||
usage_summary.gemini_tokens +
|
||||
usage_summary.mistral_tokens +
|
||||
usage_summary.openai_tokens +
|
||||
usage_summary.anthropic_tokens
|
||||
)
|
||||
|
||||
usage_summary.total_calls = total_capped_calls
|
||||
usage_summary.total_tokens = total_capped_tokens
|
||||
|
||||
# Reset status to active to allow new calls
|
||||
usage_summary.usage_status = UsageStatus.ACTIVE
|
||||
usage_summary.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
db.commit()
|
||||
capped_count += 1
|
||||
|
||||
logger.info(f" ✅ Capped usage for user {sub.user_id}:")
|
||||
logger.info(f" Gemini Calls: {old_gemini} → {usage_summary.gemini_calls} (limit: {new_call_limit})")
|
||||
logger.info(f" Mistral Calls: {old_mistral} → {usage_summary.mistral_calls} (limit: {new_call_limit})")
|
||||
logger.info(f" Tokens: {old_tokens} → {max(usage_summary.gemini_tokens, usage_summary.mistral_tokens)} (limit: {new_token_limit})")
|
||||
logger.info(f" Images: {old_images} → {usage_summary.stability_calls} (limit: {new_image_limit})")
|
||||
else:
|
||||
logger.info(f" ℹ️ No usage summary found for user {sub.user_id} (period {current_period})")
|
||||
|
||||
except Exception as cap_error:
|
||||
logger.error(f" ❌ Error capping usage for user {sub.user_id}: {cap_error}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
db.rollback()
|
||||
|
||||
if capped_count > 0:
|
||||
logger.info(f"\n✅ Capped usage counters for {capped_count} user(s)")
|
||||
logger.info(" Historical usage preserved, but capped at new limits to allow new calls")
|
||||
else:
|
||||
logger.info("\nℹ️ No usage counters to cap")
|
||||
|
||||
# Note about cache clearing
|
||||
logger.info("\n🔄 Cache Information:")
|
||||
logger.info(" The subscription limits cache is per-instance and will refresh on next request.")
|
||||
logger.info(" No manual cache clearing needed - limits will be read from database on next check.")
|
||||
|
||||
# Display final summary
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("BASIC PLAN UPDATE SUMMARY")
|
||||
logger.info("="*60)
|
||||
logger.info(f"\nPlan: {basic_plan.name} ({basic_plan.tier.value})")
|
||||
logger.info(f"Price: ${basic_plan.price_monthly}/mo, ${basic_plan.price_yearly}/yr")
|
||||
logger.info(f"\nUpdated Limits:")
|
||||
logger.info(f" LLM Calls (gemini/openai/anthropic/mistral): {basic_plan.gemini_calls_limit}")
|
||||
logger.info(f" LLM Tokens (gemini/openai/anthropic/mistral): {basic_plan.gemini_tokens_limit}")
|
||||
logger.info(f" Images (stability): {basic_plan.stability_calls_limit}")
|
||||
logger.info(f"\nUsers Affected: {affected_users}")
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("\n💡 Note: These limits apply immediately to all Basic plan users.")
|
||||
logger.info(" Historical usage has been preserved but capped at new limits.")
|
||||
logger.info(" Users can continue making new calls up to the new limits.")
|
||||
logger.info(" Users will hit rate limits faster for testing purposes.")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"❌ Error updating Basic plan: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to connect to database: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("🚀 Starting Basic plan limits update...")
|
||||
logger.info("="*60)
|
||||
logger.info("This will update Basic plan limits for testing rate limits:")
|
||||
logger.info(" - LLM Calls: 10 (all providers)")
|
||||
logger.info(" - LLM Tokens: 2000 (all providers)")
|
||||
logger.info(" - Images: 5")
|
||||
logger.info("="*60)
|
||||
|
||||
# Ask for confirmation in non-interactive mode, proceed directly
|
||||
# In interactive mode, you can add: input("\nPress Enter to continue or Ctrl+C to cancel...")
|
||||
|
||||
try:
|
||||
success = update_basic_plan_limits()
|
||||
|
||||
if success:
|
||||
logger.info("\n✅ Script completed successfully!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
logger.error("\n❌ Script failed!")
|
||||
sys.exit(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n⚠️ Script cancelled by user")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"\n❌ Unexpected error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
@@ -295,3 +295,55 @@ class ActiveStrategyService:
|
||||
'cached_users': list(self._memory_cache.keys()),
|
||||
'last_updates': {k: v.isoformat() for k, v in self._last_cache_update.items()}
|
||||
}
|
||||
|
||||
def count_active_strategies_with_tasks(self) -> int:
|
||||
"""
|
||||
Count how many active strategies have monitoring tasks.
|
||||
|
||||
This is used for intelligent scheduling - if there are no active strategies
|
||||
with tasks, the scheduler can check less frequently.
|
||||
|
||||
Returns:
|
||||
Number of active strategies that have at least one active monitoring task
|
||||
"""
|
||||
try:
|
||||
if not self.db_session:
|
||||
logger.warning("Database session not available")
|
||||
return 0
|
||||
|
||||
from sqlalchemy import func, and_
|
||||
from models.monitoring_models import MonitoringTask
|
||||
|
||||
# Count distinct strategies that:
|
||||
# 1. Have activation status = 'active'
|
||||
# 2. Have at least one active monitoring task
|
||||
count = self.db_session.query(
|
||||
func.count(func.distinct(EnhancedContentStrategy.id))
|
||||
).join(
|
||||
StrategyActivationStatus,
|
||||
EnhancedContentStrategy.id == StrategyActivationStatus.strategy_id
|
||||
).join(
|
||||
MonitoringTask,
|
||||
EnhancedContentStrategy.id == MonitoringTask.strategy_id
|
||||
).filter(
|
||||
and_(
|
||||
StrategyActivationStatus.status == 'active',
|
||||
MonitoringTask.status == 'active'
|
||||
)
|
||||
).scalar()
|
||||
|
||||
return count or 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting active strategies with tasks: {e}")
|
||||
# On error, assume there are active strategies (safer to check more frequently)
|
||||
return 1
|
||||
|
||||
def has_active_strategies_with_tasks(self) -> bool:
|
||||
"""
|
||||
Check if there are any active strategies with monitoring tasks.
|
||||
|
||||
Returns:
|
||||
True if there are active strategies with tasks, False otherwise
|
||||
"""
|
||||
return self.count_active_strategies_with_tasks() > 0
|
||||
@@ -96,13 +96,13 @@ class BlogWriterService:
|
||||
self.blog_rewriter = BlogRewriter(self.task_manager)
|
||||
|
||||
# Research Methods
|
||||
async def research(self, request: BlogResearchRequest) -> BlogResearchResponse:
|
||||
async def research(self, request: BlogResearchRequest, user_id: str) -> BlogResearchResponse:
|
||||
"""Conduct comprehensive research using Google Search grounding."""
|
||||
return await self.research_service.research(request)
|
||||
return await self.research_service.research(request, user_id)
|
||||
|
||||
async def research_with_progress(self, request: BlogResearchRequest, task_id: str) -> BlogResearchResponse:
|
||||
async def research_with_progress(self, request: BlogResearchRequest, task_id: str, user_id: str) -> BlogResearchResponse:
|
||||
"""Conduct research with real-time progress updates."""
|
||||
return await self.research_service.research_with_progress(request, task_id)
|
||||
return await self.research_service.research_with_progress(request, task_id, user_id)
|
||||
|
||||
# Outline Methods
|
||||
async def generate_outline(self, request: BlogOutlineRequest) -> BlogOutlineResponse:
|
||||
@@ -204,11 +204,14 @@ class BlogWriterService:
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def seo_analyze(self, request: BlogSEOAnalyzeRequest) -> BlogSEOAnalyzeResponse:
|
||||
async def seo_analyze(self, request: BlogSEOAnalyzeRequest, user_id: str = None) -> BlogSEOAnalyzeResponse:
|
||||
"""Analyze content for SEO optimization using comprehensive blog-specific analyzer."""
|
||||
try:
|
||||
from services.blog_writer.seo.blog_content_seo_analyzer import BlogContentSEOAnalyzer
|
||||
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
|
||||
content = request.content or ""
|
||||
target_keywords = request.keywords or []
|
||||
|
||||
@@ -231,7 +234,7 @@ class BlogWriterService:
|
||||
|
||||
# Use our comprehensive SEO analyzer
|
||||
analyzer = BlogContentSEOAnalyzer()
|
||||
analysis_results = await analyzer.analyze_blog_content(content, research_data)
|
||||
analysis_results = await analyzer.analyze_blog_content(content, research_data, user_id=user_id)
|
||||
|
||||
# Convert results to response format
|
||||
recommendations = analysis_results.get('actionable_recommendations', [])
|
||||
@@ -267,11 +270,14 @@ class BlogWriterService:
|
||||
recommendations=[f"SEO analysis failed: {str(e)}"]
|
||||
)
|
||||
|
||||
async def seo_metadata(self, request: BlogSEOMetadataRequest) -> BlogSEOMetadataResponse:
|
||||
async def seo_metadata(self, request: BlogSEOMetadataRequest, user_id: str = None) -> BlogSEOMetadataResponse:
|
||||
"""Generate comprehensive SEO metadata for content."""
|
||||
try:
|
||||
from services.blog_writer.seo.blog_seo_metadata_generator import BlogSEOMetadataGenerator
|
||||
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
|
||||
# Initialize metadata generator
|
||||
metadata_generator = BlogSEOMetadataGenerator()
|
||||
|
||||
@@ -285,7 +291,8 @@ class BlogWriterService:
|
||||
blog_title=request.title or "Untitled Blog Post",
|
||||
research_data=request.research_data or {},
|
||||
outline=outline,
|
||||
seo_analysis=seo_analysis
|
||||
seo_analysis=seo_analysis,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Convert to BlogSEOMetadataResponse format
|
||||
|
||||
@@ -163,13 +163,18 @@ class BlogWriterLogger:
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""Log error with full context."""
|
||||
# Safely format error message to avoid KeyError on format strings in error messages
|
||||
error_str = str(error)
|
||||
# Replace any curly braces that might be in the error message to avoid format string issues
|
||||
safe_error_str = error_str.replace('{', '{{').replace('}', '}}')
|
||||
|
||||
logger.error(
|
||||
f"Error in {operation}: {str(error)}",
|
||||
f"Error in {operation}: {safe_error_str}",
|
||||
extra={
|
||||
"event_type": "error",
|
||||
"operation": operation,
|
||||
"error_type": type(error).__name__,
|
||||
"error_message": str(error),
|
||||
"error_message": error_str, # Keep original in extra, but use safe version in format string
|
||||
"context": context or {}
|
||||
},
|
||||
exc_info=True
|
||||
|
||||
@@ -11,7 +11,7 @@ from loguru import logger
|
||||
class CompetitorAnalyzer:
|
||||
"""Analyzes competitors and market intelligence from research content."""
|
||||
|
||||
def analyze(self, content: str) -> Dict[str, Any]:
|
||||
def analyze(self, content: str, user_id: str = None) -> Dict[str, Any]:
|
||||
"""Parse comprehensive competitor analysis from the research content using AI."""
|
||||
competitor_prompt = f"""
|
||||
Analyze the following research content and extract competitor insights:
|
||||
@@ -57,7 +57,8 @@ class CompetitorAnalyzer:
|
||||
|
||||
competitor_analysis = llm_text_gen(
|
||||
prompt=competitor_prompt,
|
||||
json_struct=competitor_schema
|
||||
json_struct=competitor_schema,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if isinstance(competitor_analysis, dict) and 'error' not in competitor_analysis:
|
||||
|
||||
@@ -11,7 +11,7 @@ from loguru import logger
|
||||
class ContentAngleGenerator:
|
||||
"""Generates strategic content angles from research content."""
|
||||
|
||||
def generate(self, content: str, topic: str, industry: str) -> List[str]:
|
||||
def generate(self, content: str, topic: str, industry: str, user_id: str = None) -> List[str]:
|
||||
"""Parse strategic content angles from the research content using AI."""
|
||||
angles_prompt = f"""
|
||||
Analyze the following research content and create strategic content angles for: {topic} in {industry}
|
||||
@@ -65,7 +65,8 @@ class ContentAngleGenerator:
|
||||
|
||||
angles_result = llm_text_gen(
|
||||
prompt=angles_prompt,
|
||||
json_struct=angles_schema
|
||||
json_struct=angles_schema,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if isinstance(angles_result, dict) and 'content_angles' in angles_result:
|
||||
|
||||
@@ -11,7 +11,7 @@ from loguru import logger
|
||||
class KeywordAnalyzer:
|
||||
"""Analyzes keywords from research content using AI-powered extraction."""
|
||||
|
||||
def analyze(self, content: str, original_keywords: List[str]) -> Dict[str, Any]:
|
||||
def analyze(self, content: str, original_keywords: List[str], user_id: str = None) -> Dict[str, Any]:
|
||||
"""Parse comprehensive keyword analysis from the research content using AI."""
|
||||
# Use AI to extract and analyze keywords from the rich research content
|
||||
keyword_prompt = f"""
|
||||
@@ -64,7 +64,8 @@ class KeywordAnalyzer:
|
||||
|
||||
keyword_analysis = llm_text_gen(
|
||||
prompt=keyword_prompt,
|
||||
json_struct=keyword_schema
|
||||
json_struct=keyword_schema,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if isinstance(keyword_analysis, dict) and 'error' not in keyword_analysis:
|
||||
|
||||
@@ -4,7 +4,8 @@ Research Service - Core research functionality for AI Blog Writer.
|
||||
Handles Google Search grounding, caching, and research orchestration.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
|
||||
from models.blog_models import (
|
||||
@@ -17,6 +18,7 @@ from models.blog_models import (
|
||||
Citation,
|
||||
)
|
||||
from services.blog_writer.logger_config import blog_writer_logger, log_function_call
|
||||
from fastapi import HTTPException
|
||||
|
||||
from .keyword_analyzer import KeywordAnalyzer
|
||||
from .competitor_analyzer import CompetitorAnalyzer
|
||||
@@ -34,7 +36,7 @@ class ResearchService:
|
||||
self.data_filter = ResearchDataFilter()
|
||||
|
||||
@log_function_call("research_operation")
|
||||
async def research(self, request: BlogResearchRequest) -> BlogResearchResponse:
|
||||
async def research(self, request: BlogResearchRequest, user_id: str) -> BlogResearchResponse:
|
||||
"""
|
||||
Stage 1: Research & Strategy (AI Orchestration)
|
||||
Uses ONLY Gemini's native Google Search grounding - ONE API call for everything.
|
||||
@@ -71,6 +73,10 @@ class ResearchService:
|
||||
blog_writer_logger.log_operation_end("research", 0, success=True, cache_hit=True)
|
||||
return BlogResearchResponse(**cached_result)
|
||||
|
||||
# User ID validation (validation logic is now in Google Grounding provider)
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for research operation. Please provide Clerk user ID.")
|
||||
|
||||
# Cache miss - proceed with API call
|
||||
logger.info(f"Cache miss - making API call for keywords: {request.keywords}")
|
||||
blog_writer_logger.log_operation_start("gemini_api_call", api_name="gemini_grounded", operation="research")
|
||||
@@ -96,12 +102,15 @@ class ResearchService:
|
||||
"""
|
||||
|
||||
# Single Gemini call with native Google Search grounding - no fallbacks
|
||||
# Validation is handled inside generate_grounded_content when validate_subsequent_operations=True
|
||||
import time
|
||||
api_start_time = time.time()
|
||||
gemini_result = await gemini.generate_grounded_content(
|
||||
prompt=research_prompt,
|
||||
content_type="research",
|
||||
max_tokens=2000
|
||||
max_tokens=2000,
|
||||
user_id=user_id,
|
||||
validate_subsequent_operations=True # Validates Google Grounding + 3 LLM calls
|
||||
)
|
||||
api_duration_ms = (time.time() - api_start_time) * 1000
|
||||
|
||||
@@ -126,9 +135,9 @@ class ResearchService:
|
||||
|
||||
# Parse the comprehensive response for different analysis components
|
||||
content = gemini_result.get("content", "")
|
||||
keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords)
|
||||
competitor_analysis = self.competitor_analyzer.analyze(content)
|
||||
suggested_angles = self.content_angle_generator.generate(content, topic, industry)
|
||||
keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords, user_id=user_id)
|
||||
competitor_analysis = self.competitor_analyzer.analyze(content, user_id=user_id)
|
||||
suggested_angles = self.content_angle_generator.generate(content, topic, industry, user_id=user_id)
|
||||
|
||||
logger.info(f"Research completed successfully with {len(sources)} sources and {len(search_queries)} search queries")
|
||||
|
||||
@@ -179,6 +188,9 @@ class ResearchService:
|
||||
|
||||
return filtered_response
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTPException (subscription errors) - let task manager handle it
|
||||
raise
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.error(f"Research failed: {error_message}")
|
||||
@@ -244,7 +256,7 @@ class ResearchService:
|
||||
)
|
||||
|
||||
@log_function_call("research_with_progress")
|
||||
async def research_with_progress(self, request: BlogResearchRequest, task_id: str) -> BlogResearchResponse:
|
||||
async def research_with_progress(self, request: BlogResearchRequest, task_id: str, user_id: str) -> BlogResearchResponse:
|
||||
"""
|
||||
Research method with progress updates for real-time feedback.
|
||||
"""
|
||||
@@ -281,6 +293,11 @@ class ResearchService:
|
||||
logger.info(f"Returning cached research result for keywords: {request.keywords}")
|
||||
return BlogResearchResponse(**cached_result)
|
||||
|
||||
# User ID validation (validation logic is now in Google Grounding provider)
|
||||
if not user_id:
|
||||
await task_manager.update_progress(task_id, "❌ Error: User ID is required for research operation")
|
||||
raise ValueError("user_id is required for research operation. Please provide Clerk user ID.")
|
||||
|
||||
# Cache miss - proceed with API call
|
||||
await task_manager.update_progress(task_id, "🌐 Cache miss - connecting to Google Search grounding...")
|
||||
logger.info(f"Cache miss - making API call for keywords: {request.keywords}")
|
||||
@@ -307,11 +324,20 @@ class ResearchService:
|
||||
|
||||
await task_manager.update_progress(task_id, "🤖 Making AI request to Gemini with Google Search grounding...")
|
||||
# Single Gemini call with native Google Search grounding - no fallbacks
|
||||
gemini_result = await gemini.generate_grounded_content(
|
||||
prompt=research_prompt,
|
||||
content_type="research",
|
||||
max_tokens=2000
|
||||
)
|
||||
# Validation is handled inside generate_grounded_content when validate_subsequent_operations=True
|
||||
try:
|
||||
gemini_result = await gemini.generate_grounded_content(
|
||||
prompt=research_prompt,
|
||||
content_type="research",
|
||||
max_tokens=2000,
|
||||
user_id=user_id,
|
||||
validate_subsequent_operations=True # Validates Google Grounding + 3 LLM calls
|
||||
)
|
||||
except HTTPException as http_error:
|
||||
# Re-raise HTTPException so it can be properly handled by task manager
|
||||
logger.error(f"Subscription limit exceeded for research: {http_error.detail}")
|
||||
await task_manager.update_progress(task_id, f"❌ Subscription limit exceeded: {http_error.detail.get('message', str(http_error.detail)) if isinstance(http_error.detail, dict) else str(http_error.detail)}")
|
||||
raise # Re-raise HTTPException to preserve status code and error details
|
||||
|
||||
await task_manager.update_progress(task_id, "📊 Processing research results and extracting insights...")
|
||||
# Extract sources from grounding metadata
|
||||
@@ -327,9 +353,9 @@ class ResearchService:
|
||||
await task_manager.update_progress(task_id, "🔍 Analyzing keywords and content angles...")
|
||||
# Parse the comprehensive response for different analysis components
|
||||
content = gemini_result.get("content", "")
|
||||
keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords)
|
||||
competitor_analysis = self.competitor_analyzer.analyze(content)
|
||||
suggested_angles = self.content_angle_generator.generate(content, topic, industry)
|
||||
keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords, user_id=user_id)
|
||||
competitor_analysis = self.competitor_analyzer.analyze(content, user_id=user_id)
|
||||
suggested_angles = self.content_angle_generator.generate(content, topic, industry, user_id=user_id)
|
||||
|
||||
await task_manager.update_progress(task_id, "💾 Caching results for future use...")
|
||||
logger.info(f"Research completed successfully with {len(sources)} sources and {len(search_queries)} search queries")
|
||||
@@ -373,6 +399,9 @@ class ResearchService:
|
||||
|
||||
return filtered_response
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTPException (subscription errors) - let task manager handle it
|
||||
raise
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.error(f"Research failed: {error_message}")
|
||||
|
||||
@@ -34,17 +34,21 @@ class BlogContentSEOAnalyzer:
|
||||
|
||||
logger.info("BlogContentSEOAnalyzer initialized")
|
||||
|
||||
async def analyze_blog_content(self, blog_content: str, research_data: Dict[str, Any], blog_title: Optional[str] = None) -> Dict[str, Any]:
|
||||
async def analyze_blog_content(self, blog_content: str, research_data: Dict[str, Any], blog_title: Optional[str] = None, user_id: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Main analysis method with parallel processing
|
||||
|
||||
Args:
|
||||
blog_content: The blog content to analyze
|
||||
research_data: Research data containing keywords and other insights
|
||||
blog_title: Optional blog title
|
||||
user_id: Clerk user ID for subscription checking (required)
|
||||
|
||||
Returns:
|
||||
Comprehensive SEO analysis results
|
||||
"""
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
try:
|
||||
logger.info("Starting blog content SEO analysis")
|
||||
|
||||
@@ -58,7 +62,7 @@ class BlogContentSEOAnalyzer:
|
||||
|
||||
# Phase 2: Single AI analysis for structured insights
|
||||
logger.info("Running AI analysis")
|
||||
ai_insights = await self._run_ai_analysis(blog_content, keywords_data, non_ai_results)
|
||||
ai_insights = await self._run_ai_analysis(blog_content, keywords_data, non_ai_results, user_id=user_id)
|
||||
|
||||
# Phase 3: Compile and format results
|
||||
logger.info("Compiling results")
|
||||
@@ -599,8 +603,10 @@ class BlogContentSEOAnalyzer:
|
||||
|
||||
return recommendations
|
||||
|
||||
async def _run_ai_analysis(self, blog_content: str, keywords_data: Dict[str, Any], non_ai_results: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def _run_ai_analysis(self, blog_content: str, keywords_data: Dict[str, Any], non_ai_results: Dict[str, Any], user_id: str = None) -> Dict[str, Any]:
|
||||
"""Run single AI analysis for structured insights (provider-agnostic)"""
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
try:
|
||||
# Prepare context for AI analysis
|
||||
context = {
|
||||
@@ -658,7 +664,8 @@ class BlogContentSEOAnalyzer:
|
||||
ai_response = llm_text_gen(
|
||||
prompt=prompt,
|
||||
json_struct=schema,
|
||||
system_prompt=None
|
||||
system_prompt=None,
|
||||
user_id=user_id # Pass user_id for subscription checking
|
||||
)
|
||||
|
||||
return ai_response
|
||||
|
||||
@@ -28,7 +28,8 @@ class BlogSEOMetadataGenerator:
|
||||
blog_title: str,
|
||||
research_data: Dict[str, Any],
|
||||
outline: Optional[List[Dict[str, Any]]] = None,
|
||||
seo_analysis: Optional[Dict[str, Any]] = None
|
||||
seo_analysis: Optional[Dict[str, Any]] = None,
|
||||
user_id: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate comprehensive SEO metadata using maximum 2 AI calls
|
||||
@@ -39,10 +40,13 @@ class BlogSEOMetadataGenerator:
|
||||
research_data: Research data containing keywords and insights
|
||||
outline: Outline structure with sections and headings
|
||||
seo_analysis: SEO analysis results from previous phase
|
||||
user_id: Clerk user ID for subscription checking (required)
|
||||
|
||||
Returns:
|
||||
Comprehensive metadata including all SEO elements
|
||||
"""
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
try:
|
||||
logger.info("Starting comprehensive SEO metadata generation")
|
||||
|
||||
@@ -53,13 +57,13 @@ class BlogSEOMetadataGenerator:
|
||||
# Call 1: Generate core SEO metadata (parallel with Call 2)
|
||||
logger.info("Generating core SEO metadata")
|
||||
core_metadata_task = self._generate_core_metadata(
|
||||
blog_content, blog_title, keywords_data, outline, seo_analysis
|
||||
blog_content, blog_title, keywords_data, outline, seo_analysis, user_id=user_id
|
||||
)
|
||||
|
||||
# Call 2: Generate social media and structured data (parallel with Call 1)
|
||||
logger.info("Generating social media and structured data")
|
||||
social_metadata_task = self._generate_social_metadata(
|
||||
blog_content, blog_title, keywords_data, outline, seo_analysis
|
||||
blog_content, blog_title, keywords_data, outline, seo_analysis, user_id=user_id
|
||||
)
|
||||
|
||||
# Wait for both calls to complete
|
||||
@@ -114,9 +118,12 @@ class BlogSEOMetadataGenerator:
|
||||
blog_title: str,
|
||||
keywords_data: Dict[str, Any],
|
||||
outline: Optional[List[Dict[str, Any]]] = None,
|
||||
seo_analysis: Optional[Dict[str, Any]] = None
|
||||
seo_analysis: Optional[Dict[str, Any]] = None,
|
||||
user_id: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate core SEO metadata (Call 1)"""
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
try:
|
||||
# Create comprehensive prompt for core metadata
|
||||
prompt = self._create_core_metadata_prompt(
|
||||
@@ -170,7 +177,8 @@ class BlogSEOMetadataGenerator:
|
||||
ai_response_raw = llm_text_gen(
|
||||
prompt=prompt,
|
||||
json_struct=schema,
|
||||
system_prompt=None
|
||||
system_prompt=None,
|
||||
user_id=user_id # Pass user_id for subscription checking
|
||||
)
|
||||
|
||||
# Handle response: llm_text_gen may return dict (from structured JSON) or str (needs parsing)
|
||||
@@ -215,9 +223,12 @@ class BlogSEOMetadataGenerator:
|
||||
blog_title: str,
|
||||
keywords_data: Dict[str, Any],
|
||||
outline: Optional[List[Dict[str, Any]]] = None,
|
||||
seo_analysis: Optional[Dict[str, Any]] = None
|
||||
seo_analysis: Optional[Dict[str, Any]] = None,
|
||||
user_id: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate social media and structured data (Call 2)"""
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
try:
|
||||
# Create comprehensive prompt for social metadata
|
||||
prompt = self._create_social_metadata_prompt(
|
||||
@@ -274,7 +285,8 @@ class BlogSEOMetadataGenerator:
|
||||
ai_response_raw = llm_text_gen(
|
||||
prompt=prompt,
|
||||
json_struct=schema,
|
||||
system_prompt=None
|
||||
system_prompt=None,
|
||||
user_id=user_id # Pass user_id for subscription checking
|
||||
)
|
||||
|
||||
# Handle response: llm_text_gen may return dict (from structured JSON) or str (needs parsing)
|
||||
|
||||
@@ -20,8 +20,11 @@ class BlogSEORecommendationApplier:
|
||||
def __init__(self):
|
||||
logger.debug("Initialized BlogSEORecommendationApplier")
|
||||
|
||||
async def apply_recommendations(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def apply_recommendations(self, payload: Dict[str, Any], user_id: str = None) -> Dict[str, Any]:
|
||||
"""Apply recommendations and return updated content."""
|
||||
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
|
||||
title = payload.get("title", "Untitled Blog")
|
||||
sections: List[Dict[str, Any]] = payload.get("sections", [])
|
||||
@@ -88,6 +91,7 @@ class BlogSEORecommendationApplier:
|
||||
prompt,
|
||||
None,
|
||||
schema,
|
||||
user_id, # Pass user_id for subscription checking
|
||||
)
|
||||
|
||||
if not result or result.get("error"):
|
||||
|
||||
@@ -56,7 +56,9 @@ class GeminiGroundedProvider:
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2048,
|
||||
urls: Optional[List[str]] = None,
|
||||
mode: str = "polished"
|
||||
mode: str = "polished",
|
||||
user_id: Optional[str] = None,
|
||||
validate_subsequent_operations: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate grounded content using native Google Search grounding.
|
||||
@@ -66,12 +68,49 @@ class GeminiGroundedProvider:
|
||||
content_type: Type of content to generate
|
||||
temperature: Creativity level (0.0-1.0)
|
||||
max_tokens: Maximum tokens in response
|
||||
urls: Optional list of URLs for URL Context tool
|
||||
mode: Content mode ("draft" or "polished")
|
||||
user_id: User ID for subscription checking (required if validate_subsequent_operations=True)
|
||||
validate_subsequent_operations: If True, validates Google Grounding + 3 LLM calls for research workflow
|
||||
|
||||
Returns:
|
||||
Dictionary containing generated content and grounding metadata
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Generating grounded content for {content_type} using native Google Search")
|
||||
# PRE-FLIGHT VALIDATION: If this is part of a research workflow, validate ALL operations
|
||||
# MUST happen BEFORE any API calls - return immediately if validation fails
|
||||
if validate_subsequent_operations:
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required when validate_subsequent_operations=True")
|
||||
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_research_operations
|
||||
from fastapi import HTTPException
|
||||
import os
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
gpt_provider = os.getenv("GPT_PROVIDER", "google")
|
||||
|
||||
# Validate ALL research operations before making ANY API calls
|
||||
# This prevents wasteful external API calls if subsequent LLM calls would fail
|
||||
# Raises HTTPException immediately if validation fails - frontend gets immediate response
|
||||
validate_research_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id,
|
||||
gpt_provider=gpt_provider
|
||||
)
|
||||
except HTTPException as http_ex:
|
||||
# Re-raise immediately - don't proceed with API call
|
||||
logger.error(f"[Gemini Grounded] ❌ Pre-flight validation failed - blocking API call")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
logger.info(f"[Gemini Grounded] ✅ Pre-flight validation passed - proceeding with API call")
|
||||
logger.info(f"[Gemini Grounded] Generating grounded content for {content_type} using native Google Search")
|
||||
|
||||
# Build the grounded prompt
|
||||
grounded_prompt = self._build_grounded_prompt(prompt, content_type)
|
||||
|
||||
@@ -40,7 +40,38 @@ def _get_provider(provider_name: str):
|
||||
raise ValueError(f"Unknown image provider: {provider_name}")
|
||||
|
||||
|
||||
def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None) -> ImageGenerationResult:
|
||||
def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> ImageGenerationResult:
|
||||
"""Generate image with pre-flight validation.
|
||||
|
||||
Args:
|
||||
prompt: Image generation prompt
|
||||
options: Image generation options (provider, model, width, height, etc.)
|
||||
user_id: User ID for subscription checking (optional, but required for validation)
|
||||
"""
|
||||
# PRE-FLIGHT VALIDATION: Validate image generation before API call
|
||||
# MUST happen BEFORE any API calls - return immediately if validation fails
|
||||
if user_id:
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_generation_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
# Raises HTTPException immediately if validation fails - frontend gets immediate response
|
||||
validate_image_generation_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id
|
||||
)
|
||||
except HTTPException as http_ex:
|
||||
# Re-raise immediately - don't proceed with API call
|
||||
logger.error(f"[Image Generation] ❌ Pre-flight validation failed - blocking API call")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
logger.info(f"[Image Generation] ✅ Pre-flight validation passed - proceeding with image generation")
|
||||
opts = options or {}
|
||||
provider_name = _select_provider(opts.get("provider"))
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ migrated from the legacy lib/gpt_providers/text_generation/main_text_generation.
|
||||
import os
|
||||
import json
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
from ..onboarding.api_key_manager import APIKeyManager
|
||||
|
||||
@@ -14,7 +15,7 @@ from .gemini_provider import gemini_text_response, gemini_structured_json_respon
|
||||
from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response
|
||||
|
||||
|
||||
def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: Optional[Dict[str, Any]] = None) -> str:
|
||||
def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: Optional[Dict[str, Any]] = None, user_id: str = None) -> str:
|
||||
"""
|
||||
Generate text using Language Model (LLM) based on the provided prompt.
|
||||
|
||||
@@ -22,9 +23,13 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
prompt (str): The prompt to generate text from.
|
||||
system_prompt (str, optional): Custom system prompt to use instead of the default one.
|
||||
json_struct (dict, optional): JSON schema structure for structured responses.
|
||||
user_id (str): Clerk user ID for subscription checking (required).
|
||||
|
||||
Returns:
|
||||
str: Generated text based on the prompt.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If subscription limits are exceeded or user_id is missing.
|
||||
"""
|
||||
try:
|
||||
logger.info("[llm_text_gen] Starting text generation")
|
||||
@@ -93,6 +98,75 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
|
||||
logger.debug(f"[llm_text_gen] Using provider: {gpt_provider}, model: {model}")
|
||||
|
||||
# Map provider name to APIProvider enum (define at function scope for usage tracking)
|
||||
from models.subscription_models import APIProvider
|
||||
provider_enum = None
|
||||
# Store actual provider name for logging (e.g., "huggingface", "gemini")
|
||||
actual_provider_name = None
|
||||
if gpt_provider == "google":
|
||||
provider_enum = APIProvider.GEMINI
|
||||
actual_provider_name = "gemini" # Use "gemini" for consistency in logs
|
||||
elif gpt_provider == "huggingface":
|
||||
provider_enum = APIProvider.MISTRAL # HuggingFace maps to Mistral enum for usage tracking
|
||||
actual_provider_name = "huggingface" # Keep actual provider name for logs
|
||||
|
||||
if not provider_enum:
|
||||
raise RuntimeError(f"Unknown provider {gpt_provider} for subscription checking")
|
||||
|
||||
# SUBSCRIPTION CHECK - Required and strict enforcement
|
||||
if not user_id:
|
||||
raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
|
||||
try:
|
||||
from services.database import get_db
|
||||
from services.subscription import UsageTrackingService, PricingService
|
||||
from models.subscription_models import UsageSummary
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
|
||||
usage_service = UsageTrackingService(db)
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
# Estimate tokens from prompt (input tokens)
|
||||
# Note: We estimate output tokens conservatively (assume response is similar length to prompt)
|
||||
# This prevents underestimating total token usage
|
||||
input_tokens = int(len(prompt.split()) * 1.3)
|
||||
# Conservative estimate: assume output tokens ≈ input tokens * 1.0 (can be up to max_tokens)
|
||||
estimated_output_tokens = min(input_tokens, max_tokens) if max_tokens else int(input_tokens * 0.8)
|
||||
estimated_total_tokens = input_tokens + estimated_output_tokens
|
||||
|
||||
# Check limits using sync method from pricing service (strict enforcement)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=provider_enum,
|
||||
tokens_requested=estimated_total_tokens,
|
||||
actual_provider_name=actual_provider_name # Pass actual provider name for correct error messages
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
logger.warning(f"[llm_text_gen] Subscription limit exceeded for user {user_id}: {message}")
|
||||
raise RuntimeError(f"Subscription limit exceeded: {message}")
|
||||
|
||||
# Get current usage for limit checking only
|
||||
current_period = pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
usage = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
# No separate log here - we'll create unified log after API call and usage tracking
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
except RuntimeError:
|
||||
# Re-raise subscription limit errors
|
||||
raise
|
||||
except Exception as sub_error:
|
||||
# STRICT: Fail on subscription check errors
|
||||
logger.error(f"[llm_text_gen] Subscription check failed for user {user_id}: {sub_error}")
|
||||
raise RuntimeError(f"Subscription check failed: {str(sub_error)}")
|
||||
|
||||
# Construct the system prompt if not provided
|
||||
if system_prompt is None:
|
||||
system_instructions = f"""You are a highly skilled content writer with a knack for creating engaging and informative content.
|
||||
@@ -117,10 +191,12 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
system_instructions = system_prompt
|
||||
|
||||
# Generate response based on provider
|
||||
response_text = None
|
||||
actual_provider_used = gpt_provider
|
||||
try:
|
||||
if gpt_provider == "google":
|
||||
if json_struct:
|
||||
return gemini_structured_json_response(
|
||||
response_text = gemini_structured_json_response(
|
||||
prompt=prompt,
|
||||
schema=json_struct,
|
||||
temperature=temperature,
|
||||
@@ -130,7 +206,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
else:
|
||||
return gemini_text_response(
|
||||
response_text = gemini_text_response(
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
@@ -140,7 +216,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
)
|
||||
elif gpt_provider == "huggingface":
|
||||
if json_struct:
|
||||
return huggingface_structured_json_response(
|
||||
response_text = huggingface_structured_json_response(
|
||||
prompt=prompt,
|
||||
schema=json_struct,
|
||||
model=model,
|
||||
@@ -149,7 +225,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
else:
|
||||
return huggingface_text_response(
|
||||
response_text = huggingface_text_response(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
@@ -160,6 +236,107 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
else:
|
||||
logger.error(f"[llm_text_gen] Unknown provider: {gpt_provider}")
|
||||
raise RuntimeError("Unknown LLM provider. Supported providers: google, huggingface")
|
||||
|
||||
# TRACK USAGE after successful API call
|
||||
if response_text:
|
||||
logger.info(f"[llm_text_gen] ✅ API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
|
||||
try:
|
||||
db_track = next(get_db())
|
||||
try:
|
||||
# Estimate tokens from prompt and response
|
||||
tokens_input = estimated_tokens # Already calculated above
|
||||
tokens_output = int(len(str(response_text).split()) * 1.3) # Estimate output tokens
|
||||
tokens_total = tokens_input + tokens_output
|
||||
|
||||
logger.debug(f"[llm_text_gen] Token estimates: input={tokens_input}, output={tokens_output}, total={tokens_total}")
|
||||
|
||||
# Get or create usage summary
|
||||
from models.subscription_models import UsageSummary
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
logger.debug(f"[llm_text_gen] Looking for usage summary: user_id={user_id}, period={current_period}")
|
||||
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
logger.info(f"[llm_text_gen] Creating new usage summary for user {user_id}, period {current_period}")
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush() # Ensure summary is persisted before updating
|
||||
|
||||
# Get "before" state for unified log
|
||||
provider_name = provider_enum.value
|
||||
current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0
|
||||
|
||||
# Update provider-specific counters (sync operation)
|
||||
new_calls = current_calls_before + 1
|
||||
setattr(summary, f"{provider_name}_calls", new_calls)
|
||||
logger.debug(f"[llm_text_gen] Updated {provider_name}_calls: {current_calls_before} -> {new_calls}")
|
||||
|
||||
# Update token usage for LLM providers
|
||||
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0
|
||||
new_tokens = current_tokens_before + tokens_total
|
||||
setattr(summary, f"{provider_name}_tokens", new_tokens)
|
||||
logger.debug(f"[llm_text_gen] Updated {provider_name}_tokens: {current_tokens_before} -> {new_tokens}")
|
||||
else:
|
||||
current_tokens_before = 0
|
||||
new_tokens = 0
|
||||
|
||||
# Update totals
|
||||
old_total_calls = summary.total_calls or 0
|
||||
old_total_tokens = summary.total_tokens or 0
|
||||
summary.total_calls = old_total_calls + 1
|
||||
summary.total_tokens = old_total_tokens + tokens_total
|
||||
logger.debug(f"[llm_text_gen] Updated totals: calls {old_total_calls} -> {summary.total_calls}, tokens {old_total_tokens} -> {summary.total_tokens}")
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0
|
||||
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) if limits else 0
|
||||
|
||||
# Get image stats for unified log
|
||||
current_images_before = getattr(summary, "stability_calls", 0) or 0
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"[llm_text_gen] ✅ Successfully tracked usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
|
||||
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")
|
||||
# Include image stats in the log
|
||||
print(f"""
|
||||
[SUBSCRIPTION] LLM Text Generation
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: {actual_provider_name}
|
||||
├─ Model: {model}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {call_limit if call_limit > 0 else '∞'}
|
||||
├─ Tokens: {current_tokens_before} → {new_tokens} / {token_limit if token_limit > 0 else '∞'}
|
||||
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""")
|
||||
except Exception as track_error:
|
||||
logger.error(f"[llm_text_gen] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
# Non-blocking: log error but don't fail the request
|
||||
logger.error(f"[llm_text_gen] ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
|
||||
return response_text
|
||||
except Exception as provider_error:
|
||||
logger.error(f"[llm_text_gen] Provider {gpt_provider} failed: {str(provider_error)}")
|
||||
|
||||
@@ -171,9 +348,21 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
fallback_provider = fallback_providers[0] # Only try the first available
|
||||
try:
|
||||
logger.info(f"[llm_text_gen] Trying SINGLE fallback provider: {fallback_provider}")
|
||||
actual_provider_used = fallback_provider
|
||||
|
||||
# Update provider enum for fallback
|
||||
if fallback_provider == "google":
|
||||
provider_enum = APIProvider.GEMINI
|
||||
actual_provider_name = "gemini"
|
||||
fallback_model = "gemini-2.0-flash-lite"
|
||||
elif fallback_provider == "huggingface":
|
||||
provider_enum = APIProvider.MISTRAL
|
||||
actual_provider_name = "huggingface"
|
||||
fallback_model = "openai/gpt-oss-120b:groq"
|
||||
|
||||
if fallback_provider == "google":
|
||||
if json_struct:
|
||||
return gemini_structured_json_response(
|
||||
response_text = gemini_structured_json_response(
|
||||
prompt=prompt,
|
||||
schema=json_struct,
|
||||
temperature=temperature,
|
||||
@@ -183,7 +372,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
else:
|
||||
return gemini_text_response(
|
||||
response_text = gemini_text_response(
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
@@ -193,7 +382,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
)
|
||||
elif fallback_provider == "huggingface":
|
||||
if json_struct:
|
||||
return huggingface_structured_json_response(
|
||||
response_text = huggingface_structured_json_response(
|
||||
prompt=prompt,
|
||||
schema=json_struct,
|
||||
model="openai/gpt-oss-120b:groq",
|
||||
@@ -202,7 +391,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
else:
|
||||
return huggingface_text_response(
|
||||
response_text = huggingface_text_response(
|
||||
prompt=prompt,
|
||||
model="openai/gpt-oss-120b:groq",
|
||||
temperature=temperature,
|
||||
@@ -210,6 +399,96 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
top_p=top_p,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
|
||||
# TRACK USAGE after successful fallback call
|
||||
if response_text:
|
||||
logger.info(f"[llm_text_gen] ✅ Fallback API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
|
||||
try:
|
||||
db_track = next(get_db())
|
||||
try:
|
||||
# Estimate tokens from prompt and response
|
||||
tokens_input = estimated_tokens
|
||||
tokens_output = int(len(str(response_text).split()) * 1.3)
|
||||
tokens_total = tokens_input + tokens_output
|
||||
|
||||
# Get or create usage summary
|
||||
from models.subscription_models import UsageSummary
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush() # Ensure summary is persisted before updating
|
||||
|
||||
# Get "before" state for unified log
|
||||
provider_name = provider_enum.value
|
||||
current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0
|
||||
|
||||
# Update provider-specific counters (sync operation)
|
||||
new_calls = current_calls_before + 1
|
||||
setattr(summary, f"{provider_name}_calls", new_calls)
|
||||
|
||||
# Update token usage for LLM providers
|
||||
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0
|
||||
new_tokens = current_tokens_before + tokens_total
|
||||
setattr(summary, f"{provider_name}_tokens", new_tokens)
|
||||
else:
|
||||
current_tokens_before = 0
|
||||
new_tokens = 0
|
||||
|
||||
# Update totals
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.total_tokens = (summary.total_tokens or 0) + tokens_total
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0
|
||||
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) if limits else 0
|
||||
|
||||
# Get image stats for unified log
|
||||
current_images_before = getattr(summary, "stability_calls", 0) or 0
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"[llm_text_gen] ✅ Successfully tracked fallback usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG for fallback
|
||||
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")
|
||||
# Include image stats in the log
|
||||
print(f"""
|
||||
[SUBSCRIPTION] LLM Text Generation (Fallback)
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: {actual_provider_name}
|
||||
├─ Model: {fallback_model}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {call_limit if call_limit > 0 else '∞'}
|
||||
├─ Tokens: {current_tokens_before} → {new_tokens} / {token_limit if token_limit > 0 else '∞'}
|
||||
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""")
|
||||
except Exception as track_error:
|
||||
logger.error(f"[llm_text_gen] ❌ Error tracking fallback usage (non-blocking): {track_error}", exc_info=True)
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"[llm_text_gen] ❌ Failed to track fallback usage: {usage_error}", exc_info=True)
|
||||
|
||||
return response_text
|
||||
except Exception as fallback_error:
|
||||
logger.error(f"[llm_text_gen] Fallback provider {fallback_provider} also failed: {str(fallback_error)}")
|
||||
|
||||
|
||||
@@ -55,6 +55,14 @@ class MonitoringDataService:
|
||||
alert_threshold=task_data.get('alertThreshold', ''),
|
||||
status='active'
|
||||
)
|
||||
|
||||
# Initialize next_execution based on frequency
|
||||
from services.scheduler.utils.frequency_calculator import calculate_next_execution
|
||||
task.next_execution = calculate_next_execution(
|
||||
frequency=task.frequency,
|
||||
base_time=datetime.utcnow()
|
||||
)
|
||||
|
||||
self.db.add(task)
|
||||
|
||||
# Save activation status
|
||||
@@ -357,3 +365,80 @@ class MonitoringDataService:
|
||||
logger.error(f"Error updating performance metrics for strategy {strategy_id}: {e}")
|
||||
self.db.rollback()
|
||||
return False
|
||||
|
||||
def get_user_execution_logs(
|
||||
self,
|
||||
user_id: int,
|
||||
limit: Optional[int] = 50,
|
||||
offset: Optional[int] = 0,
|
||||
status_filter: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get execution logs for a specific user.
|
||||
|
||||
Args:
|
||||
user_id: User ID to filter execution logs
|
||||
limit: Maximum number of logs to return
|
||||
offset: Number of logs to skip (for pagination)
|
||||
status_filter: Optional status filter ('success', 'failed', 'running', 'skipped')
|
||||
|
||||
Returns:
|
||||
List of execution log dictionaries with task details
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Getting execution logs for user {user_id}")
|
||||
|
||||
# Build query for execution logs filtered by user_id
|
||||
query = self.db.query(TaskExecutionLog).filter(
|
||||
TaskExecutionLog.user_id == user_id
|
||||
)
|
||||
|
||||
# Apply status filter if provided
|
||||
if status_filter:
|
||||
query = query.filter(TaskExecutionLog.status == status_filter)
|
||||
|
||||
# Order by execution date (most recent first)
|
||||
query = query.order_by(desc(TaskExecutionLog.execution_date))
|
||||
|
||||
# Apply pagination
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
if offset:
|
||||
query = query.offset(offset)
|
||||
|
||||
logs = query.all()
|
||||
|
||||
# Convert to dictionaries with task details
|
||||
logs_data = []
|
||||
for log in logs:
|
||||
# Get task details if available
|
||||
task = self.db.query(MonitoringTask).filter(
|
||||
MonitoringTask.id == log.task_id
|
||||
).first()
|
||||
|
||||
log_data = {
|
||||
"id": log.id,
|
||||
"task_id": log.task_id,
|
||||
"user_id": log.user_id,
|
||||
"execution_date": log.execution_date.isoformat() if log.execution_date else None,
|
||||
"status": log.status,
|
||||
"result_data": log.result_data,
|
||||
"error_message": log.error_message,
|
||||
"execution_time_ms": log.execution_time_ms,
|
||||
"created_at": log.created_at.isoformat() if log.created_at else None,
|
||||
"task": {
|
||||
"title": task.task_title if task else None,
|
||||
"description": task.task_description if task else None,
|
||||
"assignee": task.assignee if task else None,
|
||||
"frequency": task.frequency if task else None,
|
||||
"strategy_id": task.strategy_id if task else None
|
||||
} if task else None
|
||||
}
|
||||
logs_data.append(log_data)
|
||||
|
||||
logger.info(f"Retrieved {len(logs_data)} execution logs for user {user_id}")
|
||||
return logs_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting execution logs for user {user_id}: {e}")
|
||||
return []
|
||||
|
||||
59
backend/services/scheduler/__init__.py
Normal file
59
backend/services/scheduler/__init__.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Task Scheduler Package
|
||||
Modular, pluggable scheduler for ALwrity tasks.
|
||||
"""
|
||||
|
||||
from .core.scheduler import TaskScheduler
|
||||
from .core.executor_interface import TaskExecutor, TaskExecutionResult
|
||||
from .core.exception_handler import (
|
||||
SchedulerExceptionHandler, SchedulerException, SchedulerErrorType, SchedulerErrorSeverity,
|
||||
TaskExecutionError, DatabaseError, TaskLoaderError, SchedulerConfigError
|
||||
)
|
||||
from .executors.monitoring_task_executor import MonitoringTaskExecutor
|
||||
from .utils.task_loader import load_due_monitoring_tasks
|
||||
|
||||
# Global scheduler instance (initialized on first access)
|
||||
_scheduler_instance: TaskScheduler = None
|
||||
|
||||
|
||||
def get_scheduler() -> TaskScheduler:
|
||||
"""
|
||||
Get global scheduler instance (singleton pattern).
|
||||
|
||||
Returns:
|
||||
TaskScheduler instance
|
||||
"""
|
||||
global _scheduler_instance
|
||||
if _scheduler_instance is None:
|
||||
_scheduler_instance = TaskScheduler(
|
||||
check_interval_minutes=15,
|
||||
max_concurrent_executions=10
|
||||
)
|
||||
|
||||
# Register monitoring task executor
|
||||
monitoring_executor = MonitoringTaskExecutor()
|
||||
_scheduler_instance.register_executor(
|
||||
'monitoring_task',
|
||||
monitoring_executor,
|
||||
load_due_monitoring_tasks
|
||||
)
|
||||
|
||||
return _scheduler_instance
|
||||
|
||||
|
||||
__all__ = [
|
||||
'TaskScheduler',
|
||||
'TaskExecutor',
|
||||
'TaskExecutionResult',
|
||||
'MonitoringTaskExecutor',
|
||||
'get_scheduler',
|
||||
# Exception handling
|
||||
'SchedulerExceptionHandler',
|
||||
'SchedulerException',
|
||||
'SchedulerErrorType',
|
||||
'SchedulerErrorSeverity',
|
||||
'TaskExecutionError',
|
||||
'DatabaseError',
|
||||
'TaskLoaderError',
|
||||
'SchedulerConfigError'
|
||||
]
|
||||
4
backend/services/scheduler/core/__init__.py
Normal file
4
backend/services/scheduler/core/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
Core scheduler components.
|
||||
"""
|
||||
|
||||
395
backend/services/scheduler/core/exception_handler.py
Normal file
395
backend/services/scheduler/core/exception_handler.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
Comprehensive Exception Handling and Logging for Task Scheduler
|
||||
Provides robust error handling, logging, and monitoring for the scheduler system.
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional, Union
|
||||
from enum import Enum
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("scheduler_exception_handler")
|
||||
|
||||
|
||||
class SchedulerErrorType(Enum):
|
||||
"""Error types for scheduler system."""
|
||||
DATABASE_ERROR = "database_error"
|
||||
TASK_EXECUTION_ERROR = "task_execution_error"
|
||||
TASK_LOADER_ERROR = "task_loader_error"
|
||||
SCHEDULER_CONFIG_ERROR = "scheduler_config_error"
|
||||
RETRY_ERROR = "retry_error"
|
||||
CONCURRENCY_ERROR = "concurrency_error"
|
||||
TIMEOUT_ERROR = "timeout_error"
|
||||
|
||||
|
||||
class SchedulerErrorSeverity(Enum):
|
||||
"""Severity levels for scheduler errors."""
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
class SchedulerException(Exception):
|
||||
"""Base exception for scheduler system errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
error_type: SchedulerErrorType,
|
||||
severity: SchedulerErrorSeverity = SchedulerErrorSeverity.MEDIUM,
|
||||
user_id: Optional[int] = None,
|
||||
task_id: Optional[int] = None,
|
||||
task_type: Optional[str] = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
original_error: Optional[Exception] = None
|
||||
):
|
||||
self.message = message
|
||||
self.error_type = error_type
|
||||
self.severity = severity
|
||||
self.user_id = user_id
|
||||
self.task_id = task_id
|
||||
self.task_type = task_type
|
||||
self.context = context or {}
|
||||
self.original_error = original_error
|
||||
self.timestamp = datetime.utcnow()
|
||||
|
||||
# Capture stack trace if original error provided
|
||||
self.stack_trace = None
|
||||
if self.original_error:
|
||||
try:
|
||||
exc_type, exc_value, exc_traceback = sys.exc_info()
|
||||
if exc_traceback:
|
||||
self.stack_trace = ''.join(traceback.format_exception(
|
||||
exc_type, exc_value, exc_traceback
|
||||
))
|
||||
else:
|
||||
self.stack_trace = traceback.format_exception(
|
||||
type(self.original_error),
|
||||
self.original_error,
|
||||
self.original_error.__traceback__
|
||||
)
|
||||
except Exception:
|
||||
self.stack_trace = str(self.original_error)
|
||||
|
||||
super().__init__(message)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert exception to dictionary for logging/storage."""
|
||||
return {
|
||||
"message": self.message,
|
||||
"error_type": self.error_type.value,
|
||||
"severity": self.severity.value,
|
||||
"user_id": self.user_id,
|
||||
"task_id": self.task_id,
|
||||
"task_type": self.task_type,
|
||||
"context": self.context,
|
||||
"timestamp": self.timestamp.isoformat() if isinstance(self.timestamp, datetime) else self.timestamp,
|
||||
"original_error": str(self.original_error) if self.original_error else None,
|
||||
"stack_trace": self.stack_trace
|
||||
}
|
||||
|
||||
def __str__(self):
|
||||
return f"[{self.error_type.value}] {self.message}"
|
||||
|
||||
|
||||
class DatabaseError(SchedulerException):
|
||||
"""Exception raised for database-related errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
user_id: Optional[int] = None,
|
||||
task_id: Optional[int] = None,
|
||||
context: Dict[str, Any] = None,
|
||||
original_error: Exception = None
|
||||
):
|
||||
super().__init__(
|
||||
message=message,
|
||||
error_type=SchedulerErrorType.DATABASE_ERROR,
|
||||
severity=SchedulerErrorSeverity.CRITICAL,
|
||||
user_id=user_id,
|
||||
task_id=task_id,
|
||||
context=context or {},
|
||||
original_error=original_error
|
||||
)
|
||||
|
||||
|
||||
class TaskExecutionError(SchedulerException):
|
||||
"""Exception raised for task execution failures."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
user_id: Optional[int] = None,
|
||||
task_id: Optional[int] = None,
|
||||
task_type: Optional[str] = None,
|
||||
retry_count: int = 0,
|
||||
execution_time_ms: Optional[int] = None,
|
||||
context: Dict[str, Any] = None,
|
||||
original_error: Exception = None
|
||||
):
|
||||
context = context or {}
|
||||
context.update({
|
||||
"retry_count": retry_count,
|
||||
"execution_time_ms": execution_time_ms
|
||||
})
|
||||
|
||||
super().__init__(
|
||||
message=message,
|
||||
error_type=SchedulerErrorType.TASK_EXECUTION_ERROR,
|
||||
severity=SchedulerErrorSeverity.HIGH,
|
||||
user_id=user_id,
|
||||
task_id=task_id,
|
||||
task_type=task_type,
|
||||
context=context,
|
||||
original_error=original_error
|
||||
)
|
||||
|
||||
|
||||
class TaskLoaderError(SchedulerException):
|
||||
"""Exception raised for task loading failures."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
task_type: Optional[str] = None,
|
||||
user_id: Optional[int] = None,
|
||||
context: Dict[str, Any] = None,
|
||||
original_error: Exception = None
|
||||
):
|
||||
super().__init__(
|
||||
message=message,
|
||||
error_type=SchedulerErrorType.TASK_LOADER_ERROR,
|
||||
severity=SchedulerErrorSeverity.HIGH,
|
||||
user_id=user_id,
|
||||
task_type=task_type,
|
||||
context=context or {},
|
||||
original_error=original_error
|
||||
)
|
||||
|
||||
|
||||
class SchedulerConfigError(SchedulerException):
|
||||
"""Exception raised for scheduler configuration errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
context: Dict[str, Any] = None,
|
||||
original_error: Exception = None
|
||||
):
|
||||
super().__init__(
|
||||
message=message,
|
||||
error_type=SchedulerErrorType.SCHEDULER_CONFIG_ERROR,
|
||||
severity=SchedulerErrorSeverity.CRITICAL,
|
||||
context=context or {},
|
||||
original_error=original_error
|
||||
)
|
||||
|
||||
|
||||
class SchedulerExceptionHandler:
|
||||
"""Comprehensive exception handler for the scheduler system."""
|
||||
|
||||
def __init__(self, db: Session = None):
|
||||
self.db = db
|
||||
self.logger = logger
|
||||
|
||||
def handle_exception(
|
||||
self,
|
||||
error: Union[Exception, SchedulerException],
|
||||
context: Dict[str, Any] = None,
|
||||
log_level: str = "error"
|
||||
) -> Dict[str, Any]:
|
||||
"""Handle and log scheduler exceptions."""
|
||||
|
||||
context = context or {}
|
||||
|
||||
# Convert regular exceptions to SchedulerException
|
||||
if not isinstance(error, SchedulerException):
|
||||
error = SchedulerException(
|
||||
message=str(error),
|
||||
error_type=self._classify_error(error),
|
||||
severity=self._determine_severity(error),
|
||||
context=context,
|
||||
original_error=error
|
||||
)
|
||||
|
||||
# Log the error
|
||||
error_data = error.to_dict()
|
||||
error_data.update(context)
|
||||
|
||||
log_message = f"Scheduler Error: {error.message}"
|
||||
|
||||
if log_level == "critical" or error.severity == SchedulerErrorSeverity.CRITICAL:
|
||||
self.logger.critical(log_message, extra={"error_data": error_data})
|
||||
elif log_level == "error" or error.severity == SchedulerErrorSeverity.HIGH:
|
||||
self.logger.error(log_message, extra={"error_data": error_data})
|
||||
elif log_level == "warning" or error.severity == SchedulerErrorSeverity.MEDIUM:
|
||||
self.logger.warning(log_message, extra={"error_data": error_data})
|
||||
else:
|
||||
self.logger.info(log_message, extra={"error_data": error_data})
|
||||
|
||||
# Store critical errors in database for alerting
|
||||
if error.severity in [SchedulerErrorSeverity.HIGH, SchedulerErrorSeverity.CRITICAL]:
|
||||
self._store_error_alert(error)
|
||||
|
||||
# Return formatted error response
|
||||
return self._format_error_response(error)
|
||||
|
||||
def _classify_error(self, error: Exception) -> SchedulerErrorType:
|
||||
"""Classify an exception into a scheduler error type."""
|
||||
|
||||
error_str = str(error).lower()
|
||||
error_type_name = type(error).__name__.lower()
|
||||
|
||||
# Database errors
|
||||
if isinstance(error, (SQLAlchemyError, OperationalError, IntegrityError)):
|
||||
return SchedulerErrorType.DATABASE_ERROR
|
||||
if "database" in error_str or "sql" in error_type_name or "connection" in error_str:
|
||||
return SchedulerErrorType.DATABASE_ERROR
|
||||
|
||||
# Timeout errors
|
||||
if "timeout" in error_str or "timed out" in error_str:
|
||||
return SchedulerErrorType.TIMEOUT_ERROR
|
||||
|
||||
# Concurrency errors
|
||||
if "concurrent" in error_str or "race" in error_str or "lock" in error_str:
|
||||
return SchedulerErrorType.CONCURRENCY_ERROR
|
||||
|
||||
# Task execution errors
|
||||
if "task" in error_str and "execut" in error_str:
|
||||
return SchedulerErrorType.TASK_EXECUTION_ERROR
|
||||
|
||||
# Task loader errors
|
||||
if "load" in error_str and "task" in error_str:
|
||||
return SchedulerErrorType.TASK_LOADER_ERROR
|
||||
|
||||
# Retry errors
|
||||
if "retry" in error_str:
|
||||
return SchedulerErrorType.RETRY_ERROR
|
||||
|
||||
# Config errors
|
||||
if "config" in error_str or "scheduler" in error_str and "init" in error_str:
|
||||
return SchedulerErrorType.SCHEDULER_CONFIG_ERROR
|
||||
|
||||
# Default to task execution error for unknown errors
|
||||
return SchedulerErrorType.TASK_EXECUTION_ERROR
|
||||
|
||||
def _determine_severity(self, error: Exception) -> SchedulerErrorSeverity:
|
||||
"""Determine the severity of an error."""
|
||||
|
||||
error_str = str(error).lower()
|
||||
error_type = type(error)
|
||||
|
||||
# Critical errors
|
||||
if isinstance(error, (SQLAlchemyError, OperationalError, ConnectionError)):
|
||||
return SchedulerErrorSeverity.CRITICAL
|
||||
if "database" in error_str or "connection" in error_str:
|
||||
return SchedulerErrorSeverity.CRITICAL
|
||||
|
||||
# High severity errors
|
||||
if "timeout" in error_str or "concurrent" in error_str:
|
||||
return SchedulerErrorSeverity.HIGH
|
||||
if isinstance(error, (KeyError, AttributeError)) and "config" in error_str:
|
||||
return SchedulerErrorSeverity.HIGH
|
||||
|
||||
# Medium severity errors
|
||||
if "task" in error_str or "execution" in error_str:
|
||||
return SchedulerErrorSeverity.MEDIUM
|
||||
|
||||
# Default to low
|
||||
return SchedulerErrorSeverity.LOW
|
||||
|
||||
def _store_error_alert(self, error: SchedulerException):
|
||||
"""Store critical errors in database for alerting."""
|
||||
|
||||
if not self.db:
|
||||
return
|
||||
|
||||
try:
|
||||
# Import here to avoid circular dependencies
|
||||
from models.monitoring_models import TaskExecutionLog
|
||||
|
||||
# Store as failed execution log if we have task_id (even without user_id for system errors)
|
||||
if error.task_id:
|
||||
try:
|
||||
execution_log = TaskExecutionLog(
|
||||
task_id=error.task_id,
|
||||
user_id=error.user_id, # Can be None for system-level errors
|
||||
execution_date=error.timestamp,
|
||||
status='failed',
|
||||
error_message=error.message,
|
||||
result_data={
|
||||
"error_type": error.error_type.value,
|
||||
"severity": error.severity.value,
|
||||
"context": error.context,
|
||||
"stack_trace": error.stack_trace,
|
||||
"task_type": error.task_type
|
||||
}
|
||||
)
|
||||
self.db.add(execution_log)
|
||||
self.db.commit()
|
||||
self.logger.info(f"Stored error alert in execution log for task {error.task_id}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to store error in execution log: {e}")
|
||||
self.db.rollback()
|
||||
# Note: For errors without task_id, we rely on structured logging only
|
||||
# Future: Could create a separate scheduler_error_logs table for system-level errors
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to store error alert: {e}")
|
||||
|
||||
def _format_error_response(self, error: SchedulerException) -> Dict[str, Any]:
|
||||
"""Format error for API response or logging."""
|
||||
|
||||
response = {
|
||||
"success": False,
|
||||
"error": {
|
||||
"type": error.error_type.value,
|
||||
"message": error.message,
|
||||
"severity": error.severity.value,
|
||||
"timestamp": error.timestamp.isoformat() if isinstance(error.timestamp, datetime) else str(error.timestamp),
|
||||
"user_id": error.user_id,
|
||||
"task_id": error.task_id,
|
||||
"task_type": error.task_type
|
||||
}
|
||||
}
|
||||
|
||||
# Add context for debugging (non-sensitive info only)
|
||||
if error.context:
|
||||
safe_context = {
|
||||
k: v for k, v in error.context.items()
|
||||
if k not in ["password", "token", "key", "secret", "credential"]
|
||||
}
|
||||
response["error"]["context"] = safe_context
|
||||
|
||||
# Add user-friendly message based on error type
|
||||
user_messages = {
|
||||
SchedulerErrorType.DATABASE_ERROR:
|
||||
"A database error occurred while processing the task. Please try again later.",
|
||||
SchedulerErrorType.TASK_EXECUTION_ERROR:
|
||||
"The task failed to execute. Please check the task configuration and try again.",
|
||||
SchedulerErrorType.TASK_LOADER_ERROR:
|
||||
"Failed to load tasks. The scheduler may be experiencing issues.",
|
||||
SchedulerErrorType.SCHEDULER_CONFIG_ERROR:
|
||||
"The scheduler configuration is invalid. Contact support.",
|
||||
SchedulerErrorType.RETRY_ERROR:
|
||||
"Task retry failed. The task will be rescheduled.",
|
||||
SchedulerErrorType.CONCURRENCY_ERROR:
|
||||
"A concurrency issue occurred. The task will be retried.",
|
||||
SchedulerErrorType.TIMEOUT_ERROR:
|
||||
"The task execution timed out. The task will be retried."
|
||||
}
|
||||
|
||||
response["error"]["user_message"] = user_messages.get(
|
||||
error.error_type,
|
||||
"An error occurred while processing the task."
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
75
backend/services/scheduler/core/executor_interface.py
Normal file
75
backend/services/scheduler/core/executor_interface.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
Task Executor Interface
|
||||
Abstract base class for all task executors.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskExecutionResult:
|
||||
"""Result of task execution."""
|
||||
success: bool
|
||||
error_message: Optional[str] = None
|
||||
result_data: Optional[Dict[str, Any]] = None
|
||||
execution_time_ms: Optional[int] = None
|
||||
retryable: bool = True
|
||||
retry_delay: int = 300 # seconds
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
'success': self.success,
|
||||
'error_message': self.error_message,
|
||||
'result_data': self.result_data,
|
||||
'execution_time_ms': self.execution_time_ms,
|
||||
'retryable': self.retryable,
|
||||
'retry_delay': self.retry_delay
|
||||
}
|
||||
|
||||
|
||||
class TaskExecutor(ABC):
|
||||
"""
|
||||
Abstract base class for task executors.
|
||||
|
||||
Each task type must implement this interface to be schedulable.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def execute_task(self, task: Any, db: Session) -> TaskExecutionResult:
|
||||
"""
|
||||
Execute a task.
|
||||
|
||||
Args:
|
||||
task: Task instance from database
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
TaskExecutionResult with execution details
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def calculate_next_execution(
|
||||
self,
|
||||
task: Any,
|
||||
frequency: str,
|
||||
last_execution: Optional[datetime] = None
|
||||
) -> datetime:
|
||||
"""
|
||||
Calculate next execution time based on frequency.
|
||||
|
||||
Args:
|
||||
task: Task instance
|
||||
frequency: Task frequency (e.g., 'Daily', 'Weekly')
|
||||
last_execution: Last execution datetime
|
||||
|
||||
Returns:
|
||||
Next execution datetime
|
||||
"""
|
||||
pass
|
||||
|
||||
628
backend/services/scheduler/core/scheduler.py
Normal file
628
backend/services/scheduler/core/scheduler.py
Normal file
@@ -0,0 +1,628 @@
|
||||
"""
|
||||
Core Task Scheduler Service
|
||||
Pluggable task scheduler that can work with any task model.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List, Callable
|
||||
from datetime import datetime
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .executor_interface import TaskExecutor, TaskExecutionResult
|
||||
from .task_registry import TaskRegistry
|
||||
from .exception_handler import (
|
||||
SchedulerExceptionHandler, SchedulerException, TaskExecutionError, DatabaseError,
|
||||
TaskLoaderError, SchedulerConfigError
|
||||
)
|
||||
from services.database import get_db_session
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("task_scheduler")
|
||||
|
||||
|
||||
class TaskScheduler:
|
||||
"""
|
||||
Pluggable task scheduler that can work with any task model.
|
||||
|
||||
Features:
|
||||
- Async task execution
|
||||
- Plugin-based executor system
|
||||
- Database-backed task persistence
|
||||
- Configurable check intervals
|
||||
- Automatic retry logic
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
check_interval_minutes: int = 15,
|
||||
max_concurrent_executions: int = 10,
|
||||
enable_retries: bool = True,
|
||||
max_retries: int = 3
|
||||
):
|
||||
"""
|
||||
Initialize the task scheduler.
|
||||
|
||||
Args:
|
||||
check_interval_minutes: How often to check for due tasks
|
||||
max_concurrent_executions: Maximum concurrent task executions
|
||||
enable_retries: Whether to retry failed tasks
|
||||
max_retries: Maximum retry attempts
|
||||
"""
|
||||
self.check_interval_minutes = check_interval_minutes
|
||||
self.max_concurrent_executions = max_concurrent_executions
|
||||
self.enable_retries = enable_retries
|
||||
self.max_retries = max_retries
|
||||
|
||||
# Initialize APScheduler
|
||||
self.scheduler = AsyncIOScheduler(
|
||||
timezone='UTC',
|
||||
job_defaults={
|
||||
'coalesce': True,
|
||||
'max_instances': 1,
|
||||
'misfire_grace_time': 300 # 5 minutes grace period
|
||||
}
|
||||
)
|
||||
|
||||
# Task executor registry
|
||||
self.registry = TaskRegistry()
|
||||
|
||||
# Track running executions
|
||||
self.active_executions: Dict[str, asyncio.Task] = {}
|
||||
|
||||
# Exception handler for robust error handling
|
||||
self.exception_handler = SchedulerExceptionHandler()
|
||||
|
||||
# Intelligent scheduling configuration
|
||||
self.min_check_interval_minutes = 15 # Check every 15min when active strategies exist
|
||||
self.max_check_interval_minutes = 60 # Check every 60min when no active strategies
|
||||
self.current_check_interval_minutes = check_interval_minutes # Current interval
|
||||
|
||||
# Statistics
|
||||
self.stats = {
|
||||
'total_checks': 0,
|
||||
'tasks_found': 0,
|
||||
'tasks_executed': 0,
|
||||
'tasks_failed': 0,
|
||||
'tasks_skipped': 0,
|
||||
'last_check': None,
|
||||
'per_user_stats': {}, # Track metrics per user for user isolation
|
||||
'active_strategies_count': 0, # Track active strategies with tasks
|
||||
'last_interval_adjustment': None # Track when interval was last adjusted
|
||||
}
|
||||
|
||||
self._running = False
|
||||
|
||||
def _get_trigger_for_interval(self, interval_minutes: int):
|
||||
"""
|
||||
Get the appropriate trigger for the given interval.
|
||||
|
||||
For intervals >= 60 minutes, use IntervalTrigger.
|
||||
For intervals < 60 minutes, use CronTrigger.
|
||||
|
||||
Args:
|
||||
interval_minutes: Interval in minutes
|
||||
|
||||
Returns:
|
||||
Appropriate APScheduler trigger
|
||||
"""
|
||||
if interval_minutes >= 60:
|
||||
# Use IntervalTrigger for intervals >= 60 minutes
|
||||
return IntervalTrigger(minutes=interval_minutes)
|
||||
else:
|
||||
# Use CronTrigger for intervals < 60 minutes (valid range: 0-59)
|
||||
return CronTrigger(minute=f'*/{interval_minutes}')
|
||||
|
||||
def register_executor(
|
||||
self,
|
||||
task_type: str,
|
||||
executor: TaskExecutor,
|
||||
task_loader: Callable[[Session], List[Any]]
|
||||
):
|
||||
"""
|
||||
Register a task executor for a specific task type.
|
||||
|
||||
Args:
|
||||
task_type: Unique identifier for task type (e.g., 'monitoring_task')
|
||||
executor: TaskExecutor instance that handles execution
|
||||
task_loader: Function that loads due tasks from database
|
||||
"""
|
||||
self.registry.register(task_type, executor, task_loader)
|
||||
logger.info(f"Registered executor for task type: {task_type}")
|
||||
|
||||
async def start(self):
|
||||
"""Start the scheduler with intelligent interval adjustment."""
|
||||
if self._running:
|
||||
logger.warning("Scheduler is already running")
|
||||
return
|
||||
|
||||
try:
|
||||
# Determine initial check interval based on active strategies
|
||||
initial_interval = await self._determine_optimal_interval()
|
||||
self.current_check_interval_minutes = initial_interval
|
||||
|
||||
# Add periodic job to check for due tasks
|
||||
self.scheduler.add_job(
|
||||
self._check_and_execute_due_tasks,
|
||||
trigger=self._get_trigger_for_interval(initial_interval),
|
||||
id='check_due_tasks',
|
||||
replace_existing=True
|
||||
)
|
||||
|
||||
self.scheduler.start()
|
||||
self._running = True
|
||||
|
||||
logger.info(
|
||||
f"Task scheduler started | "
|
||||
f"check_interval={initial_interval}min | "
|
||||
f"registered_types={self.registry.get_registered_types()}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start scheduler: {e}")
|
||||
raise
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the scheduler gracefully."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
try:
|
||||
# Cancel all active executions
|
||||
for task_id, execution_task in self.active_executions.items():
|
||||
execution_task.cancel()
|
||||
|
||||
# Wait for active executions to complete (with timeout)
|
||||
if self.active_executions:
|
||||
await asyncio.wait(
|
||||
self.active_executions.values(),
|
||||
timeout=30
|
||||
)
|
||||
|
||||
# Shutdown scheduler
|
||||
self.scheduler.shutdown(wait=True)
|
||||
self._running = False
|
||||
|
||||
logger.info("Task scheduler stopped gracefully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping scheduler: {e}")
|
||||
raise
|
||||
|
||||
async def _check_and_execute_due_tasks(self):
|
||||
"""
|
||||
Main scheduler loop: check for due tasks and execute them.
|
||||
This runs periodically with intelligent interval adjustment based on active strategies.
|
||||
"""
|
||||
self.stats['total_checks'] += 1
|
||||
self.stats['last_check'] = datetime.utcnow().isoformat()
|
||||
|
||||
logger.debug("Checking for due tasks...")
|
||||
|
||||
db = None
|
||||
try:
|
||||
db = get_db_session()
|
||||
if db is None:
|
||||
logger.error("Failed to get database session")
|
||||
return
|
||||
|
||||
# Check for active strategies and adjust interval intelligently
|
||||
await self._adjust_check_interval_if_needed(db)
|
||||
|
||||
# Check each registered task type
|
||||
for task_type in self.registry.get_registered_types():
|
||||
await self._process_task_type(task_type, db)
|
||||
|
||||
except Exception as e:
|
||||
error = DatabaseError(
|
||||
message=f"Error checking for due tasks: {str(e)}",
|
||||
original_error=e
|
||||
)
|
||||
self.exception_handler.handle_exception(error)
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
async def _determine_optimal_interval(self) -> int:
|
||||
"""
|
||||
Determine optimal check interval based on active strategies.
|
||||
|
||||
Returns:
|
||||
Optimal check interval in minutes
|
||||
"""
|
||||
db = None
|
||||
try:
|
||||
db = get_db_session()
|
||||
if db:
|
||||
from services.active_strategy_service import ActiveStrategyService
|
||||
active_strategy_service = ActiveStrategyService(db_session=db)
|
||||
active_count = active_strategy_service.count_active_strategies_with_tasks()
|
||||
self.stats['active_strategies_count'] = active_count
|
||||
|
||||
if active_count > 0:
|
||||
logger.info(f"Found {active_count} active strategies with tasks - using {self.min_check_interval_minutes}min interval")
|
||||
return self.min_check_interval_minutes
|
||||
else:
|
||||
logger.info(f"No active strategies with tasks - using {self.max_check_interval_minutes}min interval")
|
||||
return self.max_check_interval_minutes
|
||||
except Exception as e:
|
||||
logger.warning(f"Error determining optimal interval: {e}, using default {self.min_check_interval_minutes}min")
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
# Default to shorter interval on error (safer)
|
||||
return self.min_check_interval_minutes
|
||||
|
||||
async def _adjust_check_interval_if_needed(self, db: Session):
|
||||
"""
|
||||
Intelligently adjust check interval based on active strategies.
|
||||
|
||||
If there are active strategies with tasks, check more frequently.
|
||||
If there are no active strategies, check less frequently.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
"""
|
||||
try:
|
||||
from services.active_strategy_service import ActiveStrategyService
|
||||
|
||||
active_strategy_service = ActiveStrategyService(db_session=db)
|
||||
active_count = active_strategy_service.count_active_strategies_with_tasks()
|
||||
self.stats['active_strategies_count'] = active_count
|
||||
|
||||
# Determine optimal interval
|
||||
if active_count > 0:
|
||||
optimal_interval = self.min_check_interval_minutes
|
||||
else:
|
||||
optimal_interval = self.max_check_interval_minutes
|
||||
|
||||
# Only reschedule if interval needs to change
|
||||
if optimal_interval != self.current_check_interval_minutes:
|
||||
logger.info(
|
||||
f"Adjusting scheduler interval: {self.current_check_interval_minutes}min → {optimal_interval}min | "
|
||||
f"active_strategies={active_count}"
|
||||
)
|
||||
|
||||
# Reschedule the job with new interval
|
||||
self.scheduler.modify_job(
|
||||
'check_due_tasks',
|
||||
trigger=self._get_trigger_for_interval(optimal_interval)
|
||||
)
|
||||
|
||||
self.current_check_interval_minutes = optimal_interval
|
||||
self.stats['last_interval_adjustment'] = datetime.utcnow().isoformat()
|
||||
|
||||
logger.info(f"Scheduler interval adjusted to {optimal_interval}min")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error adjusting check interval: {e}")
|
||||
|
||||
async def trigger_interval_adjustment(self):
|
||||
"""
|
||||
Trigger immediate interval adjustment check.
|
||||
|
||||
This should be called when a strategy is activated or deactivated
|
||||
to immediately adjust the scheduler interval based on current active strategies.
|
||||
"""
|
||||
if not self._running:
|
||||
logger.debug("Scheduler not running, skipping interval adjustment")
|
||||
return
|
||||
|
||||
try:
|
||||
db = get_db_session()
|
||||
if db:
|
||||
await self._adjust_check_interval_if_needed(db)
|
||||
else:
|
||||
logger.warning("Could not get database session for interval adjustment")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error triggering interval adjustment: {e}")
|
||||
|
||||
async def _process_task_type(self, task_type: str, db: Session):
|
||||
"""Process due tasks for a specific task type."""
|
||||
try:
|
||||
# Get task loader for this type
|
||||
try:
|
||||
task_loader = self.registry.get_task_loader(task_type)
|
||||
except Exception as e:
|
||||
error = TaskLoaderError(
|
||||
message=f"Failed to get task loader for type {task_type}: {str(e)}",
|
||||
task_type=task_type,
|
||||
original_error=e
|
||||
)
|
||||
self.exception_handler.handle_exception(error)
|
||||
return
|
||||
|
||||
# Load due tasks (with error handling)
|
||||
try:
|
||||
due_tasks = task_loader(db)
|
||||
except Exception as e:
|
||||
error = TaskLoaderError(
|
||||
message=f"Failed to load due tasks for type {task_type}: {str(e)}",
|
||||
task_type=task_type,
|
||||
original_error=e
|
||||
)
|
||||
self.exception_handler.handle_exception(error)
|
||||
return
|
||||
|
||||
if not due_tasks:
|
||||
return
|
||||
|
||||
self.stats['tasks_found'] += len(due_tasks)
|
||||
logger.info(f"Found {len(due_tasks)} due tasks for type: {task_type}")
|
||||
|
||||
# Execute tasks (with concurrency limit)
|
||||
execution_tasks = []
|
||||
for task in due_tasks:
|
||||
if len(self.active_executions) >= self.max_concurrent_executions:
|
||||
logger.warning(
|
||||
f"Max concurrent executions reached ({self.max_concurrent_executions}), "
|
||||
f"skipping {len(due_tasks) - len(execution_tasks)} tasks"
|
||||
)
|
||||
break
|
||||
|
||||
# Execute task asynchronously
|
||||
# Note: Each task gets its own database session to prevent concurrent access issues
|
||||
execution_task = asyncio.create_task(
|
||||
self._execute_task_async(task_type, task)
|
||||
)
|
||||
|
||||
task_id = f"{task_type}_{getattr(task, 'id', id(task))}"
|
||||
self.active_executions[task_id] = execution_task
|
||||
|
||||
execution_tasks.append(execution_task)
|
||||
|
||||
# Wait for executions to complete (with timeout per task)
|
||||
if execution_tasks:
|
||||
await asyncio.wait(execution_tasks, timeout=300)
|
||||
|
||||
except Exception as e:
|
||||
error = TaskLoaderError(
|
||||
message=f"Error processing task type {task_type}: {str(e)}",
|
||||
task_type=task_type,
|
||||
original_error=e
|
||||
)
|
||||
self.exception_handler.handle_exception(error)
|
||||
|
||||
async def _execute_task_async(self, task_type: str, task: Any):
|
||||
"""
|
||||
Execute a single task asynchronously with user isolation.
|
||||
|
||||
Each task gets its own database session to prevent concurrent access issues,
|
||||
as SQLAlchemy sessions are not async-safe or concurrent-safe.
|
||||
|
||||
User context is extracted and tracked for user isolation.
|
||||
|
||||
Args:
|
||||
task_type: Type of task
|
||||
task: Task instance from database (detached from original session)
|
||||
"""
|
||||
task_id = f"{task_type}_{getattr(task, 'id', id(task))}"
|
||||
db = None
|
||||
user_id = None
|
||||
|
||||
try:
|
||||
# Extract user context if available (for user isolation tracking)
|
||||
try:
|
||||
if hasattr(task, 'strategy') and task.strategy:
|
||||
user_id = getattr(task.strategy, 'user_id', None)
|
||||
elif hasattr(task, 'strategy_id') and task.strategy_id:
|
||||
# Will query user_id after we have db session
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not extract user_id before execution for task {task_id}: {e}")
|
||||
|
||||
logger.info(f"Executing task: {task_id} | user_id: {user_id}")
|
||||
|
||||
# Create a new database session for this async task
|
||||
# SQLAlchemy sessions are not async-safe and cannot be shared across concurrent tasks
|
||||
db = get_db_session()
|
||||
if db is None:
|
||||
error = DatabaseError(
|
||||
message=f"Failed to get database session for task {task_id}",
|
||||
user_id=user_id,
|
||||
task_id=getattr(task, 'id', None),
|
||||
task_type=task_type
|
||||
)
|
||||
self.exception_handler.handle_exception(error, log_level="error")
|
||||
self.stats['tasks_failed'] += 1
|
||||
self._update_user_stats(user_id, success=False)
|
||||
return
|
||||
|
||||
# Set database session for exception handler
|
||||
self.exception_handler.db = db
|
||||
|
||||
# Merge the detached task object into this session
|
||||
# The task object was loaded in a different session and is now detached
|
||||
from sqlalchemy.orm import object_session
|
||||
if object_session(task) is None:
|
||||
# Task is detached, need to merge it into this session
|
||||
task = db.merge(task)
|
||||
|
||||
# Extract user_id after merge if not already available
|
||||
if user_id is None and hasattr(task, 'strategy'):
|
||||
try:
|
||||
if task.strategy:
|
||||
user_id = getattr(task.strategy, 'user_id', None)
|
||||
elif hasattr(task, 'strategy_id'):
|
||||
# Query strategy if relationship not loaded
|
||||
from models.enhanced_strategy_models import EnhancedContentStrategy
|
||||
strategy = db.query(EnhancedContentStrategy).filter(
|
||||
EnhancedContentStrategy.id == task.strategy_id
|
||||
).first()
|
||||
if strategy:
|
||||
user_id = strategy.user_id
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not extract user_id after merge for task {task_id}: {e}")
|
||||
|
||||
# Get executor for this task type
|
||||
try:
|
||||
executor = self.registry.get_executor(task_type)
|
||||
except Exception as e:
|
||||
from .exception_handler import SchedulerConfigError
|
||||
error = SchedulerConfigError(
|
||||
message=f"Failed to get executor for task type {task_type}: {str(e)}",
|
||||
user_id=user_id,
|
||||
context={
|
||||
"task_id": getattr(task, 'id', None),
|
||||
"task_type": task_type
|
||||
},
|
||||
original_error=e
|
||||
)
|
||||
self.exception_handler.handle_exception(error)
|
||||
self.stats['tasks_failed'] += 1
|
||||
self._update_user_stats(user_id, success=False)
|
||||
return
|
||||
|
||||
# Execute task with its own session (with error handling)
|
||||
try:
|
||||
result = await executor.execute_task(task, db)
|
||||
|
||||
# Handle result and update statistics
|
||||
if result.success:
|
||||
self.stats['tasks_executed'] += 1
|
||||
self._update_user_stats(user_id, success=True)
|
||||
logger.info(f"Task executed successfully: {task_id} | user_id: {user_id}")
|
||||
else:
|
||||
self.stats['tasks_failed'] += 1
|
||||
self._update_user_stats(user_id, success=False)
|
||||
|
||||
# Create structured error for failed execution
|
||||
error = TaskExecutionError(
|
||||
message=result.error_message or "Task execution failed",
|
||||
user_id=user_id,
|
||||
task_id=getattr(task, 'id', None),
|
||||
task_type=task_type,
|
||||
execution_time_ms=result.execution_time_ms,
|
||||
context={"result_data": result.result_data}
|
||||
)
|
||||
self.exception_handler.handle_exception(error, log_level="warning")
|
||||
|
||||
# Retry logic if enabled
|
||||
if self.enable_retries and result.retryable:
|
||||
await self._schedule_retry(task, result.retry_delay)
|
||||
|
||||
except SchedulerException as e:
|
||||
# Re-raise scheduler exceptions (they're already handled)
|
||||
raise
|
||||
except Exception as e:
|
||||
# Wrap unexpected exceptions
|
||||
error = TaskExecutionError(
|
||||
message=f"Unexpected error during task execution: {str(e)}",
|
||||
user_id=user_id,
|
||||
task_id=getattr(task, 'id', None),
|
||||
task_type=task_type,
|
||||
original_error=e
|
||||
)
|
||||
self.exception_handler.handle_exception(error)
|
||||
self.stats['tasks_failed'] += 1
|
||||
self._update_user_stats(user_id, success=False)
|
||||
|
||||
except SchedulerException as e:
|
||||
# Handle scheduler exceptions
|
||||
self.exception_handler.handle_exception(e)
|
||||
self.stats['tasks_failed'] += 1
|
||||
self._update_user_stats(user_id, success=False)
|
||||
except Exception as e:
|
||||
# Handle any other unexpected errors
|
||||
error = TaskExecutionError(
|
||||
message=f"Unexpected error in task execution wrapper: {str(e)}",
|
||||
user_id=user_id,
|
||||
task_id=getattr(task, 'id', None),
|
||||
task_type=task_type,
|
||||
original_error=e
|
||||
)
|
||||
self.exception_handler.handle_exception(error)
|
||||
self.stats['tasks_failed'] += 1
|
||||
self._update_user_stats(user_id, success=False)
|
||||
finally:
|
||||
# Clean up database session
|
||||
if db:
|
||||
try:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing database session for task {task_id}: {e}")
|
||||
|
||||
# Remove from active executions
|
||||
if task_id in self.active_executions:
|
||||
del self.active_executions[task_id]
|
||||
|
||||
def _update_user_stats(self, user_id: Optional[int], success: bool):
|
||||
"""
|
||||
Update per-user statistics for user isolation tracking.
|
||||
|
||||
Args:
|
||||
user_id: User ID (None if user context not available)
|
||||
success: Whether task execution was successful
|
||||
"""
|
||||
if user_id is None:
|
||||
return
|
||||
|
||||
if user_id not in self.stats['per_user_stats']:
|
||||
self.stats['per_user_stats'][user_id] = {
|
||||
'executed': 0,
|
||||
'failed': 0,
|
||||
'success_rate': 0.0
|
||||
}
|
||||
|
||||
user_stats = self.stats['per_user_stats'][user_id]
|
||||
if success:
|
||||
user_stats['executed'] += 1
|
||||
else:
|
||||
user_stats['failed'] += 1
|
||||
|
||||
# Calculate success rate
|
||||
total = user_stats['executed'] + user_stats['failed']
|
||||
if total > 0:
|
||||
user_stats['success_rate'] = (user_stats['executed'] / total) * 100.0
|
||||
|
||||
async def _schedule_retry(self, task: Any, delay_seconds: int):
|
||||
"""Schedule a retry for a failed task."""
|
||||
# This would update the task's next_execution time
|
||||
# For now, just log - could be enhanced to update next_execution
|
||||
logger.debug(f"Scheduling retry for task in {delay_seconds}s")
|
||||
|
||||
def get_stats(self, user_id: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Get scheduler statistics with optional user filtering.
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID to filter statistics for specific user
|
||||
|
||||
Returns:
|
||||
Dictionary with scheduler statistics
|
||||
"""
|
||||
base_stats = {
|
||||
**{k: v for k, v in self.stats.items() if k not in ['per_user_stats']},
|
||||
'active_executions': len(self.active_executions),
|
||||
'registered_types': self.registry.get_registered_types(),
|
||||
'running': self._running,
|
||||
'check_interval_minutes': self.current_check_interval_minutes,
|
||||
'min_check_interval_minutes': self.min_check_interval_minutes,
|
||||
'max_check_interval_minutes': self.max_check_interval_minutes,
|
||||
'intelligent_scheduling': True
|
||||
}
|
||||
|
||||
# Include per-user stats (all users or filtered)
|
||||
if user_id is not None:
|
||||
if user_id in self.stats['per_user_stats']:
|
||||
base_stats['user_stats'] = self.stats['per_user_stats'][user_id]
|
||||
else:
|
||||
base_stats['user_stats'] = {
|
||||
'executed': 0,
|
||||
'failed': 0,
|
||||
'success_rate': 0.0
|
||||
}
|
||||
else:
|
||||
# Include all per-user stats (for admin/debugging)
|
||||
base_stats['per_user_stats'] = self.stats['per_user_stats']
|
||||
|
||||
return base_stats
|
||||
|
||||
def is_running(self) -> bool:
|
||||
"""Check if scheduler is running."""
|
||||
return self._running
|
||||
|
||||
59
backend/services/scheduler/core/task_registry.py
Normal file
59
backend/services/scheduler/core/task_registry.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Task Registry
|
||||
Manages registration of task executors and loaders.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Callable, List, Any
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .executor_interface import TaskExecutor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskRegistry:
|
||||
"""Registry for task executors and loaders."""
|
||||
|
||||
def __init__(self):
|
||||
self.executors: Dict[str, TaskExecutor] = {}
|
||||
self.task_loaders: Dict[str, Callable[[Session], List[Any]]] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
task_type: str,
|
||||
executor: TaskExecutor,
|
||||
task_loader: Callable[[Session], List[Any]]
|
||||
):
|
||||
"""
|
||||
Register a task executor and loader.
|
||||
|
||||
Args:
|
||||
task_type: Unique identifier for task type
|
||||
executor: TaskExecutor instance
|
||||
task_loader: Function that loads due tasks from database
|
||||
"""
|
||||
if task_type in self.executors:
|
||||
logger.warning(f"Overwriting existing executor for task type: {task_type}")
|
||||
|
||||
self.executors[task_type] = executor
|
||||
self.task_loaders[task_type] = task_loader
|
||||
|
||||
logger.info(f"Registered task type: {task_type}")
|
||||
|
||||
def get_executor(self, task_type: str) -> TaskExecutor:
|
||||
"""Get executor for task type."""
|
||||
if task_type not in self.executors:
|
||||
raise ValueError(f"No executor registered for task type: {task_type}")
|
||||
return self.executors[task_type]
|
||||
|
||||
def get_task_loader(self, task_type: str) -> Callable[[Session], List[Any]]:
|
||||
"""Get task loader for task type."""
|
||||
if task_type not in self.task_loaders:
|
||||
raise ValueError(f"No task loader registered for task type: {task_type}")
|
||||
return self.task_loaders[task_type]
|
||||
|
||||
def get_registered_types(self) -> List[str]:
|
||||
"""Get list of registered task types."""
|
||||
return list(self.executors.keys())
|
||||
|
||||
4
backend/services/scheduler/executors/__init__.py
Normal file
4
backend/services/scheduler/executors/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
Task executor implementations.
|
||||
"""
|
||||
|
||||
266
backend/services/scheduler/executors/monitoring_task_executor.py
Normal file
266
backend/services/scheduler/executors/monitoring_task_executor.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""
|
||||
Monitoring Task Executor
|
||||
Handles execution of content strategy monitoring tasks.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..core.executor_interface import TaskExecutor, TaskExecutionResult
|
||||
from ..core.exception_handler import TaskExecutionError, DatabaseError, SchedulerExceptionHandler
|
||||
from ..utils.frequency_calculator import calculate_next_execution
|
||||
from models.monitoring_models import MonitoringTask, TaskExecutionLog
|
||||
from models.enhanced_strategy_models import EnhancedContentStrategy
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("monitoring_task_executor")
|
||||
|
||||
|
||||
class MonitoringTaskExecutor(TaskExecutor):
|
||||
"""
|
||||
Executor for content strategy monitoring tasks.
|
||||
|
||||
Handles:
|
||||
- ALwrity tasks (automated execution)
|
||||
- Human tasks (notifications/queuing)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logger
|
||||
self.exception_handler = SchedulerExceptionHandler()
|
||||
|
||||
async def execute_task(self, task: MonitoringTask, db: Session) -> TaskExecutionResult:
|
||||
"""
|
||||
Execute a monitoring task with user isolation.
|
||||
|
||||
Args:
|
||||
task: MonitoringTask instance (with strategy relationship loaded)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
TaskExecutionResult
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Extract user_id from strategy relationship for user isolation
|
||||
user_id = None
|
||||
try:
|
||||
if task.strategy and hasattr(task.strategy, 'user_id'):
|
||||
user_id = task.strategy.user_id
|
||||
elif task.strategy_id:
|
||||
# Fallback: query strategy if relationship not loaded
|
||||
strategy = db.query(EnhancedContentStrategy).filter(
|
||||
EnhancedContentStrategy.id == task.strategy_id
|
||||
).first()
|
||||
if strategy:
|
||||
user_id = strategy.user_id
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Could not extract user_id for task {task.id}: {e}")
|
||||
|
||||
try:
|
||||
self.logger.info(
|
||||
f"Executing monitoring task: {task.id} | "
|
||||
f"user_id: {user_id} | "
|
||||
f"assignee: {task.assignee} | "
|
||||
f"frequency: {task.frequency}"
|
||||
)
|
||||
|
||||
# Create execution log with user_id for user isolation tracking
|
||||
execution_log = TaskExecutionLog(
|
||||
task_id=task.id,
|
||||
user_id=user_id,
|
||||
execution_date=datetime.utcnow(),
|
||||
status='running'
|
||||
)
|
||||
db.add(execution_log)
|
||||
db.flush()
|
||||
|
||||
# Execute based on assignee
|
||||
if task.assignee == 'ALwrity':
|
||||
result = await self._execute_alwrity_task(task, db)
|
||||
else:
|
||||
result = await self._execute_human_task(task, db)
|
||||
|
||||
# Update execution log
|
||||
execution_time_ms = int((time.time() - start_time) * 1000)
|
||||
execution_log.status = 'success' if result.success else 'failed'
|
||||
execution_log.result_data = result.result_data
|
||||
execution_log.error_message = result.error_message
|
||||
execution_log.execution_time_ms = execution_time_ms
|
||||
|
||||
# Update task
|
||||
task.last_executed = datetime.utcnow()
|
||||
task.next_execution = self.calculate_next_execution(
|
||||
task,
|
||||
task.frequency,
|
||||
task.last_executed
|
||||
)
|
||||
|
||||
if result.success:
|
||||
task.status = 'completed'
|
||||
else:
|
||||
task.status = 'failed'
|
||||
|
||||
db.commit()
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
execution_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Set database session for exception handler
|
||||
self.exception_handler.db = db
|
||||
|
||||
# Create structured error
|
||||
error = TaskExecutionError(
|
||||
message=f"Error executing monitoring task {task.id}: {str(e)}",
|
||||
user_id=user_id,
|
||||
task_id=task.id,
|
||||
task_type="monitoring_task",
|
||||
execution_time_ms=execution_time_ms,
|
||||
context={
|
||||
"assignee": task.assignee,
|
||||
"frequency": task.frequency,
|
||||
"component": task.component_name
|
||||
},
|
||||
original_error=e
|
||||
)
|
||||
|
||||
# Handle exception with structured logging
|
||||
self.exception_handler.handle_exception(error)
|
||||
|
||||
# Update execution log with error (include user_id for isolation)
|
||||
try:
|
||||
execution_log = TaskExecutionLog(
|
||||
task_id=task.id,
|
||||
user_id=user_id,
|
||||
execution_date=datetime.utcnow(),
|
||||
status='failed',
|
||||
error_message=str(e),
|
||||
execution_time_ms=execution_time_ms,
|
||||
result_data={
|
||||
"error_type": error.error_type.value,
|
||||
"severity": error.severity.value,
|
||||
"context": error.context
|
||||
}
|
||||
)
|
||||
db.add(execution_log)
|
||||
|
||||
task.status = 'failed'
|
||||
task.last_executed = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
except Exception as commit_error:
|
||||
db_error = DatabaseError(
|
||||
message=f"Error saving execution log: {str(commit_error)}",
|
||||
user_id=user_id,
|
||||
task_id=task.id,
|
||||
original_error=commit_error
|
||||
)
|
||||
self.exception_handler.handle_exception(db_error)
|
||||
db.rollback()
|
||||
|
||||
return TaskExecutionResult(
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
execution_time_ms=execution_time_ms,
|
||||
retryable=True,
|
||||
retry_delay=300
|
||||
)
|
||||
|
||||
async def _execute_alwrity_task(self, task: MonitoringTask, db: Session) -> TaskExecutionResult:
|
||||
"""
|
||||
Execute an ALwrity (automated) monitoring task.
|
||||
|
||||
This is where the actual monitoring logic would go.
|
||||
For now, we'll implement a placeholder that can be extended.
|
||||
"""
|
||||
try:
|
||||
self.logger.info(f"Executing ALwrity task: {task.task_title}")
|
||||
|
||||
# TODO: Implement actual monitoring logic based on:
|
||||
# - task.metric
|
||||
# - task.measurement_method
|
||||
# - task.success_criteria
|
||||
# - task.alert_threshold
|
||||
|
||||
# Placeholder: Simulate task execution
|
||||
result_data = {
|
||||
'metric_value': 0,
|
||||
'status': 'measured',
|
||||
'message': f"Task {task.task_title} executed successfully",
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
return TaskExecutionResult(
|
||||
success=True,
|
||||
result_data=result_data
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in ALwrity task execution: {e}")
|
||||
return TaskExecutionResult(
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
retryable=True
|
||||
)
|
||||
|
||||
async def _execute_human_task(self, task: MonitoringTask, db: Session) -> TaskExecutionResult:
|
||||
"""
|
||||
Execute a Human monitoring task (notification/queuing).
|
||||
|
||||
For human tasks, we don't execute the task directly,
|
||||
but rather queue it for human review or send notifications.
|
||||
"""
|
||||
try:
|
||||
self.logger.info(f"Queuing human task: {task.task_title}")
|
||||
|
||||
# TODO: Implement notification/queuing system:
|
||||
# - Send email notification
|
||||
# - Add to user's task queue
|
||||
# - Create in-app notification
|
||||
|
||||
result_data = {
|
||||
'status': 'queued',
|
||||
'message': f"Task {task.task_title} queued for human review",
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
return TaskExecutionResult(
|
||||
success=True,
|
||||
result_data=result_data
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error queuing human task: {e}")
|
||||
return TaskExecutionResult(
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
retryable=True
|
||||
)
|
||||
|
||||
def calculate_next_execution(
|
||||
self,
|
||||
task: MonitoringTask,
|
||||
frequency: str,
|
||||
last_execution: Optional[datetime] = None
|
||||
) -> datetime:
|
||||
"""
|
||||
Calculate next execution time based on frequency.
|
||||
|
||||
Args:
|
||||
task: MonitoringTask instance
|
||||
frequency: Frequency string (Daily, Weekly, Monthly, Quarterly)
|
||||
last_execution: Last execution datetime (defaults to now)
|
||||
|
||||
Returns:
|
||||
Next execution datetime
|
||||
"""
|
||||
return calculate_next_execution(
|
||||
frequency=frequency,
|
||||
base_time=last_execution or datetime.utcnow()
|
||||
)
|
||||
|
||||
4
backend/services/scheduler/utils/__init__.py
Normal file
4
backend/services/scheduler/utils/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
Scheduler utilities.
|
||||
"""
|
||||
|
||||
33
backend/services/scheduler/utils/frequency_calculator.py
Normal file
33
backend/services/scheduler/utils/frequency_calculator.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""
|
||||
Frequency Calculator Utility
|
||||
Calculates next execution time based on frequency string.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def calculate_next_execution(frequency: str, base_time: Optional[datetime] = None) -> datetime:
|
||||
"""
|
||||
Calculate next execution time based on frequency.
|
||||
|
||||
Args:
|
||||
frequency: Frequency string ('Daily', 'Weekly', 'Monthly', 'Quarterly')
|
||||
base_time: Base time to calculate from (defaults to now if None)
|
||||
|
||||
Returns:
|
||||
Next execution datetime
|
||||
"""
|
||||
if base_time is None:
|
||||
base_time = datetime.utcnow()
|
||||
|
||||
frequency_map = {
|
||||
'Daily': timedelta(days=1),
|
||||
'Weekly': timedelta(weeks=1),
|
||||
'Monthly': timedelta(days=30),
|
||||
'Quarterly': timedelta(days=90)
|
||||
}
|
||||
|
||||
delta = frequency_map.get(frequency, timedelta(days=1))
|
||||
return base_time + delta
|
||||
|
||||
60
backend/services/scheduler/utils/task_loader.py
Normal file
60
backend/services/scheduler/utils/task_loader.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Task Loader Utilities
|
||||
Functions to load due tasks from database.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy import and_, or_
|
||||
|
||||
from models.monitoring_models import MonitoringTask
|
||||
from models.enhanced_strategy_models import EnhancedContentStrategy
|
||||
|
||||
|
||||
def load_due_monitoring_tasks(
|
||||
db: Session,
|
||||
user_id: Optional[int] = None
|
||||
) -> List[MonitoringTask]:
|
||||
"""
|
||||
Load all monitoring tasks that are due for execution.
|
||||
|
||||
Criteria:
|
||||
- status == 'active'
|
||||
- next_execution <= now (or is None for first execution)
|
||||
- Optional: user_id filter for specific user (for future admin features)
|
||||
|
||||
Note: Strategy relationship is eagerly loaded to ensure user_id is accessible
|
||||
during task execution for user isolation.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: Optional user ID to filter tasks (if None, loads all users' tasks)
|
||||
|
||||
Returns:
|
||||
List of due MonitoringTask instances with strategy relationship loaded
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
|
||||
# Join with strategy to ensure relationship is loaded and support user filtering
|
||||
query = db.query(MonitoringTask).join(
|
||||
EnhancedContentStrategy,
|
||||
MonitoringTask.strategy_id == EnhancedContentStrategy.id
|
||||
).options(
|
||||
joinedload(MonitoringTask.strategy) # Eagerly load strategy relationship
|
||||
).filter(
|
||||
and_(
|
||||
MonitoringTask.status == 'active',
|
||||
or_(
|
||||
MonitoringTask.next_execution <= now,
|
||||
MonitoringTask.next_execution.is_(None)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Apply user filter if provided
|
||||
if user_id is not None:
|
||||
query = query.filter(EnhancedContentStrategy.user_id == user_id)
|
||||
|
||||
return query.all()
|
||||
|
||||
189
backend/services/subscription/preflight_validator.py
Normal file
189
backend/services/subscription/preflight_validator.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
Pre-flight Validation Utility for Multi-Operation Workflows
|
||||
|
||||
Provides transparent validation for operations that involve multiple API calls.
|
||||
Services can use this to validate entire workflows before making any external API calls.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from services.subscription.pricing_service import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
|
||||
|
||||
def validate_research_operations(
|
||||
pricing_service: PricingService,
|
||||
user_id: str,
|
||||
gpt_provider: str = "google"
|
||||
) -> None:
|
||||
"""
|
||||
Validate all operations for a research workflow before making ANY API calls.
|
||||
|
||||
This prevents wasteful external API calls (like Google Grounding) if subsequent
|
||||
LLM calls would fail due to token or call limits.
|
||||
|
||||
Args:
|
||||
pricing_service: PricingService instance
|
||||
user_id: User ID for subscription checking
|
||||
gpt_provider: GPT provider from env var (defaults to "google")
|
||||
|
||||
Returns:
|
||||
(can_proceed, error_message, error_details)
|
||||
If can_proceed is False, raises HTTPException with 429 status
|
||||
"""
|
||||
try:
|
||||
# Determine actual provider for LLM calls based on GPT_PROVIDER env var
|
||||
gpt_provider_lower = gpt_provider.lower()
|
||||
if gpt_provider_lower == "huggingface":
|
||||
llm_provider_enum = APIProvider.MISTRAL # Maps to HuggingFace
|
||||
llm_provider_name = "huggingface"
|
||||
else:
|
||||
llm_provider_enum = APIProvider.GEMINI
|
||||
llm_provider_name = "gemini"
|
||||
|
||||
# Estimate tokens for each operation in research workflow
|
||||
# Google Grounding call: ~2000 tokens (input + output)
|
||||
# Keyword analyzer: ~1000 tokens (input: 3000 chars research, output: structured JSON)
|
||||
# Competitor analyzer: ~1000 tokens (input: 3000 chars research, output: structured JSON)
|
||||
# Content angle generator: ~1000 tokens (input: 3000 chars research, output: list of angles)
|
||||
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': APIProvider.GEMINI, # Google Grounding uses Gemini
|
||||
'tokens_requested': 2000,
|
||||
'actual_provider_name': 'gemini',
|
||||
'operation_type': 'google_grounding'
|
||||
},
|
||||
{
|
||||
'provider': llm_provider_enum,
|
||||
'tokens_requested': 1000,
|
||||
'actual_provider_name': llm_provider_name,
|
||||
'operation_type': 'keyword_analysis'
|
||||
},
|
||||
{
|
||||
'provider': llm_provider_enum,
|
||||
'tokens_requested': 1000,
|
||||
'actual_provider_name': llm_provider_name,
|
||||
'operation_type': 'competitor_analysis'
|
||||
},
|
||||
{
|
||||
'provider': llm_provider_enum,
|
||||
'tokens_requested': 1000,
|
||||
'actual_provider_name': llm_provider_name,
|
||||
'operation_type': 'content_angle_generation'
|
||||
}
|
||||
]
|
||||
|
||||
logger.info(f"[Pre-flight Validator] 🚀 Starting Research Workflow Validation")
|
||||
logger.info(f" ├─ User: {user_id}")
|
||||
logger.info(f" ├─ LLM Provider: {llm_provider_name} (GPT_PROVIDER={gpt_provider})")
|
||||
logger.info(f" └─ Operations to validate: {len(operations_to_validate)}")
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
usage_info = error_details.get('usage_info', {}) if error_details else {}
|
||||
provider = usage_info.get('provider', llm_provider_name) if usage_info else llm_provider_name
|
||||
operation_type = usage_info.get('operation_type', 'unknown')
|
||||
|
||||
logger.error(f"[Pre-flight Validator] ❌ RESEARCH WORKFLOW BLOCKED")
|
||||
logger.error(f" ├─ User: {user_id}")
|
||||
logger.error(f" ├─ Blocked at: {operation_type}")
|
||||
logger.error(f" ├─ Provider: {provider}")
|
||||
logger.error(f" └─ Reason: {message}")
|
||||
|
||||
# Raise HTTPException immediately - frontend gets immediate response, no API calls made
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': provider,
|
||||
'usage_info': usage_info if usage_info else error_details
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[Pre-flight Validator] ✅ RESEARCH WORKFLOW APPROVED")
|
||||
logger.info(f" ├─ User: {user_id}")
|
||||
logger.info(f" └─ All {len(operations_to_validate)} operations validated - proceeding with API calls")
|
||||
# Validation passed - no return needed (function raises HTTPException if validation fails)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Validator] Error validating research operations: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
'error': f"Failed to validate operations: {str(e)}",
|
||||
'message': f"Failed to validate operations: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def validate_image_generation_operations(
|
||||
pricing_service: PricingService,
|
||||
user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Validate image generation operation before making API calls.
|
||||
|
||||
Args:
|
||||
pricing_service: PricingService instance
|
||||
user_id: User ID for subscription checking
|
||||
|
||||
Returns:
|
||||
(can_proceed, error_message, error_details)
|
||||
If can_proceed is False, raises HTTPException with 429 status
|
||||
"""
|
||||
try:
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': APIProvider.STABILITY,
|
||||
'tokens_requested': 0,
|
||||
'actual_provider_name': 'stability',
|
||||
'operation_type': 'image_generation'
|
||||
}
|
||||
]
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
logger.error(f"[Pre-flight Validator] Image generation blocked for user {user_id}: {message}")
|
||||
|
||||
usage_info = error_details.get('usage_info', {}) if error_details else {}
|
||||
provider = usage_info.get('provider', 'stability') if usage_info else 'stability'
|
||||
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': provider,
|
||||
'usage_info': usage_info if usage_info else error_details
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[Pre-flight Validator] ✅ Image generation validated for user {user_id}")
|
||||
# Validation passed - no return needed (function raises HTTPException if validation fails)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Validator] Error validating image generation: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
'error': f"Failed to validate image generation: {str(e)}",
|
||||
'message': f"Failed to validate image generation: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
@@ -3,10 +3,11 @@ Pricing Service for API Usage Tracking
|
||||
Manages API pricing, cost calculation, and subscription limits.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List, Tuple
|
||||
from typing import Dict, Any, Optional, List, Tuple, Union
|
||||
from decimal import Decimal, ROUND_HALF_UP
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
from loguru import logger
|
||||
|
||||
from models.subscription_models import (
|
||||
@@ -17,13 +18,17 @@ from models.subscription_models import (
|
||||
class PricingService:
|
||||
"""Service for managing API pricing and cost calculations."""
|
||||
|
||||
# Class-level cache shared across all instances (critical for cache invalidation on subscription renewal)
|
||||
# key: f"{user_id}:{provider}", value: { 'result': (bool, str, dict), 'expires_at': datetime }
|
||||
_limits_cache: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self._pricing_cache = {}
|
||||
self._plans_cache = {}
|
||||
# Lightweight in-process cache for limit checks
|
||||
# key: f"{user_id}:{provider}", value: { 'result': (bool, str, dict), 'expires_at': datetime }
|
||||
self._limits_cache: Dict[str, Dict[str, Any]] = {}
|
||||
# Cache for schema feature detection (ai_text_generation_calls_limit column)
|
||||
self._ai_text_gen_col_checked: bool = False
|
||||
self._ai_text_gen_col_available: bool = False
|
||||
|
||||
# ------------------- Billing period helpers -------------------
|
||||
def _compute_next_period_end(self, start: datetime, cycle: str) -> datetime:
|
||||
@@ -68,6 +73,15 @@ class PricingService:
|
||||
self._ensure_subscription_current(subscription)
|
||||
# Continue to use YYYY-MM for summaries
|
||||
return datetime.now().strftime("%Y-%m")
|
||||
|
||||
@classmethod
|
||||
def clear_user_cache(cls, user_id: str) -> int:
|
||||
"""Clear all cached limit checks for a specific user. Returns number of entries cleared."""
|
||||
keys_to_remove = [key for key in cls._limits_cache.keys() if key.startswith(f"{user_id}:")]
|
||||
for key in keys_to_remove:
|
||||
del cls._limits_cache[key]
|
||||
logger.info(f"Cleared {len(keys_to_remove)} cache entries for user {user_id}")
|
||||
return len(keys_to_remove)
|
||||
|
||||
def initialize_default_pricing(self):
|
||||
"""Initialize default pricing for all API providers."""
|
||||
@@ -292,7 +306,8 @@ class PricingService:
|
||||
"tier": SubscriptionTier.BASIC,
|
||||
"price_monthly": 29.0,
|
||||
"price_yearly": 290.0,
|
||||
"gemini_calls_limit": 1000,
|
||||
"ai_text_generation_calls_limit": 10, # Unified limit for all LLM providers
|
||||
"gemini_calls_limit": 1000, # Legacy, kept for backwards compatibility (not used for enforcement)
|
||||
"openai_calls_limit": 500,
|
||||
"anthropic_calls_limit": 200,
|
||||
"mistral_calls_limit": 500,
|
||||
@@ -300,11 +315,11 @@ class PricingService:
|
||||
"serper_calls_limit": 200,
|
||||
"metaphor_calls_limit": 100,
|
||||
"firecrawl_calls_limit": 100,
|
||||
"stability_calls_limit": 50,
|
||||
"gemini_tokens_limit": 1000000,
|
||||
"openai_tokens_limit": 500000,
|
||||
"anthropic_tokens_limit": 200000,
|
||||
"mistral_tokens_limit": 500000,
|
||||
"stability_calls_limit": 5,
|
||||
"gemini_tokens_limit": 2000,
|
||||
"openai_tokens_limit": 2000,
|
||||
"anthropic_tokens_limit": 2000,
|
||||
"mistral_tokens_limit": 2000,
|
||||
"monthly_cost_limit": 50.0,
|
||||
"features": ["full_content_generation", "advanced_research", "basic_analytics"],
|
||||
"description": "Great for individuals and small teams"
|
||||
@@ -426,21 +441,60 @@ class PricingService:
|
||||
self._ensure_subscription_current(subscription)
|
||||
return self._plan_to_limits_dict(subscription.plan)
|
||||
|
||||
def _ensure_ai_text_gen_column_detection(self) -> None:
|
||||
"""Detect at runtime whether ai_text_generation_calls_limit column exists and cache the result."""
|
||||
if self._ai_text_gen_col_checked:
|
||||
return
|
||||
try:
|
||||
# Try to query the column - if it exists, this will work
|
||||
self.db.execute(text('SELECT ai_text_generation_calls_limit FROM subscription_plans LIMIT 0'))
|
||||
self._ai_text_gen_col_available = True
|
||||
except Exception:
|
||||
self._ai_text_gen_col_available = False
|
||||
finally:
|
||||
self._ai_text_gen_col_checked = True
|
||||
|
||||
def _plan_to_limits_dict(self, plan: SubscriptionPlan) -> Dict[str, Any]:
|
||||
"""Convert subscription plan to limits dictionary."""
|
||||
# Detect if unified AI text generation limit column exists
|
||||
self._ensure_ai_text_gen_column_detection()
|
||||
|
||||
# Use unified AI text generation limit if column exists and is set
|
||||
ai_text_gen_limit = None
|
||||
if self._ai_text_gen_col_available:
|
||||
try:
|
||||
ai_text_gen_limit = getattr(plan, 'ai_text_generation_calls_limit', None)
|
||||
# If 0, treat as not set (unlimited for Enterprise or use fallback)
|
||||
if ai_text_gen_limit == 0:
|
||||
ai_text_gen_limit = None
|
||||
except (AttributeError, Exception):
|
||||
# Column exists but access failed - use fallback
|
||||
ai_text_gen_limit = None
|
||||
|
||||
return {
|
||||
'plan_name': plan.name,
|
||||
'tier': plan.tier.value,
|
||||
'limits': {
|
||||
# Unified AI text generation limit (applies to all LLM providers)
|
||||
# If not set, fall back to first non-zero legacy limit for backwards compatibility
|
||||
'ai_text_generation_calls': ai_text_gen_limit if ai_text_gen_limit is not None else (
|
||||
plan.gemini_calls_limit if plan.gemini_calls_limit > 0 else
|
||||
plan.openai_calls_limit if plan.openai_calls_limit > 0 else
|
||||
plan.anthropic_calls_limit if plan.anthropic_calls_limit > 0 else
|
||||
plan.mistral_calls_limit if plan.mistral_calls_limit > 0 else 0
|
||||
),
|
||||
# Legacy per-provider limits (for backwards compatibility and analytics)
|
||||
'gemini_calls': plan.gemini_calls_limit,
|
||||
'openai_calls': plan.openai_calls_limit,
|
||||
'anthropic_calls': plan.anthropic_calls_limit,
|
||||
'mistral_calls': plan.mistral_calls_limit,
|
||||
# Other API limits
|
||||
'tavily_calls': plan.tavily_calls_limit,
|
||||
'serper_calls': plan.serper_calls_limit,
|
||||
'metaphor_calls': plan.metaphor_calls_limit,
|
||||
'firecrawl_calls': plan.firecrawl_calls_limit,
|
||||
'stability_calls': plan.stability_calls_limit,
|
||||
# Token limits
|
||||
'gemini_tokens': plan.gemini_tokens_limit,
|
||||
'openai_tokens': plan.openai_tokens_limit,
|
||||
'anthropic_tokens': plan.anthropic_tokens_limit,
|
||||
@@ -451,101 +505,293 @@ class PricingService:
|
||||
}
|
||||
|
||||
def check_usage_limits(self, user_id: str, provider: APIProvider,
|
||||
tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]:
|
||||
"""Check if user can make an API call within their limits."""
|
||||
# Short TTL cache to reduce DB reads under sustained traffic
|
||||
cache_key = f"{user_id}:{provider.value}"
|
||||
now = datetime.utcnow()
|
||||
cached = self._limits_cache.get(cache_key)
|
||||
if cached and cached.get('expires_at') and cached['expires_at'] > now:
|
||||
return tuple(cached['result']) # type: ignore
|
||||
|
||||
# Get user limits
|
||||
limits = self.get_user_limits(user_id)
|
||||
if not limits:
|
||||
return False, "No subscription plan found", {}
|
||||
tokens_requested: int = 0, actual_provider_name: Optional[str] = None) -> Tuple[bool, str, Dict[str, Any]]:
|
||||
"""Check if user can make an API call within their limits.
|
||||
|
||||
# Get current usage for this billing period
|
||||
current_period = self.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not usage:
|
||||
# First usage this period, create summary
|
||||
usage = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
self.db.add(usage)
|
||||
self.db.commit()
|
||||
|
||||
# Check call limits
|
||||
provider_name = provider.value
|
||||
current_calls = getattr(usage, f"{provider_name}_calls", 0)
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
|
||||
|
||||
if call_limit > 0 and current_calls >= call_limit:
|
||||
result = (False, f"API call limit reached for {provider_name}", {
|
||||
'current_calls': current_calls,
|
||||
'limit': call_limit,
|
||||
'usage_percentage': 100.0
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
|
||||
# Check token limits for LLM providers
|
||||
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
current_tokens = getattr(usage, f"{provider_name}_tokens", 0)
|
||||
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0)
|
||||
Args:
|
||||
user_id: User ID
|
||||
provider: APIProvider enum (may be MISTRAL for HuggingFace)
|
||||
tokens_requested: Estimated tokens for the request
|
||||
actual_provider_name: Optional actual provider name (e.g., "huggingface" when provider is MISTRAL)
|
||||
"""
|
||||
try:
|
||||
# Use actual_provider_name if provided, otherwise use enum value
|
||||
# This fixes cases where HuggingFace maps to MISTRAL enum but should show as "huggingface" in errors
|
||||
display_provider_name = actual_provider_name or provider.value
|
||||
|
||||
if token_limit > 0 and (current_tokens + tokens_requested) > token_limit:
|
||||
result = (False, f"Token limit would be exceeded for {provider_name}", {
|
||||
'current_tokens': current_tokens,
|
||||
'requested_tokens': tokens_requested,
|
||||
'limit': token_limit,
|
||||
'usage_percentage': ((current_tokens + tokens_requested) / token_limit) * 100
|
||||
logger.debug(f"[Subscription Check] Starting limit check for user {user_id}, provider {display_provider_name}, tokens {tokens_requested}")
|
||||
|
||||
# Short TTL cache to reduce DB reads under sustained traffic
|
||||
cache_key = f"{user_id}:{provider.value}"
|
||||
now = datetime.utcnow()
|
||||
cached = self._limits_cache.get(cache_key)
|
||||
if cached and cached.get('expires_at') and cached['expires_at'] > now:
|
||||
logger.debug(f"[Subscription Check] Using cached result for {user_id}:{provider.value}")
|
||||
return tuple(cached['result']) # type: ignore
|
||||
|
||||
# Get user subscription first to check expiration
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id,
|
||||
UserSubscription.is_active == True
|
||||
).first()
|
||||
|
||||
if subscription:
|
||||
logger.debug(f"[Subscription Check] Found subscription for user {user_id}: plan_id={subscription.plan_id}, period_end={subscription.current_period_end}")
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] No active subscription found for user {user_id}")
|
||||
|
||||
# Check subscription expiration (STRICT: deny if expired)
|
||||
if subscription:
|
||||
if subscription.current_period_end < now:
|
||||
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}: period_end={subscription.current_period_end}, now={now}")
|
||||
# Subscription expired - check if auto_renew is enabled
|
||||
if not getattr(subscription, 'auto_renew', False):
|
||||
# Expired and no auto-renew - deny access
|
||||
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}, auto_renew=False, denying access")
|
||||
result = (False, "Subscription expired. Please renew your subscription to continue using the service.", {
|
||||
'expired': True,
|
||||
'period_end': subscription.current_period_end.isoformat()
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
else:
|
||||
# Try to auto-renew
|
||||
if not self._ensure_subscription_current(subscription):
|
||||
# Auto-renew failed - deny access
|
||||
result = (False, "Subscription expired and auto-renewal failed. Please renew manually.", {
|
||||
'expired': True,
|
||||
'auto_renew_failed': True
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
|
||||
# Get user limits with error handling (STRICT: fail on errors)
|
||||
try:
|
||||
limits = self.get_user_limits(user_id)
|
||||
if limits:
|
||||
logger.debug(f"[Subscription Check] Retrieved limits for user {user_id}: plan={limits.get('plan_name')}, tier={limits.get('tier')}")
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] No limits found for user {user_id}, checking free tier")
|
||||
except Exception as e:
|
||||
logger.error(f"[Subscription Check] Error getting user limits for {user_id}: {e}", exc_info=True)
|
||||
# STRICT: Fail closed - deny request if we can't check limits
|
||||
return False, f"Failed to retrieve subscription limits: {str(e)}", {}
|
||||
|
||||
if not limits:
|
||||
# No subscription found - check for free tier
|
||||
free_plan = self.db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.tier == SubscriptionTier.FREE,
|
||||
SubscriptionPlan.is_active == True
|
||||
).first()
|
||||
if free_plan:
|
||||
logger.info(f"[Subscription Check] Assigning free tier to user {user_id}")
|
||||
limits = self._plan_to_limits_dict(free_plan)
|
||||
else:
|
||||
# No subscription and no free tier - deny access
|
||||
logger.warning(f"[Subscription Check] No subscription or free tier found for user {user_id}, denying access")
|
||||
return False, "No subscription plan found. Please subscribe to a plan.", {}
|
||||
|
||||
# Get current usage for this billing period with error handling
|
||||
try:
|
||||
current_period = self.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not usage:
|
||||
# First usage this period, create summary
|
||||
try:
|
||||
usage = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
self.db.add(usage)
|
||||
self.db.commit()
|
||||
except Exception as create_error:
|
||||
logger.error(f"Error creating usage summary: {create_error}")
|
||||
self.db.rollback()
|
||||
# STRICT: Fail closed on DB error
|
||||
return False, f"Failed to create usage summary: {str(create_error)}", {}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting usage summary for {user_id}: {e}")
|
||||
self.db.rollback()
|
||||
# STRICT: Fail closed on DB error
|
||||
return False, f"Failed to retrieve usage summary: {str(e)}", {}
|
||||
|
||||
# Check call limits with error handling
|
||||
# NOTE: call_limit = 0 means UNLIMITED (Enterprise plans)
|
||||
try:
|
||||
# Use display_provider_name for error messages, but provider.value for DB queries
|
||||
provider_name = provider.value # For DB field names (e.g., "mistral_calls", "mistral_tokens")
|
||||
|
||||
# For LLM text generation providers, check against unified total_calls limit
|
||||
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
|
||||
is_llm_provider = provider_name in llm_providers
|
||||
|
||||
if is_llm_provider:
|
||||
# Use unified AI text generation limit (total_calls across all LLM providers)
|
||||
ai_text_gen_limit = limits['limits'].get('ai_text_generation_calls', 0) or 0
|
||||
|
||||
# If unified limit not set, fall back to provider-specific limit for backwards compatibility
|
||||
if ai_text_gen_limit == 0:
|
||||
ai_text_gen_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
|
||||
|
||||
# Calculate total LLM provider calls (sum of gemini + openai + anthropic + mistral)
|
||||
current_total_llm_calls = (
|
||||
(usage.gemini_calls or 0) +
|
||||
(usage.openai_calls or 0) +
|
||||
(usage.anthropic_calls or 0) +
|
||||
(usage.mistral_calls or 0)
|
||||
)
|
||||
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if ai_text_gen_limit > 0 and current_total_llm_calls >= ai_text_gen_limit:
|
||||
logger.error(f"[Subscription Check] AI text generation call limit exceeded for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit} (provider: {display_provider_name})")
|
||||
result = (False, f"AI text generation call limit reached. Used {current_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls this billing period.", {
|
||||
'current_calls': current_total_llm_calls,
|
||||
'limit': ai_text_gen_limit,
|
||||
'usage_percentage': (current_total_llm_calls / ai_text_gen_limit) * 100 if ai_text_gen_limit > 0 else 0,
|
||||
'provider': display_provider_name, # Use display name for consistency
|
||||
'usage_info': {
|
||||
'provider': display_provider_name, # Use display name for user-facing info
|
||||
'current_calls': current_total_llm_calls,
|
||||
'limit': ai_text_gen_limit,
|
||||
'type': 'ai_text_generation',
|
||||
'breakdown': {
|
||||
'gemini': usage.gemini_calls or 0,
|
||||
'openai': usage.openai_calls or 0,
|
||||
'anthropic': usage.anthropic_calls or 0,
|
||||
'mistral': usage.mistral_calls or 0 # DB field name (not display name)
|
||||
}
|
||||
}
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] AI text generation limit check passed for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit if ai_text_gen_limit > 0 else 'unlimited'} (provider: {display_provider_name})")
|
||||
else:
|
||||
# For non-LLM providers, check provider-specific limit
|
||||
current_calls = getattr(usage, f"{provider_name}_calls", 0) or 0
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
|
||||
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if call_limit > 0 and current_calls >= call_limit:
|
||||
logger.error(f"[Subscription Check] Call limit exceeded for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit}")
|
||||
result = (False, f"API call limit reached for {display_provider_name}. Used {current_calls} of {call_limit} calls this billing period.", {
|
||||
'current_calls': current_calls,
|
||||
'limit': call_limit,
|
||||
'usage_percentage': 100.0,
|
||||
'provider': display_provider_name # Use display name for consistency
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] Call limit check passed for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit if call_limit > 0 else 'unlimited'}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking call limits: {e}")
|
||||
# Continue to next check
|
||||
|
||||
# Check token limits for LLM providers with error handling
|
||||
# NOTE: token_limit = 0 means UNLIMITED (Enterprise plans)
|
||||
try:
|
||||
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
current_tokens = getattr(usage, f"{provider_name}_tokens", 0) or 0
|
||||
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
|
||||
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if token_limit > 0 and (current_tokens + tokens_requested) > token_limit:
|
||||
result = (False, f"Token limit would be exceeded for {display_provider_name}. Current: {current_tokens}, Requested: {tokens_requested}, Limit: {token_limit}", {
|
||||
'current_tokens': current_tokens,
|
||||
'requested_tokens': tokens_requested,
|
||||
'limit': token_limit,
|
||||
'usage_percentage': ((current_tokens + tokens_requested) / token_limit) * 100,
|
||||
'provider': display_provider_name, # Use display name in error details
|
||||
'usage_info': {
|
||||
'provider': display_provider_name,
|
||||
'current_tokens': current_tokens,
|
||||
'requested_tokens': tokens_requested,
|
||||
'limit': token_limit,
|
||||
'type': 'tokens'
|
||||
}
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking token limits: {e}")
|
||||
# Continue to next check
|
||||
|
||||
# Check cost limits with error handling
|
||||
# NOTE: cost_limit = 0 means UNLIMITED (Enterprise plans)
|
||||
try:
|
||||
cost_limit = limits['limits'].get('monthly_cost', 0) or 0
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if cost_limit > 0 and usage.total_cost >= cost_limit:
|
||||
result = (False, f"Monthly cost limit reached. Current cost: ${usage.total_cost:.2f}, Limit: ${cost_limit:.2f}", {
|
||||
'current_cost': usage.total_cost,
|
||||
'limit': cost_limit,
|
||||
'usage_percentage': 100.0
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking cost limits: {e}")
|
||||
# Continue to success case
|
||||
|
||||
# Calculate usage percentages for warnings
|
||||
try:
|
||||
# Determine which call variables to use based on provider type
|
||||
if is_llm_provider:
|
||||
# Use unified LLM call tracking
|
||||
current_call_count = current_total_llm_calls
|
||||
call_limit_value = ai_text_gen_limit
|
||||
else:
|
||||
# Use provider-specific call tracking
|
||||
current_call_count = current_calls
|
||||
call_limit_value = call_limit
|
||||
|
||||
call_usage_pct = (current_call_count / max(call_limit_value, 1)) * 100 if call_limit_value > 0 else 0
|
||||
cost_usage_pct = (usage.total_cost / max(cost_limit, 1)) * 100 if cost_limit > 0 else 0
|
||||
result = (True, "Within limits", {
|
||||
'current_calls': current_call_count,
|
||||
'call_limit': call_limit_value,
|
||||
'call_usage_percentage': call_usage_pct,
|
||||
'current_cost': usage.total_cost,
|
||||
'cost_limit': cost_limit,
|
||||
'cost_usage_percentage': cost_usage_pct
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating usage percentages: {e}")
|
||||
# Return basic success
|
||||
return True, "Within limits", {}
|
||||
|
||||
# Check cost limits
|
||||
cost_limit = limits['limits'].get('monthly_cost', 0)
|
||||
if cost_limit > 0 and usage.total_cost >= cost_limit:
|
||||
result = (False, "Monthly cost limit reached", {
|
||||
'current_cost': usage.total_cost,
|
||||
'limit': cost_limit,
|
||||
'usage_percentage': 100.0
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
|
||||
# Calculate usage percentages for warnings
|
||||
call_usage_pct = (current_calls / max(call_limit, 1)) * 100 if call_limit > 0 else 0
|
||||
cost_usage_pct = (usage.total_cost / max(cost_limit, 1)) * 100 if cost_limit > 0 else 0
|
||||
result = (True, "Within limits", {
|
||||
'current_calls': current_calls,
|
||||
'call_limit': call_limit,
|
||||
'call_usage_percentage': call_usage_pct,
|
||||
'current_cost': usage.total_cost,
|
||||
'cost_limit': cost_limit,
|
||||
'cost_usage_percentage': cost_usage_pct
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in check_usage_limits for {user_id}: {e}")
|
||||
# STRICT: Fail closed - deny requests if subscription system fails
|
||||
return False, f"Subscription check error: {str(e)}", {}
|
||||
|
||||
def estimate_tokens(self, text: str, provider: APIProvider) -> int:
|
||||
"""Estimate token count for text based on provider."""
|
||||
@@ -581,6 +827,236 @@ class PricingService:
|
||||
if not pricing:
|
||||
return None
|
||||
|
||||
def check_comprehensive_limits(
|
||||
self,
|
||||
user_id: str,
|
||||
operations: List[Dict[str, Any]]
|
||||
) -> Tuple[bool, Optional[str], Optional[Dict[str, Any]]]:
|
||||
"""
|
||||
Comprehensive pre-flight validation that checks ALL limits before making ANY API calls.
|
||||
|
||||
This prevents wasteful API calls by validating that ALL subsequent operations will succeed
|
||||
before making the first external API call.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
operations: List of operations to validate, each with:
|
||||
- 'provider': APIProvider enum
|
||||
- 'tokens_requested': int (estimated tokens for LLM calls, 0 for non-LLM)
|
||||
- 'actual_provider_name': Optional[str] (e.g., "huggingface" when provider is MISTRAL)
|
||||
- 'operation_type': str (e.g., "google_grounding", "llm_call", "image_generation")
|
||||
|
||||
Returns:
|
||||
(can_proceed, error_message, error_details)
|
||||
If can_proceed is False, error_message explains which limit would be exceeded
|
||||
"""
|
||||
try:
|
||||
logger.info(f"[Pre-flight Check] 🔍 Starting comprehensive validation for user {user_id}")
|
||||
logger.info(f"[Pre-flight Check] 📋 Validating {len(operations)} operation(s) before making any API calls")
|
||||
|
||||
# Get current usage and limits once
|
||||
current_period = self.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not usage:
|
||||
# First usage this period, create summary
|
||||
try:
|
||||
usage = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
self.db.add(usage)
|
||||
self.db.commit()
|
||||
except Exception as create_error:
|
||||
logger.error(f"Error creating usage summary: {create_error}")
|
||||
self.db.rollback()
|
||||
return False, f"Failed to create usage summary: {str(create_error)}", {}
|
||||
|
||||
# Get user limits
|
||||
limits_dict = self.get_user_limits(user_id)
|
||||
if not limits_dict:
|
||||
# No subscription found - check for free tier
|
||||
free_plan = self.db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.tier == SubscriptionTier.FREE,
|
||||
SubscriptionPlan.is_active == True
|
||||
).first()
|
||||
if free_plan:
|
||||
limits_dict = self._plan_to_limits_dict(free_plan)
|
||||
else:
|
||||
return False, "No subscription plan found. Please subscribe to a plan.", {}
|
||||
|
||||
limits = limits_dict.get('limits', {})
|
||||
|
||||
# Track cumulative usage across all operations
|
||||
total_llm_calls = (
|
||||
(usage.gemini_calls or 0) +
|
||||
(usage.openai_calls or 0) +
|
||||
(usage.anthropic_calls or 0) +
|
||||
(usage.mistral_calls or 0)
|
||||
)
|
||||
total_llm_tokens = {}
|
||||
total_images = usage.stability_calls or 0
|
||||
|
||||
# Log current usage summary
|
||||
logger.info(f"[Pre-flight Check] 📊 Current Usage Summary:")
|
||||
logger.info(f" └─ Total LLM Calls: {total_llm_calls}")
|
||||
logger.info(f" └─ Gemini Tokens: {usage.gemini_tokens or 0}, Mistral/HF Tokens: {usage.mistral_tokens or 0}")
|
||||
logger.info(f" └─ Image Calls: {total_images}")
|
||||
|
||||
# Validate each operation
|
||||
for op_idx, operation in enumerate(operations):
|
||||
provider = operation.get('provider')
|
||||
provider_name = provider.value if hasattr(provider, 'value') else str(provider)
|
||||
tokens_requested = operation.get('tokens_requested', 0)
|
||||
actual_provider_name = operation.get('actual_provider_name')
|
||||
operation_type = operation.get('operation_type', 'unknown')
|
||||
|
||||
display_provider_name = actual_provider_name or provider_name
|
||||
|
||||
logger.info(f"[Pre-flight Check] ✅ Operation {op_idx + 1}/{len(operations)}: {operation_type}")
|
||||
logger.info(f" ├─ Provider: {display_provider_name} (enum: {provider_name})")
|
||||
logger.info(f" └─ Estimated Tokens: {tokens_requested}")
|
||||
|
||||
# Check if this is an LLM provider
|
||||
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
|
||||
is_llm_provider = provider_name in llm_providers
|
||||
|
||||
# Check unified AI text generation limit for LLM providers
|
||||
if is_llm_provider:
|
||||
ai_text_gen_limit = limits.get('ai_text_generation_calls', 0) or 0
|
||||
if ai_text_gen_limit == 0:
|
||||
# Fallback to provider-specific limit
|
||||
ai_text_gen_limit = limits.get(f"{provider_name}_calls", 0) or 0
|
||||
|
||||
# Count this operation as an LLM call
|
||||
projected_total_llm_calls = total_llm_calls + 1
|
||||
|
||||
if ai_text_gen_limit > 0 and projected_total_llm_calls > ai_text_gen_limit:
|
||||
error_info = {
|
||||
'current_calls': total_llm_calls,
|
||||
'limit': ai_text_gen_limit,
|
||||
'provider': display_provider_name,
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
return False, f"AI text generation call limit would be exceeded. Would use {projected_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls.", {
|
||||
'error_type': 'call_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
|
||||
# Check token limits for this provider
|
||||
# Use cumulative projected tokens from previous operations, or current from DB if first operation
|
||||
provider_tokens_key = f"{provider_name}_tokens"
|
||||
if provider_tokens_key in total_llm_tokens:
|
||||
# Use cumulative projected tokens from previous operations
|
||||
current_provider_tokens = total_llm_tokens[provider_tokens_key]
|
||||
logger.info(f" └─ Using cumulative projected tokens: {current_provider_tokens}")
|
||||
else:
|
||||
# First operation for this provider - get current from database
|
||||
current_provider_tokens = getattr(usage, provider_tokens_key, 0) or 0
|
||||
total_llm_tokens[provider_tokens_key] = current_provider_tokens
|
||||
logger.info(f" └─ Current tokens from DB: {current_provider_tokens}")
|
||||
|
||||
token_limit = limits.get(provider_tokens_key, 0) or 0
|
||||
|
||||
if token_limit > 0 and tokens_requested > 0:
|
||||
projected_tokens = current_provider_tokens + tokens_requested
|
||||
logger.info(f" └─ Token Check: {current_provider_tokens} (current) + {tokens_requested} (requested) = {projected_tokens} (total) / {token_limit} (limit)")
|
||||
|
||||
if projected_tokens > token_limit:
|
||||
usage_percentage = (projected_tokens / token_limit) * 100 if token_limit > 0 else 0
|
||||
error_info = {
|
||||
'current_tokens': current_provider_tokens,
|
||||
'requested_tokens': tokens_requested,
|
||||
'limit': token_limit,
|
||||
'provider': display_provider_name,
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
error_msg = (
|
||||
f"Token limit exceeded for {display_provider_name} "
|
||||
f"({operation_type}). "
|
||||
f"Current: {current_provider_tokens}/{token_limit}, "
|
||||
f"Requested: {tokens_requested}, "
|
||||
f"Would exceed by: {projected_tokens - token_limit} tokens "
|
||||
f"({usage_percentage:.1f}% of limit)"
|
||||
)
|
||||
logger.error(f"[Pre-flight Check] ❌ BLOCKED: {error_msg}")
|
||||
return False, error_msg, {
|
||||
'error_type': 'token_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
else:
|
||||
logger.info(f" └─ ✅ Token limit check passed: {projected_tokens} <= {token_limit}")
|
||||
|
||||
# Update cumulative counts for next operation
|
||||
total_llm_calls = projected_total_llm_calls
|
||||
total_llm_tokens[provider_tokens_key] += tokens_requested
|
||||
logger.info(f" └─ Updated cumulative tokens for {display_provider_name}: {total_llm_tokens[provider_tokens_key]}")
|
||||
|
||||
# Check image generation limits
|
||||
elif provider == APIProvider.STABILITY:
|
||||
image_limit = limits.get('stability_calls', 0) or 0
|
||||
projected_images = total_images + 1
|
||||
|
||||
if image_limit > 0 and projected_images > image_limit:
|
||||
error_info = {
|
||||
'current_images': total_images,
|
||||
'limit': image_limit,
|
||||
'provider': 'stability',
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
return False, f"Image generation limit would be exceeded. Would use {projected_images} of {image_limit} images this billing period.", {
|
||||
'error_type': 'image_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
|
||||
total_images = projected_images
|
||||
|
||||
# Check other provider-specific limits
|
||||
else:
|
||||
provider_calls_key = f"{provider_name}_calls"
|
||||
current_provider_calls = getattr(usage, provider_calls_key, 0) or 0
|
||||
call_limit = limits.get(provider_calls_key, 0) or 0
|
||||
|
||||
if call_limit > 0:
|
||||
projected_calls = current_provider_calls + 1
|
||||
if projected_calls > call_limit:
|
||||
error_info = {
|
||||
'current_calls': current_provider_calls,
|
||||
'limit': call_limit,
|
||||
'provider': display_provider_name,
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
return False, f"API call limit would be exceeded for {display_provider_name}. Would use {projected_calls} of {call_limit} calls this billing period.", {
|
||||
'error_type': 'call_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
|
||||
# All checks passed
|
||||
logger.info(f"[Pre-flight Check] ✅ All {len(operations)} operation(s) validated successfully")
|
||||
logger.info(f"[Pre-flight Check] ✅ User {user_id} is cleared to proceed with API calls")
|
||||
return True, None, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Check] Error during comprehensive limit check: {e}", exc_info=True)
|
||||
return False, f"Failed to validate limits: {str(e)}", {}
|
||||
|
||||
def get_pricing_for_provider_model(self, provider: APIProvider, model_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get pricing configuration for a specific provider and model."""
|
||||
pricing = self.db.query(APIProviderPricing).filter(
|
||||
APIProviderPricing.provider == provider,
|
||||
APIProviderPricing.model_name == model_name
|
||||
).first()
|
||||
|
||||
if not pricing:
|
||||
return None
|
||||
|
||||
return {
|
||||
'provider': pricing.provider.value,
|
||||
'model_name': pricing.model_name,
|
||||
|
||||
@@ -502,7 +502,7 @@ class UsageTrackingService:
|
||||
return result
|
||||
|
||||
async def reset_current_billing_period(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Reset usage status for the current billing period (after plan change)."""
|
||||
"""Reset usage status and counters for the current billing period (after plan renewal/change)."""
|
||||
try:
|
||||
billing_period = datetime.now().strftime("%Y-%m")
|
||||
summary = self.db.query(UsageSummary).filter(
|
||||
@@ -514,11 +514,52 @@ class UsageTrackingService:
|
||||
# Nothing to reset
|
||||
return {"reset": False, "reason": "no_summary"}
|
||||
|
||||
# Clear LIMIT_REACHED so the user can resume; keep counters intact
|
||||
# CRITICAL: Reset ALL usage counters to 0 so user gets fresh limits with new/renewed plan
|
||||
# Clear LIMIT_REACHED status
|
||||
summary.usage_status = UsageStatus.ACTIVE
|
||||
|
||||
# Reset all LLM provider call counters
|
||||
summary.gemini_calls = 0
|
||||
summary.openai_calls = 0
|
||||
summary.anthropic_calls = 0
|
||||
summary.mistral_calls = 0
|
||||
|
||||
# Reset all LLM provider token counters
|
||||
summary.gemini_tokens = 0
|
||||
summary.openai_tokens = 0
|
||||
summary.anthropic_tokens = 0
|
||||
summary.mistral_tokens = 0
|
||||
|
||||
# Reset search/research provider counters
|
||||
summary.tavily_calls = 0
|
||||
summary.serper_calls = 0
|
||||
summary.metaphor_calls = 0
|
||||
summary.firecrawl_calls = 0
|
||||
|
||||
# Reset image generation counters
|
||||
summary.stability_calls = 0
|
||||
|
||||
# Reset cost counters
|
||||
summary.gemini_cost = 0.0
|
||||
summary.openai_cost = 0.0
|
||||
summary.anthropic_cost = 0.0
|
||||
summary.mistral_cost = 0.0
|
||||
summary.tavily_cost = 0.0
|
||||
summary.serper_cost = 0.0
|
||||
summary.metaphor_cost = 0.0
|
||||
summary.firecrawl_cost = 0.0
|
||||
summary.stability_cost = 0.0
|
||||
|
||||
# Reset totals
|
||||
summary.total_calls = 0
|
||||
summary.total_tokens = 0
|
||||
summary.total_cost = 0.0
|
||||
|
||||
summary.updated_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
return {"reset": True}
|
||||
|
||||
logger.info(f"Reset usage counters for user {user_id} in billing period {billing_period} after renewal")
|
||||
return {"reset": True, "counters_reset": True}
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"Error resetting usage status: {e}")
|
||||
|
||||
Reference in New Issue
Block a user