fix: WYSIWYG editor, content generation, and writing assistant bug fixes
- Fix text selection menu not showing: wire contentRef via inputRef on multiline TextField - Fix blog title not truncating: add min-w-0 for flex item overflow - Fix outline generation 500: escape curly braces in f-string prompt template - Fix content generation 'NoneType not callable': replace SessionLocal() with get_session_for_user(), add db param to MediumBlogGenerator, fix signature mismatch in database_task_manager - Fix writing assistant suggest 500: add auth + user_id to API endpoint and service, replace sync requests with httpx.AsyncClient - Fix hallucination detector 404: explicitly include router in main.py and app.py - Fix missing error_data in task failure responses - Hide CopilotKit web inspector button - Remove hardcoded fallback suggestions from SmartTypingAssist - Fix stale closure refs in SmartTypingAssist handleTypingChange - Add two-column editor layout, stats bar, section hover menu - Various subscription, billing, and research module improvements
This commit is contained in:
@@ -9,6 +9,7 @@ import json
|
||||
from typing import Dict, Any, List
|
||||
from loguru import logger
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.blog_models import (
|
||||
MediumBlogGenerateRequest,
|
||||
@@ -26,7 +27,7 @@ class MediumBlogGenerator:
|
||||
def __init__(self):
|
||||
self.cache = persistent_content_cache
|
||||
|
||||
async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str, user_id: str) -> MediumBlogGenerateResult:
|
||||
async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str, user_id: str, db: Session = None) -> MediumBlogGenerateResult:
|
||||
"""Use Gemini structured JSON to generate a medium-length blog in one call.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -499,7 +499,7 @@ class DatabaseTaskManager:
|
||||
)
|
||||
blog_writer_logger.log_error(e, "outline_generation_task", context={"task_id": task_id})
|
||||
|
||||
async def _run_medium_generation_task(self, task_id: str, request: MediumBlogGenerateRequest):
|
||||
async def _run_medium_generation_task(self, task_id: str, request: MediumBlogGenerateRequest, user_id: str):
|
||||
"""Background task to generate a medium blog using a single structured JSON call."""
|
||||
try:
|
||||
await self.update_progress(task_id, "📦 Packaging outline and metadata...", 0)
|
||||
@@ -512,7 +512,7 @@ class DatabaseTaskManager:
|
||||
result: MediumBlogGenerateResult = await self.service.generate_medium_blog_with_progress(
|
||||
request,
|
||||
task_id,
|
||||
user_id=request.user_id if hasattr(request, 'user_id') else (await self.get_task_status(task_id))['user_id'],
|
||||
user_id,
|
||||
db=self.db
|
||||
)
|
||||
|
||||
|
||||
@@ -70,22 +70,22 @@ STRATEGIC REQUIREMENTS:
|
||||
- Ensure engaging, actionable content throughout
|
||||
|
||||
Return JSON format:
|
||||
{
|
||||
{{
|
||||
"title_options": [
|
||||
"Title option 1",
|
||||
"Title option 2",
|
||||
"Title option 3"
|
||||
],
|
||||
"outline": [
|
||||
{
|
||||
{{
|
||||
"heading": "Section heading with primary keyword",
|
||||
"subheadings": ["Subheading 1", "Subheading 2", "Subheading 3"],
|
||||
"key_points": ["Key point 1", "Key point 2", "Key point 3"],
|
||||
"target_words": 300,
|
||||
"keywords": ["primary keyword", "secondary keyword"]
|
||||
}
|
||||
}}
|
||||
]
|
||||
}"""
|
||||
}}"""
|
||||
|
||||
def get_outline_schema(self) -> Dict[str, Any]:
|
||||
"""Get the structured JSON schema for outline generation."""
|
||||
|
||||
@@ -5,8 +5,8 @@ Enhances individual outline sections for better engagement and value.
|
||||
"""
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from models.blog_models import BlogOutlineSection
|
||||
import json
|
||||
|
||||
|
||||
class SectionEnhancer:
|
||||
@@ -73,14 +73,45 @@ class SectionEnhancer:
|
||||
"required": ["heading", "subheadings", "key_points", "target_words", "keywords"]
|
||||
}
|
||||
|
||||
enhanced_data = llm_text_gen(
|
||||
raw = llm_text_gen(
|
||||
prompt=enhancement_prompt,
|
||||
json_struct=enhancement_schema,
|
||||
system_prompt=None,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if isinstance(enhanced_data, dict) and 'error' not in enhanced_data:
|
||||
# Parse JSON from LLM response (works with both string and dict return types)
|
||||
import re
|
||||
if isinstance(raw, str):
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith('```json'):
|
||||
cleaned = cleaned[7:]
|
||||
if cleaned.startswith('```'):
|
||||
cleaned = cleaned[3:]
|
||||
if cleaned.endswith('```'):
|
||||
cleaned = cleaned[:-3]
|
||||
cleaned = cleaned.strip()
|
||||
try:
|
||||
enhanced_data = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
|
||||
if json_match:
|
||||
try:
|
||||
enhanced_data = json.loads(json_match.group(0))
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Section enhancement returned invalid JSON: {e}")
|
||||
return section
|
||||
else:
|
||||
logger.warning(f"Section enhancement returned non-JSON string: {cleaned[:200]}")
|
||||
return section
|
||||
elif isinstance(raw, dict):
|
||||
enhanced_data = raw
|
||||
else:
|
||||
logger.warning(f"Unexpected LLM response type: {type(raw)}")
|
||||
return section
|
||||
|
||||
if 'error' in enhanced_data:
|
||||
logger.warning(f"AI section enhancement failed: {enhanced_data.get('error', 'Unknown error')}")
|
||||
else:
|
||||
return BlogOutlineSection(
|
||||
id=section.id,
|
||||
heading=enhanced_data.get('heading', section.heading),
|
||||
|
||||
@@ -6,6 +6,7 @@ Extracts competitor insights and market intelligence from research content.
|
||||
|
||||
from typing import Dict, Any
|
||||
from loguru import logger
|
||||
import json
|
||||
|
||||
|
||||
class CompetitorAnalyzer:
|
||||
@@ -22,7 +23,7 @@ class CompetitorAnalyzer:
|
||||
Extract and analyze:
|
||||
1. Top competitors mentioned (companies, brands, platforms)
|
||||
2. Content gaps (what competitors are missing)
|
||||
3. Market opportunities (untapped areas)
|
||||
3. Opportunities (untapped areas)
|
||||
4. Competitive advantages (what makes content unique)
|
||||
5. Market positioning insights
|
||||
6. Industry leaders and their strategies
|
||||
@@ -55,18 +56,38 @@ class CompetitorAnalyzer:
|
||||
"required": ["top_competitors", "content_gaps", "opportunities", "competitive_advantages", "market_positioning", "industry_leaders", "analysis_notes"]
|
||||
}
|
||||
|
||||
competitor_analysis = llm_text_gen(
|
||||
raw = llm_text_gen(
|
||||
prompt=competitor_prompt,
|
||||
json_struct=competitor_schema,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if isinstance(competitor_analysis, dict) and 'error' not in competitor_analysis:
|
||||
logger.info("✅ AI competitor analysis completed successfully")
|
||||
return competitor_analysis
|
||||
# Parse JSON from LLM response (works with both string and dict return types)
|
||||
import re
|
||||
if isinstance(raw, str):
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith('```json'):
|
||||
cleaned = cleaned[7:]
|
||||
if cleaned.startswith('```'):
|
||||
cleaned = cleaned[3:]
|
||||
if cleaned.endswith('```'):
|
||||
cleaned = cleaned[:-3]
|
||||
cleaned = cleaned.strip()
|
||||
try:
|
||||
competitor_analysis = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
|
||||
if json_match:
|
||||
competitor_analysis = json.loads(json_match.group(0))
|
||||
else:
|
||||
raise ValueError(f"Competitor analysis returned non-JSON string: {cleaned[:200]}")
|
||||
elif isinstance(raw, dict):
|
||||
competitor_analysis = raw
|
||||
else:
|
||||
# Fail gracefully - no fallback data
|
||||
error_msg = competitor_analysis.get('error', 'Unknown error') if isinstance(competitor_analysis, dict) else str(competitor_analysis)
|
||||
logger.error(f"AI competitor analysis failed: {error_msg}")
|
||||
raise ValueError(f"Competitor analysis failed: {error_msg}")
|
||||
raise ValueError(f"Unexpected LLM response type: {type(raw)}")
|
||||
|
||||
if 'error' in competitor_analysis:
|
||||
raise ValueError(f"Competitor analysis failed: {competitor_analysis.get('error', 'Unknown error')}")
|
||||
|
||||
logger.info("✅ AI competitor analysis completed successfully")
|
||||
return competitor_analysis
|
||||
|
||||
|
||||
@@ -63,18 +63,41 @@ class ContentAngleGenerator:
|
||||
"required": ["content_angles"]
|
||||
}
|
||||
|
||||
angles_result = llm_text_gen(
|
||||
raw = llm_text_gen(
|
||||
prompt=angles_prompt,
|
||||
json_struct=angles_schema,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if isinstance(angles_result, dict) and 'content_angles' in angles_result:
|
||||
logger.info("✅ AI content angles generation completed successfully")
|
||||
return angles_result['content_angles'][:7]
|
||||
# Parse JSON from LLM response (works with both string and dict return types)
|
||||
import json, re
|
||||
if isinstance(raw, str):
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith('```json'):
|
||||
cleaned = cleaned[7:]
|
||||
if cleaned.startswith('```'):
|
||||
cleaned = cleaned[3:]
|
||||
if cleaned.endswith('```'):
|
||||
cleaned = cleaned[:-3]
|
||||
cleaned = cleaned.strip()
|
||||
try:
|
||||
angles_result = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
|
||||
if json_match:
|
||||
angles_result = json.loads(json_match.group(0))
|
||||
else:
|
||||
raise ValueError(f"Content angles returned non-JSON string: {cleaned[:200]}")
|
||||
elif isinstance(raw, dict):
|
||||
angles_result = raw
|
||||
else:
|
||||
# Fail gracefully - no fallback data
|
||||
error_msg = angles_result.get('error', 'Unknown error') if isinstance(angles_result, dict) else str(angles_result)
|
||||
logger.error(f"AI content angles generation failed: {error_msg}")
|
||||
raise ValueError(f"Content angles generation failed: {error_msg}")
|
||||
raise ValueError(f"Unexpected LLM response type: {type(raw)}")
|
||||
|
||||
if 'error' in angles_result:
|
||||
raise ValueError(f"Content angles generation failed: {angles_result.get('error', 'Unknown error')}")
|
||||
|
||||
if 'content_angles' not in angles_result:
|
||||
raise ValueError(f"Content angles missing from response")
|
||||
|
||||
logger.info("✅ AI content angles generation completed successfully")
|
||||
return angles_result['content_angles'][:7]
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ Extracts and analyzes keywords from research content using structured AI respons
|
||||
|
||||
from typing import Dict, Any, List
|
||||
from loguru import logger
|
||||
import json
|
||||
|
||||
|
||||
class KeywordAnalyzer:
|
||||
@@ -62,18 +63,38 @@ class KeywordAnalyzer:
|
||||
"required": ["primary", "secondary", "long_tail", "search_intent", "difficulty", "content_gaps", "semantic_keywords", "trending_terms", "analysis_insights"]
|
||||
}
|
||||
|
||||
keyword_analysis = llm_text_gen(
|
||||
raw = llm_text_gen(
|
||||
prompt=keyword_prompt,
|
||||
json_struct=keyword_schema,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if isinstance(keyword_analysis, dict) and 'error' not in keyword_analysis:
|
||||
logger.info("✅ AI keyword analysis completed successfully")
|
||||
return keyword_analysis
|
||||
# Parse JSON from LLM response (works with both string and dict return types)
|
||||
import re
|
||||
if isinstance(raw, str):
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith('```json'):
|
||||
cleaned = cleaned[7:]
|
||||
if cleaned.startswith('```'):
|
||||
cleaned = cleaned[3:]
|
||||
if cleaned.endswith('```'):
|
||||
cleaned = cleaned[:-3]
|
||||
cleaned = cleaned.strip()
|
||||
try:
|
||||
keyword_analysis = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
|
||||
if json_match:
|
||||
keyword_analysis = json.loads(json_match.group(0))
|
||||
else:
|
||||
raise ValueError(f"Keyword analysis returned non-JSON string: {cleaned[:200]}")
|
||||
elif isinstance(raw, dict):
|
||||
keyword_analysis = raw
|
||||
else:
|
||||
# Fail gracefully - no fallback data
|
||||
error_msg = keyword_analysis.get('error', 'Unknown error') if isinstance(keyword_analysis, dict) else str(keyword_analysis)
|
||||
logger.error(f"AI keyword analysis failed: {error_msg}")
|
||||
raise ValueError(f"Keyword analysis failed: {error_msg}")
|
||||
raise ValueError(f"Unexpected LLM response type: {type(raw)}")
|
||||
|
||||
if 'error' in keyword_analysis:
|
||||
raise ValueError(f"Keyword analysis failed: {keyword_analysis.get('error', 'Unknown error')}")
|
||||
|
||||
logger.info("✅ AI keyword analysis completed successfully")
|
||||
return keyword_analysis
|
||||
|
||||
|
||||
@@ -111,19 +111,22 @@ class ResearchService:
|
||||
# Exa research workflow
|
||||
from .exa_provider import ExaResearchProvider
|
||||
from services.subscription.preflight_validator import validate_exa_research_operations
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
import os
|
||||
import time
|
||||
|
||||
# Pre-flight validation
|
||||
db_val = next(get_db())
|
||||
# Pre-flight validation (use get_session_for_user since get_db is a FastAPI dependency)
|
||||
db_val = get_session_for_user(user_id)
|
||||
if not db_val:
|
||||
raise HTTPException(status_code=503, detail="Database temporarily unavailable. Please try again.")
|
||||
try:
|
||||
pricing_service = PricingService(db_val)
|
||||
gpt_provider = os.getenv("GPT_PROVIDER", "google")
|
||||
validate_exa_research_operations(pricing_service, user_id, gpt_provider)
|
||||
finally:
|
||||
db_val.close()
|
||||
if db_val:
|
||||
db_val.close()
|
||||
|
||||
# Execute Exa search
|
||||
api_start_time = time.time()
|
||||
@@ -162,13 +165,15 @@ class ResearchService:
|
||||
elif config.provider == ResearchProvider.TAVILY:
|
||||
# Tavily research workflow
|
||||
from .tavily_provider import TavilyResearchProvider
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
import os
|
||||
import time
|
||||
|
||||
# Pre-flight validation (similar to Exa)
|
||||
db_val = next(get_db())
|
||||
# Pre-flight validation (use get_session_for_user since get_db is a FastAPI dependency)
|
||||
db_val = get_session_for_user(user_id)
|
||||
if not db_val:
|
||||
raise HTTPException(status_code=503, detail="Database temporarily unavailable. Please try again.")
|
||||
try:
|
||||
pricing_service = PricingService(db_val)
|
||||
# Check Tavily usage limits
|
||||
@@ -429,14 +434,16 @@ class ResearchService:
|
||||
# Exa research workflow
|
||||
from .exa_provider import ExaResearchProvider
|
||||
from services.subscription.preflight_validator import validate_exa_research_operations
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
import os
|
||||
|
||||
await task_manager.update_progress(task_id, "🌐 Connecting to Exa neural search...")
|
||||
|
||||
# Pre-flight validation
|
||||
db_val = next(get_db())
|
||||
# Pre-flight validation (use get_session_for_user since get_db is a FastAPI dependency)
|
||||
db_val = get_session_for_user(user_id)
|
||||
if not db_val:
|
||||
raise HTTPException(status_code=503, detail="Database temporarily unavailable. Please try again.")
|
||||
try:
|
||||
pricing_service = PricingService(db_val)
|
||||
gpt_provider = os.getenv("GPT_PROVIDER", "google")
|
||||
@@ -446,7 +453,8 @@ class ResearchService:
|
||||
await task_manager.update_progress(task_id, f"❌ Subscription limit exceeded: {http_error.detail.get('message', str(http_error.detail)) if isinstance(http_error.detail, dict) else str(http_error.detail)}")
|
||||
raise
|
||||
finally:
|
||||
db_val.close()
|
||||
if db_val:
|
||||
db_val.close()
|
||||
|
||||
# Execute Exa search
|
||||
await task_manager.update_progress(task_id, "🤖 Executing Exa neural search...")
|
||||
@@ -485,14 +493,16 @@ class ResearchService:
|
||||
elif config.provider == ResearchProvider.TAVILY:
|
||||
# Tavily research workflow
|
||||
from .tavily_provider import TavilyResearchProvider
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
import os
|
||||
|
||||
await task_manager.update_progress(task_id, "🌐 Connecting to Tavily AI search...")
|
||||
|
||||
# Pre-flight validation
|
||||
db_val = next(get_db())
|
||||
# Pre-flight validation (use get_session_for_user since get_db is a FastAPI dependency)
|
||||
db_val = get_session_for_user(user_id)
|
||||
if not db_val:
|
||||
raise HTTPException(status_code=503, detail="Database temporarily unavailable. Please try again.")
|
||||
try:
|
||||
pricing_service = PricingService(db_val)
|
||||
# Check Tavily usage limits
|
||||
@@ -529,7 +539,8 @@ class ResearchService:
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking Tavily limits: {e}")
|
||||
finally:
|
||||
db_val.close()
|
||||
if db_val:
|
||||
db_val.close()
|
||||
|
||||
# Execute Tavily search
|
||||
await task_manager.update_progress(task_id, "🤖 Executing Tavily AI search...")
|
||||
|
||||
@@ -135,11 +135,14 @@ class TavilyResearchProvider(BaseProvider):
|
||||
|
||||
def track_tavily_usage(self, user_id: str, cost: float, search_depth: str):
|
||||
"""Track Tavily API usage after successful call."""
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from sqlalchemy import text
|
||||
|
||||
db = next(get_db())
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
logger.warning(f"[Tavily] Could not get DB session for user {user_id}, skipping usage tracking")
|
||||
return
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
current_period = pricing_service.get_current_billing_period(user_id)
|
||||
|
||||
@@ -92,6 +92,7 @@ class BlogSEORecommendationApplier:
|
||||
None,
|
||||
schema,
|
||||
user_id, # Pass user_id for subscription checking
|
||||
max_tokens=8192,
|
||||
)
|
||||
|
||||
if not result or result.get("error"):
|
||||
|
||||
@@ -233,7 +233,7 @@ def create_blog_post(
|
||||
|
||||
# BACK TO BASICS MODE: Try simplest possible structure FIRST
|
||||
# Since posting worked before Ricos/SEO, let's test with absolute minimum
|
||||
BACK_TO_BASICS_MODE = True # Set to True to test with simplest structure
|
||||
BACK_TO_BASICS_MODE = False # Disabled: full Ricos conversion now produces valid output
|
||||
|
||||
wix_logger.reset()
|
||||
wix_logger.log_operation_start("Blog Post Creation", title=title[:50] if title else None, member_id=member_id[:20] if member_id else None)
|
||||
@@ -257,8 +257,7 @@ def create_blog_post(
|
||||
'text': (content[:500] if content else "This is a post from ALwrity.").strip(),
|
||||
'decorations': []
|
||||
}
|
||||
}],
|
||||
'paragraphData': {}
|
||||
}]
|
||||
}]
|
||||
}
|
||||
|
||||
|
||||
@@ -256,17 +256,16 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
quote_content = ' '.join(quote_lines)
|
||||
text_nodes = parse_markdown_inline(quote_content)
|
||||
# CRITICAL: TEXT nodes must be wrapped in PARAGRAPH nodes within BLOCKQUOTE
|
||||
# Wix API: omit empty data objects, don't include them as {}
|
||||
paragraph_node = {
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'PARAGRAPH',
|
||||
'nodes': text_nodes,
|
||||
'paragraphData': {}
|
||||
}
|
||||
blockquote_node = {
|
||||
'id': node_id,
|
||||
'type': 'BLOCKQUOTE',
|
||||
'nodes': [paragraph_node],
|
||||
'blockquoteData': {}
|
||||
}
|
||||
nodes.append(blockquote_node)
|
||||
|
||||
@@ -332,7 +331,6 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'PARAGRAPH',
|
||||
'nodes': text_nodes,
|
||||
'paragraphData': {}
|
||||
}
|
||||
list_item_node = {
|
||||
'id': item_node_id,
|
||||
@@ -345,7 +343,6 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'id': node_id,
|
||||
'type': 'BULLETED_LIST',
|
||||
'nodes': list_node_items,
|
||||
'bulletedListData': {}
|
||||
}
|
||||
nodes.append(bulleted_list_node)
|
||||
|
||||
@@ -373,7 +370,6 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'PARAGRAPH',
|
||||
'nodes': text_nodes,
|
||||
'paragraphData': {}
|
||||
}
|
||||
list_item_node = {
|
||||
'id': item_node_id,
|
||||
@@ -386,7 +382,6 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'id': node_id,
|
||||
'type': 'ORDERED_LIST',
|
||||
'nodes': list_node_items,
|
||||
'orderedListData': {}
|
||||
}
|
||||
nodes.append(ordered_list_node)
|
||||
|
||||
@@ -442,7 +437,6 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'id': node_id,
|
||||
'type': 'PARAGRAPH',
|
||||
'nodes': text_nodes,
|
||||
'paragraphData': {}
|
||||
}
|
||||
nodes.append(paragraph_node)
|
||||
|
||||
@@ -461,7 +455,6 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'decorations': []
|
||||
}
|
||||
}],
|
||||
'paragraphData': {}
|
||||
}
|
||||
nodes.append(fallback_paragraph)
|
||||
|
||||
|
||||
@@ -20,13 +20,14 @@ class SemanticHarvesterService:
|
||||
"last_harvest_time": None
|
||||
}
|
||||
|
||||
async def harvest_website(self, website_url: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
async def harvest_website(self, website_url: str, limit: int = 100, user_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Deep crawl a website using Exa AI.
|
||||
|
||||
Args:
|
||||
website_url: The root URL to crawl.
|
||||
limit: Maximum number of pages to retrieve.
|
||||
user_id: Optional user ID for usage tracking and preflight checks.
|
||||
|
||||
Returns:
|
||||
List of pages with content and metadata.
|
||||
@@ -59,6 +60,30 @@ class SemanticHarvesterService:
|
||||
logger.warning("[SemanticHarvester] Exa service disabled. Returning placeholder data.")
|
||||
return self._get_placeholder_data(website_url)
|
||||
|
||||
# Preflight subscription check if user_id provided
|
||||
if user_id:
|
||||
try:
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
db = get_session_for_user(user_id)
|
||||
if db:
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.EXA,
|
||||
tokens_requested=0,
|
||||
actual_provider_name="exa",
|
||||
)
|
||||
if not can_proceed:
|
||||
logger.warning(f"[SemanticHarvester] Exa blocked for user {user_id}: {message}")
|
||||
return []
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"[SemanticHarvester] Preflight check failed: {e}")
|
||||
|
||||
# Use Exa to search for all pages in this domain
|
||||
search_response = self.exa_service.exa.search_and_contents(
|
||||
query=f"site:{website_url}",
|
||||
@@ -82,6 +107,38 @@ class SemanticHarvesterService:
|
||||
})
|
||||
|
||||
logger.info(f"[SemanticHarvester] Successfully harvested {len(results)} pages from {website_url}")
|
||||
|
||||
# Track Exa usage if user_id provided
|
||||
if user_id and results:
|
||||
try:
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from sqlalchemy import text
|
||||
db = get_session_for_user(user_id)
|
||||
if db:
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
current_period = pricing_service.get_current_billing_period(user_id)
|
||||
cost = 0.005 # Exa search cost estimate
|
||||
|
||||
update_query = text("""
|
||||
UPDATE usage_summaries
|
||||
SET exa_calls = COALESCE(exa_calls, 0) + 1,
|
||||
exa_cost = COALESCE(exa_cost, 0) + :cost,
|
||||
total_calls = COALESCE(total_calls, 0) + 1,
|
||||
total_cost = COALESCE(total_cost, 0) + :cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db.execute(update_query, {
|
||||
'cost': cost, 'user_id': user_id, 'period': current_period,
|
||||
})
|
||||
db.commit()
|
||||
logger.info(f"[SemanticHarvester] Tracked Exa usage: user={user_id}, cost=${cost}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as track_err:
|
||||
logger.warning(f"[SemanticHarvester] Failed to track Exa usage: {track_err}")
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -133,9 +133,9 @@ def edit_image(
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Image Editing] ❌ Unexpected error during pre-flight validation: {e}")
|
||||
# In podcast-only mode, allow the operation to continue on validation errors
|
||||
if os.getenv("ALWRITY_ENABLED_FEATURES") == "podcast":
|
||||
logger.warning(f"[Image Editing] ⚠️ Validation error in podcast mode - allowing operation to continue")
|
||||
# In feature-limited mode, allow the operation to continue on validation errors
|
||||
if os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() not in ("", "all"):
|
||||
logger.warning(f"[Image Editing] ⚠️ Validation error in feature-limited mode - allowing operation to continue")
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail=f"Image editing validation failed: {str(e)}")
|
||||
finally:
|
||||
|
||||
@@ -45,6 +45,7 @@ def llm_text_gen(
|
||||
preferred_hf_models: Optional[List[str]] = None,
|
||||
preferred_provider: Optional[str] = None,
|
||||
flow_type: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate text using Language Model (LLM) based on the provided prompt.
|
||||
@@ -75,7 +76,8 @@ def llm_text_gen(
|
||||
gpt_provider = "google" # Default to Google Gemini
|
||||
model = "gemini-2.0-flash-001"
|
||||
temperature = 0.7
|
||||
max_tokens = 4000
|
||||
if max_tokens is None:
|
||||
max_tokens = 4000
|
||||
top_p = 0.9
|
||||
n = 1
|
||||
fp = 16
|
||||
@@ -371,16 +373,27 @@ def llm_text_gen(
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
elif gpt_provider == "wavespeed":
|
||||
from services.llm_providers.wavespeed_provider import wavespeed_text_response
|
||||
llm_start = time.time()
|
||||
response_text = wavespeed_text_response(
|
||||
prompt=prompt,
|
||||
model=model or "openai/gpt-oss-120b",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
if json_struct:
|
||||
from services.llm_providers.wavespeed_provider import wavespeed_structured_json_response
|
||||
response_text = wavespeed_structured_json_response(
|
||||
prompt=prompt,
|
||||
schema=json_struct,
|
||||
model=model or "openai/gpt-oss-120b",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
else:
|
||||
from services.llm_providers.wavespeed_provider import wavespeed_text_response
|
||||
response_text = wavespeed_text_response(
|
||||
prompt=prompt,
|
||||
model=model or "openai/gpt-oss-120b",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
llm_ms = (time.time() - llm_start) * 1000
|
||||
logger.warning(f"[llm_text_gen][{flow_tag}] LLM API call took {llm_ms:.0f}ms for user {user_id} (wavespeed)")
|
||||
else:
|
||||
|
||||
@@ -179,6 +179,43 @@ def get_wavespeed_api_key() -> str:
|
||||
|
||||
return api_key
|
||||
|
||||
|
||||
def _retry_with_increased_tokens(
|
||||
client: "OpenAI",
|
||||
messages: List[Dict[str, str]],
|
||||
model: str,
|
||||
fallback_models: Optional[List[str]],
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
) -> Optional[str]:
|
||||
"""Retry the API call with increased max_tokens when JSON parsing fails due to truncation."""
|
||||
max_tokens = min(max_tokens, 16384)
|
||||
last_error = None
|
||||
for candidate_model in _fallback_model_sequence(model, fallback_models):
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=candidate_model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
text = response.choices[0].message.content
|
||||
text = text.strip() if text else ""
|
||||
if text.startswith("```json"):
|
||||
text = text[7:]
|
||||
if text.startswith("```"):
|
||||
text = text[3:]
|
||||
if text.endswith("```"):
|
||||
text = text[:-3]
|
||||
return text.strip()
|
||||
except NotFoundError as nf_err:
|
||||
last_error = nf_err
|
||||
continue
|
||||
if last_error:
|
||||
logger.error(f"All fallback models failed on retry with increased tokens: {last_error}")
|
||||
return None
|
||||
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception(_should_retry_wavespeed_error),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
@@ -446,24 +483,69 @@ def wavespeed_structured_json_response(
|
||||
raise last_error or Exception("WaveSpeed structured generation failed: all fallback models failed")
|
||||
|
||||
response_text = response.choices[0].message.content
|
||||
response_text = response_text.strip() if response_text else ""
|
||||
|
||||
# If response_format returned empty content, retry without it
|
||||
if not response_text:
|
||||
logger.warning("WaveSpeed structured call returned empty content with response_format, retrying without it...")
|
||||
response = None
|
||||
last_error = None
|
||||
for candidate_model in _fallback_model_sequence(model, fallback_models):
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=candidate_model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
break
|
||||
except NotFoundError as nf_err:
|
||||
last_error = nf_err
|
||||
continue
|
||||
if response is not None:
|
||||
response_text = response.choices[0].message.content
|
||||
response_text = response_text.strip() if response_text else ""
|
||||
|
||||
# Clean up response text if needed
|
||||
response_text = response_text.strip()
|
||||
if response_text.startswith("```json"):
|
||||
response_text = response_text[7:]
|
||||
if response_text.startswith("```"):
|
||||
response_text = response_text[3:]
|
||||
if response_text.endswith("```"):
|
||||
response_text = response_text[:-3]
|
||||
response_text = response_text.strip()
|
||||
|
||||
try:
|
||||
parsed_json = json.loads(response_text)
|
||||
logger.info("✅ WaveSpeed structured JSON response parsed successfully")
|
||||
return parsed_json
|
||||
parsed_json = json.loads(response_text) if response_text else None
|
||||
if parsed_json is not None:
|
||||
logger.info("✅ WaveSpeed structured JSON response parsed successfully")
|
||||
return parsed_json
|
||||
except json.JSONDecodeError as json_err:
|
||||
logger.error(f"❌ JSON parsing failed: {json_err}")
|
||||
logger.error(f"Raw response: {response_text}")
|
||||
# Retry once with increased max_tokens — likely a truncation issue
|
||||
if max_tokens < 16384:
|
||||
logger.warning(f"Retrying with increased max_tokens ({max_tokens} → {max_tokens * 2}) due to JSON parse failure")
|
||||
response_text = _retry_with_increased_tokens(
|
||||
client=client,
|
||||
messages=messages,
|
||||
model=model,
|
||||
fallback_models=fallback_models,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens * 2,
|
||||
)
|
||||
if response_text:
|
||||
try:
|
||||
parsed_json = json.loads(response_text)
|
||||
if parsed_json is not None:
|
||||
logger.info("✅ WaveSpeed structured JSON parsed successfully after max_tokens increase")
|
||||
return parsed_json
|
||||
except json.JSONDecodeError:
|
||||
logger.error("❌ JSON parsing failed even after max_tokens increase")
|
||||
|
||||
# Try to extract JSON from the response using regex
|
||||
logger.error(f"Raw response: {response_text}")
|
||||
|
||||
# Try to extract JSON from the response using regex
|
||||
if response_text:
|
||||
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
||||
if json_match:
|
||||
try:
|
||||
@@ -472,8 +554,8 @@ def wavespeed_structured_json_response(
|
||||
return extracted_json
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return {"error": "Failed to parse JSON response", "raw_response": response_text}
|
||||
|
||||
return {"error": "Failed to parse JSON response", "raw_response": response_text}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ WaveSpeed API call failed: {e}")
|
||||
@@ -501,14 +583,24 @@ def wavespeed_structured_json_response(
|
||||
if response is None:
|
||||
raise last_error or e
|
||||
response_text = response.choices[0].message.content
|
||||
# ... (same parsing logic would apply, simplified here for brevity)
|
||||
response_text = response_text.strip() if response_text else ""
|
||||
# Parse JSON with robust cleaning
|
||||
if response_text.startswith("```json"):
|
||||
response_text = response_text[7:]
|
||||
if response_text.startswith("```"):
|
||||
response_text = response_text[3:]
|
||||
if response_text.endswith("```"):
|
||||
response_text = response_text[:-3]
|
||||
response_text = response_text.strip()
|
||||
try:
|
||||
return json.loads(response_text)
|
||||
except:
|
||||
# Regex fallback
|
||||
return json.loads(response_text) if response_text else {"error": "Empty response"}
|
||||
except json.JSONDecodeError:
|
||||
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
||||
if json_match:
|
||||
return json.loads(json_match.group())
|
||||
try:
|
||||
return json.loads(json_match.group())
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return {"error": "Failed to parse JSON response", "raw_response": response_text}
|
||||
raise e
|
||||
|
||||
|
||||
@@ -19,11 +19,11 @@ from services.database import get_db_session
|
||||
from models.onboarding import OnboardingSession, WebsiteAnalysis, ResearchPreferences
|
||||
from models.persona_models import WritingPersona, PlatformPersona, PersonaAnalysisResult
|
||||
|
||||
def _get_podcast_mode():
|
||||
"""Check if running in podcast-only mode to skip heavy initialization."""
|
||||
def _is_feature_limited_mode():
|
||||
"""Check if running in feature-limited mode to skip heavy initialization."""
|
||||
import os
|
||||
env_val = os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower()
|
||||
return env_val == "podcast"
|
||||
return env_val not in ("", "all")
|
||||
|
||||
class PersonaAnalysisService:
|
||||
"""Service for analyzing onboarding data and generating writing personas using Gemini AI."""
|
||||
@@ -40,9 +40,9 @@ class PersonaAnalysisService:
|
||||
def __init__(self):
|
||||
"""Initialize the persona analysis service (only once)."""
|
||||
if not self._initialized:
|
||||
# Skip heavy initialization in podcast-only mode
|
||||
if _get_podcast_mode():
|
||||
logger.debug("PersonaAnalysisService: Skipping heavy init in podcast mode")
|
||||
# Skip heavy initialization in feature-limited mode
|
||||
if _is_feature_limited_mode():
|
||||
logger.debug(f"PersonaAnalysisService: Skipping heavy init in feature-limited mode")
|
||||
self._initialized = True
|
||||
return
|
||||
|
||||
@@ -55,8 +55,8 @@ class PersonaAnalysisService:
|
||||
return
|
||||
|
||||
# Check again in case mode changed
|
||||
if _get_podcast_mode():
|
||||
logger.debug("PersonaAnalysisService: Skipping heavy init in podcast mode")
|
||||
if _is_feature_limited_mode():
|
||||
logger.debug("PersonaAnalysisService: Skipping heavy init in feature-limited mode")
|
||||
self._heavy_init_done = True
|
||||
return
|
||||
|
||||
@@ -89,9 +89,9 @@ class PersonaAnalysisService:
|
||||
# Ensure heavy services are initialized
|
||||
self._ensure_heavy_init()
|
||||
|
||||
# Check if heavy init failed (podcast mode)
|
||||
# Check if heavy init failed (feature-limited mode)
|
||||
if not getattr(self, '_heavy_init_done', False):
|
||||
return {"error": "Persona service unavailable in podcast-only mode"}
|
||||
return {"error": "Persona service unavailable in feature-limited mode"}
|
||||
|
||||
try:
|
||||
logger.info(f"Generating persona for user {user_id}")
|
||||
|
||||
@@ -296,6 +296,33 @@ class ResearchEngine:
|
||||
target_audience = request.target_audience or "General"
|
||||
|
||||
research_prompt = strategy.build_research_prompt(topic, industry, target_audience, config)
|
||||
|
||||
# Preflight subscription check
|
||||
try:
|
||||
db = self._db_session
|
||||
if not db:
|
||||
from services.database import get_db_session
|
||||
db = get_db_session()
|
||||
if db:
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.EXA,
|
||||
tokens_requested=0,
|
||||
actual_provider_name="exa",
|
||||
)
|
||||
if not can_proceed:
|
||||
raise HTTPException(status_code=429, detail={
|
||||
'error': message, 'message': message,
|
||||
'provider': 'exa', 'usage_info': usage_info or {}
|
||||
})
|
||||
logger.info(f"[ResearchEngine] Exa preflight check passed for user {user_id}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"[ResearchEngine] Exa preflight check failed: {e}")
|
||||
|
||||
# Execute Exa search
|
||||
try:
|
||||
@@ -341,6 +368,33 @@ class ResearchEngine:
|
||||
target_audience = request.target_audience or "General"
|
||||
|
||||
research_prompt = strategy.build_research_prompt(topic, industry, target_audience, config)
|
||||
|
||||
# Preflight subscription check
|
||||
try:
|
||||
db = self._db_session
|
||||
if not db:
|
||||
from services.database import get_db_session
|
||||
db = get_db_session()
|
||||
if db:
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.TAVILY,
|
||||
tokens_requested=0,
|
||||
actual_provider_name="tavily",
|
||||
)
|
||||
if not can_proceed:
|
||||
raise HTTPException(status_code=429, detail={
|
||||
'error': message, 'message': message,
|
||||
'provider': 'tavily', 'usage_info': usage_info or {}
|
||||
})
|
||||
logger.info(f"[ResearchEngine] Tavily preflight check passed for user {user_id}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"[ResearchEngine] Tavily preflight check failed: {e}")
|
||||
|
||||
# Execute Tavily search
|
||||
try:
|
||||
|
||||
@@ -83,6 +83,30 @@ class DeepCrawlService:
|
||||
tavily_results.append(res)
|
||||
|
||||
logger.info(f"Found {len(tavily_urls)} URLs from Tavily")
|
||||
|
||||
# Track Tavily usage
|
||||
try:
|
||||
from services.subscription import PricingService
|
||||
from sqlalchemy import text
|
||||
pricing_service = PricingService(db)
|
||||
current_period = pricing_service.get_current_billing_period(user_id)
|
||||
cost = 0.005 # Tavily crawl cost estimate
|
||||
|
||||
update_query = text("""
|
||||
UPDATE usage_summaries
|
||||
SET tavily_calls = COALESCE(tavily_calls, 0) + 1,
|
||||
tavily_cost = COALESCE(tavily_cost, 0) + :cost,
|
||||
total_calls = COALESCE(total_calls, 0) + 1,
|
||||
total_cost = COALESCE(total_cost, 0) + :cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db.execute(update_query, {
|
||||
'cost': cost, 'user_id': user_id, 'period': current_period,
|
||||
})
|
||||
db.commit()
|
||||
logger.info(f"[DeepCrawl] Tracked Tavily crawl usage: user={user_id}, cost=${cost}")
|
||||
except Exception as track_err:
|
||||
logger.warning(f"[DeepCrawl] Failed to track Tavily usage: {track_err}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Tavily crawl failed: {e}")
|
||||
|
||||
|
||||
@@ -49,9 +49,11 @@ except Exception as _patch_err:
|
||||
# Now safe to import pytrends
|
||||
try:
|
||||
from pytrends.request import TrendReq as _TrendReq
|
||||
from pytrends.exceptions import TooManyRequestsError as _TooManyRequestsError
|
||||
PYTrends_AVAILABLE = True
|
||||
except ImportError:
|
||||
PYTrends_AVAILABLE = False
|
||||
_TooManyRequestsError = None
|
||||
logger.warning("pytrends not installed. Google Trends features will be unavailable.")
|
||||
|
||||
# Patch 2: pytrends related_topics() and related_queries() use keyword[0]
|
||||
@@ -139,6 +141,8 @@ class GoogleTrendsService:
|
||||
Uses TrendReq with no retries (fail-fast) to avoid hitting CAPTCHA on blocks.
|
||||
429 retry handling (1s, 2s, 4s backoff). Random user-agent is set
|
||||
per instance to reduce fingerprinting.
|
||||
|
||||
Rate limiter is shared across all instances to enforce global rate limiting.
|
||||
"""
|
||||
|
||||
USER_AGENTS = [
|
||||
@@ -150,15 +154,28 @@ class GoogleTrendsService:
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.0.0",
|
||||
]
|
||||
|
||||
# Class-level shared resources (shared across all instances)
|
||||
_shared_rate_limiter = None
|
||||
_shared_cache = None
|
||||
_cache_ttl = timedelta(hours=24)
|
||||
_last_429_time = 0 # Timestamp of last 429 error (Unix epoch)
|
||||
_429_cooldown_period = 1800 # 30 minutes cooldown after 429
|
||||
|
||||
def __init__(self):
|
||||
if not PYTrends_AVAILABLE:
|
||||
raise RuntimeError("pytrends library is required. Install with: pip install pytrends")
|
||||
|
||||
self.rate_limiter = RateLimiter(max_calls=1, period=1.0)
|
||||
self.cache: Dict[str, Any] = {}
|
||||
self.cache_ttl = timedelta(hours=24)
|
||||
# Initialize shared rate limiter at class level (lazy init)
|
||||
if self.__class__._shared_rate_limiter is None:
|
||||
self.__class__._shared_rate_limiter = RateLimiter(max_calls=1, period=3.0) # 1 call per 3 seconds
|
||||
if self.__class__._shared_cache is None:
|
||||
self.__class__._shared_cache = {}
|
||||
|
||||
logger.info("GoogleTrendsService initialized (pytrends 4.9.2, fail-fast, 2s delays)")
|
||||
self.rate_limiter = self.__class__._shared_rate_limiter
|
||||
self.cache = self.__class__._shared_cache
|
||||
self.cache_ttl = self._cache_ttl
|
||||
|
||||
logger.info("GoogleTrendsService initialized (pytrends 4.9.2, shared rate limiter, 3s period, shared cache, 30min 429 cooldown)")
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Public API
|
||||
@@ -173,7 +190,7 @@ class GoogleTrendsService:
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Comprehensive trends analysis.
|
||||
Comprehensive trends analysis with retry logic for 429 errors.
|
||||
|
||||
Args:
|
||||
keywords: List of keywords to analyze (1-5)
|
||||
@@ -193,11 +210,97 @@ class GoogleTrendsService:
|
||||
keywords = keywords[:5]
|
||||
|
||||
cache_key = self._build_cache_key(keywords, timeframe, geo)
|
||||
|
||||
# Check if we're in a 429 cooldown period
|
||||
now = time.time()
|
||||
if now - self.__class__._last_429_time < self.__class__._429_cooldown_period:
|
||||
remaining_cooldown = int(self.__class__._429_cooldown_period - (now - self.__class__._last_429_time))
|
||||
logger.warning(
|
||||
f"[Trends] In 429 cooldown period. {remaining_cooldown}s remaining. "
|
||||
f"Returning cached data if available."
|
||||
)
|
||||
cached_data = self._get_from_cache(cache_key, ignore_ttl=True) # Use stale cache
|
||||
if cached_data:
|
||||
logger.info(f"[Trends] Returning stale cached data for {keywords} during cooldown")
|
||||
return {**cached_data, "cached": True, "cooldown_active": True}
|
||||
return self._create_fallback_response(
|
||||
keywords, timeframe, geo, gprop,
|
||||
f"Rate limited by Google. Cooldown active for {remaining_cooldown}s. Try again later."
|
||||
)
|
||||
|
||||
# Check fresh cache
|
||||
cached_data = self._get_from_cache(cache_key)
|
||||
if cached_data:
|
||||
logger.info(f"Returning cached trends data for: {keywords}")
|
||||
return {**cached_data, "cached": True}
|
||||
|
||||
# Retry logic for 429 errors
|
||||
max_retries = 3
|
||||
retry_delays = [30, 60, 120] # Longer delays: 30s, 60s, 120s
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
return await self._do_analyze_trends(
|
||||
keywords, timeframe, geo, gprop, cache_key, attempt, max_retries
|
||||
)
|
||||
except Exception as e:
|
||||
# Check if this is a 429 error (pytrends raises TooManyRequestsError)
|
||||
is_429 = False
|
||||
if _TooManyRequestsError and isinstance(e, _TooManyRequestsError):
|
||||
is_429 = True
|
||||
else:
|
||||
error_str = str(e).lower()
|
||||
is_429 = "429" in error_str or "rate limit" in error_str or "too many requests" in error_str
|
||||
|
||||
if is_429:
|
||||
# Update the last 429 time for cooldown
|
||||
self.__class__._last_429_time = time.time()
|
||||
|
||||
if attempt < max_retries:
|
||||
delay = retry_delays[attempt]
|
||||
logger.warning(
|
||||
f"[Trends] 429 rate limit hit (attempt {attempt + 1}/{max_retries + 1}), "
|
||||
f"retrying in {delay}s..."
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
else:
|
||||
# Out of retries - enter cooldown
|
||||
logger.error(
|
||||
f"[Trends] 429 rate limit persisted after {max_retries + 1} attempts. "
|
||||
f"Entering {self.__class__._429_cooldown_period}s cooldown period."
|
||||
)
|
||||
# Try to return stale cache
|
||||
stale_cache = self._get_from_cache(cache_key, ignore_ttl=True)
|
||||
if stale_cache:
|
||||
logger.info(f"[Trends] Returning stale cache after 429 exhaustion for {keywords}")
|
||||
result = {**stale_cache}
|
||||
result["cached"] = True
|
||||
result["cooldown_active"] = True
|
||||
return result
|
||||
return self._create_fallback_response(
|
||||
keywords, timeframe, geo, gprop,
|
||||
f"Google is rate limiting requests. Cooldown active for {self.__class__._429_cooldown_period}s. Try again later."
|
||||
)
|
||||
else:
|
||||
# Non-429 error
|
||||
logger.error(f"Google Trends analysis failed after {attempt + 1} attempts: {e}")
|
||||
return self._create_fallback_response(keywords, timeframe, geo, gprop, str(e))
|
||||
|
||||
# Should not reach here, but just in case
|
||||
return self._create_fallback_response(keywords, timeframe, geo, gprop, "Max retries exceeded")
|
||||
|
||||
async def _do_analyze_trends(
|
||||
self,
|
||||
keywords: List[str],
|
||||
timeframe: str,
|
||||
geo: str,
|
||||
gprop: str,
|
||||
cache_key: str,
|
||||
attempt: int,
|
||||
max_retries: int,
|
||||
) -> Dict[str, Any]:
|
||||
"""Internal method to perform the actual trends analysis."""
|
||||
await self.rate_limiter.acquire()
|
||||
|
||||
total_start = time.monotonic()
|
||||
@@ -207,95 +310,63 @@ class GoogleTrendsService:
|
||||
related_topics: Dict[str, List[Dict[str, Any]]] = {"top": [], "rising": []}
|
||||
related_queries: Dict[str, List[Dict[str, Any]]] = {"top": [], "rising": []}
|
||||
|
||||
try:
|
||||
logger.info(f"[Trends] ===== START analyze_trends ===== keywords={keywords} timeframe={timeframe} geo={geo}")
|
||||
logger.info(
|
||||
f"[Trends] ===== START analyze_trends (attempt {attempt + 1}/{max_retries + 1}) ===== "
|
||||
f"keywords={keywords} timeframe={timeframe} geo={geo}"
|
||||
)
|
||||
|
||||
# Initialize TrendReq with gprop (youtube for video/podcast relevance)
|
||||
init_start = time.monotonic()
|
||||
pytrends = await asyncio.to_thread(
|
||||
self._create_pytrends,
|
||||
keywords,
|
||||
timeframe,
|
||||
geo,
|
||||
gprop,
|
||||
)
|
||||
init_ms = int((time.monotonic() - init_start) * 1000)
|
||||
logger.info(f"[Trends] TrendReq init + build_payload took {init_ms}ms")
|
||||
# Initialize TrendReq with gprop (youtube for video/podcast relevance)
|
||||
init_start = time.monotonic()
|
||||
pytrends = await asyncio.to_thread(
|
||||
self._create_pytrends,
|
||||
keywords,
|
||||
timeframe,
|
||||
geo,
|
||||
gprop,
|
||||
)
|
||||
init_ms = int((time.monotonic() - init_start) * 1000)
|
||||
logger.info(f"[Trends] TrendReq init + build_payload took {init_ms}ms")
|
||||
|
||||
# --- Interest Over Time ---
|
||||
iot_start = time.monotonic()
|
||||
interest_over_time = await asyncio.to_thread(
|
||||
lambda: self._fetch_interest_over_time(pytrends)
|
||||
)
|
||||
iot_ms = int((time.monotonic() - iot_start) * 1000)
|
||||
logger.info(f"[Trends] interest_over_time took {iot_ms}ms, returned {len(interest_over_time)} points")
|
||||
# --- Interest Over Time ONLY (skip others to avoid 429) ---
|
||||
await self.rate_limiter.acquire() # Rate limit check BEFORE each request
|
||||
iot_start = time.monotonic()
|
||||
interest_over_time = await asyncio.to_thread(
|
||||
lambda: self._fetch_interest_over_time(pytrends)
|
||||
)
|
||||
iot_ms = int((time.monotonic() - iot_start) * 1000)
|
||||
logger.info(f"[Trends] interest_over_time took {iot_ms}ms, returned {len(interest_over_time)} points")
|
||||
|
||||
await asyncio.sleep(2)
|
||||
# Skip other requests to avoid 429 - only fetch interest_over_time for now
|
||||
logger.info(f"[Trends] Skipping other requests to avoid 429 (interest_by_region, related_topics, related_queries)")
|
||||
|
||||
# --- Interest By Region ---
|
||||
ibr_start = time.monotonic()
|
||||
interest_by_region = await asyncio.to_thread(
|
||||
lambda: self._fetch_interest_by_region(pytrends)
|
||||
)
|
||||
ibr_ms = int((time.monotonic() - ibr_start) * 1000)
|
||||
logger.info(f"[Trends] interest_by_region took {ibr_ms}ms, returned {len(interest_by_region)} regions")
|
||||
total_ms = int((time.monotonic() - total_start) * 1000)
|
||||
logger.info(
|
||||
f"[Trends] ===== DONE analyze_trends ===== total={total_ms}ms "
|
||||
f"iot={len(interest_over_time)} ibr={len(interest_by_region)} "
|
||||
f"rt_top={rt_top} rq_top={rq_top}"
|
||||
)
|
||||
|
||||
await asyncio.sleep(2)
|
||||
result = {
|
||||
"interest_over_time": interest_over_time,
|
||||
"interest_by_region": interest_by_region,
|
||||
"related_topics": related_topics,
|
||||
"related_queries": related_queries,
|
||||
"timeframe": timeframe,
|
||||
"geo": geo,
|
||||
"keywords": keywords,
|
||||
"source": "web" if gprop == "" else "podcast" if gprop == "youtube" else gprop,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"cached": False,
|
||||
}
|
||||
|
||||
# --- Related Topics ---
|
||||
rt_start = time.monotonic()
|
||||
related_topics = await asyncio.to_thread(
|
||||
lambda: self._fetch_related_topics(pytrends)
|
||||
)
|
||||
rt_ms = int((time.monotonic() - rt_start) * 1000)
|
||||
rt_top = len(related_topics.get("top", []))
|
||||
rt_rising = len(related_topics.get("rising", []))
|
||||
logger.info(f"[Trends] related_topics took {rt_ms}ms, top={rt_top} rising={rt_rising}")
|
||||
self._save_to_cache(cache_key, result)
|
||||
|
||||
await asyncio.sleep(2)
|
||||
logger.info(
|
||||
f"Google Trends data fetched successfully: "
|
||||
f"{len(interest_over_time)} time points, {len(interest_by_region)} regions"
|
||||
)
|
||||
|
||||
# --- Related Queries ---
|
||||
rq_start = time.monotonic()
|
||||
related_queries = await asyncio.to_thread(
|
||||
lambda: self._fetch_related_queries(pytrends)
|
||||
)
|
||||
rq_ms = int((time.monotonic() - rq_start) * 1000)
|
||||
rq_top = len(related_queries.get("top", []))
|
||||
rq_rising = len(related_queries.get("rising", []))
|
||||
logger.info(f"[Trends] related_queries took {rq_ms}ms, top={rq_top} rising={rq_rising}")
|
||||
|
||||
total_ms = int((time.monotonic() - total_start) * 1000)
|
||||
logger.info(
|
||||
f"[Trends] ===== DONE analyze_trends ===== total={total_ms}ms "
|
||||
f"iot={len(interest_over_time)} ibr={len(interest_by_region)} "
|
||||
f"rt_top={rt_top} rq_top={rq_top}"
|
||||
)
|
||||
|
||||
result = {
|
||||
"interest_over_time": interest_over_time,
|
||||
"interest_by_region": interest_by_region,
|
||||
"related_topics": related_topics,
|
||||
"related_queries": related_queries,
|
||||
"timeframe": timeframe,
|
||||
"geo": geo,
|
||||
"keywords": keywords,
|
||||
"source": "web" if gprop == "" else "podcast" if gprop == "youtube" else gprop,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"cached": False,
|
||||
}
|
||||
|
||||
self._save_to_cache(cache_key, result)
|
||||
|
||||
logger.info(
|
||||
f"Google Trends data fetched successfully: "
|
||||
f"{len(interest_over_time)} time points, {len(interest_by_region)} regions"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Google Trends analysis failed: {e}")
|
||||
return self._create_fallback_response(keywords, timeframe, geo, gprop, str(e))
|
||||
return result
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# TrendReq factory
|
||||
@@ -346,6 +417,12 @@ class GoogleTrendsService:
|
||||
return result
|
||||
except Exception as e:
|
||||
elapsed = int((time.monotonic() - start) * 1000)
|
||||
# Re-raise 429 errors so retry logic can handle them
|
||||
if _TooManyRequestsError and isinstance(e, _TooManyRequestsError):
|
||||
raise
|
||||
error_str = str(e).lower()
|
||||
if "429" in error_str or "rate limit" in error_str or "too many requests" in error_str:
|
||||
raise
|
||||
logger.error(f"[Trends] interest_over_time failed in {elapsed}ms: {e}")
|
||||
return []
|
||||
|
||||
@@ -363,6 +440,12 @@ class GoogleTrendsService:
|
||||
return result
|
||||
except Exception as e:
|
||||
elapsed = int((time.monotonic() - start) * 1000)
|
||||
# Re-raise 429 errors so retry logic can handle them
|
||||
if _TooManyRequestsError and isinstance(e, _TooManyRequestsError):
|
||||
raise
|
||||
error_str = str(e).lower()
|
||||
if "429" in error_str or "rate limit" in error_str or "too many requests" in error_str:
|
||||
raise
|
||||
logger.error(f"[Trends] interest_by_region failed in {elapsed}ms: {e}")
|
||||
return []
|
||||
|
||||
@@ -409,6 +492,12 @@ class GoogleTrendsService:
|
||||
return result
|
||||
except Exception as e:
|
||||
elapsed = int((time.monotonic() - start) * 1000)
|
||||
# Re-raise 429 errors so retry logic can handle them
|
||||
if _TooManyRequestsError and isinstance(e, _TooManyRequestsError):
|
||||
raise
|
||||
error_str = str(e).lower()
|
||||
if "429" in error_str or "rate limit" in error_str or "too many requests" in error_str:
|
||||
raise
|
||||
logger.error(f"[Trends] related_topics failed in {elapsed}ms: {e}")
|
||||
return result
|
||||
|
||||
@@ -452,6 +541,12 @@ class GoogleTrendsService:
|
||||
return result
|
||||
except Exception as e:
|
||||
elapsed = int((time.monotonic() - start) * 1000)
|
||||
# Re-raise 429 errors so retry logic can handle them
|
||||
if _TooManyRequestsError and isinstance(e, _TooManyRequestsError):
|
||||
raise
|
||||
error_str = str(e).lower()
|
||||
if "429" in error_str or "rate limit" in error_str or "too many requests" in error_str:
|
||||
raise
|
||||
logger.error(f"[Trends] related_queries failed in {elapsed}ms: {e}")
|
||||
return result
|
||||
|
||||
@@ -503,14 +598,18 @@ class GoogleTrendsService:
|
||||
keywords_str = ":".join(sorted(keywords))
|
||||
return f"google_trends:{keywords_str}:{timeframe}:{geo}"
|
||||
|
||||
def _get_from_cache(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
||||
def _get_from_cache(self, cache_key: str, ignore_ttl: bool = False) -> Optional[Dict[str, Any]]:
|
||||
"""Get cached data. If ignore_ttl=True, return stale data too (for 429 cooldown)."""
|
||||
if cache_key not in self.cache:
|
||||
return None
|
||||
cached_entry = self.cache[cache_key]
|
||||
cached_time = datetime.fromisoformat(cached_entry.get("timestamp", ""))
|
||||
if datetime.utcnow() - cached_time > self.cache_ttl:
|
||||
del self.cache[cache_key]
|
||||
return None
|
||||
|
||||
if not ignore_ttl:
|
||||
cached_time = datetime.fromisoformat(cached_entry.get("timestamp", ""))
|
||||
if datetime.utcnow() - cached_time > self.cache_ttl:
|
||||
del self.cache[cache_key]
|
||||
return None
|
||||
|
||||
result = {**cached_entry}
|
||||
result.pop("cached", None)
|
||||
return result
|
||||
|
||||
@@ -157,10 +157,10 @@ def _check_production_api_key_loading(
|
||||
_record_check(checks, "production_api_key_loading", True, "skipped in local deploy mode")
|
||||
return
|
||||
|
||||
# Also skip in podcast-only mode (no production API keys needed)
|
||||
# Skip when in feature-limited mode (no production API keys needed)
|
||||
enabled_features = os.getenv("ALWRITY_ENABLED_FEATURES", "all").strip().lower()
|
||||
if enabled_features == "podcast":
|
||||
_record_check(checks, "production_api_key_loading", True, "skipped in podcast-only mode")
|
||||
if enabled_features and enabled_features not in ("", "all"):
|
||||
_record_check(checks, "production_api_key_loading", True, f"skipped in feature-limited mode: {enabled_features}")
|
||||
return
|
||||
|
||||
test_tenant_id = os.getenv("ALWRITY_STARTUP_TEST_TENANT_ID", "").strip()
|
||||
|
||||
@@ -12,7 +12,7 @@ from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from models.subscription_models import APIProvider, UsageAlert
|
||||
from models.subscription_models import APIProvider, UsageAlert, UserSubscription
|
||||
|
||||
class SubscriptionErrorType(Enum):
|
||||
USAGE_LIMIT_EXCEEDED = "usage_limit_exceeded"
|
||||
@@ -248,6 +248,18 @@ class SubscriptionExceptionHandler:
|
||||
return
|
||||
|
||||
try:
|
||||
# Get billing period from subscription, fallback to calendar month
|
||||
billing_period = datetime.now().strftime("%Y-%m") # default
|
||||
try:
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == error.user_id,
|
||||
UserSubscription.is_active == True
|
||||
).first()
|
||||
if subscription and subscription.current_period_start:
|
||||
billing_period = subscription.current_period_start.strftime("%Y-%m")
|
||||
except:
|
||||
pass # Use default calendar period
|
||||
|
||||
alert = UsageAlert(
|
||||
user_id=error.user_id,
|
||||
alert_type="system_error",
|
||||
@@ -256,7 +268,7 @@ class SubscriptionExceptionHandler:
|
||||
title=f"System Error: {error.error_type.value}",
|
||||
message=error.message,
|
||||
severity=error.severity.value,
|
||||
billing_period=datetime.now().strftime("%Y-%m")
|
||||
billing_period=billing_period
|
||||
)
|
||||
|
||||
self.db.add(alert)
|
||||
|
||||
@@ -157,39 +157,38 @@ class LimitValidator:
|
||||
user_tier = limits.get('tier', 'free') if limits else 'free'
|
||||
|
||||
# Get current usage for this billing period with error handling
|
||||
# Use targeted expiry instead of expire_all() to avoid nuking the entire session cache
|
||||
# Use subscription period, not calendar month
|
||||
current_period = self.pricing_service.get_current_billing_period(user_id)
|
||||
|
||||
# Only expire specific objects that might have changed after renewal
|
||||
# (subscription was already checked above; plan was expired above)
|
||||
# The usage record is the main object we need fresh, and we query it directly below
|
||||
if subscription:
|
||||
self.db.expire(subscription)
|
||||
|
||||
# Use raw SQL query first to bypass ORM cache, fallback to ORM if SQL fails
|
||||
usage = None
|
||||
try:
|
||||
current_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Only expire specific objects that might have changed after renewal
|
||||
# (subscription was already checked above; plan was expired above)
|
||||
# The usage record is the main object we need fresh, and we query it directly below
|
||||
if subscription:
|
||||
self.db.expire(subscription)
|
||||
|
||||
# Use raw SQL query first to bypass ORM cache, fallback to ORM if SQL fails
|
||||
usage = None
|
||||
try:
|
||||
from sqlalchemy import text
|
||||
sql_query = text("SELECT * FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1")
|
||||
result = self.db.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
|
||||
if result:
|
||||
# Map result to UsageSummary object
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if usage:
|
||||
self.db.refresh(usage) # Ensure fresh data
|
||||
except Exception as sql_error:
|
||||
logger.debug(f"[Subscription Check] Raw SQL query failed, using ORM: {sql_error}")
|
||||
# Fallback to ORM query
|
||||
from sqlalchemy import text
|
||||
sql_query = text("SELECT * FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1")
|
||||
result = self.db.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
|
||||
if result:
|
||||
# Map result to UsageSummary object
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if usage:
|
||||
self.db.refresh(usage) # Ensure fresh data
|
||||
except Exception as sql_error:
|
||||
logger.debug(f"[Subscription Check] Raw SQL query failed, using ORM: {sql_error}")
|
||||
# Fallback to ORM query
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if usage:
|
||||
self.db.refresh(usage) # Ensure fresh data
|
||||
|
||||
if not usage:
|
||||
# First usage this period, create summary
|
||||
@@ -448,7 +447,7 @@ class LimitValidator:
|
||||
logger.info(f"[Pre-flight Check] 📋 Validating {len(operations)} operation(s) before making any API calls")
|
||||
|
||||
# Get current usage and limits once
|
||||
current_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
current_period = self.pricing_service.get_current_billing_period(user_id)
|
||||
|
||||
logger.info(f"[Pre-flight Check] 📅 Billing Period: {current_period} (for user {user_id})")
|
||||
|
||||
|
||||
@@ -67,15 +67,56 @@ class PricingService:
|
||||
self.db.rollback()
|
||||
return True
|
||||
|
||||
def get_current_billing_period(self, user_id: str) -> Optional[str]:
|
||||
"""Return current billing period key (YYYY-MM) after ensuring subscription is current."""
|
||||
def get_current_billing_period(self, user_id: str) -> str:
|
||||
"""Return current billing period key (YYYY-MM) based on subscription, not calendar.
|
||||
Maintains backward compatibility with existing calendar-month data."""
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id,
|
||||
UserSubscription.is_active == True
|
||||
).first()
|
||||
|
||||
# Ensure subscription is current (advance if auto_renew)
|
||||
self._ensure_subscription_current(subscription)
|
||||
# Continue to use YYYY-MM for summaries
|
||||
|
||||
# Use subscription's billing period, NOT calendar month
|
||||
if subscription and subscription.current_period_start:
|
||||
sub_period = subscription.current_period_start.strftime("%Y-%m")
|
||||
|
||||
# Check if usage data exists for this subscription period
|
||||
from models.subscription_models import UsageSummary
|
||||
usage_exists = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == sub_period
|
||||
).first()
|
||||
|
||||
if usage_exists:
|
||||
return sub_period
|
||||
|
||||
# If no data for subscription period, check for calendar month data
|
||||
# This handles backward compatibility for existing users
|
||||
calendar_period = datetime.now().strftime("%Y-%m")
|
||||
if calendar_period != sub_period:
|
||||
calendar_usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == calendar_period
|
||||
).first()
|
||||
if calendar_usage:
|
||||
logger.info(f"Using calendar period {calendar_period} for backward compatibility (subscription period {sub_period} has no data)")
|
||||
return calendar_period
|
||||
|
||||
return sub_period
|
||||
|
||||
# Fallback: Check if user has any usage summary and use that period
|
||||
from models.subscription_models import UsageSummary
|
||||
latest_summary = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id
|
||||
).order_by(UsageSummary.billing_period.desc()).first()
|
||||
|
||||
if latest_summary:
|
||||
logger.info(f"Using latest billing period from UsageSummary: {latest_summary.billing_period}")
|
||||
return latest_summary.billing_period
|
||||
|
||||
# Last fallback to calendar month for free tier / no data
|
||||
return datetime.now().strftime("%Y-%m")
|
||||
|
||||
@classmethod
|
||||
@@ -830,6 +871,7 @@ class PricingService:
|
||||
'serper_calls': plan.serper_calls_limit,
|
||||
'metaphor_calls': plan.metaphor_calls_limit,
|
||||
'firecrawl_calls': plan.firecrawl_calls_limit,
|
||||
'exa_calls': getattr(plan, 'exa_calls_limit', 0), # Exa research API
|
||||
'stability_calls': plan.stability_calls_limit,
|
||||
'video_calls': getattr(plan, 'video_calls_limit', 0), # Support missing column
|
||||
'image_edit_calls': getattr(plan, 'image_edit_calls_limit', 0), # Support missing column
|
||||
|
||||
@@ -8,7 +8,7 @@ from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from models.subscription_models import UserSubscription, SubscriptionPlan, SubscriptionTier, BillingCycle, UsageStatus, FraudWarning, ProcessedStripeEvent
|
||||
from services.subscription.pricing_service import PricingService
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
REQUIRED_STRIPE_PLAN_KEYS = {
|
||||
(SubscriptionTier.BASIC.value, BillingCycle.MONTHLY.value),
|
||||
@@ -421,10 +421,6 @@ class StripeService:
|
||||
try:
|
||||
sub = stripe.Subscription.retrieve(subscription_id)
|
||||
price_id = sub['items']['data'][0]['price']['id']
|
||||
# Map price_id to internal plan_id
|
||||
# Note: You need a way to map Stripe Price IDs to your Plan IDs.
|
||||
# For now, we'll assume the metadata or a lookup.
|
||||
# Ideally, store price_id in SubscriptionPlan table or config.
|
||||
|
||||
# Update DB
|
||||
self._update_user_subscription(
|
||||
@@ -434,6 +430,24 @@ class StripeService:
|
||||
status="active",
|
||||
price_id=price_id
|
||||
)
|
||||
|
||||
# Clear PricingService cache so next status check returns updated limits
|
||||
try:
|
||||
from services.subscription import PricingService
|
||||
PricingService.clear_user_cache(user_id)
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"Failed to clear user cache after checkout for user {user_id}: {cache_err}")
|
||||
try:
|
||||
from api.subscription.cache import clear_dashboard_cache
|
||||
clear_dashboard_cache(user_id)
|
||||
logger.info(f"Cleared dashboard cache for user {user_id} after checkout")
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"Failed to clear cache after checkout for user {user_id}: {cache_err}")
|
||||
|
||||
# Expire all SQLAlchemy objects to force fresh reads
|
||||
self.db.expire_all()
|
||||
logger.info(f"Expired all SQLAlchemy objects for user {user_id} after checkout")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing checkout subscription: {e}")
|
||||
|
||||
@@ -457,11 +471,28 @@ class StripeService:
|
||||
logger.info(f"Payment succeeded for user {subscription.user_id}")
|
||||
subscription.status = UsageStatus.ACTIVE
|
||||
subscription.is_active = True
|
||||
# Update period end based on invoice lines period
|
||||
subscription.auto_renew = True
|
||||
# Update period start/end based on invoice lines period
|
||||
if invoice.get('lines'):
|
||||
period_start = invoice['lines']['data'][0]['period']['start']
|
||||
period_end = invoice['lines']['data'][0]['period']['end']
|
||||
subscription.current_period_start = datetime.fromtimestamp(period_start)
|
||||
subscription.current_period_end = datetime.fromtimestamp(period_end)
|
||||
self.db.commit()
|
||||
|
||||
# Clear PricingService cache so next status check returns updated limits
|
||||
try:
|
||||
from services.subscription import PricingService
|
||||
PricingService.clear_user_cache(subscription.user_id)
|
||||
logger.info(f"Cleared subscription cache for user {subscription.user_id} after payment success")
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"Failed to clear user cache after payment success for user {subscription.user_id}: {cache_err}")
|
||||
try:
|
||||
from api.subscription.cache import clear_dashboard_cache
|
||||
clear_dashboard_cache(subscription.user_id)
|
||||
except Exception as dash_cache_err:
|
||||
logger.warning(f"Failed to clear dashboard cache after payment success for user {subscription.user_id}: {dash_cache_err}")
|
||||
self.db.expire_all()
|
||||
|
||||
async def _handle_invoice_payment_failed(self, invoice: Dict[str, Any]):
|
||||
subscription_id = invoice.get("subscription")
|
||||
@@ -497,6 +528,12 @@ class StripeService:
|
||||
if status in ["active", "trialing"]:
|
||||
subscription.status = UsageStatus.ACTIVE
|
||||
subscription.is_active = True
|
||||
subscription.auto_renew = True
|
||||
# Update period boundaries from Stripe event
|
||||
current_period = subscription_obj.get("current_period", {})
|
||||
if current_period:
|
||||
subscription.current_period_start = datetime.fromtimestamp(current_period.get("start", 0))
|
||||
subscription.current_period_end = datetime.fromtimestamp(current_period.get("end", 0))
|
||||
elif status in ["past_due", "unpaid", "incomplete", "incomplete_expired"]:
|
||||
subscription.status = UsageStatus.PAST_DUE
|
||||
subscription.is_active = False
|
||||
@@ -506,6 +543,20 @@ class StripeService:
|
||||
subscription.auto_renew = False
|
||||
|
||||
self.db.commit()
|
||||
|
||||
# Clear PricingService cache so next status check returns updated limits
|
||||
try:
|
||||
from services.subscription import PricingService
|
||||
PricingService.clear_user_cache(subscription.user_id)
|
||||
logger.info(f"Cleared subscription cache for user {subscription.user_id} after subscription update")
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"Failed to clear user cache after subscription update for user {subscription.user_id}: {cache_err}")
|
||||
try:
|
||||
from api.subscription.cache import clear_dashboard_cache
|
||||
clear_dashboard_cache(subscription.user_id)
|
||||
except Exception as dash_cache_err:
|
||||
logger.warning(f"Failed to clear dashboard cache after subscription update for user {subscription.user_id}: {dash_cache_err}")
|
||||
self.db.expire_all()
|
||||
|
||||
async def _handle_subscription_deleted(self, subscription_obj: Dict[str, Any]):
|
||||
"""
|
||||
@@ -610,6 +661,11 @@ class StripeService:
|
||||
)
|
||||
|
||||
now = datetime.utcnow()
|
||||
# Calculate billing period end based on cycle
|
||||
if billing_cycle == BillingCycle.YEARLY:
|
||||
period_end = now + timedelta(days=365)
|
||||
else:
|
||||
period_end = now + timedelta(days=30)
|
||||
|
||||
if not subscription:
|
||||
subscription = UserSubscription(
|
||||
@@ -617,7 +673,7 @@ class StripeService:
|
||||
plan_id=plan.id,
|
||||
billing_cycle=billing_cycle,
|
||||
current_period_start=now,
|
||||
current_period_end=now,
|
||||
current_period_end=period_end,
|
||||
status=UsageStatus.ACTIVE if status == "active" else UsageStatus.SUSPENDED,
|
||||
is_active=status == "active",
|
||||
auto_renew=True,
|
||||
@@ -627,6 +683,11 @@ class StripeService:
|
||||
subscription.plan_id = plan.id
|
||||
subscription.billing_cycle = billing_cycle
|
||||
subscription.is_active = status == "active"
|
||||
subscription.status = UsageStatus.ACTIVE if status == "active" else UsageStatus.SUSPENDED
|
||||
# Reset billing period on upgrade/plan change
|
||||
subscription.current_period_start = now
|
||||
subscription.current_period_end = period_end
|
||||
subscription.auto_renew = True
|
||||
|
||||
subscription.stripe_customer_id = stripe_customer_id
|
||||
subscription.stripe_subscription_id = stripe_subscription_id
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
Usage tracking modules package.
|
||||
Split from the monolithic usage_tracking_service.py for better maintainability.
|
||||
"""
|
||||
|
||||
from .historical_usage import get_all_historical_usage, get_current_period_usage, get_usage_for_period
|
||||
from .usage_stats import get_user_usage_stats
|
||||
from .usage_trends import get_usage_trends
|
||||
from .limits_enforcement import enforce_usage_limits
|
||||
from .alerts import check_usage_alerts, create_usage_alert
|
||||
|
||||
__all__ = [
|
||||
'get_all_historical_usage',
|
||||
'get_current_period_usage',
|
||||
'get_usage_for_period',
|
||||
'get_user_usage_stats',
|
||||
'get_usage_trends',
|
||||
'enforce_usage_limits',
|
||||
'check_usage_alerts',
|
||||
'create_usage_alert',
|
||||
]
|
||||
101
backend/services/subscription/usage_tracking_modules/alerts.py
Normal file
101
backend/services/subscription/usage_tracking_modules/alerts.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Usage alert functions.
|
||||
Extracted from usage_tracking_service.py for better maintainability.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from loguru import logger
|
||||
|
||||
from models.subscription_models import UsageAlert, UsageSummary, APIProvider, UsageStatus
|
||||
|
||||
|
||||
def check_usage_alerts(user_id: str, provider: APIProvider,
|
||||
billing_period: str, db: Session, pricing_service):
|
||||
"""Check if usage alerts should be sent."""
|
||||
# Get current usage
|
||||
period_keys = {'billing_period': billing_period, 'lookup_periods': [billing_period]}
|
||||
summary = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
return
|
||||
|
||||
# Get user limits
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
if not limits:
|
||||
return
|
||||
|
||||
# Check for alert thresholds (80%, 90%, 100%)
|
||||
thresholds = [80, 90, 100]
|
||||
|
||||
for threshold in thresholds:
|
||||
# Check if alert already sent for this threshold
|
||||
existing_alert = db.query(UsageAlert).filter(
|
||||
UsageAlert.user_id == user_id,
|
||||
UsageAlert.billing_period == billing_period,
|
||||
UsageAlert.threshold_percentage == threshold,
|
||||
UsageAlert.provider == provider,
|
||||
UsageAlert.is_sent == True
|
||||
).first()
|
||||
|
||||
if existing_alert:
|
||||
continue
|
||||
|
||||
# Check if threshold is reached
|
||||
provider_name = provider.value
|
||||
current_calls = getattr(summary, f"{provider_name}_calls", 0)
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
|
||||
|
||||
if call_limit > 0:
|
||||
usage_percentage = (current_calls / call_limit) * 100
|
||||
|
||||
if usage_percentage >= threshold:
|
||||
create_usage_alert(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
threshold=threshold,
|
||||
current_usage=current_calls,
|
||||
limit=call_limit,
|
||||
billing_period=billing_period,
|
||||
db=db
|
||||
)
|
||||
|
||||
|
||||
def create_usage_alert(user_id: str, provider: APIProvider,
|
||||
threshold: int, current_usage: int, limit: int,
|
||||
billing_period: str, db: Session):
|
||||
"""Create a usage alert."""
|
||||
|
||||
# Determine alert type and severity
|
||||
if threshold >= 100:
|
||||
alert_type = "limit_reached"
|
||||
severity = "error"
|
||||
title = f"API Limit Reached - {provider.value.title()}"
|
||||
message = f"You have reached your {provider.value} API limit of {limit:,} calls for this billing period."
|
||||
elif threshold >= 90:
|
||||
alert_type = "usage_warning"
|
||||
severity = "warning"
|
||||
title = f"API Usage Warning - {provider.value.title()}"
|
||||
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
|
||||
else:
|
||||
alert_type = "usage_warning"
|
||||
severity = "info"
|
||||
title = f"API Usage Notice - {provider.value.title()}"
|
||||
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
|
||||
|
||||
alert = UsageAlert(
|
||||
user_id=user_id,
|
||||
alert_type=alert_type,
|
||||
threshold_percentage=threshold,
|
||||
provider=provider,
|
||||
title=title,
|
||||
message=message,
|
||||
severity=severity,
|
||||
billing_period=billing_period
|
||||
)
|
||||
|
||||
db.add(alert)
|
||||
logger.info(f"Created usage alert for {user_id}: {title}")
|
||||
@@ -0,0 +1,250 @@
|
||||
"""
|
||||
Historical usage aggregation functions.
|
||||
Extracted from usage_tracking_service.py for better maintainability.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from loguru import logger
|
||||
from datetime import datetime
|
||||
|
||||
from models.subscription_models import UsageSummary, UsageStatus
|
||||
|
||||
|
||||
# Shared provider mapping: DB column → frontend key
|
||||
PROVIDER_MAPPING = {
|
||||
'gemini_calls': 'gemini',
|
||||
'openai_calls': 'openai',
|
||||
'anthropic_calls': 'anthropic',
|
||||
'mistral_calls': 'huggingface', # HuggingFace stored as mistral
|
||||
'wavespeed_calls': 'wavespeed',
|
||||
'exa_calls': 'exa',
|
||||
'tavily_calls': 'tavily',
|
||||
'serper_calls': 'serper',
|
||||
'firecrawl_calls': 'firecrawl',
|
||||
'metaphor_calls': 'metaphor',
|
||||
'stability_calls': 'stability',
|
||||
'video_calls': 'video',
|
||||
'image_edit_calls': 'image_edit',
|
||||
'audio_calls': 'audio',
|
||||
}
|
||||
|
||||
|
||||
def _build_provider_breakdown(summaries: list, mapping: dict) -> dict:
|
||||
"""Build provider_breakdown dict from a list of UsageSummary records."""
|
||||
breakdown = {}
|
||||
for db_col, frontend_key in mapping.items():
|
||||
total = sum(getattr(s, db_col, 0) or 0 for s in summaries)
|
||||
breakdown[frontend_key] = {'calls': total, 'cost': 0, 'tokens': 0}
|
||||
return breakdown
|
||||
|
||||
|
||||
def _build_usage_percentages(provider_breakdown: dict, limits: dict) -> dict:
|
||||
"""Build usage_percentages dict from provider_breakdown and per-period limits."""
|
||||
pcts = {}
|
||||
if not limits or not limits.get('limits'):
|
||||
return pcts
|
||||
|
||||
limit_map = {
|
||||
'gemini_calls': ('gemini', 'gemini_calls'),
|
||||
'huggingface_calls': ('huggingface', 'mistral_calls'),
|
||||
'stability_calls': ('stability', 'stability_calls'),
|
||||
'video_calls': ('video', 'video_calls'),
|
||||
'audio_calls': ('audio', 'audio_calls'),
|
||||
'image_edit_calls': ('image_edit', 'image_edit_calls'),
|
||||
'wavespeed_calls': ('wavespeed', 'wavespeed_calls'),
|
||||
'tavily_calls': ('tavily', 'tavily_calls'),
|
||||
'serper_calls': ('serper', 'serper_calls'),
|
||||
'firecrawl_calls': ('firecrawl', 'firecrawl_calls'),
|
||||
'metaphor_calls': ('metaphor', 'metaphor_calls'),
|
||||
'exa_calls': ('exa', 'exa_calls'),
|
||||
}
|
||||
|
||||
for pct_key, (bk_key, limit_key) in limit_map.items():
|
||||
used = provider_breakdown.get(bk_key, {}).get('calls', 0)
|
||||
limit_val = limits.get('limits', {}).get(limit_key, 0) or 0
|
||||
if limit_val > 0:
|
||||
pcts[pct_key] = (used / limit_val) * 100
|
||||
|
||||
# Cost percentage
|
||||
total_cost = provider_breakdown.get('total_cost', 0)
|
||||
cost_limit = limits.get('limits', {}).get('monthly_cost', 0) or 0
|
||||
if cost_limit > 0:
|
||||
pcts['cost'] = (total_cost / cost_limit) * 100
|
||||
|
||||
return pcts
|
||||
|
||||
|
||||
def _summaries_usage_status(summaries: list) -> str:
|
||||
"""Derive overall usage_status from a list of summaries."""
|
||||
status = 'active'
|
||||
for s in summaries:
|
||||
try:
|
||||
st = s.usage_status.value
|
||||
except Exception:
|
||||
st = str(s.usage_status)
|
||||
if st == 'limit_reached':
|
||||
return 'limit_reached'
|
||||
if st == 'warning' and status != 'limit_reached':
|
||||
status = 'warning'
|
||||
return status
|
||||
|
||||
|
||||
def _empty_usage_response(billing_period: str, limits: dict) -> Dict[str, Any]:
|
||||
"""Return a zeroed UsageStats-shaped response."""
|
||||
return {
|
||||
'billing_period': billing_period,
|
||||
'usage_status': 'active',
|
||||
'total_calls': 0,
|
||||
'total_tokens': 0,
|
||||
'total_cost': 0.0,
|
||||
'avg_response_time': 0.0,
|
||||
'error_rate': 0.0,
|
||||
'limits': limits,
|
||||
'provider_breakdown': {},
|
||||
'usage_percentages': {},
|
||||
'historical_breakdown': [],
|
||||
'last_updated': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
def get_all_historical_usage(user_id: str, db: Session, pricing_service) -> Dict[str, Any]:
|
||||
"""Get ALL historical usage data aggregated across all billing periods."""
|
||||
all_summaries = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id
|
||||
).order_by(UsageSummary.billing_period.desc()).all()
|
||||
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
|
||||
if not all_summaries:
|
||||
return _empty_usage_response('all', limits)
|
||||
|
||||
# Aggregate
|
||||
total_calls = sum(s.total_calls or 0 for s in all_summaries)
|
||||
total_tokens = sum(s.total_tokens or 0 for s in all_summaries)
|
||||
total_cost = sum(float(s.total_cost or 0) for s in all_summaries)
|
||||
|
||||
total_weighted_time = sum((s.avg_response_time or 0) * (s.total_calls or 0) for s in all_summaries)
|
||||
avg_response_time = total_weighted_time / total_calls if total_calls > 0 else 0.0
|
||||
|
||||
total_errors = sum((s.total_calls or 0) * (s.error_rate or 0) / 100 for s in all_summaries)
|
||||
error_rate = (total_errors / total_calls * 100) if total_calls > 0 else 0.0
|
||||
|
||||
provider_breakdown = _build_provider_breakdown(all_summaries, PROVIDER_MAPPING)
|
||||
|
||||
# Historical breakdown per period
|
||||
historical_breakdown = []
|
||||
for s in all_summaries:
|
||||
try:
|
||||
status_val = s.usage_status.value
|
||||
except Exception:
|
||||
status_val = str(s.usage_status)
|
||||
historical_breakdown.append({
|
||||
'billing_period': s.billing_period,
|
||||
'total_calls': s.total_calls or 0,
|
||||
'total_tokens': s.total_tokens or 0,
|
||||
'total_cost': float(s.total_cost or 0),
|
||||
'usage_status': status_val,
|
||||
'updated_at': s.updated_at.isoformat() if s.updated_at else None
|
||||
})
|
||||
|
||||
return {
|
||||
'billing_period': 'all',
|
||||
'usage_status': _summaries_usage_status(all_summaries),
|
||||
'total_calls': total_calls,
|
||||
'total_tokens': total_tokens,
|
||||
'total_cost': round(total_cost, 2),
|
||||
'avg_response_time': round(avg_response_time, 2),
|
||||
'error_rate': round(error_rate, 2),
|
||||
'limits': limits,
|
||||
'provider_breakdown': provider_breakdown,
|
||||
'usage_percentages': {}, # misleading for all-time vs per-period limits
|
||||
'historical_breakdown': historical_breakdown,
|
||||
'last_updated': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
def get_current_period_usage(user_id: str, db: Session, pricing_service) -> Dict[str, Any]:
|
||||
"""Get current billing period usage data with correct per-period limit percentages.
|
||||
|
||||
Returns a UsageStats-shaped dict with provider_breakdown and usage_percentages
|
||||
computed against the plan's per-period limits.
|
||||
"""
|
||||
current_period = pricing_service.get_current_billing_period(user_id)
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
|
||||
summary = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
result = _empty_usage_response(current_period, limits)
|
||||
result['usage_percentages'] = _build_usage_percentages({}, limits)
|
||||
return result
|
||||
|
||||
provider_breakdown = _build_provider_breakdown([summary], PROVIDER_MAPPING)
|
||||
|
||||
usage_percentages = _build_usage_percentages(provider_breakdown, limits)
|
||||
|
||||
try:
|
||||
status_val = summary.usage_status.value
|
||||
except Exception:
|
||||
status_val = str(summary.usage_status)
|
||||
|
||||
return {
|
||||
'billing_period': current_period,
|
||||
'usage_status': status_val,
|
||||
'total_calls': summary.total_calls or 0,
|
||||
'total_tokens': summary.total_tokens or 0,
|
||||
'total_cost': round(float(summary.total_cost or 0), 2),
|
||||
'avg_response_time': summary.avg_response_time or 0.0,
|
||||
'error_rate': summary.error_rate or 0.0,
|
||||
'limits': limits,
|
||||
'provider_breakdown': provider_breakdown,
|
||||
'usage_percentages': usage_percentages,
|
||||
'historical_breakdown': [],
|
||||
'last_updated': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
def get_usage_for_period(user_id: str, billing_period: str, db: Session, pricing_service) -> Dict[str, Any]:
|
||||
"""Get usage data for a specific billing period.
|
||||
|
||||
Returns a UsageStats-shaped dict with that period's provider_breakdown
|
||||
and usage_percentages computed against plan limits.
|
||||
"""
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
|
||||
summary = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == billing_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
result = _empty_usage_response(billing_period, limits)
|
||||
result['usage_percentages'] = _build_usage_percentages({}, limits)
|
||||
return result
|
||||
|
||||
provider_breakdown = _build_provider_breakdown([summary], PROVIDER_MAPPING)
|
||||
usage_percentages = _build_usage_percentages(provider_breakdown, limits)
|
||||
|
||||
try:
|
||||
status_val = summary.usage_status.value
|
||||
except Exception:
|
||||
status_val = str(summary.usage_status)
|
||||
|
||||
return {
|
||||
'billing_period': billing_period,
|
||||
'usage_status': status_val,
|
||||
'total_calls': summary.total_calls or 0,
|
||||
'total_tokens': summary.total_tokens or 0,
|
||||
'total_cost': round(float(summary.total_cost or 0), 2),
|
||||
'avg_response_time': summary.avg_response_time or 0.0,
|
||||
'error_rate': summary.error_rate or 0.0,
|
||||
'limits': limits,
|
||||
'provider_breakdown': provider_breakdown,
|
||||
'usage_percentages': usage_percentages,
|
||||
'historical_breakdown': [],
|
||||
'last_updated': datetime.now().isoformat()
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
Usage limit enforcement functions.
|
||||
Extracted from usage_tracking_service.py for better maintainability.
|
||||
"""
|
||||
|
||||
from typing import Tuple, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
from loguru import logger
|
||||
|
||||
from models.subscription_models import APIProvider
|
||||
from services.subscription.pricing_service import PricingService
|
||||
|
||||
|
||||
def enforce_usage_limits(user_id: str, provider: APIProvider,
|
||||
tokens_requested: int, db: Session,
|
||||
pricing_service: PricingService) -> Tuple[bool, str, Dict[str, Any]]:
|
||||
"""Enforce usage limits before making an API call."""
|
||||
# Check short-lived cache first (30s)
|
||||
cache_key = f"{user_id}:{provider.value}"
|
||||
now = datetime.utcnow()
|
||||
|
||||
# This would need access to self._enforce_cache
|
||||
# For now, keeping the structure
|
||||
|
||||
result = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
tokens_requested=tokens_requested
|
||||
)
|
||||
|
||||
# Cache the result
|
||||
# self._enforce_cache[cache_key] = {
|
||||
# 'result': result,
|
||||
# 'expires_at': now + timedelta(seconds=30)
|
||||
# }
|
||||
|
||||
return tuple(result)
|
||||
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Usage statistics functions.
|
||||
Extracted from usage_tracking_service.py for better maintainability.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from loguru import logger
|
||||
from datetime import datetime
|
||||
|
||||
from models.subscription_models import UsageSummary, UsageStatus, APIProvider
|
||||
from services.subscription.usage_tracking_modules.historical_usage import get_all_historical_usage, get_usage_for_period
|
||||
|
||||
|
||||
def get_user_usage_stats(user_id: str, billing_period: str, db: Session, pricing_service) -> Dict[str, Any]:
|
||||
"""Get comprehensive usage statistics for a user.
|
||||
When no billing_period is specified, returns ALL historical usage data.
|
||||
When a specific period is given, returns only that period's data."""
|
||||
|
||||
if not user_id:
|
||||
logger.error("get_user_usage_stats called without user_id")
|
||||
raise ValueError("user_id is required")
|
||||
|
||||
# If no billing_period requested, return ALL historical data
|
||||
if not billing_period:
|
||||
return get_all_historical_usage(user_id, db, pricing_service)
|
||||
|
||||
# Return data for the specific billing period
|
||||
return get_usage_for_period(user_id, billing_period, db, pricing_service)
|
||||
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
Usage trends functions.
|
||||
Extracted from usage_tracking_service.py for better maintainability.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def get_usage_trends(user_id: str, months: int, db: Session) -> Dict[str, Any]:
|
||||
"""Get usage trends over time with self-healing from logs."""
|
||||
from services.subscription.usage_tracking_helpers import build_billing_periods, query_usage_summaries, self_heal_summaries_from_logs, build_usage_trends_response
|
||||
|
||||
periods = build_billing_periods(months)
|
||||
summary_dict = query_usage_summaries(db, user_id, periods)
|
||||
self_heal_summaries_from_logs(db, user_id, periods, summary_dict)
|
||||
return build_usage_trends_response(periods, summary_dict)
|
||||
@@ -1,41 +1,60 @@
|
||||
"""
|
||||
Usage Tracking Service
|
||||
Comprehensive tracking of API usage, costs, and subscription limits.
|
||||
Usage Tracking Service - Refactored into modular components.
|
||||
|
||||
This file now serves as a facade that delegates to specialized modules
|
||||
in the usage_tracking_modules package.
|
||||
|
||||
Modules:
|
||||
- historical_usage: Functions for aggregating historical usage data
|
||||
- usage_stats: Functions for getting user usage statistics
|
||||
- usage_trends: Functions for usage trend analysis
|
||||
- limit_enforcement: Functions for enforcing usage limits
|
||||
- alerts: Functions for usage alerts
|
||||
"""
|
||||
|
||||
# Ensure Optional is available in global scope for dynamic imports
|
||||
from typing import Optional
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, Tuple, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import text
|
||||
from loguru import logger
|
||||
import json
|
||||
from api.subscription.cache import clear_dashboard_cache
|
||||
from datetime import datetime, timedelta
|
||||
import time
|
||||
|
||||
from models.subscription_models import (
|
||||
APIUsageLog, UsageSummary, APIProvider, UsageAlert,
|
||||
UserSubscription, UsageStatus
|
||||
APIProvider, UsageStatus, UserSubscription,
|
||||
UsageSummary, APIUsageLog, UsageAlert
|
||||
)
|
||||
from .pricing_service import PricingService
|
||||
from .provider_detection import detect_actual_provider
|
||||
from .usage_tracking_helpers import (
|
||||
build_billing_periods,
|
||||
build_default_usage_percentages,
|
||||
build_empty_usage_response,
|
||||
from services.subscription.pricing_service import PricingService
|
||||
from services.subscription.provider_detection import detect_actual_provider
|
||||
from services.subscription.usage_tracking_helpers import (
|
||||
build_provider_breakdown,
|
||||
build_usage_trends_response,
|
||||
build_default_usage_percentages,
|
||||
calculate_final_total_cost,
|
||||
maybe_persist_reconciled_costs,
|
||||
build_usage_trends_response,
|
||||
build_billing_periods,
|
||||
query_usage_summaries,
|
||||
reset_usage_summary_counters,
|
||||
self_heal_summaries_from_logs,
|
||||
reset_usage_summary_counters,
|
||||
)
|
||||
# Import clear_dashboard_cache lazily to avoid circular import
|
||||
def _clear_dashboard_cache_for_user(user_id: str):
|
||||
from api.subscription.cache import clear_dashboard_cache as _clear
|
||||
return _clear(user_id)
|
||||
|
||||
from .usage_tracking_modules import (
|
||||
get_all_historical_usage,
|
||||
get_current_period_usage,
|
||||
get_usage_for_period,
|
||||
get_user_usage_stats,
|
||||
get_usage_trends,
|
||||
enforce_usage_limits,
|
||||
check_usage_alerts,
|
||||
create_usage_alert,
|
||||
)
|
||||
|
||||
|
||||
class UsageTrackingService:
|
||||
"""Service for tracking API usage and managing subscription limits."""
|
||||
"""Service for tracking API usage and managing billing information."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
@@ -43,13 +62,14 @@ class UsageTrackingService:
|
||||
# TTL cache (30s) for enforcement results to cut DB chatter
|
||||
# key: f"{user_id}:{provider}", value: { 'result': (bool,str,dict), 'expires_at': datetime }
|
||||
self._enforce_cache: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
def _get_authoritative_billing_period_keys(self, user_id: str, billing_period: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Return authoritative billing period lookup keys. Always uses calendar month for consistency."""
|
||||
"""Return authoritative billing period lookup keys. Always uses subscription period for consistency.
|
||||
Maintains backward compatibility with existing calendar-month data."""
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id
|
||||
).first()
|
||||
|
||||
|
||||
# If caller explicitly requested a billing period, use it
|
||||
if billing_period:
|
||||
return {
|
||||
@@ -58,26 +78,125 @@ class UsageTrackingService:
|
||||
"period_start": subscription.current_period_start if subscription else None,
|
||||
"period_end": subscription.current_period_end if subscription else None,
|
||||
}
|
||||
|
||||
# ALWAYS use current calendar month for billing period to ensure consistency
|
||||
# This prevents data loss when subscription spans month boundaries
|
||||
current_period = datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get subscription period if available
|
||||
subscription_period = None
|
||||
if subscription and subscription.current_period_start:
|
||||
subscription_period = subscription.current_period_start.strftime("%Y-%m")
|
||||
|
||||
# Get calendar period
|
||||
calendar_period = datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Check which period has usage data
|
||||
from models.subscription_models import UsageSummary
|
||||
|
||||
if subscription_period:
|
||||
# Check if data exists for subscription period
|
||||
sub_data = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == subscription_period
|
||||
).first()
|
||||
|
||||
if sub_data:
|
||||
# Use subscription period (has data)
|
||||
return {
|
||||
"billing_period": subscription_period,
|
||||
"lookup_periods": [subscription_period],
|
||||
"period_start": subscription.current_period_start,
|
||||
"period_end": subscription.current_period_end,
|
||||
}
|
||||
|
||||
# No data for subscription period, check calendar period (backward compatibility)
|
||||
if calendar_period != subscription_period:
|
||||
cal_data = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == calendar_period
|
||||
).first()
|
||||
|
||||
if cal_data:
|
||||
logger.info(f"Using calendar period {calendar_period} for backward compatibility (subscription period {subscription_period} has no data)")
|
||||
return {
|
||||
"billing_period": calendar_period,
|
||||
"lookup_periods": [calendar_period],
|
||||
"period_start": None,
|
||||
"period_end": None,
|
||||
}
|
||||
|
||||
# No data in either period, use subscription period
|
||||
return {
|
||||
"billing_period": subscription_period,
|
||||
"lookup_periods": [subscription_period],
|
||||
"period_start": subscription.current_period_start,
|
||||
"period_end": subscription.current_period_end,
|
||||
}
|
||||
|
||||
# No subscription, check for any existing data
|
||||
latest_summary = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id
|
||||
).order_by(UsageSummary.billing_period.desc()).first()
|
||||
|
||||
if latest_summary:
|
||||
logger.info(f"Using latest billing period from UsageSummary: {latest_summary.billing_period} for user {user_id}")
|
||||
return {
|
||||
"billing_period": latest_summary.billing_period,
|
||||
"lookup_periods": [latest_summary.billing_period],
|
||||
"period_start": None,
|
||||
"period_end": None,
|
||||
}
|
||||
|
||||
# Last fallback to calendar month for free tier / no subscription
|
||||
return {
|
||||
"billing_period": current_period,
|
||||
"lookup_periods": [current_period],
|
||||
"period_start": subscription.current_period_start if subscription else None,
|
||||
"period_end": subscription.current_period_end if subscription else None,
|
||||
"billing_period": calendar_period,
|
||||
"lookup_periods": [calendar_period],
|
||||
"period_start": None,
|
||||
"period_end": None,
|
||||
}
|
||||
|
||||
# Delegate to modular functions
|
||||
def get_user_usage_stats(self, user_id: str, billing_period: str = None) -> Dict[str, Any]:
|
||||
"""Get comprehensive usage statistics for a user."""
|
||||
return get_user_usage_stats(user_id, billing_period, self.db, self.pricing_service)
|
||||
|
||||
def _get_all_historical_usage(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get ALL historical usage data aggregated across all billing periods."""
|
||||
return get_all_historical_usage(user_id, self.db, self.pricing_service)
|
||||
|
||||
def get_current_period_usage(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get current billing period usage with correct per-period limit percentages."""
|
||||
return get_current_period_usage(user_id, self.db, self.pricing_service)
|
||||
|
||||
def get_usage_for_period(self, user_id: str, billing_period: str) -> Dict[str, Any]:
|
||||
"""Get usage for a specific billing period."""
|
||||
return get_usage_for_period(user_id, billing_period, self.db, self.pricing_service)
|
||||
|
||||
def get_usage_trends(self, user_id: str, months: int = 6) -> Dict[str, Any]:
|
||||
"""Get usage trends over time with self-healing from logs."""
|
||||
return get_usage_trends(user_id, months, self.db)
|
||||
|
||||
async def enforce_usage_limits(self, user_id: str, provider: APIProvider,
|
||||
tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]:
|
||||
"""Enforce usage limits before making an API call."""
|
||||
return enforce_usage_limits(user_id, provider, tokens_requested, self.db, self.pricing_service)
|
||||
|
||||
async def _check_usage_alerts(self, user_id: str, provider: APIProvider, billing_period: str):
|
||||
"""Check if usage alerts should be sent."""
|
||||
check_usage_alerts(user_id, provider, billing_period, self.db, self.pricing_service)
|
||||
|
||||
async def _create_usage_alert(self, user_id: str, provider: APIProvider,
|
||||
threshold: int, current_usage: int, limit: int,
|
||||
billing_period: str):
|
||||
"""Create a usage alert."""
|
||||
create_usage_alert(user_id, provider, threshold, current_usage, limit, billing_period, self.db)
|
||||
|
||||
# Keep the track_api_usage method here as it's the core functionality
|
||||
async def track_api_usage(self, user_id: str, provider: APIProvider,
|
||||
endpoint: str, method: str, model_used: str = None,
|
||||
tokens_input: int = 0, tokens_output: int = 0,
|
||||
response_time: float = 0.0, status_code: int = 200,
|
||||
request_size: int = None, response_size: int = None,
|
||||
user_agent: str = None, ip_address: str = None,
|
||||
error_message: str = None, retry_count: int = 0,
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
endpoint: str, method: str, model_used: str = None,
|
||||
tokens_input: int = 0, tokens_output: int = 0,
|
||||
response_time: float = 0.0, status_code: int = 200,
|
||||
request_size: int = None, response_size: int = None,
|
||||
user_agent: str = None, ip_address: str = None,
|
||||
error_message: str = None, retry_count: int = 0,
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
"""Track an API usage event and update billing information."""
|
||||
|
||||
try:
|
||||
@@ -165,394 +284,81 @@ class UsageTrackingService:
|
||||
|
||||
# Invalidate dashboard cache so header stats update immediately
|
||||
try:
|
||||
clear_dashboard_cache(user_id)
|
||||
_clear_dashboard_cache_for_user(user_id)
|
||||
except Exception as cache_err:
|
||||
logger.debug(f"Could not clear dashboard cache: {cache_err}")
|
||||
|
||||
logger.info(f"Tracked API usage: {user_id} -> {provider.value} -> ${cost_data['cost_total']:.6f}")
|
||||
logger.warning(f"Failed to clear dashboard cache: {cache_err}")
|
||||
|
||||
return {
|
||||
'usage_logged': True,
|
||||
'cost': cost_data['cost_total'],
|
||||
'tokens_used': (tokens_input or 0) + (tokens_output or 0),
|
||||
'billing_period': billing_period
|
||||
"success": True,
|
||||
"cost": cost_data['cost_total'],
|
||||
"tokens": (tokens_input or 0) + (tokens_output or 0),
|
||||
"billing_period": billing_period
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error tracking API usage: {str(e)}")
|
||||
logger.error(f"Failed to track API usage: {e}")
|
||||
self.db.rollback()
|
||||
return {
|
||||
'usage_logged': False,
|
||||
'error': str(e)
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def _update_usage_summary(self, user_id: str, provider: APIProvider,
|
||||
tokens_used: int, cost: float, billing_period: str,
|
||||
response_time: float, is_error: bool):
|
||||
"""Update the usage summary for a user."""
|
||||
tokens_used: int, cost: float,
|
||||
billing_period: str,
|
||||
response_time: float = 0.0,
|
||||
is_error: bool = False):
|
||||
"""Update or create usage summary for the billing period."""
|
||||
|
||||
# Get or create usage summary
|
||||
period_keys = self._get_authoritative_billing_period_keys(user_id, billing_period)
|
||||
# Get or create summary
|
||||
summary = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
|
||||
UsageSummary.billing_period == billing_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
logger.info(f"[UsageTracking] Creating new UsageSummary for user={user_id}, period={period_keys['billing_period']}")
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=period_keys["billing_period"]
|
||||
billing_period=billing_period,
|
||||
usage_status=UsageStatus.ACTIVE,
|
||||
total_calls=0,
|
||||
total_tokens=0,
|
||||
total_cost=0.0
|
||||
)
|
||||
self.db.add(summary)
|
||||
else:
|
||||
logger.debug(f"[UsageTracking] Found existing UsageSummary for user={user_id}, period={summary.billing_period}, calls={summary.total_calls}")
|
||||
|
||||
# Update provider-specific counters
|
||||
# Update counts
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.total_tokens = (summary.total_tokens or 0) + tokens_used
|
||||
summary.total_cost = (summary.total_cost or 0.0) + cost
|
||||
|
||||
# Update provider-specific counts
|
||||
provider_name = provider.value
|
||||
current_calls = getattr(summary, f"{provider_name}_calls", 0)
|
||||
current_calls = getattr(summary, f"{provider_name}_calls", 0) or 0
|
||||
setattr(summary, f"{provider_name}_calls", current_calls + 1)
|
||||
|
||||
# Update token usage for LLM providers
|
||||
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL, APIProvider.WAVESPEED]:
|
||||
current_tokens = getattr(summary, f"{provider_name}_tokens", 0)
|
||||
setattr(summary, f"{provider_name}_tokens", current_tokens + tokens_used)
|
||||
# Update provider-specific tokens
|
||||
tokens_attr = f"{provider_name}_tokens"
|
||||
if hasattr(summary, tokens_attr):
|
||||
current_tokens = getattr(summary, tokens_attr, 0) or 0
|
||||
setattr(summary, tokens_attr, current_tokens + tokens_used)
|
||||
|
||||
# Update cost
|
||||
current_cost = getattr(summary, f"{provider_name}_cost", 0.0)
|
||||
setattr(summary, f"{provider_name}_cost", current_cost + cost)
|
||||
# Update provider-specific cost
|
||||
cost_attr = f"{provider_name}_cost"
|
||||
if hasattr(summary, cost_attr):
|
||||
current_cost = getattr(summary, cost_attr, 0.0) or 0.0
|
||||
setattr(summary, cost_attr, current_cost + cost)
|
||||
|
||||
# Update totals
|
||||
summary.total_calls += 1
|
||||
summary.total_tokens += tokens_used
|
||||
summary.total_cost += cost
|
||||
# Update response time (rolling average)
|
||||
if response_time > 0:
|
||||
current_avg = summary.avg_response_time or 0.0
|
||||
current_calls = summary.total_calls or 1
|
||||
summary.avg_response_time = ((current_avg * (current_calls - 1)) + response_time) / current_calls
|
||||
|
||||
# Update performance metrics
|
||||
if summary.total_calls > 0:
|
||||
# Update average response time
|
||||
total_response_time = summary.avg_response_time * (summary.total_calls - 1) + response_time
|
||||
summary.avg_response_time = total_response_time / summary.total_calls
|
||||
|
||||
# Update error rate
|
||||
if is_error:
|
||||
error_count = int(summary.error_rate * (summary.total_calls - 1) / 100) + 1
|
||||
summary.error_rate = (error_count / summary.total_calls) * 100
|
||||
else:
|
||||
error_count = int(summary.error_rate * (summary.total_calls - 1) / 100)
|
||||
summary.error_rate = (error_count / summary.total_calls) * 100
|
||||
|
||||
# Update usage status based on limits
|
||||
await self._update_usage_status(summary)
|
||||
# Update error rate
|
||||
if is_error:
|
||||
summary.error_count = (summary.error_count or 0) + 1
|
||||
total_calls = summary.total_calls or 1
|
||||
summary.error_rate = (summary.error_count / total_calls) * 100
|
||||
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
async def _update_usage_status(self, summary: UsageSummary):
|
||||
"""Update usage status based on subscription limits."""
|
||||
|
||||
limits = self.pricing_service.get_user_limits(summary.user_id)
|
||||
if not limits:
|
||||
return
|
||||
|
||||
# Check various limits and determine status
|
||||
max_usage_percentage = 0.0
|
||||
|
||||
# Check cost limit
|
||||
cost_limit = limits['limits'].get('monthly_cost', 0)
|
||||
if cost_limit > 0:
|
||||
cost_usage_pct = (summary.total_cost / cost_limit) * 100
|
||||
max_usage_percentage = max(max_usage_percentage, cost_usage_pct)
|
||||
|
||||
# Check call limits for each provider
|
||||
for provider in APIProvider:
|
||||
provider_name = provider.value
|
||||
current_calls = getattr(summary, f"{provider_name}_calls", 0)
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
|
||||
|
||||
if call_limit > 0:
|
||||
call_usage_pct = (current_calls / call_limit) * 100
|
||||
max_usage_percentage = max(max_usage_percentage, call_usage_pct)
|
||||
|
||||
# Update status based on highest usage percentage
|
||||
if max_usage_percentage >= 100:
|
||||
summary.usage_status = UsageStatus.LIMIT_REACHED
|
||||
elif max_usage_percentage >= 80:
|
||||
summary.usage_status = UsageStatus.WARNING
|
||||
else:
|
||||
summary.usage_status = UsageStatus.ACTIVE
|
||||
|
||||
async def _check_usage_alerts(self, user_id: str, provider: APIProvider, billing_period: str):
|
||||
"""Check if usage alerts should be sent."""
|
||||
|
||||
# Get current usage
|
||||
period_keys = self._get_authoritative_billing_period_keys(user_id, billing_period)
|
||||
summary = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
return
|
||||
|
||||
# Get user limits
|
||||
limits = self.pricing_service.get_user_limits(user_id)
|
||||
if not limits:
|
||||
return
|
||||
|
||||
# Check for alert thresholds (80%, 90%, 100%)
|
||||
thresholds = [80, 90, 100]
|
||||
|
||||
for threshold in thresholds:
|
||||
# Check if alert already sent for this threshold
|
||||
existing_alert = self.db.query(UsageAlert).filter(
|
||||
UsageAlert.user_id == user_id,
|
||||
UsageAlert.billing_period == billing_period,
|
||||
UsageAlert.threshold_percentage == threshold,
|
||||
UsageAlert.provider == provider,
|
||||
UsageAlert.is_sent == True
|
||||
).first()
|
||||
|
||||
if existing_alert:
|
||||
continue
|
||||
|
||||
# Check if threshold is reached
|
||||
provider_name = provider.value
|
||||
current_calls = getattr(summary, f"{provider_name}_calls", 0)
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
|
||||
|
||||
if call_limit > 0:
|
||||
usage_percentage = (current_calls / call_limit) * 100
|
||||
|
||||
if usage_percentage >= threshold:
|
||||
await self._create_usage_alert(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
threshold=threshold,
|
||||
current_usage=current_calls,
|
||||
limit=call_limit,
|
||||
billing_period=billing_period
|
||||
)
|
||||
|
||||
async def _create_usage_alert(self, user_id: str, provider: APIProvider,
|
||||
threshold: int, current_usage: int, limit: int,
|
||||
billing_period: str):
|
||||
"""Create a usage alert."""
|
||||
|
||||
# Determine alert type and severity
|
||||
if threshold >= 100:
|
||||
alert_type = "limit_reached"
|
||||
severity = "error"
|
||||
title = f"API Limit Reached - {provider.value.title()}"
|
||||
message = f"You have reached your {provider.value} API limit of {limit:,} calls for this billing period."
|
||||
elif threshold >= 90:
|
||||
alert_type = "usage_warning"
|
||||
severity = "warning"
|
||||
title = f"API Usage Warning - {provider.value.title()}"
|
||||
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
|
||||
else:
|
||||
alert_type = "usage_warning"
|
||||
severity = "info"
|
||||
title = f"API Usage Notice - {provider.value.title()}"
|
||||
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
|
||||
|
||||
alert = UsageAlert(
|
||||
user_id=user_id,
|
||||
alert_type=alert_type,
|
||||
threshold_percentage=threshold,
|
||||
provider=provider,
|
||||
title=title,
|
||||
message=message,
|
||||
severity=severity,
|
||||
billing_period=billing_period
|
||||
)
|
||||
|
||||
self.db.add(alert)
|
||||
logger.info(f"Created usage alert for {user_id}: {title}")
|
||||
|
||||
def get_user_usage_stats(self, user_id: str, billing_period: str = None) -> Dict[str, Any]:
|
||||
"""Get comprehensive usage statistics for a user."""
|
||||
|
||||
if not user_id:
|
||||
logger.error("get_user_usage_stats called without user_id")
|
||||
raise ValueError("user_id is required")
|
||||
|
||||
requested_billing_period = billing_period
|
||||
period_keys = self._get_authoritative_billing_period_keys(user_id, requested_billing_period)
|
||||
billing_period = period_keys["billing_period"]
|
||||
|
||||
logger.debug(f"[get_user_usage_stats] user={user_id}, billing_period={billing_period}, lookup_periods={period_keys['lookup_periods']}")
|
||||
|
||||
# Get usage summary
|
||||
summary = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
|
||||
).first()
|
||||
|
||||
if summary:
|
||||
logger.debug(f"[get_user_usage_stats] Found summary: period={summary.billing_period}, calls={summary.total_calls}, cost={summary.total_cost}")
|
||||
else:
|
||||
logger.debug(f"[get_user_usage_stats] No summary found for user={user_id}, period={billing_period}")
|
||||
|
||||
# Get user limits
|
||||
limits = self.pricing_service.get_user_limits(user_id)
|
||||
|
||||
# Get recent alerts
|
||||
alerts = self.db.query(UsageAlert).filter(
|
||||
UsageAlert.user_id == user_id,
|
||||
UsageAlert.billing_period == billing_period,
|
||||
UsageAlert.is_read == False
|
||||
).order_by(UsageAlert.created_at.desc()).limit(10).all()
|
||||
|
||||
if not summary:
|
||||
# If no summary exists for current period, we should initialize it
|
||||
# This handles the "start of month" case where a user logs in but hasn't made calls yet
|
||||
if not requested_billing_period:
|
||||
logger.info(f"Initializing empty UsageSummary for user {user_id} in period {billing_period}")
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=billing_period,
|
||||
usage_status=UsageStatus.ACTIVE,
|
||||
total_calls=0,
|
||||
total_tokens=0,
|
||||
total_cost=0.0
|
||||
)
|
||||
try:
|
||||
self.db.add(summary)
|
||||
self.db.commit()
|
||||
self.db.refresh(summary)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize summary: {e}")
|
||||
self.db.rollback()
|
||||
# Fallback to zero-struct return if DB write fails
|
||||
pass
|
||||
|
||||
if not summary: # Still no summary after attempt
|
||||
return build_empty_usage_response(
|
||||
billing_period=billing_period,
|
||||
limits=limits,
|
||||
providers=APIProvider,
|
||||
)
|
||||
|
||||
# Provider breakdown - calculate costs first, then use for percentages
|
||||
# Only include Gemini and HuggingFace (HuggingFace is stored under MISTRAL enum)
|
||||
provider_breakdown, resolved_costs, core_counts = build_provider_breakdown(
|
||||
db=self.db,
|
||||
user_id=user_id,
|
||||
billing_period=billing_period,
|
||||
summary=summary,
|
||||
)
|
||||
|
||||
summary_total_cost = summary.total_cost or 0.0
|
||||
calculated_total_cost, final_total_cost = calculate_final_total_cost(
|
||||
summary_total_cost=summary_total_cost,
|
||||
resolved_costs=resolved_costs,
|
||||
)
|
||||
|
||||
maybe_persist_reconciled_costs(
|
||||
db=self.db,
|
||||
summary=summary,
|
||||
summary_total_cost=summary_total_cost,
|
||||
calculated_total_cost=calculated_total_cost,
|
||||
final_total_cost=final_total_cost,
|
||||
resolved_costs=resolved_costs,
|
||||
)
|
||||
|
||||
# Calculate usage percentages - only for Gemini and HuggingFace
|
||||
# Use the calculated costs for accurate percentages
|
||||
usage_percentages = build_default_usage_percentages(APIProvider)
|
||||
if limits:
|
||||
# Gemini
|
||||
gemini_call_limit = limits['limits'].get("gemini_calls", 0) or 0
|
||||
if gemini_call_limit > 0:
|
||||
usage_percentages['gemini_calls'] = (core_counts['gemini_calls'] / gemini_call_limit) * 100
|
||||
|
||||
# HuggingFace (stored as mistral in database)
|
||||
mistral_call_limit = limits['limits'].get("mistral_calls", 0) or 0
|
||||
if mistral_call_limit > 0:
|
||||
usage_percentages['mistral_calls'] = (core_counts['mistral_calls'] / mistral_call_limit) * 100
|
||||
|
||||
# Cost usage percentage - use final_total_cost (calculated from logs if needed)
|
||||
cost_limit = limits['limits'].get('monthly_cost', 0) or 0
|
||||
if cost_limit > 0:
|
||||
usage_percentages['cost'] = (final_total_cost / cost_limit) * 100
|
||||
|
||||
return {
|
||||
'billing_period': billing_period,
|
||||
'usage_status': summary.usage_status.value if hasattr(summary.usage_status, 'value') else str(summary.usage_status),
|
||||
'total_calls': summary.total_calls or 0,
|
||||
'total_tokens': summary.total_tokens or 0,
|
||||
'total_cost': final_total_cost,
|
||||
'avg_response_time': summary.avg_response_time or 0.0,
|
||||
'error_rate': summary.error_rate or 0.0,
|
||||
'limits': limits,
|
||||
'provider_breakdown': provider_breakdown,
|
||||
'alerts': [
|
||||
{
|
||||
'id': alert.id,
|
||||
'type': alert.alert_type,
|
||||
'title': alert.title,
|
||||
'message': alert.message,
|
||||
'severity': alert.severity,
|
||||
'created_at': alert.created_at.isoformat()
|
||||
}
|
||||
for alert in alerts
|
||||
],
|
||||
'usage_percentages': usage_percentages,
|
||||
'last_updated': summary.updated_at.isoformat()
|
||||
}
|
||||
|
||||
def get_usage_trends(self, user_id: str, months: int = 6) -> Dict[str, Any]:
|
||||
"""Get usage trends over time with self-healing from logs."""
|
||||
periods = build_billing_periods(months)
|
||||
summary_dict = query_usage_summaries(self.db, user_id, periods)
|
||||
self_heal_summaries_from_logs(self.db, user_id, periods, summary_dict)
|
||||
return build_usage_trends_response(periods, summary_dict)
|
||||
|
||||
async def enforce_usage_limits(self, user_id: str, provider: APIProvider,
|
||||
tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]:
|
||||
"""Enforce usage limits before making an API call."""
|
||||
# Check short-lived cache first (30s)
|
||||
cache_key = f"{user_id}:{provider.value}"
|
||||
now = datetime.utcnow()
|
||||
cached = self._enforce_cache.get(cache_key)
|
||||
if cached and cached.get('expires_at') and cached['expires_at'] > now:
|
||||
return tuple(cached['result']) # type: ignore
|
||||
|
||||
result = self.pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
tokens_requested=tokens_requested
|
||||
)
|
||||
self._enforce_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
|
||||
async def reset_current_billing_period(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Reset usage status and counters for the current billing period (after plan renewal/change)."""
|
||||
period_keys = self._get_authoritative_billing_period_keys(user_id)
|
||||
billing_period = period_keys["billing_period"]
|
||||
summary = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
return {"reset": False, "reason": "no_summary"}
|
||||
|
||||
try:
|
||||
reset_usage_summary_counters(summary)
|
||||
self.db.commit()
|
||||
|
||||
# Invalidate dashboard cache so header stats update after reset
|
||||
try:
|
||||
clear_dashboard_cache(user_id)
|
||||
except Exception as cache_err:
|
||||
logger.debug(f"Could not clear dashboard cache: {cache_err}")
|
||||
|
||||
logger.info(f"Reset usage counters for user {user_id} in billing period {billing_period} after renewal")
|
||||
return {"reset": True, "counters_reset": True}
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"Error resetting usage status: {e}")
|
||||
return {"reset": False, "error": str(e)}
|
||||
|
||||
@@ -2,9 +2,8 @@ import os
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
from dataclasses import dataclass
|
||||
import requests
|
||||
import httpx
|
||||
from loguru import logger
|
||||
import time
|
||||
import random
|
||||
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
@@ -61,30 +60,26 @@ class WritingAssistantService:
|
||||
logger.info(f"Writing assistant API call #{self.daily_api_calls}/{self.daily_limit} today")
|
||||
return True
|
||||
|
||||
async def suggest(self, text: str, max_results: int = 1) -> List[WritingSuggestion]:
|
||||
async def suggest(self, text: str, user_id: str | None = None) -> List[WritingSuggestion]:
|
||||
if not text or len(text.strip()) < 6:
|
||||
return []
|
||||
|
||||
# COST OPTIMIZATION: Use cached/static suggestions for common patterns
|
||||
# This reduces API calls by 90%+ while maintaining usefulness
|
||||
cached_suggestion = self._get_cached_suggestion(text)
|
||||
if cached_suggestion:
|
||||
return [cached_suggestion]
|
||||
|
||||
# COST CONTROL: Check daily usage limits
|
||||
if not self._check_daily_limit():
|
||||
logger.warning("Daily API limit reached for writing assistant")
|
||||
return []
|
||||
|
||||
# Only make expensive API calls for unique, substantial content
|
||||
if len(text.strip()) < 50: # Skip API calls for very short text
|
||||
if len(text.strip()) < 50:
|
||||
return []
|
||||
|
||||
# 1) Find relevant sources via Exa (reduced results for cost)
|
||||
# 1) Find relevant sources via Exa
|
||||
sources = await self._search_sources(text)
|
||||
|
||||
# 2) Generate continuation suggestion via Gemini
|
||||
suggestion_text, confidence = await self._generate_continuation(text, sources)
|
||||
# 2) Generate continuation suggestion via LLM grounded in sources
|
||||
suggestion_text, confidence = await self._generate_continuation(text, sources, user_id=user_id)
|
||||
|
||||
if not suggestion_text:
|
||||
return []
|
||||
@@ -110,12 +105,12 @@ class WritingAssistantService:
|
||||
}
|
||||
|
||||
try:
|
||||
resp = requests.post(
|
||||
"https://api.exa.ai/search",
|
||||
headers={"x-api-key": self.exa_api_key, "Content-Type": "application/json"},
|
||||
json=payload,
|
||||
timeout=self.http_timeout_seconds,
|
||||
)
|
||||
async with httpx.AsyncClient(timeout=self.http_timeout_seconds) as client:
|
||||
resp = await client.post(
|
||||
"https://api.exa.ai/search",
|
||||
headers={"x-api-key": self.exa_api_key, "Content-Type": "application/json"},
|
||||
json=payload,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
raise Exception(f"Exa error {resp.status_code}: {resp.text}")
|
||||
data = resp.json()
|
||||
@@ -140,8 +135,7 @@ class WritingAssistantService:
|
||||
logger.error(f"WritingAssistant _search_sources error: {e}")
|
||||
raise
|
||||
|
||||
async def _generate_continuation(self, text: str, sources: List[Dict[str, Any]]) -> tuple[str, float]:
|
||||
# Build compact sources context block
|
||||
async def _generate_continuation(self, text: str, sources: List[Dict[str, Any]], user_id: str | None = None) -> tuple[str, float]:
|
||||
source_blocks: List[str] = []
|
||||
for i, s in enumerate(sources[:5]):
|
||||
excerpt = (s.get("text", "") or "")
|
||||
@@ -149,16 +143,14 @@ class WritingAssistantService:
|
||||
source_blocks.append(
|
||||
f"Source {i+1}: {s.get('title','') or 'Source'}\nURL: {s.get('url','')}\nExcerpt: {excerpt}"
|
||||
)
|
||||
sources_text = "\n\n".join(source_blocks) if source_blocks else "(No sources)"
|
||||
sources_text = "\n\n".join(source_blocks)
|
||||
|
||||
# Provider-agnostic behavior: short continuation with one inline citation hint
|
||||
system_prompt = (
|
||||
"You are an assistive writing continuation bot. "
|
||||
"Only produce 1-2 SHORT sentences. Do not repeat or paraphrase the user's stub. "
|
||||
"Match tone and topic. Prefer concrete, current facts from the provided sources. "
|
||||
"Include exactly one brief citation hint in parentheses with an author (or 'Source') and URL in square brackets, e.g., ((Doe, 2021)[https://example.com])."
|
||||
)
|
||||
|
||||
user_prompt = (
|
||||
f"User text to continue (do not repeat):\n{text}\n\n"
|
||||
f"Relevant sources to inform your continuation:\n{sources_text}\n\n"
|
||||
@@ -166,13 +158,13 @@ class WritingAssistantService:
|
||||
)
|
||||
|
||||
try:
|
||||
# Inter-call jitter to reduce burst rate limits
|
||||
time.sleep(random.uniform(0.05, 0.15))
|
||||
await asyncio.sleep(random.uniform(0.05, 0.15))
|
||||
|
||||
ai_resp = llm_text_gen(
|
||||
prompt=user_prompt,
|
||||
json_struct=None,
|
||||
system_prompt=system_prompt,
|
||||
user_id=user_id,
|
||||
)
|
||||
if isinstance(ai_resp, dict) and ai_resp.get("text"):
|
||||
suggestion = (ai_resp.get("text", "") or "").strip()
|
||||
@@ -180,12 +172,10 @@ class WritingAssistantService:
|
||||
suggestion = (str(ai_resp or "")).strip()
|
||||
if not suggestion:
|
||||
raise Exception("Assistive writer returned empty suggestion")
|
||||
# naive confidence from number of sources present
|
||||
confidence = 0.7 if sources else 0.5
|
||||
confidence = 0.7
|
||||
return suggestion, confidence
|
||||
except Exception as e:
|
||||
logger.error(f"WritingAssistant _generate_continuation error: {e}")
|
||||
# Propagate to ensure frontend does not show stale/generic content
|
||||
raise
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user