Subscription dashboard improvements, AI text generation limit, and other fixes.

This commit is contained in:
ajaysi
2025-11-01 18:01:14 +05:30
parent cdb41aec1b
commit de4328175d
64 changed files with 5809 additions and 444 deletions

View File

@@ -295,3 +295,55 @@ class ActiveStrategyService:
'cached_users': list(self._memory_cache.keys()),
'last_updates': {k: v.isoformat() for k, v in self._last_cache_update.items()}
}
def count_active_strategies_with_tasks(self) -> int:
"""
Count how many active strategies have monitoring tasks.
This is used for intelligent scheduling - if there are no active strategies
with tasks, the scheduler can check less frequently.
Returns:
Number of active strategies that have at least one active monitoring task
"""
try:
if not self.db_session:
logger.warning("Database session not available")
return 0
from sqlalchemy import func, and_
from models.monitoring_models import MonitoringTask
# Count distinct strategies that:
# 1. Have activation status = 'active'
# 2. Have at least one active monitoring task
count = self.db_session.query(
func.count(func.distinct(EnhancedContentStrategy.id))
).join(
StrategyActivationStatus,
EnhancedContentStrategy.id == StrategyActivationStatus.strategy_id
).join(
MonitoringTask,
EnhancedContentStrategy.id == MonitoringTask.strategy_id
).filter(
and_(
StrategyActivationStatus.status == 'active',
MonitoringTask.status == 'active'
)
).scalar()
return count or 0
except Exception as e:
logger.error(f"Error counting active strategies with tasks: {e}")
# On error, assume there are active strategies (safer to check more frequently)
return 1
def has_active_strategies_with_tasks(self) -> bool:
"""
Check if there are any active strategies with monitoring tasks.
Returns:
True if there are active strategies with tasks, False otherwise
"""
return self.count_active_strategies_with_tasks() > 0

View File

@@ -96,13 +96,13 @@ class BlogWriterService:
self.blog_rewriter = BlogRewriter(self.task_manager)
# Research Methods
async def research(self, request: BlogResearchRequest) -> BlogResearchResponse:
async def research(self, request: BlogResearchRequest, user_id: str) -> BlogResearchResponse:
"""Conduct comprehensive research using Google Search grounding."""
return await self.research_service.research(request)
return await self.research_service.research(request, user_id)
async def research_with_progress(self, request: BlogResearchRequest, task_id: str) -> BlogResearchResponse:
async def research_with_progress(self, request: BlogResearchRequest, task_id: str, user_id: str) -> BlogResearchResponse:
"""Conduct research with real-time progress updates."""
return await self.research_service.research_with_progress(request, task_id)
return await self.research_service.research_with_progress(request, task_id, user_id)
# Outline Methods
async def generate_outline(self, request: BlogOutlineRequest) -> BlogOutlineResponse:
@@ -204,11 +204,14 @@ class BlogWriterService:
except Exception as e:
return {"success": False, "error": str(e)}
async def seo_analyze(self, request: BlogSEOAnalyzeRequest) -> BlogSEOAnalyzeResponse:
async def seo_analyze(self, request: BlogSEOAnalyzeRequest, user_id: str = None) -> BlogSEOAnalyzeResponse:
"""Analyze content for SEO optimization using comprehensive blog-specific analyzer."""
try:
from services.blog_writer.seo.blog_content_seo_analyzer import BlogContentSEOAnalyzer
if not user_id:
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
content = request.content or ""
target_keywords = request.keywords or []
@@ -231,7 +234,7 @@ class BlogWriterService:
# Use our comprehensive SEO analyzer
analyzer = BlogContentSEOAnalyzer()
analysis_results = await analyzer.analyze_blog_content(content, research_data)
analysis_results = await analyzer.analyze_blog_content(content, research_data, user_id=user_id)
# Convert results to response format
recommendations = analysis_results.get('actionable_recommendations', [])
@@ -267,11 +270,14 @@ class BlogWriterService:
recommendations=[f"SEO analysis failed: {str(e)}"]
)
async def seo_metadata(self, request: BlogSEOMetadataRequest) -> BlogSEOMetadataResponse:
async def seo_metadata(self, request: BlogSEOMetadataRequest, user_id: str = None) -> BlogSEOMetadataResponse:
"""Generate comprehensive SEO metadata for content."""
try:
from services.blog_writer.seo.blog_seo_metadata_generator import BlogSEOMetadataGenerator
if not user_id:
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
# Initialize metadata generator
metadata_generator = BlogSEOMetadataGenerator()
@@ -285,7 +291,8 @@ class BlogWriterService:
blog_title=request.title or "Untitled Blog Post",
research_data=request.research_data or {},
outline=outline,
seo_analysis=seo_analysis
seo_analysis=seo_analysis,
user_id=user_id
)
# Convert to BlogSEOMetadataResponse format

View File

@@ -163,13 +163,18 @@ class BlogWriterLogger:
context: Optional[Dict[str, Any]] = None
):
"""Log error with full context."""
# Safely format error message to avoid KeyError on format strings in error messages
error_str = str(error)
# Replace any curly braces that might be in the error message to avoid format string issues
safe_error_str = error_str.replace('{', '{{').replace('}', '}}')
logger.error(
f"Error in {operation}: {str(error)}",
f"Error in {operation}: {safe_error_str}",
extra={
"event_type": "error",
"operation": operation,
"error_type": type(error).__name__,
"error_message": str(error),
"error_message": error_str, # Keep original in extra, but use safe version in format string
"context": context or {}
},
exc_info=True

View File

@@ -11,7 +11,7 @@ from loguru import logger
class CompetitorAnalyzer:
"""Analyzes competitors and market intelligence from research content."""
def analyze(self, content: str) -> Dict[str, Any]:
def analyze(self, content: str, user_id: str = None) -> Dict[str, Any]:
"""Parse comprehensive competitor analysis from the research content using AI."""
competitor_prompt = f"""
Analyze the following research content and extract competitor insights:
@@ -57,7 +57,8 @@ class CompetitorAnalyzer:
competitor_analysis = llm_text_gen(
prompt=competitor_prompt,
json_struct=competitor_schema
json_struct=competitor_schema,
user_id=user_id
)
if isinstance(competitor_analysis, dict) and 'error' not in competitor_analysis:

View File

@@ -11,7 +11,7 @@ from loguru import logger
class ContentAngleGenerator:
"""Generates strategic content angles from research content."""
def generate(self, content: str, topic: str, industry: str) -> List[str]:
def generate(self, content: str, topic: str, industry: str, user_id: str = None) -> List[str]:
"""Parse strategic content angles from the research content using AI."""
angles_prompt = f"""
Analyze the following research content and create strategic content angles for: {topic} in {industry}
@@ -65,7 +65,8 @@ class ContentAngleGenerator:
angles_result = llm_text_gen(
prompt=angles_prompt,
json_struct=angles_schema
json_struct=angles_schema,
user_id=user_id
)
if isinstance(angles_result, dict) and 'content_angles' in angles_result:

View File

@@ -11,7 +11,7 @@ from loguru import logger
class KeywordAnalyzer:
"""Analyzes keywords from research content using AI-powered extraction."""
def analyze(self, content: str, original_keywords: List[str]) -> Dict[str, Any]:
def analyze(self, content: str, original_keywords: List[str], user_id: str = None) -> Dict[str, Any]:
"""Parse comprehensive keyword analysis from the research content using AI."""
# Use AI to extract and analyze keywords from the rich research content
keyword_prompt = f"""
@@ -64,7 +64,8 @@ class KeywordAnalyzer:
keyword_analysis = llm_text_gen(
prompt=keyword_prompt,
json_struct=keyword_schema
json_struct=keyword_schema,
user_id=user_id
)
if isinstance(keyword_analysis, dict) and 'error' not in keyword_analysis:

View File

@@ -4,7 +4,8 @@ Research Service - Core research functionality for AI Blog Writer.
Handles Google Search grounding, caching, and research orchestration.
"""
from typing import Dict, Any, List
from typing import Dict, Any, List, Optional
from datetime import datetime
from loguru import logger
from models.blog_models import (
@@ -17,6 +18,7 @@ from models.blog_models import (
Citation,
)
from services.blog_writer.logger_config import blog_writer_logger, log_function_call
from fastapi import HTTPException
from .keyword_analyzer import KeywordAnalyzer
from .competitor_analyzer import CompetitorAnalyzer
@@ -34,7 +36,7 @@ class ResearchService:
self.data_filter = ResearchDataFilter()
@log_function_call("research_operation")
async def research(self, request: BlogResearchRequest) -> BlogResearchResponse:
async def research(self, request: BlogResearchRequest, user_id: str) -> BlogResearchResponse:
"""
Stage 1: Research & Strategy (AI Orchestration)
Uses ONLY Gemini's native Google Search grounding - ONE API call for everything.
@@ -71,6 +73,10 @@ class ResearchService:
blog_writer_logger.log_operation_end("research", 0, success=True, cache_hit=True)
return BlogResearchResponse(**cached_result)
# User ID validation (validation logic is now in Google Grounding provider)
if not user_id:
raise ValueError("user_id is required for research operation. Please provide Clerk user ID.")
# Cache miss - proceed with API call
logger.info(f"Cache miss - making API call for keywords: {request.keywords}")
blog_writer_logger.log_operation_start("gemini_api_call", api_name="gemini_grounded", operation="research")
@@ -96,12 +102,15 @@ class ResearchService:
"""
# Single Gemini call with native Google Search grounding - no fallbacks
# Validation is handled inside generate_grounded_content when validate_subsequent_operations=True
import time
api_start_time = time.time()
gemini_result = await gemini.generate_grounded_content(
prompt=research_prompt,
content_type="research",
max_tokens=2000
max_tokens=2000,
user_id=user_id,
validate_subsequent_operations=True # Validates Google Grounding + 3 LLM calls
)
api_duration_ms = (time.time() - api_start_time) * 1000
@@ -126,9 +135,9 @@ class ResearchService:
# Parse the comprehensive response for different analysis components
content = gemini_result.get("content", "")
keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords)
competitor_analysis = self.competitor_analyzer.analyze(content)
suggested_angles = self.content_angle_generator.generate(content, topic, industry)
keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords, user_id=user_id)
competitor_analysis = self.competitor_analyzer.analyze(content, user_id=user_id)
suggested_angles = self.content_angle_generator.generate(content, topic, industry, user_id=user_id)
logger.info(f"Research completed successfully with {len(sources)} sources and {len(search_queries)} search queries")
@@ -179,6 +188,9 @@ class ResearchService:
return filtered_response
except HTTPException:
# Re-raise HTTPException (subscription errors) - let task manager handle it
raise
except Exception as e:
error_message = str(e)
logger.error(f"Research failed: {error_message}")
@@ -244,7 +256,7 @@ class ResearchService:
)
@log_function_call("research_with_progress")
async def research_with_progress(self, request: BlogResearchRequest, task_id: str) -> BlogResearchResponse:
async def research_with_progress(self, request: BlogResearchRequest, task_id: str, user_id: str) -> BlogResearchResponse:
"""
Research method with progress updates for real-time feedback.
"""
@@ -281,6 +293,11 @@ class ResearchService:
logger.info(f"Returning cached research result for keywords: {request.keywords}")
return BlogResearchResponse(**cached_result)
# User ID validation (validation logic is now in Google Grounding provider)
if not user_id:
await task_manager.update_progress(task_id, "❌ Error: User ID is required for research operation")
raise ValueError("user_id is required for research operation. Please provide Clerk user ID.")
# Cache miss - proceed with API call
await task_manager.update_progress(task_id, "🌐 Cache miss - connecting to Google Search grounding...")
logger.info(f"Cache miss - making API call for keywords: {request.keywords}")
@@ -307,11 +324,20 @@ class ResearchService:
await task_manager.update_progress(task_id, "🤖 Making AI request to Gemini with Google Search grounding...")
# Single Gemini call with native Google Search grounding - no fallbacks
gemini_result = await gemini.generate_grounded_content(
prompt=research_prompt,
content_type="research",
max_tokens=2000
)
# Validation is handled inside generate_grounded_content when validate_subsequent_operations=True
try:
gemini_result = await gemini.generate_grounded_content(
prompt=research_prompt,
content_type="research",
max_tokens=2000,
user_id=user_id,
validate_subsequent_operations=True # Validates Google Grounding + 3 LLM calls
)
except HTTPException as http_error:
# Re-raise HTTPException so it can be properly handled by task manager
logger.error(f"Subscription limit exceeded for research: {http_error.detail}")
await task_manager.update_progress(task_id, f"❌ Subscription limit exceeded: {http_error.detail.get('message', str(http_error.detail)) if isinstance(http_error.detail, dict) else str(http_error.detail)}")
raise # Re-raise HTTPException to preserve status code and error details
await task_manager.update_progress(task_id, "📊 Processing research results and extracting insights...")
# Extract sources from grounding metadata
@@ -327,9 +353,9 @@ class ResearchService:
await task_manager.update_progress(task_id, "🔍 Analyzing keywords and content angles...")
# Parse the comprehensive response for different analysis components
content = gemini_result.get("content", "")
keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords)
competitor_analysis = self.competitor_analyzer.analyze(content)
suggested_angles = self.content_angle_generator.generate(content, topic, industry)
keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords, user_id=user_id)
competitor_analysis = self.competitor_analyzer.analyze(content, user_id=user_id)
suggested_angles = self.content_angle_generator.generate(content, topic, industry, user_id=user_id)
await task_manager.update_progress(task_id, "💾 Caching results for future use...")
logger.info(f"Research completed successfully with {len(sources)} sources and {len(search_queries)} search queries")
@@ -373,6 +399,9 @@ class ResearchService:
return filtered_response
except HTTPException:
# Re-raise HTTPException (subscription errors) - let task manager handle it
raise
except Exception as e:
error_message = str(e)
logger.error(f"Research failed: {error_message}")

View File

@@ -34,17 +34,21 @@ class BlogContentSEOAnalyzer:
logger.info("BlogContentSEOAnalyzer initialized")
async def analyze_blog_content(self, blog_content: str, research_data: Dict[str, Any], blog_title: Optional[str] = None) -> Dict[str, Any]:
async def analyze_blog_content(self, blog_content: str, research_data: Dict[str, Any], blog_title: Optional[str] = None, user_id: str = None) -> Dict[str, Any]:
"""
Main analysis method with parallel processing
Args:
blog_content: The blog content to analyze
research_data: Research data containing keywords and other insights
blog_title: Optional blog title
user_id: Clerk user ID for subscription checking (required)
Returns:
Comprehensive SEO analysis results
"""
if not user_id:
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
try:
logger.info("Starting blog content SEO analysis")
@@ -58,7 +62,7 @@ class BlogContentSEOAnalyzer:
# Phase 2: Single AI analysis for structured insights
logger.info("Running AI analysis")
ai_insights = await self._run_ai_analysis(blog_content, keywords_data, non_ai_results)
ai_insights = await self._run_ai_analysis(blog_content, keywords_data, non_ai_results, user_id=user_id)
# Phase 3: Compile and format results
logger.info("Compiling results")
@@ -599,8 +603,10 @@ class BlogContentSEOAnalyzer:
return recommendations
async def _run_ai_analysis(self, blog_content: str, keywords_data: Dict[str, Any], non_ai_results: Dict[str, Any]) -> Dict[str, Any]:
async def _run_ai_analysis(self, blog_content: str, keywords_data: Dict[str, Any], non_ai_results: Dict[str, Any], user_id: str = None) -> Dict[str, Any]:
"""Run single AI analysis for structured insights (provider-agnostic)"""
if not user_id:
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
try:
# Prepare context for AI analysis
context = {
@@ -658,7 +664,8 @@ class BlogContentSEOAnalyzer:
ai_response = llm_text_gen(
prompt=prompt,
json_struct=schema,
system_prompt=None
system_prompt=None,
user_id=user_id # Pass user_id for subscription checking
)
return ai_response

View File

@@ -28,7 +28,8 @@ class BlogSEOMetadataGenerator:
blog_title: str,
research_data: Dict[str, Any],
outline: Optional[List[Dict[str, Any]]] = None,
seo_analysis: Optional[Dict[str, Any]] = None
seo_analysis: Optional[Dict[str, Any]] = None,
user_id: str = None
) -> Dict[str, Any]:
"""
Generate comprehensive SEO metadata using maximum 2 AI calls
@@ -39,10 +40,13 @@ class BlogSEOMetadataGenerator:
research_data: Research data containing keywords and insights
outline: Outline structure with sections and headings
seo_analysis: SEO analysis results from previous phase
user_id: Clerk user ID for subscription checking (required)
Returns:
Comprehensive metadata including all SEO elements
"""
if not user_id:
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
try:
logger.info("Starting comprehensive SEO metadata generation")
@@ -53,13 +57,13 @@ class BlogSEOMetadataGenerator:
# Call 1: Generate core SEO metadata (parallel with Call 2)
logger.info("Generating core SEO metadata")
core_metadata_task = self._generate_core_metadata(
blog_content, blog_title, keywords_data, outline, seo_analysis
blog_content, blog_title, keywords_data, outline, seo_analysis, user_id=user_id
)
# Call 2: Generate social media and structured data (parallel with Call 1)
logger.info("Generating social media and structured data")
social_metadata_task = self._generate_social_metadata(
blog_content, blog_title, keywords_data, outline, seo_analysis
blog_content, blog_title, keywords_data, outline, seo_analysis, user_id=user_id
)
# Wait for both calls to complete
@@ -114,9 +118,12 @@ class BlogSEOMetadataGenerator:
blog_title: str,
keywords_data: Dict[str, Any],
outline: Optional[List[Dict[str, Any]]] = None,
seo_analysis: Optional[Dict[str, Any]] = None
seo_analysis: Optional[Dict[str, Any]] = None,
user_id: str = None
) -> Dict[str, Any]:
"""Generate core SEO metadata (Call 1)"""
if not user_id:
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
try:
# Create comprehensive prompt for core metadata
prompt = self._create_core_metadata_prompt(
@@ -170,7 +177,8 @@ class BlogSEOMetadataGenerator:
ai_response_raw = llm_text_gen(
prompt=prompt,
json_struct=schema,
system_prompt=None
system_prompt=None,
user_id=user_id # Pass user_id for subscription checking
)
# Handle response: llm_text_gen may return dict (from structured JSON) or str (needs parsing)
@@ -215,9 +223,12 @@ class BlogSEOMetadataGenerator:
blog_title: str,
keywords_data: Dict[str, Any],
outline: Optional[List[Dict[str, Any]]] = None,
seo_analysis: Optional[Dict[str, Any]] = None
seo_analysis: Optional[Dict[str, Any]] = None,
user_id: str = None
) -> Dict[str, Any]:
"""Generate social media and structured data (Call 2)"""
if not user_id:
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
try:
# Create comprehensive prompt for social metadata
prompt = self._create_social_metadata_prompt(
@@ -274,7 +285,8 @@ class BlogSEOMetadataGenerator:
ai_response_raw = llm_text_gen(
prompt=prompt,
json_struct=schema,
system_prompt=None
system_prompt=None,
user_id=user_id # Pass user_id for subscription checking
)
# Handle response: llm_text_gen may return dict (from structured JSON) or str (needs parsing)

View File

@@ -20,8 +20,11 @@ class BlogSEORecommendationApplier:
def __init__(self):
logger.debug("Initialized BlogSEORecommendationApplier")
async def apply_recommendations(self, payload: Dict[str, Any]) -> Dict[str, Any]:
async def apply_recommendations(self, payload: Dict[str, Any], user_id: str = None) -> Dict[str, Any]:
"""Apply recommendations and return updated content."""
if not user_id:
raise ValueError("user_id is required for subscription checking. Please provide Clerk user ID.")
title = payload.get("title", "Untitled Blog")
sections: List[Dict[str, Any]] = payload.get("sections", [])
@@ -88,6 +91,7 @@ class BlogSEORecommendationApplier:
prompt,
None,
schema,
user_id, # Pass user_id for subscription checking
)
if not result or result.get("error"):

View File

@@ -56,7 +56,9 @@ class GeminiGroundedProvider:
temperature: float = 0.7,
max_tokens: int = 2048,
urls: Optional[List[str]] = None,
mode: str = "polished"
mode: str = "polished",
user_id: Optional[str] = None,
validate_subsequent_operations: bool = False
) -> Dict[str, Any]:
"""
Generate grounded content using native Google Search grounding.
@@ -66,12 +68,49 @@ class GeminiGroundedProvider:
content_type: Type of content to generate
temperature: Creativity level (0.0-1.0)
max_tokens: Maximum tokens in response
urls: Optional list of URLs for URL Context tool
mode: Content mode ("draft" or "polished")
user_id: User ID for subscription checking (required if validate_subsequent_operations=True)
validate_subsequent_operations: If True, validates Google Grounding + 3 LLM calls for research workflow
Returns:
Dictionary containing generated content and grounding metadata
"""
try:
logger.info(f"Generating grounded content for {content_type} using native Google Search")
# PRE-FLIGHT VALIDATION: If this is part of a research workflow, validate ALL operations
# MUST happen BEFORE any API calls - return immediately if validation fails
if validate_subsequent_operations:
if not user_id:
raise ValueError("user_id is required when validate_subsequent_operations=True")
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_research_operations
from fastapi import HTTPException
import os
db = next(get_db())
try:
pricing_service = PricingService(db)
gpt_provider = os.getenv("GPT_PROVIDER", "google")
# Validate ALL research operations before making ANY API calls
# This prevents wasteful external API calls if subsequent LLM calls would fail
# Raises HTTPException immediately if validation fails - frontend gets immediate response
validate_research_operations(
pricing_service=pricing_service,
user_id=user_id,
gpt_provider=gpt_provider
)
except HTTPException as http_ex:
# Re-raise immediately - don't proceed with API call
logger.error(f"[Gemini Grounded] ❌ Pre-flight validation failed - blocking API call")
raise
finally:
db.close()
logger.info(f"[Gemini Grounded] ✅ Pre-flight validation passed - proceeding with API call")
logger.info(f"[Gemini Grounded] Generating grounded content for {content_type} using native Google Search")
# Build the grounded prompt
grounded_prompt = self._build_grounded_prompt(prompt, content_type)

View File

@@ -40,7 +40,38 @@ def _get_provider(provider_name: str):
raise ValueError(f"Unknown image provider: {provider_name}")
def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None) -> ImageGenerationResult:
def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> ImageGenerationResult:
"""Generate image with pre-flight validation.
Args:
prompt: Image generation prompt
options: Image generation options (provider, model, width, height, etc.)
user_id: User ID for subscription checking (optional, but required for validation)
"""
# PRE-FLIGHT VALIDATION: Validate image generation before API call
# MUST happen BEFORE any API calls - return immediately if validation fails
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_generation_operations
from fastapi import HTTPException
db = next(get_db())
try:
pricing_service = PricingService(db)
# Raises HTTPException immediately if validation fails - frontend gets immediate response
validate_image_generation_operations(
pricing_service=pricing_service,
user_id=user_id
)
except HTTPException as http_ex:
# Re-raise immediately - don't proceed with API call
logger.error(f"[Image Generation] ❌ Pre-flight validation failed - blocking API call")
raise
finally:
db.close()
logger.info(f"[Image Generation] ✅ Pre-flight validation passed - proceeding with image generation")
opts = options or {}
provider_name = _select_provider(opts.get("provider"))

View File

@@ -7,6 +7,7 @@ migrated from the legacy lib/gpt_providers/text_generation/main_text_generation.
import os
import json
from typing import Optional, Dict, Any
from datetime import datetime
from loguru import logger
from ..onboarding.api_key_manager import APIKeyManager
@@ -14,7 +15,7 @@ from .gemini_provider import gemini_text_response, gemini_structured_json_respon
from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response
def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: Optional[Dict[str, Any]] = None) -> str:
def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: Optional[Dict[str, Any]] = None, user_id: str = None) -> str:
"""
Generate text using Language Model (LLM) based on the provided prompt.
@@ -22,9 +23,13 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
prompt (str): The prompt to generate text from.
system_prompt (str, optional): Custom system prompt to use instead of the default one.
json_struct (dict, optional): JSON schema structure for structured responses.
user_id (str): Clerk user ID for subscription checking (required).
Returns:
str: Generated text based on the prompt.
Raises:
RuntimeError: If subscription limits are exceeded or user_id is missing.
"""
try:
logger.info("[llm_text_gen] Starting text generation")
@@ -93,6 +98,75 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
logger.debug(f"[llm_text_gen] Using provider: {gpt_provider}, model: {model}")
# Map provider name to APIProvider enum (define at function scope for usage tracking)
from models.subscription_models import APIProvider
provider_enum = None
# Store actual provider name for logging (e.g., "huggingface", "gemini")
actual_provider_name = None
if gpt_provider == "google":
provider_enum = APIProvider.GEMINI
actual_provider_name = "gemini" # Use "gemini" for consistency in logs
elif gpt_provider == "huggingface":
provider_enum = APIProvider.MISTRAL # HuggingFace maps to Mistral enum for usage tracking
actual_provider_name = "huggingface" # Keep actual provider name for logs
if not provider_enum:
raise RuntimeError(f"Unknown provider {gpt_provider} for subscription checking")
# SUBSCRIPTION CHECK - Required and strict enforcement
if not user_id:
raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.")
try:
from services.database import get_db
from services.subscription import UsageTrackingService, PricingService
from models.subscription_models import UsageSummary
db = next(get_db())
try:
usage_service = UsageTrackingService(db)
pricing_service = PricingService(db)
# Estimate tokens from prompt (input tokens)
# Note: We estimate output tokens conservatively (assume response is similar length to prompt)
# This prevents underestimating total token usage
input_tokens = int(len(prompt.split()) * 1.3)
# Conservative estimate: assume output tokens ≈ input tokens * 1.0 (can be up to max_tokens)
estimated_output_tokens = min(input_tokens, max_tokens) if max_tokens else int(input_tokens * 0.8)
estimated_total_tokens = input_tokens + estimated_output_tokens
# Check limits using sync method from pricing service (strict enforcement)
can_proceed, message, usage_info = pricing_service.check_usage_limits(
user_id=user_id,
provider=provider_enum,
tokens_requested=estimated_total_tokens,
actual_provider_name=actual_provider_name # Pass actual provider name for correct error messages
)
if not can_proceed:
logger.warning(f"[llm_text_gen] Subscription limit exceeded for user {user_id}: {message}")
raise RuntimeError(f"Subscription limit exceeded: {message}")
# Get current usage for limit checking only
current_period = pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
usage = db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
# No separate log here - we'll create unified log after API call and usage tracking
finally:
db.close()
except RuntimeError:
# Re-raise subscription limit errors
raise
except Exception as sub_error:
# STRICT: Fail on subscription check errors
logger.error(f"[llm_text_gen] Subscription check failed for user {user_id}: {sub_error}")
raise RuntimeError(f"Subscription check failed: {str(sub_error)}")
# Construct the system prompt if not provided
if system_prompt is None:
system_instructions = f"""You are a highly skilled content writer with a knack for creating engaging and informative content.
@@ -117,10 +191,12 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
system_instructions = system_prompt
# Generate response based on provider
response_text = None
actual_provider_used = gpt_provider
try:
if gpt_provider == "google":
if json_struct:
return gemini_structured_json_response(
response_text = gemini_structured_json_response(
prompt=prompt,
schema=json_struct,
temperature=temperature,
@@ -130,7 +206,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
system_prompt=system_instructions
)
else:
return gemini_text_response(
response_text = gemini_text_response(
prompt=prompt,
temperature=temperature,
top_p=top_p,
@@ -140,7 +216,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
)
elif gpt_provider == "huggingface":
if json_struct:
return huggingface_structured_json_response(
response_text = huggingface_structured_json_response(
prompt=prompt,
schema=json_struct,
model=model,
@@ -149,7 +225,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
system_prompt=system_instructions
)
else:
return huggingface_text_response(
response_text = huggingface_text_response(
prompt=prompt,
model=model,
temperature=temperature,
@@ -160,6 +236,107 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
else:
logger.error(f"[llm_text_gen] Unknown provider: {gpt_provider}")
raise RuntimeError("Unknown LLM provider. Supported providers: google, huggingface")
# TRACK USAGE after successful API call
if response_text:
logger.info(f"[llm_text_gen] ✅ API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
try:
db_track = next(get_db())
try:
# Estimate tokens from prompt and response
tokens_input = estimated_tokens # Already calculated above
tokens_output = int(len(str(response_text).split()) * 1.3) # Estimate output tokens
tokens_total = tokens_input + tokens_output
logger.debug(f"[llm_text_gen] Token estimates: input={tokens_input}, output={tokens_output}, total={tokens_total}")
# Get or create usage summary
from models.subscription_models import UsageSummary
from services.subscription import PricingService
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
logger.debug(f"[llm_text_gen] Looking for usage summary: user_id={user_id}, period={current_period}")
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
logger.info(f"[llm_text_gen] Creating new usage summary for user {user_id}, period {current_period}")
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush() # Ensure summary is persisted before updating
# Get "before" state for unified log
provider_name = provider_enum.value
current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0
# Update provider-specific counters (sync operation)
new_calls = current_calls_before + 1
setattr(summary, f"{provider_name}_calls", new_calls)
logger.debug(f"[llm_text_gen] Updated {provider_name}_calls: {current_calls_before} -> {new_calls}")
# Update token usage for LLM providers
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0
new_tokens = current_tokens_before + tokens_total
setattr(summary, f"{provider_name}_tokens", new_tokens)
logger.debug(f"[llm_text_gen] Updated {provider_name}_tokens: {current_tokens_before} -> {new_tokens}")
else:
current_tokens_before = 0
new_tokens = 0
# Update totals
old_total_calls = summary.total_calls or 0
old_total_tokens = summary.total_tokens or 0
summary.total_calls = old_total_calls + 1
summary.total_tokens = old_total_tokens + tokens_total
logger.debug(f"[llm_text_gen] Updated totals: calls {old_total_calls} -> {summary.total_calls}, tokens {old_total_tokens} -> {summary.total_tokens}")
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) if limits else 0
# Get image stats for unified log
current_images_before = getattr(summary, "stability_calls", 0) or 0
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
db_track.commit()
logger.info(f"[llm_text_gen] ✅ Successfully tracked usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens")
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")
# Include image stats in the log
print(f"""
[SUBSCRIPTION] LLM Text Generation
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: {actual_provider_name}
├─ Model: {model}
├─ Calls: {current_calls_before}{new_calls} / {call_limit if call_limit > 0 else ''}
├─ Tokens: {current_tokens_before}{new_tokens} / {token_limit if token_limit > 0 else ''}
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else ''}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:
logger.error(f"[llm_text_gen] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
# Non-blocking: log error but don't fail the request
logger.error(f"[llm_text_gen] ❌ Failed to track usage: {usage_error}", exc_info=True)
return response_text
except Exception as provider_error:
logger.error(f"[llm_text_gen] Provider {gpt_provider} failed: {str(provider_error)}")
@@ -171,9 +348,21 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
fallback_provider = fallback_providers[0] # Only try the first available
try:
logger.info(f"[llm_text_gen] Trying SINGLE fallback provider: {fallback_provider}")
actual_provider_used = fallback_provider
# Update provider enum for fallback
if fallback_provider == "google":
provider_enum = APIProvider.GEMINI
actual_provider_name = "gemini"
fallback_model = "gemini-2.0-flash-lite"
elif fallback_provider == "huggingface":
provider_enum = APIProvider.MISTRAL
actual_provider_name = "huggingface"
fallback_model = "openai/gpt-oss-120b:groq"
if fallback_provider == "google":
if json_struct:
return gemini_structured_json_response(
response_text = gemini_structured_json_response(
prompt=prompt,
schema=json_struct,
temperature=temperature,
@@ -183,7 +372,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
system_prompt=system_instructions
)
else:
return gemini_text_response(
response_text = gemini_text_response(
prompt=prompt,
temperature=temperature,
top_p=top_p,
@@ -193,7 +382,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
)
elif fallback_provider == "huggingface":
if json_struct:
return huggingface_structured_json_response(
response_text = huggingface_structured_json_response(
prompt=prompt,
schema=json_struct,
model="openai/gpt-oss-120b:groq",
@@ -202,7 +391,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
system_prompt=system_instructions
)
else:
return huggingface_text_response(
response_text = huggingface_text_response(
prompt=prompt,
model="openai/gpt-oss-120b:groq",
temperature=temperature,
@@ -210,6 +399,96 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
top_p=top_p,
system_prompt=system_instructions
)
# TRACK USAGE after successful fallback call
if response_text:
logger.info(f"[llm_text_gen] ✅ Fallback API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
try:
db_track = next(get_db())
try:
# Estimate tokens from prompt and response
tokens_input = estimated_tokens
tokens_output = int(len(str(response_text).split()) * 1.3)
tokens_total = tokens_input + tokens_output
# Get or create usage summary
from models.subscription_models import UsageSummary
from services.subscription import PricingService
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush() # Ensure summary is persisted before updating
# Get "before" state for unified log
provider_name = provider_enum.value
current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0
# Update provider-specific counters (sync operation)
new_calls = current_calls_before + 1
setattr(summary, f"{provider_name}_calls", new_calls)
# Update token usage for LLM providers
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0
new_tokens = current_tokens_before + tokens_total
setattr(summary, f"{provider_name}_tokens", new_tokens)
else:
current_tokens_before = 0
new_tokens = 0
# Update totals
summary.total_calls = (summary.total_calls or 0) + 1
summary.total_tokens = (summary.total_tokens or 0) + tokens_total
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) if limits else 0
# Get image stats for unified log
current_images_before = getattr(summary, "stability_calls", 0) or 0
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
db_track.commit()
logger.info(f"[llm_text_gen] ✅ Successfully tracked fallback usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens")
# UNIFIED SUBSCRIPTION LOG for fallback
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")
# Include image stats in the log
print(f"""
[SUBSCRIPTION] LLM Text Generation (Fallback)
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: {actual_provider_name}
├─ Model: {fallback_model}
├─ Calls: {current_calls_before}{new_calls} / {call_limit if call_limit > 0 else ''}
├─ Tokens: {current_tokens_before}{new_tokens} / {token_limit if token_limit > 0 else ''}
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else ''}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:
logger.error(f"[llm_text_gen] ❌ Error tracking fallback usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
logger.error(f"[llm_text_gen] ❌ Failed to track fallback usage: {usage_error}", exc_info=True)
return response_text
except Exception as fallback_error:
logger.error(f"[llm_text_gen] Fallback provider {fallback_provider} also failed: {str(fallback_error)}")

View File

@@ -55,6 +55,14 @@ class MonitoringDataService:
alert_threshold=task_data.get('alertThreshold', ''),
status='active'
)
# Initialize next_execution based on frequency
from services.scheduler.utils.frequency_calculator import calculate_next_execution
task.next_execution = calculate_next_execution(
frequency=task.frequency,
base_time=datetime.utcnow()
)
self.db.add(task)
# Save activation status
@@ -357,3 +365,80 @@ class MonitoringDataService:
logger.error(f"Error updating performance metrics for strategy {strategy_id}: {e}")
self.db.rollback()
return False
def get_user_execution_logs(
self,
user_id: int,
limit: Optional[int] = 50,
offset: Optional[int] = 0,
status_filter: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
Get execution logs for a specific user.
Args:
user_id: User ID to filter execution logs
limit: Maximum number of logs to return
offset: Number of logs to skip (for pagination)
status_filter: Optional status filter ('success', 'failed', 'running', 'skipped')
Returns:
List of execution log dictionaries with task details
"""
try:
logger.info(f"Getting execution logs for user {user_id}")
# Build query for execution logs filtered by user_id
query = self.db.query(TaskExecutionLog).filter(
TaskExecutionLog.user_id == user_id
)
# Apply status filter if provided
if status_filter:
query = query.filter(TaskExecutionLog.status == status_filter)
# Order by execution date (most recent first)
query = query.order_by(desc(TaskExecutionLog.execution_date))
# Apply pagination
if limit:
query = query.limit(limit)
if offset:
query = query.offset(offset)
logs = query.all()
# Convert to dictionaries with task details
logs_data = []
for log in logs:
# Get task details if available
task = self.db.query(MonitoringTask).filter(
MonitoringTask.id == log.task_id
).first()
log_data = {
"id": log.id,
"task_id": log.task_id,
"user_id": log.user_id,
"execution_date": log.execution_date.isoformat() if log.execution_date else None,
"status": log.status,
"result_data": log.result_data,
"error_message": log.error_message,
"execution_time_ms": log.execution_time_ms,
"created_at": log.created_at.isoformat() if log.created_at else None,
"task": {
"title": task.task_title if task else None,
"description": task.task_description if task else None,
"assignee": task.assignee if task else None,
"frequency": task.frequency if task else None,
"strategy_id": task.strategy_id if task else None
} if task else None
}
logs_data.append(log_data)
logger.info(f"Retrieved {len(logs_data)} execution logs for user {user_id}")
return logs_data
except Exception as e:
logger.error(f"Error getting execution logs for user {user_id}: {e}")
return []

View File

@@ -0,0 +1,59 @@
"""
Task Scheduler Package
Modular, pluggable scheduler for ALwrity tasks.
"""
from .core.scheduler import TaskScheduler
from .core.executor_interface import TaskExecutor, TaskExecutionResult
from .core.exception_handler import (
SchedulerExceptionHandler, SchedulerException, SchedulerErrorType, SchedulerErrorSeverity,
TaskExecutionError, DatabaseError, TaskLoaderError, SchedulerConfigError
)
from .executors.monitoring_task_executor import MonitoringTaskExecutor
from .utils.task_loader import load_due_monitoring_tasks
# Global scheduler instance (initialized on first access)
_scheduler_instance: TaskScheduler = None
def get_scheduler() -> TaskScheduler:
"""
Get global scheduler instance (singleton pattern).
Returns:
TaskScheduler instance
"""
global _scheduler_instance
if _scheduler_instance is None:
_scheduler_instance = TaskScheduler(
check_interval_minutes=15,
max_concurrent_executions=10
)
# Register monitoring task executor
monitoring_executor = MonitoringTaskExecutor()
_scheduler_instance.register_executor(
'monitoring_task',
monitoring_executor,
load_due_monitoring_tasks
)
return _scheduler_instance
__all__ = [
'TaskScheduler',
'TaskExecutor',
'TaskExecutionResult',
'MonitoringTaskExecutor',
'get_scheduler',
# Exception handling
'SchedulerExceptionHandler',
'SchedulerException',
'SchedulerErrorType',
'SchedulerErrorSeverity',
'TaskExecutionError',
'DatabaseError',
'TaskLoaderError',
'SchedulerConfigError'
]

View File

@@ -0,0 +1,4 @@
"""
Core scheduler components.
"""

View File

@@ -0,0 +1,395 @@
"""
Comprehensive Exception Handling and Logging for Task Scheduler
Provides robust error handling, logging, and monitoring for the scheduler system.
"""
import traceback
import sys
from datetime import datetime
from typing import Dict, Any, Optional, Union
from enum import Enum
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError
from utils.logger_utils import get_service_logger
logger = get_service_logger("scheduler_exception_handler")
class SchedulerErrorType(Enum):
"""Error types for scheduler system."""
DATABASE_ERROR = "database_error"
TASK_EXECUTION_ERROR = "task_execution_error"
TASK_LOADER_ERROR = "task_loader_error"
SCHEDULER_CONFIG_ERROR = "scheduler_config_error"
RETRY_ERROR = "retry_error"
CONCURRENCY_ERROR = "concurrency_error"
TIMEOUT_ERROR = "timeout_error"
class SchedulerErrorSeverity(Enum):
"""Severity levels for scheduler errors."""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class SchedulerException(Exception):
"""Base exception for scheduler system errors."""
def __init__(
self,
message: str,
error_type: SchedulerErrorType,
severity: SchedulerErrorSeverity = SchedulerErrorSeverity.MEDIUM,
user_id: Optional[int] = None,
task_id: Optional[int] = None,
task_type: Optional[str] = None,
context: Optional[Dict[str, Any]] = None,
original_error: Optional[Exception] = None
):
self.message = message
self.error_type = error_type
self.severity = severity
self.user_id = user_id
self.task_id = task_id
self.task_type = task_type
self.context = context or {}
self.original_error = original_error
self.timestamp = datetime.utcnow()
# Capture stack trace if original error provided
self.stack_trace = None
if self.original_error:
try:
exc_type, exc_value, exc_traceback = sys.exc_info()
if exc_traceback:
self.stack_trace = ''.join(traceback.format_exception(
exc_type, exc_value, exc_traceback
))
else:
self.stack_trace = traceback.format_exception(
type(self.original_error),
self.original_error,
self.original_error.__traceback__
)
except Exception:
self.stack_trace = str(self.original_error)
super().__init__(message)
def to_dict(self) -> Dict[str, Any]:
"""Convert exception to dictionary for logging/storage."""
return {
"message": self.message,
"error_type": self.error_type.value,
"severity": self.severity.value,
"user_id": self.user_id,
"task_id": self.task_id,
"task_type": self.task_type,
"context": self.context,
"timestamp": self.timestamp.isoformat() if isinstance(self.timestamp, datetime) else self.timestamp,
"original_error": str(self.original_error) if self.original_error else None,
"stack_trace": self.stack_trace
}
def __str__(self):
return f"[{self.error_type.value}] {self.message}"
class DatabaseError(SchedulerException):
"""Exception raised for database-related errors."""
def __init__(
self,
message: str,
user_id: Optional[int] = None,
task_id: Optional[int] = None,
context: Dict[str, Any] = None,
original_error: Exception = None
):
super().__init__(
message=message,
error_type=SchedulerErrorType.DATABASE_ERROR,
severity=SchedulerErrorSeverity.CRITICAL,
user_id=user_id,
task_id=task_id,
context=context or {},
original_error=original_error
)
class TaskExecutionError(SchedulerException):
"""Exception raised for task execution failures."""
def __init__(
self,
message: str,
user_id: Optional[int] = None,
task_id: Optional[int] = None,
task_type: Optional[str] = None,
retry_count: int = 0,
execution_time_ms: Optional[int] = None,
context: Dict[str, Any] = None,
original_error: Exception = None
):
context = context or {}
context.update({
"retry_count": retry_count,
"execution_time_ms": execution_time_ms
})
super().__init__(
message=message,
error_type=SchedulerErrorType.TASK_EXECUTION_ERROR,
severity=SchedulerErrorSeverity.HIGH,
user_id=user_id,
task_id=task_id,
task_type=task_type,
context=context,
original_error=original_error
)
class TaskLoaderError(SchedulerException):
"""Exception raised for task loading failures."""
def __init__(
self,
message: str,
task_type: Optional[str] = None,
user_id: Optional[int] = None,
context: Dict[str, Any] = None,
original_error: Exception = None
):
super().__init__(
message=message,
error_type=SchedulerErrorType.TASK_LOADER_ERROR,
severity=SchedulerErrorSeverity.HIGH,
user_id=user_id,
task_type=task_type,
context=context or {},
original_error=original_error
)
class SchedulerConfigError(SchedulerException):
"""Exception raised for scheduler configuration errors."""
def __init__(
self,
message: str,
context: Dict[str, Any] = None,
original_error: Exception = None
):
super().__init__(
message=message,
error_type=SchedulerErrorType.SCHEDULER_CONFIG_ERROR,
severity=SchedulerErrorSeverity.CRITICAL,
context=context or {},
original_error=original_error
)
class SchedulerExceptionHandler:
"""Comprehensive exception handler for the scheduler system."""
def __init__(self, db: Session = None):
self.db = db
self.logger = logger
def handle_exception(
self,
error: Union[Exception, SchedulerException],
context: Dict[str, Any] = None,
log_level: str = "error"
) -> Dict[str, Any]:
"""Handle and log scheduler exceptions."""
context = context or {}
# Convert regular exceptions to SchedulerException
if not isinstance(error, SchedulerException):
error = SchedulerException(
message=str(error),
error_type=self._classify_error(error),
severity=self._determine_severity(error),
context=context,
original_error=error
)
# Log the error
error_data = error.to_dict()
error_data.update(context)
log_message = f"Scheduler Error: {error.message}"
if log_level == "critical" or error.severity == SchedulerErrorSeverity.CRITICAL:
self.logger.critical(log_message, extra={"error_data": error_data})
elif log_level == "error" or error.severity == SchedulerErrorSeverity.HIGH:
self.logger.error(log_message, extra={"error_data": error_data})
elif log_level == "warning" or error.severity == SchedulerErrorSeverity.MEDIUM:
self.logger.warning(log_message, extra={"error_data": error_data})
else:
self.logger.info(log_message, extra={"error_data": error_data})
# Store critical errors in database for alerting
if error.severity in [SchedulerErrorSeverity.HIGH, SchedulerErrorSeverity.CRITICAL]:
self._store_error_alert(error)
# Return formatted error response
return self._format_error_response(error)
def _classify_error(self, error: Exception) -> SchedulerErrorType:
"""Classify an exception into a scheduler error type."""
error_str = str(error).lower()
error_type_name = type(error).__name__.lower()
# Database errors
if isinstance(error, (SQLAlchemyError, OperationalError, IntegrityError)):
return SchedulerErrorType.DATABASE_ERROR
if "database" in error_str or "sql" in error_type_name or "connection" in error_str:
return SchedulerErrorType.DATABASE_ERROR
# Timeout errors
if "timeout" in error_str or "timed out" in error_str:
return SchedulerErrorType.TIMEOUT_ERROR
# Concurrency errors
if "concurrent" in error_str or "race" in error_str or "lock" in error_str:
return SchedulerErrorType.CONCURRENCY_ERROR
# Task execution errors
if "task" in error_str and "execut" in error_str:
return SchedulerErrorType.TASK_EXECUTION_ERROR
# Task loader errors
if "load" in error_str and "task" in error_str:
return SchedulerErrorType.TASK_LOADER_ERROR
# Retry errors
if "retry" in error_str:
return SchedulerErrorType.RETRY_ERROR
# Config errors
if "config" in error_str or "scheduler" in error_str and "init" in error_str:
return SchedulerErrorType.SCHEDULER_CONFIG_ERROR
# Default to task execution error for unknown errors
return SchedulerErrorType.TASK_EXECUTION_ERROR
def _determine_severity(self, error: Exception) -> SchedulerErrorSeverity:
"""Determine the severity of an error."""
error_str = str(error).lower()
error_type = type(error)
# Critical errors
if isinstance(error, (SQLAlchemyError, OperationalError, ConnectionError)):
return SchedulerErrorSeverity.CRITICAL
if "database" in error_str or "connection" in error_str:
return SchedulerErrorSeverity.CRITICAL
# High severity errors
if "timeout" in error_str or "concurrent" in error_str:
return SchedulerErrorSeverity.HIGH
if isinstance(error, (KeyError, AttributeError)) and "config" in error_str:
return SchedulerErrorSeverity.HIGH
# Medium severity errors
if "task" in error_str or "execution" in error_str:
return SchedulerErrorSeverity.MEDIUM
# Default to low
return SchedulerErrorSeverity.LOW
def _store_error_alert(self, error: SchedulerException):
"""Store critical errors in database for alerting."""
if not self.db:
return
try:
# Import here to avoid circular dependencies
from models.monitoring_models import TaskExecutionLog
# Store as failed execution log if we have task_id (even without user_id for system errors)
if error.task_id:
try:
execution_log = TaskExecutionLog(
task_id=error.task_id,
user_id=error.user_id, # Can be None for system-level errors
execution_date=error.timestamp,
status='failed',
error_message=error.message,
result_data={
"error_type": error.error_type.value,
"severity": error.severity.value,
"context": error.context,
"stack_trace": error.stack_trace,
"task_type": error.task_type
}
)
self.db.add(execution_log)
self.db.commit()
self.logger.info(f"Stored error alert in execution log for task {error.task_id}")
except Exception as e:
self.logger.error(f"Failed to store error in execution log: {e}")
self.db.rollback()
# Note: For errors without task_id, we rely on structured logging only
# Future: Could create a separate scheduler_error_logs table for system-level errors
except Exception as e:
self.logger.error(f"Failed to store error alert: {e}")
def _format_error_response(self, error: SchedulerException) -> Dict[str, Any]:
"""Format error for API response or logging."""
response = {
"success": False,
"error": {
"type": error.error_type.value,
"message": error.message,
"severity": error.severity.value,
"timestamp": error.timestamp.isoformat() if isinstance(error.timestamp, datetime) else str(error.timestamp),
"user_id": error.user_id,
"task_id": error.task_id,
"task_type": error.task_type
}
}
# Add context for debugging (non-sensitive info only)
if error.context:
safe_context = {
k: v for k, v in error.context.items()
if k not in ["password", "token", "key", "secret", "credential"]
}
response["error"]["context"] = safe_context
# Add user-friendly message based on error type
user_messages = {
SchedulerErrorType.DATABASE_ERROR:
"A database error occurred while processing the task. Please try again later.",
SchedulerErrorType.TASK_EXECUTION_ERROR:
"The task failed to execute. Please check the task configuration and try again.",
SchedulerErrorType.TASK_LOADER_ERROR:
"Failed to load tasks. The scheduler may be experiencing issues.",
SchedulerErrorType.SCHEDULER_CONFIG_ERROR:
"The scheduler configuration is invalid. Contact support.",
SchedulerErrorType.RETRY_ERROR:
"Task retry failed. The task will be rescheduled.",
SchedulerErrorType.CONCURRENCY_ERROR:
"A concurrency issue occurred. The task will be retried.",
SchedulerErrorType.TIMEOUT_ERROR:
"The task execution timed out. The task will be retried."
}
response["error"]["user_message"] = user_messages.get(
error.error_type,
"An error occurred while processing the task."
)
return response

View File

@@ -0,0 +1,75 @@
"""
Task Executor Interface
Abstract base class for all task executors.
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
from dataclasses import dataclass
from datetime import datetime
from sqlalchemy.orm import Session
@dataclass
class TaskExecutionResult:
"""Result of task execution."""
success: bool
error_message: Optional[str] = None
result_data: Optional[Dict[str, Any]] = None
execution_time_ms: Optional[int] = None
retryable: bool = True
retry_delay: int = 300 # seconds
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
'success': self.success,
'error_message': self.error_message,
'result_data': self.result_data,
'execution_time_ms': self.execution_time_ms,
'retryable': self.retryable,
'retry_delay': self.retry_delay
}
class TaskExecutor(ABC):
"""
Abstract base class for task executors.
Each task type must implement this interface to be schedulable.
"""
@abstractmethod
async def execute_task(self, task: Any, db: Session) -> TaskExecutionResult:
"""
Execute a task.
Args:
task: Task instance from database
db: Database session
Returns:
TaskExecutionResult with execution details
"""
pass
@abstractmethod
def calculate_next_execution(
self,
task: Any,
frequency: str,
last_execution: Optional[datetime] = None
) -> datetime:
"""
Calculate next execution time based on frequency.
Args:
task: Task instance
frequency: Task frequency (e.g., 'Daily', 'Weekly')
last_execution: Last execution datetime
Returns:
Next execution datetime
"""
pass

View File

@@ -0,0 +1,628 @@
"""
Core Task Scheduler Service
Pluggable task scheduler that can work with any task model.
"""
import asyncio
import logging
from typing import Dict, Any, Optional, List, Callable
from datetime import datetime
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
from apscheduler.triggers.interval import IntervalTrigger
from sqlalchemy.orm import Session
from .executor_interface import TaskExecutor, TaskExecutionResult
from .task_registry import TaskRegistry
from .exception_handler import (
SchedulerExceptionHandler, SchedulerException, TaskExecutionError, DatabaseError,
TaskLoaderError, SchedulerConfigError
)
from services.database import get_db_session
from utils.logger_utils import get_service_logger
logger = get_service_logger("task_scheduler")
class TaskScheduler:
"""
Pluggable task scheduler that can work with any task model.
Features:
- Async task execution
- Plugin-based executor system
- Database-backed task persistence
- Configurable check intervals
- Automatic retry logic
"""
def __init__(
self,
check_interval_minutes: int = 15,
max_concurrent_executions: int = 10,
enable_retries: bool = True,
max_retries: int = 3
):
"""
Initialize the task scheduler.
Args:
check_interval_minutes: How often to check for due tasks
max_concurrent_executions: Maximum concurrent task executions
enable_retries: Whether to retry failed tasks
max_retries: Maximum retry attempts
"""
self.check_interval_minutes = check_interval_minutes
self.max_concurrent_executions = max_concurrent_executions
self.enable_retries = enable_retries
self.max_retries = max_retries
# Initialize APScheduler
self.scheduler = AsyncIOScheduler(
timezone='UTC',
job_defaults={
'coalesce': True,
'max_instances': 1,
'misfire_grace_time': 300 # 5 minutes grace period
}
)
# Task executor registry
self.registry = TaskRegistry()
# Track running executions
self.active_executions: Dict[str, asyncio.Task] = {}
# Exception handler for robust error handling
self.exception_handler = SchedulerExceptionHandler()
# Intelligent scheduling configuration
self.min_check_interval_minutes = 15 # Check every 15min when active strategies exist
self.max_check_interval_minutes = 60 # Check every 60min when no active strategies
self.current_check_interval_minutes = check_interval_minutes # Current interval
# Statistics
self.stats = {
'total_checks': 0,
'tasks_found': 0,
'tasks_executed': 0,
'tasks_failed': 0,
'tasks_skipped': 0,
'last_check': None,
'per_user_stats': {}, # Track metrics per user for user isolation
'active_strategies_count': 0, # Track active strategies with tasks
'last_interval_adjustment': None # Track when interval was last adjusted
}
self._running = False
def _get_trigger_for_interval(self, interval_minutes: int):
"""
Get the appropriate trigger for the given interval.
For intervals >= 60 minutes, use IntervalTrigger.
For intervals < 60 minutes, use CronTrigger.
Args:
interval_minutes: Interval in minutes
Returns:
Appropriate APScheduler trigger
"""
if interval_minutes >= 60:
# Use IntervalTrigger for intervals >= 60 minutes
return IntervalTrigger(minutes=interval_minutes)
else:
# Use CronTrigger for intervals < 60 minutes (valid range: 0-59)
return CronTrigger(minute=f'*/{interval_minutes}')
def register_executor(
self,
task_type: str,
executor: TaskExecutor,
task_loader: Callable[[Session], List[Any]]
):
"""
Register a task executor for a specific task type.
Args:
task_type: Unique identifier for task type (e.g., 'monitoring_task')
executor: TaskExecutor instance that handles execution
task_loader: Function that loads due tasks from database
"""
self.registry.register(task_type, executor, task_loader)
logger.info(f"Registered executor for task type: {task_type}")
async def start(self):
"""Start the scheduler with intelligent interval adjustment."""
if self._running:
logger.warning("Scheduler is already running")
return
try:
# Determine initial check interval based on active strategies
initial_interval = await self._determine_optimal_interval()
self.current_check_interval_minutes = initial_interval
# Add periodic job to check for due tasks
self.scheduler.add_job(
self._check_and_execute_due_tasks,
trigger=self._get_trigger_for_interval(initial_interval),
id='check_due_tasks',
replace_existing=True
)
self.scheduler.start()
self._running = True
logger.info(
f"Task scheduler started | "
f"check_interval={initial_interval}min | "
f"registered_types={self.registry.get_registered_types()}"
)
except Exception as e:
logger.error(f"Failed to start scheduler: {e}")
raise
async def stop(self):
"""Stop the scheduler gracefully."""
if not self._running:
return
try:
# Cancel all active executions
for task_id, execution_task in self.active_executions.items():
execution_task.cancel()
# Wait for active executions to complete (with timeout)
if self.active_executions:
await asyncio.wait(
self.active_executions.values(),
timeout=30
)
# Shutdown scheduler
self.scheduler.shutdown(wait=True)
self._running = False
logger.info("Task scheduler stopped gracefully")
except Exception as e:
logger.error(f"Error stopping scheduler: {e}")
raise
async def _check_and_execute_due_tasks(self):
"""
Main scheduler loop: check for due tasks and execute them.
This runs periodically with intelligent interval adjustment based on active strategies.
"""
self.stats['total_checks'] += 1
self.stats['last_check'] = datetime.utcnow().isoformat()
logger.debug("Checking for due tasks...")
db = None
try:
db = get_db_session()
if db is None:
logger.error("Failed to get database session")
return
# Check for active strategies and adjust interval intelligently
await self._adjust_check_interval_if_needed(db)
# Check each registered task type
for task_type in self.registry.get_registered_types():
await self._process_task_type(task_type, db)
except Exception as e:
error = DatabaseError(
message=f"Error checking for due tasks: {str(e)}",
original_error=e
)
self.exception_handler.handle_exception(error)
finally:
if db:
db.close()
async def _determine_optimal_interval(self) -> int:
"""
Determine optimal check interval based on active strategies.
Returns:
Optimal check interval in minutes
"""
db = None
try:
db = get_db_session()
if db:
from services.active_strategy_service import ActiveStrategyService
active_strategy_service = ActiveStrategyService(db_session=db)
active_count = active_strategy_service.count_active_strategies_with_tasks()
self.stats['active_strategies_count'] = active_count
if active_count > 0:
logger.info(f"Found {active_count} active strategies with tasks - using {self.min_check_interval_minutes}min interval")
return self.min_check_interval_minutes
else:
logger.info(f"No active strategies with tasks - using {self.max_check_interval_minutes}min interval")
return self.max_check_interval_minutes
except Exception as e:
logger.warning(f"Error determining optimal interval: {e}, using default {self.min_check_interval_minutes}min")
finally:
if db:
db.close()
# Default to shorter interval on error (safer)
return self.min_check_interval_minutes
async def _adjust_check_interval_if_needed(self, db: Session):
"""
Intelligently adjust check interval based on active strategies.
If there are active strategies with tasks, check more frequently.
If there are no active strategies, check less frequently.
Args:
db: Database session
"""
try:
from services.active_strategy_service import ActiveStrategyService
active_strategy_service = ActiveStrategyService(db_session=db)
active_count = active_strategy_service.count_active_strategies_with_tasks()
self.stats['active_strategies_count'] = active_count
# Determine optimal interval
if active_count > 0:
optimal_interval = self.min_check_interval_minutes
else:
optimal_interval = self.max_check_interval_minutes
# Only reschedule if interval needs to change
if optimal_interval != self.current_check_interval_minutes:
logger.info(
f"Adjusting scheduler interval: {self.current_check_interval_minutes}min → {optimal_interval}min | "
f"active_strategies={active_count}"
)
# Reschedule the job with new interval
self.scheduler.modify_job(
'check_due_tasks',
trigger=self._get_trigger_for_interval(optimal_interval)
)
self.current_check_interval_minutes = optimal_interval
self.stats['last_interval_adjustment'] = datetime.utcnow().isoformat()
logger.info(f"Scheduler interval adjusted to {optimal_interval}min")
except Exception as e:
logger.warning(f"Error adjusting check interval: {e}")
async def trigger_interval_adjustment(self):
"""
Trigger immediate interval adjustment check.
This should be called when a strategy is activated or deactivated
to immediately adjust the scheduler interval based on current active strategies.
"""
if not self._running:
logger.debug("Scheduler not running, skipping interval adjustment")
return
try:
db = get_db_session()
if db:
await self._adjust_check_interval_if_needed(db)
else:
logger.warning("Could not get database session for interval adjustment")
except Exception as e:
logger.warning(f"Error triggering interval adjustment: {e}")
async def _process_task_type(self, task_type: str, db: Session):
"""Process due tasks for a specific task type."""
try:
# Get task loader for this type
try:
task_loader = self.registry.get_task_loader(task_type)
except Exception as e:
error = TaskLoaderError(
message=f"Failed to get task loader for type {task_type}: {str(e)}",
task_type=task_type,
original_error=e
)
self.exception_handler.handle_exception(error)
return
# Load due tasks (with error handling)
try:
due_tasks = task_loader(db)
except Exception as e:
error = TaskLoaderError(
message=f"Failed to load due tasks for type {task_type}: {str(e)}",
task_type=task_type,
original_error=e
)
self.exception_handler.handle_exception(error)
return
if not due_tasks:
return
self.stats['tasks_found'] += len(due_tasks)
logger.info(f"Found {len(due_tasks)} due tasks for type: {task_type}")
# Execute tasks (with concurrency limit)
execution_tasks = []
for task in due_tasks:
if len(self.active_executions) >= self.max_concurrent_executions:
logger.warning(
f"Max concurrent executions reached ({self.max_concurrent_executions}), "
f"skipping {len(due_tasks) - len(execution_tasks)} tasks"
)
break
# Execute task asynchronously
# Note: Each task gets its own database session to prevent concurrent access issues
execution_task = asyncio.create_task(
self._execute_task_async(task_type, task)
)
task_id = f"{task_type}_{getattr(task, 'id', id(task))}"
self.active_executions[task_id] = execution_task
execution_tasks.append(execution_task)
# Wait for executions to complete (with timeout per task)
if execution_tasks:
await asyncio.wait(execution_tasks, timeout=300)
except Exception as e:
error = TaskLoaderError(
message=f"Error processing task type {task_type}: {str(e)}",
task_type=task_type,
original_error=e
)
self.exception_handler.handle_exception(error)
async def _execute_task_async(self, task_type: str, task: Any):
"""
Execute a single task asynchronously with user isolation.
Each task gets its own database session to prevent concurrent access issues,
as SQLAlchemy sessions are not async-safe or concurrent-safe.
User context is extracted and tracked for user isolation.
Args:
task_type: Type of task
task: Task instance from database (detached from original session)
"""
task_id = f"{task_type}_{getattr(task, 'id', id(task))}"
db = None
user_id = None
try:
# Extract user context if available (for user isolation tracking)
try:
if hasattr(task, 'strategy') and task.strategy:
user_id = getattr(task.strategy, 'user_id', None)
elif hasattr(task, 'strategy_id') and task.strategy_id:
# Will query user_id after we have db session
pass
except Exception as e:
logger.debug(f"Could not extract user_id before execution for task {task_id}: {e}")
logger.info(f"Executing task: {task_id} | user_id: {user_id}")
# Create a new database session for this async task
# SQLAlchemy sessions are not async-safe and cannot be shared across concurrent tasks
db = get_db_session()
if db is None:
error = DatabaseError(
message=f"Failed to get database session for task {task_id}",
user_id=user_id,
task_id=getattr(task, 'id', None),
task_type=task_type
)
self.exception_handler.handle_exception(error, log_level="error")
self.stats['tasks_failed'] += 1
self._update_user_stats(user_id, success=False)
return
# Set database session for exception handler
self.exception_handler.db = db
# Merge the detached task object into this session
# The task object was loaded in a different session and is now detached
from sqlalchemy.orm import object_session
if object_session(task) is None:
# Task is detached, need to merge it into this session
task = db.merge(task)
# Extract user_id after merge if not already available
if user_id is None and hasattr(task, 'strategy'):
try:
if task.strategy:
user_id = getattr(task.strategy, 'user_id', None)
elif hasattr(task, 'strategy_id'):
# Query strategy if relationship not loaded
from models.enhanced_strategy_models import EnhancedContentStrategy
strategy = db.query(EnhancedContentStrategy).filter(
EnhancedContentStrategy.id == task.strategy_id
).first()
if strategy:
user_id = strategy.user_id
except Exception as e:
logger.debug(f"Could not extract user_id after merge for task {task_id}: {e}")
# Get executor for this task type
try:
executor = self.registry.get_executor(task_type)
except Exception as e:
from .exception_handler import SchedulerConfigError
error = SchedulerConfigError(
message=f"Failed to get executor for task type {task_type}: {str(e)}",
user_id=user_id,
context={
"task_id": getattr(task, 'id', None),
"task_type": task_type
},
original_error=e
)
self.exception_handler.handle_exception(error)
self.stats['tasks_failed'] += 1
self._update_user_stats(user_id, success=False)
return
# Execute task with its own session (with error handling)
try:
result = await executor.execute_task(task, db)
# Handle result and update statistics
if result.success:
self.stats['tasks_executed'] += 1
self._update_user_stats(user_id, success=True)
logger.info(f"Task executed successfully: {task_id} | user_id: {user_id}")
else:
self.stats['tasks_failed'] += 1
self._update_user_stats(user_id, success=False)
# Create structured error for failed execution
error = TaskExecutionError(
message=result.error_message or "Task execution failed",
user_id=user_id,
task_id=getattr(task, 'id', None),
task_type=task_type,
execution_time_ms=result.execution_time_ms,
context={"result_data": result.result_data}
)
self.exception_handler.handle_exception(error, log_level="warning")
# Retry logic if enabled
if self.enable_retries and result.retryable:
await self._schedule_retry(task, result.retry_delay)
except SchedulerException as e:
# Re-raise scheduler exceptions (they're already handled)
raise
except Exception as e:
# Wrap unexpected exceptions
error = TaskExecutionError(
message=f"Unexpected error during task execution: {str(e)}",
user_id=user_id,
task_id=getattr(task, 'id', None),
task_type=task_type,
original_error=e
)
self.exception_handler.handle_exception(error)
self.stats['tasks_failed'] += 1
self._update_user_stats(user_id, success=False)
except SchedulerException as e:
# Handle scheduler exceptions
self.exception_handler.handle_exception(e)
self.stats['tasks_failed'] += 1
self._update_user_stats(user_id, success=False)
except Exception as e:
# Handle any other unexpected errors
error = TaskExecutionError(
message=f"Unexpected error in task execution wrapper: {str(e)}",
user_id=user_id,
task_id=getattr(task, 'id', None),
task_type=task_type,
original_error=e
)
self.exception_handler.handle_exception(error)
self.stats['tasks_failed'] += 1
self._update_user_stats(user_id, success=False)
finally:
# Clean up database session
if db:
try:
db.close()
except Exception as e:
logger.error(f"Error closing database session for task {task_id}: {e}")
# Remove from active executions
if task_id in self.active_executions:
del self.active_executions[task_id]
def _update_user_stats(self, user_id: Optional[int], success: bool):
"""
Update per-user statistics for user isolation tracking.
Args:
user_id: User ID (None if user context not available)
success: Whether task execution was successful
"""
if user_id is None:
return
if user_id not in self.stats['per_user_stats']:
self.stats['per_user_stats'][user_id] = {
'executed': 0,
'failed': 0,
'success_rate': 0.0
}
user_stats = self.stats['per_user_stats'][user_id]
if success:
user_stats['executed'] += 1
else:
user_stats['failed'] += 1
# Calculate success rate
total = user_stats['executed'] + user_stats['failed']
if total > 0:
user_stats['success_rate'] = (user_stats['executed'] / total) * 100.0
async def _schedule_retry(self, task: Any, delay_seconds: int):
"""Schedule a retry for a failed task."""
# This would update the task's next_execution time
# For now, just log - could be enhanced to update next_execution
logger.debug(f"Scheduling retry for task in {delay_seconds}s")
def get_stats(self, user_id: Optional[int] = None) -> Dict[str, Any]:
"""
Get scheduler statistics with optional user filtering.
Args:
user_id: Optional user ID to filter statistics for specific user
Returns:
Dictionary with scheduler statistics
"""
base_stats = {
**{k: v for k, v in self.stats.items() if k not in ['per_user_stats']},
'active_executions': len(self.active_executions),
'registered_types': self.registry.get_registered_types(),
'running': self._running,
'check_interval_minutes': self.current_check_interval_minutes,
'min_check_interval_minutes': self.min_check_interval_minutes,
'max_check_interval_minutes': self.max_check_interval_minutes,
'intelligent_scheduling': True
}
# Include per-user stats (all users or filtered)
if user_id is not None:
if user_id in self.stats['per_user_stats']:
base_stats['user_stats'] = self.stats['per_user_stats'][user_id]
else:
base_stats['user_stats'] = {
'executed': 0,
'failed': 0,
'success_rate': 0.0
}
else:
# Include all per-user stats (for admin/debugging)
base_stats['per_user_stats'] = self.stats['per_user_stats']
return base_stats
def is_running(self) -> bool:
"""Check if scheduler is running."""
return self._running

View File

@@ -0,0 +1,59 @@
"""
Task Registry
Manages registration of task executors and loaders.
"""
import logging
from typing import Dict, Callable, List, Any
from sqlalchemy.orm import Session
from .executor_interface import TaskExecutor
logger = logging.getLogger(__name__)
class TaskRegistry:
"""Registry for task executors and loaders."""
def __init__(self):
self.executors: Dict[str, TaskExecutor] = {}
self.task_loaders: Dict[str, Callable[[Session], List[Any]]] = {}
def register(
self,
task_type: str,
executor: TaskExecutor,
task_loader: Callable[[Session], List[Any]]
):
"""
Register a task executor and loader.
Args:
task_type: Unique identifier for task type
executor: TaskExecutor instance
task_loader: Function that loads due tasks from database
"""
if task_type in self.executors:
logger.warning(f"Overwriting existing executor for task type: {task_type}")
self.executors[task_type] = executor
self.task_loaders[task_type] = task_loader
logger.info(f"Registered task type: {task_type}")
def get_executor(self, task_type: str) -> TaskExecutor:
"""Get executor for task type."""
if task_type not in self.executors:
raise ValueError(f"No executor registered for task type: {task_type}")
return self.executors[task_type]
def get_task_loader(self, task_type: str) -> Callable[[Session], List[Any]]:
"""Get task loader for task type."""
if task_type not in self.task_loaders:
raise ValueError(f"No task loader registered for task type: {task_type}")
return self.task_loaders[task_type]
def get_registered_types(self) -> List[str]:
"""Get list of registered task types."""
return list(self.executors.keys())

View File

@@ -0,0 +1,4 @@
"""
Task executor implementations.
"""

View File

@@ -0,0 +1,266 @@
"""
Monitoring Task Executor
Handles execution of content strategy monitoring tasks.
"""
import logging
import time
from datetime import datetime
from typing import Dict, Any, Optional
from sqlalchemy.orm import Session
from ..core.executor_interface import TaskExecutor, TaskExecutionResult
from ..core.exception_handler import TaskExecutionError, DatabaseError, SchedulerExceptionHandler
from ..utils.frequency_calculator import calculate_next_execution
from models.monitoring_models import MonitoringTask, TaskExecutionLog
from models.enhanced_strategy_models import EnhancedContentStrategy
from utils.logger_utils import get_service_logger
logger = get_service_logger("monitoring_task_executor")
class MonitoringTaskExecutor(TaskExecutor):
"""
Executor for content strategy monitoring tasks.
Handles:
- ALwrity tasks (automated execution)
- Human tasks (notifications/queuing)
"""
def __init__(self):
self.logger = logger
self.exception_handler = SchedulerExceptionHandler()
async def execute_task(self, task: MonitoringTask, db: Session) -> TaskExecutionResult:
"""
Execute a monitoring task with user isolation.
Args:
task: MonitoringTask instance (with strategy relationship loaded)
db: Database session
Returns:
TaskExecutionResult
"""
start_time = time.time()
# Extract user_id from strategy relationship for user isolation
user_id = None
try:
if task.strategy and hasattr(task.strategy, 'user_id'):
user_id = task.strategy.user_id
elif task.strategy_id:
# Fallback: query strategy if relationship not loaded
strategy = db.query(EnhancedContentStrategy).filter(
EnhancedContentStrategy.id == task.strategy_id
).first()
if strategy:
user_id = strategy.user_id
except Exception as e:
self.logger.warning(f"Could not extract user_id for task {task.id}: {e}")
try:
self.logger.info(
f"Executing monitoring task: {task.id} | "
f"user_id: {user_id} | "
f"assignee: {task.assignee} | "
f"frequency: {task.frequency}"
)
# Create execution log with user_id for user isolation tracking
execution_log = TaskExecutionLog(
task_id=task.id,
user_id=user_id,
execution_date=datetime.utcnow(),
status='running'
)
db.add(execution_log)
db.flush()
# Execute based on assignee
if task.assignee == 'ALwrity':
result = await self._execute_alwrity_task(task, db)
else:
result = await self._execute_human_task(task, db)
# Update execution log
execution_time_ms = int((time.time() - start_time) * 1000)
execution_log.status = 'success' if result.success else 'failed'
execution_log.result_data = result.result_data
execution_log.error_message = result.error_message
execution_log.execution_time_ms = execution_time_ms
# Update task
task.last_executed = datetime.utcnow()
task.next_execution = self.calculate_next_execution(
task,
task.frequency,
task.last_executed
)
if result.success:
task.status = 'completed'
else:
task.status = 'failed'
db.commit()
return result
except Exception as e:
execution_time_ms = int((time.time() - start_time) * 1000)
# Set database session for exception handler
self.exception_handler.db = db
# Create structured error
error = TaskExecutionError(
message=f"Error executing monitoring task {task.id}: {str(e)}",
user_id=user_id,
task_id=task.id,
task_type="monitoring_task",
execution_time_ms=execution_time_ms,
context={
"assignee": task.assignee,
"frequency": task.frequency,
"component": task.component_name
},
original_error=e
)
# Handle exception with structured logging
self.exception_handler.handle_exception(error)
# Update execution log with error (include user_id for isolation)
try:
execution_log = TaskExecutionLog(
task_id=task.id,
user_id=user_id,
execution_date=datetime.utcnow(),
status='failed',
error_message=str(e),
execution_time_ms=execution_time_ms,
result_data={
"error_type": error.error_type.value,
"severity": error.severity.value,
"context": error.context
}
)
db.add(execution_log)
task.status = 'failed'
task.last_executed = datetime.utcnow()
db.commit()
except Exception as commit_error:
db_error = DatabaseError(
message=f"Error saving execution log: {str(commit_error)}",
user_id=user_id,
task_id=task.id,
original_error=commit_error
)
self.exception_handler.handle_exception(db_error)
db.rollback()
return TaskExecutionResult(
success=False,
error_message=str(e),
execution_time_ms=execution_time_ms,
retryable=True,
retry_delay=300
)
async def _execute_alwrity_task(self, task: MonitoringTask, db: Session) -> TaskExecutionResult:
"""
Execute an ALwrity (automated) monitoring task.
This is where the actual monitoring logic would go.
For now, we'll implement a placeholder that can be extended.
"""
try:
self.logger.info(f"Executing ALwrity task: {task.task_title}")
# TODO: Implement actual monitoring logic based on:
# - task.metric
# - task.measurement_method
# - task.success_criteria
# - task.alert_threshold
# Placeholder: Simulate task execution
result_data = {
'metric_value': 0,
'status': 'measured',
'message': f"Task {task.task_title} executed successfully",
'timestamp': datetime.utcnow().isoformat()
}
return TaskExecutionResult(
success=True,
result_data=result_data
)
except Exception as e:
self.logger.error(f"Error in ALwrity task execution: {e}")
return TaskExecutionResult(
success=False,
error_message=str(e),
retryable=True
)
async def _execute_human_task(self, task: MonitoringTask, db: Session) -> TaskExecutionResult:
"""
Execute a Human monitoring task (notification/queuing).
For human tasks, we don't execute the task directly,
but rather queue it for human review or send notifications.
"""
try:
self.logger.info(f"Queuing human task: {task.task_title}")
# TODO: Implement notification/queuing system:
# - Send email notification
# - Add to user's task queue
# - Create in-app notification
result_data = {
'status': 'queued',
'message': f"Task {task.task_title} queued for human review",
'timestamp': datetime.utcnow().isoformat()
}
return TaskExecutionResult(
success=True,
result_data=result_data
)
except Exception as e:
self.logger.error(f"Error queuing human task: {e}")
return TaskExecutionResult(
success=False,
error_message=str(e),
retryable=True
)
def calculate_next_execution(
self,
task: MonitoringTask,
frequency: str,
last_execution: Optional[datetime] = None
) -> datetime:
"""
Calculate next execution time based on frequency.
Args:
task: MonitoringTask instance
frequency: Frequency string (Daily, Weekly, Monthly, Quarterly)
last_execution: Last execution datetime (defaults to now)
Returns:
Next execution datetime
"""
return calculate_next_execution(
frequency=frequency,
base_time=last_execution or datetime.utcnow()
)

View File

@@ -0,0 +1,4 @@
"""
Scheduler utilities.
"""

View File

@@ -0,0 +1,33 @@
"""
Frequency Calculator Utility
Calculates next execution time based on frequency string.
"""
from datetime import datetime, timedelta
from typing import Optional
def calculate_next_execution(frequency: str, base_time: Optional[datetime] = None) -> datetime:
"""
Calculate next execution time based on frequency.
Args:
frequency: Frequency string ('Daily', 'Weekly', 'Monthly', 'Quarterly')
base_time: Base time to calculate from (defaults to now if None)
Returns:
Next execution datetime
"""
if base_time is None:
base_time = datetime.utcnow()
frequency_map = {
'Daily': timedelta(days=1),
'Weekly': timedelta(weeks=1),
'Monthly': timedelta(days=30),
'Quarterly': timedelta(days=90)
}
delta = frequency_map.get(frequency, timedelta(days=1))
return base_time + delta

View File

@@ -0,0 +1,60 @@
"""
Task Loader Utilities
Functions to load due tasks from database.
"""
from datetime import datetime
from typing import List, Optional
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import and_, or_
from models.monitoring_models import MonitoringTask
from models.enhanced_strategy_models import EnhancedContentStrategy
def load_due_monitoring_tasks(
db: Session,
user_id: Optional[int] = None
) -> List[MonitoringTask]:
"""
Load all monitoring tasks that are due for execution.
Criteria:
- status == 'active'
- next_execution <= now (or is None for first execution)
- Optional: user_id filter for specific user (for future admin features)
Note: Strategy relationship is eagerly loaded to ensure user_id is accessible
during task execution for user isolation.
Args:
db: Database session
user_id: Optional user ID to filter tasks (if None, loads all users' tasks)
Returns:
List of due MonitoringTask instances with strategy relationship loaded
"""
now = datetime.utcnow()
# Join with strategy to ensure relationship is loaded and support user filtering
query = db.query(MonitoringTask).join(
EnhancedContentStrategy,
MonitoringTask.strategy_id == EnhancedContentStrategy.id
).options(
joinedload(MonitoringTask.strategy) # Eagerly load strategy relationship
).filter(
and_(
MonitoringTask.status == 'active',
or_(
MonitoringTask.next_execution <= now,
MonitoringTask.next_execution.is_(None)
)
)
)
# Apply user filter if provided
if user_id is not None:
query = query.filter(EnhancedContentStrategy.user_id == user_id)
return query.all()

View File

@@ -0,0 +1,189 @@
"""
Pre-flight Validation Utility for Multi-Operation Workflows
Provides transparent validation for operations that involve multiple API calls.
Services can use this to validate entire workflows before making any external API calls.
"""
from typing import Dict, Any, List, Optional, Tuple
from fastapi import HTTPException
from loguru import logger
from services.subscription.pricing_service import PricingService
from models.subscription_models import APIProvider
def validate_research_operations(
pricing_service: PricingService,
user_id: str,
gpt_provider: str = "google"
) -> None:
"""
Validate all operations for a research workflow before making ANY API calls.
This prevents wasteful external API calls (like Google Grounding) if subsequent
LLM calls would fail due to token or call limits.
Args:
pricing_service: PricingService instance
user_id: User ID for subscription checking
gpt_provider: GPT provider from env var (defaults to "google")
Returns:
(can_proceed, error_message, error_details)
If can_proceed is False, raises HTTPException with 429 status
"""
try:
# Determine actual provider for LLM calls based on GPT_PROVIDER env var
gpt_provider_lower = gpt_provider.lower()
if gpt_provider_lower == "huggingface":
llm_provider_enum = APIProvider.MISTRAL # Maps to HuggingFace
llm_provider_name = "huggingface"
else:
llm_provider_enum = APIProvider.GEMINI
llm_provider_name = "gemini"
# Estimate tokens for each operation in research workflow
# Google Grounding call: ~2000 tokens (input + output)
# Keyword analyzer: ~1000 tokens (input: 3000 chars research, output: structured JSON)
# Competitor analyzer: ~1000 tokens (input: 3000 chars research, output: structured JSON)
# Content angle generator: ~1000 tokens (input: 3000 chars research, output: list of angles)
operations_to_validate = [
{
'provider': APIProvider.GEMINI, # Google Grounding uses Gemini
'tokens_requested': 2000,
'actual_provider_name': 'gemini',
'operation_type': 'google_grounding'
},
{
'provider': llm_provider_enum,
'tokens_requested': 1000,
'actual_provider_name': llm_provider_name,
'operation_type': 'keyword_analysis'
},
{
'provider': llm_provider_enum,
'tokens_requested': 1000,
'actual_provider_name': llm_provider_name,
'operation_type': 'competitor_analysis'
},
{
'provider': llm_provider_enum,
'tokens_requested': 1000,
'actual_provider_name': llm_provider_name,
'operation_type': 'content_angle_generation'
}
]
logger.info(f"[Pre-flight Validator] 🚀 Starting Research Workflow Validation")
logger.info(f" ├─ User: {user_id}")
logger.info(f" ├─ LLM Provider: {llm_provider_name} (GPT_PROVIDER={gpt_provider})")
logger.info(f" └─ Operations to validate: {len(operations_to_validate)}")
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate
)
if not can_proceed:
usage_info = error_details.get('usage_info', {}) if error_details else {}
provider = usage_info.get('provider', llm_provider_name) if usage_info else llm_provider_name
operation_type = usage_info.get('operation_type', 'unknown')
logger.error(f"[Pre-flight Validator] ❌ RESEARCH WORKFLOW BLOCKED")
logger.error(f" ├─ User: {user_id}")
logger.error(f" ├─ Blocked at: {operation_type}")
logger.error(f" ├─ Provider: {provider}")
logger.error(f" └─ Reason: {message}")
# Raise HTTPException immediately - frontend gets immediate response, no API calls made
raise HTTPException(
status_code=429,
detail={
'error': message,
'message': message,
'provider': provider,
'usage_info': usage_info if usage_info else error_details
}
)
logger.info(f"[Pre-flight Validator] ✅ RESEARCH WORKFLOW APPROVED")
logger.info(f" ├─ User: {user_id}")
logger.info(f" └─ All {len(operations_to_validate)} operations validated - proceeding with API calls")
# Validation passed - no return needed (function raises HTTPException if validation fails)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Pre-flight Validator] Error validating research operations: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
'error': f"Failed to validate operations: {str(e)}",
'message': f"Failed to validate operations: {str(e)}"
}
)
def validate_image_generation_operations(
pricing_service: PricingService,
user_id: str
) -> None:
"""
Validate image generation operation before making API calls.
Args:
pricing_service: PricingService instance
user_id: User ID for subscription checking
Returns:
(can_proceed, error_message, error_details)
If can_proceed is False, raises HTTPException with 429 status
"""
try:
operations_to_validate = [
{
'provider': APIProvider.STABILITY,
'tokens_requested': 0,
'actual_provider_name': 'stability',
'operation_type': 'image_generation'
}
]
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate
)
if not can_proceed:
logger.error(f"[Pre-flight Validator] Image generation blocked for user {user_id}: {message}")
usage_info = error_details.get('usage_info', {}) if error_details else {}
provider = usage_info.get('provider', 'stability') if usage_info else 'stability'
raise HTTPException(
status_code=429,
detail={
'error': message,
'message': message,
'provider': provider,
'usage_info': usage_info if usage_info else error_details
}
)
logger.info(f"[Pre-flight Validator] ✅ Image generation validated for user {user_id}")
# Validation passed - no return needed (function raises HTTPException if validation fails)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Pre-flight Validator] Error validating image generation: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
'error': f"Failed to validate image generation: {str(e)}",
'message': f"Failed to validate image generation: {str(e)}"
}
)

View File

@@ -3,10 +3,11 @@ Pricing Service for API Usage Tracking
Manages API pricing, cost calculation, and subscription limits.
"""
from typing import Dict, Any, Optional, List, Tuple
from typing import Dict, Any, Optional, List, Tuple, Union
from decimal import Decimal, ROUND_HALF_UP
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from sqlalchemy import text
from loguru import logger
from models.subscription_models import (
@@ -17,13 +18,17 @@ from models.subscription_models import (
class PricingService:
"""Service for managing API pricing and cost calculations."""
# Class-level cache shared across all instances (critical for cache invalidation on subscription renewal)
# key: f"{user_id}:{provider}", value: { 'result': (bool, str, dict), 'expires_at': datetime }
_limits_cache: Dict[str, Dict[str, Any]] = {}
def __init__(self, db: Session):
self.db = db
self._pricing_cache = {}
self._plans_cache = {}
# Lightweight in-process cache for limit checks
# key: f"{user_id}:{provider}", value: { 'result': (bool, str, dict), 'expires_at': datetime }
self._limits_cache: Dict[str, Dict[str, Any]] = {}
# Cache for schema feature detection (ai_text_generation_calls_limit column)
self._ai_text_gen_col_checked: bool = False
self._ai_text_gen_col_available: bool = False
# ------------------- Billing period helpers -------------------
def _compute_next_period_end(self, start: datetime, cycle: str) -> datetime:
@@ -68,6 +73,15 @@ class PricingService:
self._ensure_subscription_current(subscription)
# Continue to use YYYY-MM for summaries
return datetime.now().strftime("%Y-%m")
@classmethod
def clear_user_cache(cls, user_id: str) -> int:
"""Clear all cached limit checks for a specific user. Returns number of entries cleared."""
keys_to_remove = [key for key in cls._limits_cache.keys() if key.startswith(f"{user_id}:")]
for key in keys_to_remove:
del cls._limits_cache[key]
logger.info(f"Cleared {len(keys_to_remove)} cache entries for user {user_id}")
return len(keys_to_remove)
def initialize_default_pricing(self):
"""Initialize default pricing for all API providers."""
@@ -292,7 +306,8 @@ class PricingService:
"tier": SubscriptionTier.BASIC,
"price_monthly": 29.0,
"price_yearly": 290.0,
"gemini_calls_limit": 1000,
"ai_text_generation_calls_limit": 10, # Unified limit for all LLM providers
"gemini_calls_limit": 1000, # Legacy, kept for backwards compatibility (not used for enforcement)
"openai_calls_limit": 500,
"anthropic_calls_limit": 200,
"mistral_calls_limit": 500,
@@ -300,11 +315,11 @@ class PricingService:
"serper_calls_limit": 200,
"metaphor_calls_limit": 100,
"firecrawl_calls_limit": 100,
"stability_calls_limit": 50,
"gemini_tokens_limit": 1000000,
"openai_tokens_limit": 500000,
"anthropic_tokens_limit": 200000,
"mistral_tokens_limit": 500000,
"stability_calls_limit": 5,
"gemini_tokens_limit": 2000,
"openai_tokens_limit": 2000,
"anthropic_tokens_limit": 2000,
"mistral_tokens_limit": 2000,
"monthly_cost_limit": 50.0,
"features": ["full_content_generation", "advanced_research", "basic_analytics"],
"description": "Great for individuals and small teams"
@@ -426,21 +441,60 @@ class PricingService:
self._ensure_subscription_current(subscription)
return self._plan_to_limits_dict(subscription.plan)
def _ensure_ai_text_gen_column_detection(self) -> None:
"""Detect at runtime whether ai_text_generation_calls_limit column exists and cache the result."""
if self._ai_text_gen_col_checked:
return
try:
# Try to query the column - if it exists, this will work
self.db.execute(text('SELECT ai_text_generation_calls_limit FROM subscription_plans LIMIT 0'))
self._ai_text_gen_col_available = True
except Exception:
self._ai_text_gen_col_available = False
finally:
self._ai_text_gen_col_checked = True
def _plan_to_limits_dict(self, plan: SubscriptionPlan) -> Dict[str, Any]:
"""Convert subscription plan to limits dictionary."""
# Detect if unified AI text generation limit column exists
self._ensure_ai_text_gen_column_detection()
# Use unified AI text generation limit if column exists and is set
ai_text_gen_limit = None
if self._ai_text_gen_col_available:
try:
ai_text_gen_limit = getattr(plan, 'ai_text_generation_calls_limit', None)
# If 0, treat as not set (unlimited for Enterprise or use fallback)
if ai_text_gen_limit == 0:
ai_text_gen_limit = None
except (AttributeError, Exception):
# Column exists but access failed - use fallback
ai_text_gen_limit = None
return {
'plan_name': plan.name,
'tier': plan.tier.value,
'limits': {
# Unified AI text generation limit (applies to all LLM providers)
# If not set, fall back to first non-zero legacy limit for backwards compatibility
'ai_text_generation_calls': ai_text_gen_limit if ai_text_gen_limit is not None else (
plan.gemini_calls_limit if plan.gemini_calls_limit > 0 else
plan.openai_calls_limit if plan.openai_calls_limit > 0 else
plan.anthropic_calls_limit if plan.anthropic_calls_limit > 0 else
plan.mistral_calls_limit if plan.mistral_calls_limit > 0 else 0
),
# Legacy per-provider limits (for backwards compatibility and analytics)
'gemini_calls': plan.gemini_calls_limit,
'openai_calls': plan.openai_calls_limit,
'anthropic_calls': plan.anthropic_calls_limit,
'mistral_calls': plan.mistral_calls_limit,
# Other API limits
'tavily_calls': plan.tavily_calls_limit,
'serper_calls': plan.serper_calls_limit,
'metaphor_calls': plan.metaphor_calls_limit,
'firecrawl_calls': plan.firecrawl_calls_limit,
'stability_calls': plan.stability_calls_limit,
# Token limits
'gemini_tokens': plan.gemini_tokens_limit,
'openai_tokens': plan.openai_tokens_limit,
'anthropic_tokens': plan.anthropic_tokens_limit,
@@ -451,101 +505,293 @@ class PricingService:
}
def check_usage_limits(self, user_id: str, provider: APIProvider,
tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]:
"""Check if user can make an API call within their limits."""
# Short TTL cache to reduce DB reads under sustained traffic
cache_key = f"{user_id}:{provider.value}"
now = datetime.utcnow()
cached = self._limits_cache.get(cache_key)
if cached and cached.get('expires_at') and cached['expires_at'] > now:
return tuple(cached['result']) # type: ignore
# Get user limits
limits = self.get_user_limits(user_id)
if not limits:
return False, "No subscription plan found", {}
tokens_requested: int = 0, actual_provider_name: Optional[str] = None) -> Tuple[bool, str, Dict[str, Any]]:
"""Check if user can make an API call within their limits.
# Get current usage for this billing period
current_period = self.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not usage:
# First usage this period, create summary
usage = UsageSummary(
user_id=user_id,
billing_period=current_period
)
self.db.add(usage)
self.db.commit()
# Check call limits
provider_name = provider.value
current_calls = getattr(usage, f"{provider_name}_calls", 0)
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
if call_limit > 0 and current_calls >= call_limit:
result = (False, f"API call limit reached for {provider_name}", {
'current_calls': current_calls,
'limit': call_limit,
'usage_percentage': 100.0
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
# Check token limits for LLM providers
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
current_tokens = getattr(usage, f"{provider_name}_tokens", 0)
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0)
Args:
user_id: User ID
provider: APIProvider enum (may be MISTRAL for HuggingFace)
tokens_requested: Estimated tokens for the request
actual_provider_name: Optional actual provider name (e.g., "huggingface" when provider is MISTRAL)
"""
try:
# Use actual_provider_name if provided, otherwise use enum value
# This fixes cases where HuggingFace maps to MISTRAL enum but should show as "huggingface" in errors
display_provider_name = actual_provider_name or provider.value
if token_limit > 0 and (current_tokens + tokens_requested) > token_limit:
result = (False, f"Token limit would be exceeded for {provider_name}", {
'current_tokens': current_tokens,
'requested_tokens': tokens_requested,
'limit': token_limit,
'usage_percentage': ((current_tokens + tokens_requested) / token_limit) * 100
logger.debug(f"[Subscription Check] Starting limit check for user {user_id}, provider {display_provider_name}, tokens {tokens_requested}")
# Short TTL cache to reduce DB reads under sustained traffic
cache_key = f"{user_id}:{provider.value}"
now = datetime.utcnow()
cached = self._limits_cache.get(cache_key)
if cached and cached.get('expires_at') and cached['expires_at'] > now:
logger.debug(f"[Subscription Check] Using cached result for {user_id}:{provider.value}")
return tuple(cached['result']) # type: ignore
# Get user subscription first to check expiration
subscription = self.db.query(UserSubscription).filter(
UserSubscription.user_id == user_id,
UserSubscription.is_active == True
).first()
if subscription:
logger.debug(f"[Subscription Check] Found subscription for user {user_id}: plan_id={subscription.plan_id}, period_end={subscription.current_period_end}")
else:
logger.debug(f"[Subscription Check] No active subscription found for user {user_id}")
# Check subscription expiration (STRICT: deny if expired)
if subscription:
if subscription.current_period_end < now:
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}: period_end={subscription.current_period_end}, now={now}")
# Subscription expired - check if auto_renew is enabled
if not getattr(subscription, 'auto_renew', False):
# Expired and no auto-renew - deny access
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}, auto_renew=False, denying access")
result = (False, "Subscription expired. Please renew your subscription to continue using the service.", {
'expired': True,
'period_end': subscription.current_period_end.isoformat()
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
else:
# Try to auto-renew
if not self._ensure_subscription_current(subscription):
# Auto-renew failed - deny access
result = (False, "Subscription expired and auto-renewal failed. Please renew manually.", {
'expired': True,
'auto_renew_failed': True
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
# Get user limits with error handling (STRICT: fail on errors)
try:
limits = self.get_user_limits(user_id)
if limits:
logger.debug(f"[Subscription Check] Retrieved limits for user {user_id}: plan={limits.get('plan_name')}, tier={limits.get('tier')}")
else:
logger.debug(f"[Subscription Check] No limits found for user {user_id}, checking free tier")
except Exception as e:
logger.error(f"[Subscription Check] Error getting user limits for {user_id}: {e}", exc_info=True)
# STRICT: Fail closed - deny request if we can't check limits
return False, f"Failed to retrieve subscription limits: {str(e)}", {}
if not limits:
# No subscription found - check for free tier
free_plan = self.db.query(SubscriptionPlan).filter(
SubscriptionPlan.tier == SubscriptionTier.FREE,
SubscriptionPlan.is_active == True
).first()
if free_plan:
logger.info(f"[Subscription Check] Assigning free tier to user {user_id}")
limits = self._plan_to_limits_dict(free_plan)
else:
# No subscription and no free tier - deny access
logger.warning(f"[Subscription Check] No subscription or free tier found for user {user_id}, denying access")
return False, "No subscription plan found. Please subscribe to a plan.", {}
# Get current usage for this billing period with error handling
try:
current_period = self.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not usage:
# First usage this period, create summary
try:
usage = UsageSummary(
user_id=user_id,
billing_period=current_period
)
self.db.add(usage)
self.db.commit()
except Exception as create_error:
logger.error(f"Error creating usage summary: {create_error}")
self.db.rollback()
# STRICT: Fail closed on DB error
return False, f"Failed to create usage summary: {str(create_error)}", {}
except Exception as e:
logger.error(f"Error getting usage summary for {user_id}: {e}")
self.db.rollback()
# STRICT: Fail closed on DB error
return False, f"Failed to retrieve usage summary: {str(e)}", {}
# Check call limits with error handling
# NOTE: call_limit = 0 means UNLIMITED (Enterprise plans)
try:
# Use display_provider_name for error messages, but provider.value for DB queries
provider_name = provider.value # For DB field names (e.g., "mistral_calls", "mistral_tokens")
# For LLM text generation providers, check against unified total_calls limit
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
is_llm_provider = provider_name in llm_providers
if is_llm_provider:
# Use unified AI text generation limit (total_calls across all LLM providers)
ai_text_gen_limit = limits['limits'].get('ai_text_generation_calls', 0) or 0
# If unified limit not set, fall back to provider-specific limit for backwards compatibility
if ai_text_gen_limit == 0:
ai_text_gen_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
# Calculate total LLM provider calls (sum of gemini + openai + anthropic + mistral)
current_total_llm_calls = (
(usage.gemini_calls or 0) +
(usage.openai_calls or 0) +
(usage.anthropic_calls or 0) +
(usage.mistral_calls or 0)
)
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if ai_text_gen_limit > 0 and current_total_llm_calls >= ai_text_gen_limit:
logger.error(f"[Subscription Check] AI text generation call limit exceeded for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit} (provider: {display_provider_name})")
result = (False, f"AI text generation call limit reached. Used {current_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls this billing period.", {
'current_calls': current_total_llm_calls,
'limit': ai_text_gen_limit,
'usage_percentage': (current_total_llm_calls / ai_text_gen_limit) * 100 if ai_text_gen_limit > 0 else 0,
'provider': display_provider_name, # Use display name for consistency
'usage_info': {
'provider': display_provider_name, # Use display name for user-facing info
'current_calls': current_total_llm_calls,
'limit': ai_text_gen_limit,
'type': 'ai_text_generation',
'breakdown': {
'gemini': usage.gemini_calls or 0,
'openai': usage.openai_calls or 0,
'anthropic': usage.anthropic_calls or 0,
'mistral': usage.mistral_calls or 0 # DB field name (not display name)
}
}
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
else:
logger.debug(f"[Subscription Check] AI text generation limit check passed for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit if ai_text_gen_limit > 0 else 'unlimited'} (provider: {display_provider_name})")
else:
# For non-LLM providers, check provider-specific limit
current_calls = getattr(usage, f"{provider_name}_calls", 0) or 0
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if call_limit > 0 and current_calls >= call_limit:
logger.error(f"[Subscription Check] Call limit exceeded for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit}")
result = (False, f"API call limit reached for {display_provider_name}. Used {current_calls} of {call_limit} calls this billing period.", {
'current_calls': current_calls,
'limit': call_limit,
'usage_percentage': 100.0,
'provider': display_provider_name # Use display name for consistency
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
else:
logger.debug(f"[Subscription Check] Call limit check passed for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit if call_limit > 0 else 'unlimited'}")
except Exception as e:
logger.error(f"Error checking call limits: {e}")
# Continue to next check
# Check token limits for LLM providers with error handling
# NOTE: token_limit = 0 means UNLIMITED (Enterprise plans)
try:
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
current_tokens = getattr(usage, f"{provider_name}_tokens", 0) or 0
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if token_limit > 0 and (current_tokens + tokens_requested) > token_limit:
result = (False, f"Token limit would be exceeded for {display_provider_name}. Current: {current_tokens}, Requested: {tokens_requested}, Limit: {token_limit}", {
'current_tokens': current_tokens,
'requested_tokens': tokens_requested,
'limit': token_limit,
'usage_percentage': ((current_tokens + tokens_requested) / token_limit) * 100,
'provider': display_provider_name, # Use display name in error details
'usage_info': {
'provider': display_provider_name,
'current_tokens': current_tokens,
'requested_tokens': tokens_requested,
'limit': token_limit,
'type': 'tokens'
}
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
except Exception as e:
logger.error(f"Error checking token limits: {e}")
# Continue to next check
# Check cost limits with error handling
# NOTE: cost_limit = 0 means UNLIMITED (Enterprise plans)
try:
cost_limit = limits['limits'].get('monthly_cost', 0) or 0
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if cost_limit > 0 and usage.total_cost >= cost_limit:
result = (False, f"Monthly cost limit reached. Current cost: ${usage.total_cost:.2f}, Limit: ${cost_limit:.2f}", {
'current_cost': usage.total_cost,
'limit': cost_limit,
'usage_percentage': 100.0
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
except Exception as e:
logger.error(f"Error checking cost limits: {e}")
# Continue to success case
# Calculate usage percentages for warnings
try:
# Determine which call variables to use based on provider type
if is_llm_provider:
# Use unified LLM call tracking
current_call_count = current_total_llm_calls
call_limit_value = ai_text_gen_limit
else:
# Use provider-specific call tracking
current_call_count = current_calls
call_limit_value = call_limit
call_usage_pct = (current_call_count / max(call_limit_value, 1)) * 100 if call_limit_value > 0 else 0
cost_usage_pct = (usage.total_cost / max(cost_limit, 1)) * 100 if cost_limit > 0 else 0
result = (True, "Within limits", {
'current_calls': current_call_count,
'call_limit': call_limit_value,
'call_usage_percentage': call_usage_pct,
'current_cost': usage.total_cost,
'cost_limit': cost_limit,
'cost_usage_percentage': cost_usage_pct
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
except Exception as e:
logger.error(f"Error calculating usage percentages: {e}")
# Return basic success
return True, "Within limits", {}
# Check cost limits
cost_limit = limits['limits'].get('monthly_cost', 0)
if cost_limit > 0 and usage.total_cost >= cost_limit:
result = (False, "Monthly cost limit reached", {
'current_cost': usage.total_cost,
'limit': cost_limit,
'usage_percentage': 100.0
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
# Calculate usage percentages for warnings
call_usage_pct = (current_calls / max(call_limit, 1)) * 100 if call_limit > 0 else 0
cost_usage_pct = (usage.total_cost / max(cost_limit, 1)) * 100 if cost_limit > 0 else 0
result = (True, "Within limits", {
'current_calls': current_calls,
'call_limit': call_limit,
'call_usage_percentage': call_usage_pct,
'current_cost': usage.total_cost,
'cost_limit': cost_limit,
'cost_usage_percentage': cost_usage_pct
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
except Exception as e:
logger.error(f"Unexpected error in check_usage_limits for {user_id}: {e}")
# STRICT: Fail closed - deny requests if subscription system fails
return False, f"Subscription check error: {str(e)}", {}
def estimate_tokens(self, text: str, provider: APIProvider) -> int:
"""Estimate token count for text based on provider."""
@@ -581,6 +827,236 @@ class PricingService:
if not pricing:
return None
def check_comprehensive_limits(
self,
user_id: str,
operations: List[Dict[str, Any]]
) -> Tuple[bool, Optional[str], Optional[Dict[str, Any]]]:
"""
Comprehensive pre-flight validation that checks ALL limits before making ANY API calls.
This prevents wasteful API calls by validating that ALL subsequent operations will succeed
before making the first external API call.
Args:
user_id: User ID
operations: List of operations to validate, each with:
- 'provider': APIProvider enum
- 'tokens_requested': int (estimated tokens for LLM calls, 0 for non-LLM)
- 'actual_provider_name': Optional[str] (e.g., "huggingface" when provider is MISTRAL)
- 'operation_type': str (e.g., "google_grounding", "llm_call", "image_generation")
Returns:
(can_proceed, error_message, error_details)
If can_proceed is False, error_message explains which limit would be exceeded
"""
try:
logger.info(f"[Pre-flight Check] 🔍 Starting comprehensive validation for user {user_id}")
logger.info(f"[Pre-flight Check] 📋 Validating {len(operations)} operation(s) before making any API calls")
# Get current usage and limits once
current_period = self.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not usage:
# First usage this period, create summary
try:
usage = UsageSummary(
user_id=user_id,
billing_period=current_period
)
self.db.add(usage)
self.db.commit()
except Exception as create_error:
logger.error(f"Error creating usage summary: {create_error}")
self.db.rollback()
return False, f"Failed to create usage summary: {str(create_error)}", {}
# Get user limits
limits_dict = self.get_user_limits(user_id)
if not limits_dict:
# No subscription found - check for free tier
free_plan = self.db.query(SubscriptionPlan).filter(
SubscriptionPlan.tier == SubscriptionTier.FREE,
SubscriptionPlan.is_active == True
).first()
if free_plan:
limits_dict = self._plan_to_limits_dict(free_plan)
else:
return False, "No subscription plan found. Please subscribe to a plan.", {}
limits = limits_dict.get('limits', {})
# Track cumulative usage across all operations
total_llm_calls = (
(usage.gemini_calls or 0) +
(usage.openai_calls or 0) +
(usage.anthropic_calls or 0) +
(usage.mistral_calls or 0)
)
total_llm_tokens = {}
total_images = usage.stability_calls or 0
# Log current usage summary
logger.info(f"[Pre-flight Check] 📊 Current Usage Summary:")
logger.info(f" └─ Total LLM Calls: {total_llm_calls}")
logger.info(f" └─ Gemini Tokens: {usage.gemini_tokens or 0}, Mistral/HF Tokens: {usage.mistral_tokens or 0}")
logger.info(f" └─ Image Calls: {total_images}")
# Validate each operation
for op_idx, operation in enumerate(operations):
provider = operation.get('provider')
provider_name = provider.value if hasattr(provider, 'value') else str(provider)
tokens_requested = operation.get('tokens_requested', 0)
actual_provider_name = operation.get('actual_provider_name')
operation_type = operation.get('operation_type', 'unknown')
display_provider_name = actual_provider_name or provider_name
logger.info(f"[Pre-flight Check] ✅ Operation {op_idx + 1}/{len(operations)}: {operation_type}")
logger.info(f" ├─ Provider: {display_provider_name} (enum: {provider_name})")
logger.info(f" └─ Estimated Tokens: {tokens_requested}")
# Check if this is an LLM provider
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
is_llm_provider = provider_name in llm_providers
# Check unified AI text generation limit for LLM providers
if is_llm_provider:
ai_text_gen_limit = limits.get('ai_text_generation_calls', 0) or 0
if ai_text_gen_limit == 0:
# Fallback to provider-specific limit
ai_text_gen_limit = limits.get(f"{provider_name}_calls", 0) or 0
# Count this operation as an LLM call
projected_total_llm_calls = total_llm_calls + 1
if ai_text_gen_limit > 0 and projected_total_llm_calls > ai_text_gen_limit:
error_info = {
'current_calls': total_llm_calls,
'limit': ai_text_gen_limit,
'provider': display_provider_name,
'operation_type': operation_type,
'operation_index': op_idx
}
return False, f"AI text generation call limit would be exceeded. Would use {projected_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls.", {
'error_type': 'call_limit',
'usage_info': error_info
}
# Check token limits for this provider
# Use cumulative projected tokens from previous operations, or current from DB if first operation
provider_tokens_key = f"{provider_name}_tokens"
if provider_tokens_key in total_llm_tokens:
# Use cumulative projected tokens from previous operations
current_provider_tokens = total_llm_tokens[provider_tokens_key]
logger.info(f" └─ Using cumulative projected tokens: {current_provider_tokens}")
else:
# First operation for this provider - get current from database
current_provider_tokens = getattr(usage, provider_tokens_key, 0) or 0
total_llm_tokens[provider_tokens_key] = current_provider_tokens
logger.info(f" └─ Current tokens from DB: {current_provider_tokens}")
token_limit = limits.get(provider_tokens_key, 0) or 0
if token_limit > 0 and tokens_requested > 0:
projected_tokens = current_provider_tokens + tokens_requested
logger.info(f" └─ Token Check: {current_provider_tokens} (current) + {tokens_requested} (requested) = {projected_tokens} (total) / {token_limit} (limit)")
if projected_tokens > token_limit:
usage_percentage = (projected_tokens / token_limit) * 100 if token_limit > 0 else 0
error_info = {
'current_tokens': current_provider_tokens,
'requested_tokens': tokens_requested,
'limit': token_limit,
'provider': display_provider_name,
'operation_type': operation_type,
'operation_index': op_idx
}
error_msg = (
f"Token limit exceeded for {display_provider_name} "
f"({operation_type}). "
f"Current: {current_provider_tokens}/{token_limit}, "
f"Requested: {tokens_requested}, "
f"Would exceed by: {projected_tokens - token_limit} tokens "
f"({usage_percentage:.1f}% of limit)"
)
logger.error(f"[Pre-flight Check] ❌ BLOCKED: {error_msg}")
return False, error_msg, {
'error_type': 'token_limit',
'usage_info': error_info
}
else:
logger.info(f" └─ ✅ Token limit check passed: {projected_tokens} <= {token_limit}")
# Update cumulative counts for next operation
total_llm_calls = projected_total_llm_calls
total_llm_tokens[provider_tokens_key] += tokens_requested
logger.info(f" └─ Updated cumulative tokens for {display_provider_name}: {total_llm_tokens[provider_tokens_key]}")
# Check image generation limits
elif provider == APIProvider.STABILITY:
image_limit = limits.get('stability_calls', 0) or 0
projected_images = total_images + 1
if image_limit > 0 and projected_images > image_limit:
error_info = {
'current_images': total_images,
'limit': image_limit,
'provider': 'stability',
'operation_type': operation_type,
'operation_index': op_idx
}
return False, f"Image generation limit would be exceeded. Would use {projected_images} of {image_limit} images this billing period.", {
'error_type': 'image_limit',
'usage_info': error_info
}
total_images = projected_images
# Check other provider-specific limits
else:
provider_calls_key = f"{provider_name}_calls"
current_provider_calls = getattr(usage, provider_calls_key, 0) or 0
call_limit = limits.get(provider_calls_key, 0) or 0
if call_limit > 0:
projected_calls = current_provider_calls + 1
if projected_calls > call_limit:
error_info = {
'current_calls': current_provider_calls,
'limit': call_limit,
'provider': display_provider_name,
'operation_type': operation_type,
'operation_index': op_idx
}
return False, f"API call limit would be exceeded for {display_provider_name}. Would use {projected_calls} of {call_limit} calls this billing period.", {
'error_type': 'call_limit',
'usage_info': error_info
}
# All checks passed
logger.info(f"[Pre-flight Check] ✅ All {len(operations)} operation(s) validated successfully")
logger.info(f"[Pre-flight Check] ✅ User {user_id} is cleared to proceed with API calls")
return True, None, None
except Exception as e:
logger.error(f"[Pre-flight Check] Error during comprehensive limit check: {e}", exc_info=True)
return False, f"Failed to validate limits: {str(e)}", {}
def get_pricing_for_provider_model(self, provider: APIProvider, model_name: str) -> Optional[Dict[str, Any]]:
"""Get pricing configuration for a specific provider and model."""
pricing = self.db.query(APIProviderPricing).filter(
APIProviderPricing.provider == provider,
APIProviderPricing.model_name == model_name
).first()
if not pricing:
return None
return {
'provider': pricing.provider.value,
'model_name': pricing.model_name,

View File

@@ -502,7 +502,7 @@ class UsageTrackingService:
return result
async def reset_current_billing_period(self, user_id: str) -> Dict[str, Any]:
"""Reset usage status for the current billing period (after plan change)."""
"""Reset usage status and counters for the current billing period (after plan renewal/change)."""
try:
billing_period = datetime.now().strftime("%Y-%m")
summary = self.db.query(UsageSummary).filter(
@@ -514,11 +514,52 @@ class UsageTrackingService:
# Nothing to reset
return {"reset": False, "reason": "no_summary"}
# Clear LIMIT_REACHED so the user can resume; keep counters intact
# CRITICAL: Reset ALL usage counters to 0 so user gets fresh limits with new/renewed plan
# Clear LIMIT_REACHED status
summary.usage_status = UsageStatus.ACTIVE
# Reset all LLM provider call counters
summary.gemini_calls = 0
summary.openai_calls = 0
summary.anthropic_calls = 0
summary.mistral_calls = 0
# Reset all LLM provider token counters
summary.gemini_tokens = 0
summary.openai_tokens = 0
summary.anthropic_tokens = 0
summary.mistral_tokens = 0
# Reset search/research provider counters
summary.tavily_calls = 0
summary.serper_calls = 0
summary.metaphor_calls = 0
summary.firecrawl_calls = 0
# Reset image generation counters
summary.stability_calls = 0
# Reset cost counters
summary.gemini_cost = 0.0
summary.openai_cost = 0.0
summary.anthropic_cost = 0.0
summary.mistral_cost = 0.0
summary.tavily_cost = 0.0
summary.serper_cost = 0.0
summary.metaphor_cost = 0.0
summary.firecrawl_cost = 0.0
summary.stability_cost = 0.0
# Reset totals
summary.total_calls = 0
summary.total_tokens = 0
summary.total_cost = 0.0
summary.updated_at = datetime.utcnow()
self.db.commit()
return {"reset": True}
logger.info(f"Reset usage counters for user {user_id} in billing period {billing_period} after renewal")
return {"reset": True, "counters_reset": True}
except Exception as e:
self.db.rollback()
logger.error(f"Error resetting usage status: {e}")