fix: WYSIWYG editor, content generation, and writing assistant bug fixes

- Fix text selection menu not showing: wire contentRef via inputRef on multiline TextField
- Fix blog title not truncating: add min-w-0 for flex item overflow
- Fix outline generation 500: escape curly braces in f-string prompt template
- Fix content generation 'NoneType not callable': replace SessionLocal() with get_session_for_user(), add db param to MediumBlogGenerator, fix signature mismatch in database_task_manager
- Fix writing assistant suggest 500: add auth + user_id to API endpoint and service, replace sync requests with httpx.AsyncClient
- Fix hallucination detector 404: explicitly include router in main.py and app.py
- Fix missing error_data in task failure responses
- Hide CopilotKit web inspector button
- Remove hardcoded fallback suggestions from SmartTypingAssist
- Fix stale closure refs in SmartTypingAssist handleTypingChange
- Add two-column editor layout, stats bar, section hover menu
- Various subscription, billing, and research module improvements
This commit is contained in:
ajaysi
2026-05-14 09:11:30 +05:30
parent 7385100017
commit 928c2f20aa
113 changed files with 4344 additions and 10064 deletions

View File

@@ -9,6 +9,7 @@ import json
from typing import Dict, Any, List
from loguru import logger
from fastapi import HTTPException
from sqlalchemy.orm import Session
from models.blog_models import (
MediumBlogGenerateRequest,
@@ -26,7 +27,7 @@ class MediumBlogGenerator:
def __init__(self):
self.cache = persistent_content_cache
async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str, user_id: str) -> MediumBlogGenerateResult:
async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str, user_id: str, db: Session = None) -> MediumBlogGenerateResult:
"""Use Gemini structured JSON to generate a medium-length blog in one call.
Args:

View File

@@ -499,7 +499,7 @@ class DatabaseTaskManager:
)
blog_writer_logger.log_error(e, "outline_generation_task", context={"task_id": task_id})
async def _run_medium_generation_task(self, task_id: str, request: MediumBlogGenerateRequest):
async def _run_medium_generation_task(self, task_id: str, request: MediumBlogGenerateRequest, user_id: str):
"""Background task to generate a medium blog using a single structured JSON call."""
try:
await self.update_progress(task_id, "📦 Packaging outline and metadata...", 0)
@@ -512,7 +512,7 @@ class DatabaseTaskManager:
result: MediumBlogGenerateResult = await self.service.generate_medium_blog_with_progress(
request,
task_id,
user_id=request.user_id if hasattr(request, 'user_id') else (await self.get_task_status(task_id))['user_id'],
user_id,
db=self.db
)

View File

@@ -70,22 +70,22 @@ STRATEGIC REQUIREMENTS:
- Ensure engaging, actionable content throughout
Return JSON format:
{
{{
"title_options": [
"Title option 1",
"Title option 2",
"Title option 3"
],
"outline": [
{
{{
"heading": "Section heading with primary keyword",
"subheadings": ["Subheading 1", "Subheading 2", "Subheading 3"],
"key_points": ["Key point 1", "Key point 2", "Key point 3"],
"target_words": 300,
"keywords": ["primary keyword", "secondary keyword"]
}
}}
]
}"""
}}"""
def get_outline_schema(self) -> Dict[str, Any]:
"""Get the structured JSON schema for outline generation."""

View File

@@ -5,8 +5,8 @@ Enhances individual outline sections for better engagement and value.
"""
from loguru import logger
from models.blog_models import BlogOutlineSection
import json
class SectionEnhancer:
@@ -73,14 +73,45 @@ class SectionEnhancer:
"required": ["heading", "subheadings", "key_points", "target_words", "keywords"]
}
enhanced_data = llm_text_gen(
raw = llm_text_gen(
prompt=enhancement_prompt,
json_struct=enhancement_schema,
system_prompt=None,
user_id=user_id
)
if isinstance(enhanced_data, dict) and 'error' not in enhanced_data:
# Parse JSON from LLM response (works with both string and dict return types)
import re
if isinstance(raw, str):
cleaned = raw.strip()
if cleaned.startswith('```json'):
cleaned = cleaned[7:]
if cleaned.startswith('```'):
cleaned = cleaned[3:]
if cleaned.endswith('```'):
cleaned = cleaned[:-3]
cleaned = cleaned.strip()
try:
enhanced_data = json.loads(cleaned)
except json.JSONDecodeError:
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
if json_match:
try:
enhanced_data = json.loads(json_match.group(0))
except json.JSONDecodeError as e:
logger.warning(f"Section enhancement returned invalid JSON: {e}")
return section
else:
logger.warning(f"Section enhancement returned non-JSON string: {cleaned[:200]}")
return section
elif isinstance(raw, dict):
enhanced_data = raw
else:
logger.warning(f"Unexpected LLM response type: {type(raw)}")
return section
if 'error' in enhanced_data:
logger.warning(f"AI section enhancement failed: {enhanced_data.get('error', 'Unknown error')}")
else:
return BlogOutlineSection(
id=section.id,
heading=enhanced_data.get('heading', section.heading),

View File

@@ -6,6 +6,7 @@ Extracts competitor insights and market intelligence from research content.
from typing import Dict, Any
from loguru import logger
import json
class CompetitorAnalyzer:
@@ -22,7 +23,7 @@ class CompetitorAnalyzer:
Extract and analyze:
1. Top competitors mentioned (companies, brands, platforms)
2. Content gaps (what competitors are missing)
3. Market opportunities (untapped areas)
3. Opportunities (untapped areas)
4. Competitive advantages (what makes content unique)
5. Market positioning insights
6. Industry leaders and their strategies
@@ -55,18 +56,38 @@ class CompetitorAnalyzer:
"required": ["top_competitors", "content_gaps", "opportunities", "competitive_advantages", "market_positioning", "industry_leaders", "analysis_notes"]
}
competitor_analysis = llm_text_gen(
raw = llm_text_gen(
prompt=competitor_prompt,
json_struct=competitor_schema,
user_id=user_id
)
if isinstance(competitor_analysis, dict) and 'error' not in competitor_analysis:
logger.info("✅ AI competitor analysis completed successfully")
return competitor_analysis
# Parse JSON from LLM response (works with both string and dict return types)
import re
if isinstance(raw, str):
cleaned = raw.strip()
if cleaned.startswith('```json'):
cleaned = cleaned[7:]
if cleaned.startswith('```'):
cleaned = cleaned[3:]
if cleaned.endswith('```'):
cleaned = cleaned[:-3]
cleaned = cleaned.strip()
try:
competitor_analysis = json.loads(cleaned)
except json.JSONDecodeError:
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
if json_match:
competitor_analysis = json.loads(json_match.group(0))
else:
raise ValueError(f"Competitor analysis returned non-JSON string: {cleaned[:200]}")
elif isinstance(raw, dict):
competitor_analysis = raw
else:
# Fail gracefully - no fallback data
error_msg = competitor_analysis.get('error', 'Unknown error') if isinstance(competitor_analysis, dict) else str(competitor_analysis)
logger.error(f"AI competitor analysis failed: {error_msg}")
raise ValueError(f"Competitor analysis failed: {error_msg}")
raise ValueError(f"Unexpected LLM response type: {type(raw)}")
if 'error' in competitor_analysis:
raise ValueError(f"Competitor analysis failed: {competitor_analysis.get('error', 'Unknown error')}")
logger.info("✅ AI competitor analysis completed successfully")
return competitor_analysis

View File

@@ -63,18 +63,41 @@ class ContentAngleGenerator:
"required": ["content_angles"]
}
angles_result = llm_text_gen(
raw = llm_text_gen(
prompt=angles_prompt,
json_struct=angles_schema,
user_id=user_id
)
if isinstance(angles_result, dict) and 'content_angles' in angles_result:
logger.info("✅ AI content angles generation completed successfully")
return angles_result['content_angles'][:7]
# Parse JSON from LLM response (works with both string and dict return types)
import json, re
if isinstance(raw, str):
cleaned = raw.strip()
if cleaned.startswith('```json'):
cleaned = cleaned[7:]
if cleaned.startswith('```'):
cleaned = cleaned[3:]
if cleaned.endswith('```'):
cleaned = cleaned[:-3]
cleaned = cleaned.strip()
try:
angles_result = json.loads(cleaned)
except json.JSONDecodeError:
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
if json_match:
angles_result = json.loads(json_match.group(0))
else:
raise ValueError(f"Content angles returned non-JSON string: {cleaned[:200]}")
elif isinstance(raw, dict):
angles_result = raw
else:
# Fail gracefully - no fallback data
error_msg = angles_result.get('error', 'Unknown error') if isinstance(angles_result, dict) else str(angles_result)
logger.error(f"AI content angles generation failed: {error_msg}")
raise ValueError(f"Content angles generation failed: {error_msg}")
raise ValueError(f"Unexpected LLM response type: {type(raw)}")
if 'error' in angles_result:
raise ValueError(f"Content angles generation failed: {angles_result.get('error', 'Unknown error')}")
if 'content_angles' not in angles_result:
raise ValueError(f"Content angles missing from response")
logger.info("✅ AI content angles generation completed successfully")
return angles_result['content_angles'][:7]

View File

@@ -6,6 +6,7 @@ Extracts and analyzes keywords from research content using structured AI respons
from typing import Dict, Any, List
from loguru import logger
import json
class KeywordAnalyzer:
@@ -62,18 +63,38 @@ class KeywordAnalyzer:
"required": ["primary", "secondary", "long_tail", "search_intent", "difficulty", "content_gaps", "semantic_keywords", "trending_terms", "analysis_insights"]
}
keyword_analysis = llm_text_gen(
raw = llm_text_gen(
prompt=keyword_prompt,
json_struct=keyword_schema,
user_id=user_id
)
if isinstance(keyword_analysis, dict) and 'error' not in keyword_analysis:
logger.info("✅ AI keyword analysis completed successfully")
return keyword_analysis
# Parse JSON from LLM response (works with both string and dict return types)
import re
if isinstance(raw, str):
cleaned = raw.strip()
if cleaned.startswith('```json'):
cleaned = cleaned[7:]
if cleaned.startswith('```'):
cleaned = cleaned[3:]
if cleaned.endswith('```'):
cleaned = cleaned[:-3]
cleaned = cleaned.strip()
try:
keyword_analysis = json.loads(cleaned)
except json.JSONDecodeError:
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
if json_match:
keyword_analysis = json.loads(json_match.group(0))
else:
raise ValueError(f"Keyword analysis returned non-JSON string: {cleaned[:200]}")
elif isinstance(raw, dict):
keyword_analysis = raw
else:
# Fail gracefully - no fallback data
error_msg = keyword_analysis.get('error', 'Unknown error') if isinstance(keyword_analysis, dict) else str(keyword_analysis)
logger.error(f"AI keyword analysis failed: {error_msg}")
raise ValueError(f"Keyword analysis failed: {error_msg}")
raise ValueError(f"Unexpected LLM response type: {type(raw)}")
if 'error' in keyword_analysis:
raise ValueError(f"Keyword analysis failed: {keyword_analysis.get('error', 'Unknown error')}")
logger.info("✅ AI keyword analysis completed successfully")
return keyword_analysis

View File

@@ -111,19 +111,22 @@ class ResearchService:
# Exa research workflow
from .exa_provider import ExaResearchProvider
from services.subscription.preflight_validator import validate_exa_research_operations
from services.database import get_db
from services.database import get_session_for_user
from services.subscription import PricingService
import os
import time
# Pre-flight validation
db_val = next(get_db())
# Pre-flight validation (use get_session_for_user since get_db is a FastAPI dependency)
db_val = get_session_for_user(user_id)
if not db_val:
raise HTTPException(status_code=503, detail="Database temporarily unavailable. Please try again.")
try:
pricing_service = PricingService(db_val)
gpt_provider = os.getenv("GPT_PROVIDER", "google")
validate_exa_research_operations(pricing_service, user_id, gpt_provider)
finally:
db_val.close()
if db_val:
db_val.close()
# Execute Exa search
api_start_time = time.time()
@@ -162,13 +165,15 @@ class ResearchService:
elif config.provider == ResearchProvider.TAVILY:
# Tavily research workflow
from .tavily_provider import TavilyResearchProvider
from services.database import get_db
from services.database import get_session_for_user
from services.subscription import PricingService
import os
import time
# Pre-flight validation (similar to Exa)
db_val = next(get_db())
# Pre-flight validation (use get_session_for_user since get_db is a FastAPI dependency)
db_val = get_session_for_user(user_id)
if not db_val:
raise HTTPException(status_code=503, detail="Database temporarily unavailable. Please try again.")
try:
pricing_service = PricingService(db_val)
# Check Tavily usage limits
@@ -429,14 +434,16 @@ class ResearchService:
# Exa research workflow
from .exa_provider import ExaResearchProvider
from services.subscription.preflight_validator import validate_exa_research_operations
from services.database import get_db
from services.database import get_session_for_user
from services.subscription import PricingService
import os
await task_manager.update_progress(task_id, "🌐 Connecting to Exa neural search...")
# Pre-flight validation
db_val = next(get_db())
# Pre-flight validation (use get_session_for_user since get_db is a FastAPI dependency)
db_val = get_session_for_user(user_id)
if not db_val:
raise HTTPException(status_code=503, detail="Database temporarily unavailable. Please try again.")
try:
pricing_service = PricingService(db_val)
gpt_provider = os.getenv("GPT_PROVIDER", "google")
@@ -446,7 +453,8 @@ class ResearchService:
await task_manager.update_progress(task_id, f"❌ Subscription limit exceeded: {http_error.detail.get('message', str(http_error.detail)) if isinstance(http_error.detail, dict) else str(http_error.detail)}")
raise
finally:
db_val.close()
if db_val:
db_val.close()
# Execute Exa search
await task_manager.update_progress(task_id, "🤖 Executing Exa neural search...")
@@ -485,14 +493,16 @@ class ResearchService:
elif config.provider == ResearchProvider.TAVILY:
# Tavily research workflow
from .tavily_provider import TavilyResearchProvider
from services.database import get_db
from services.database import get_session_for_user
from services.subscription import PricingService
import os
await task_manager.update_progress(task_id, "🌐 Connecting to Tavily AI search...")
# Pre-flight validation
db_val = next(get_db())
# Pre-flight validation (use get_session_for_user since get_db is a FastAPI dependency)
db_val = get_session_for_user(user_id)
if not db_val:
raise HTTPException(status_code=503, detail="Database temporarily unavailable. Please try again.")
try:
pricing_service = PricingService(db_val)
# Check Tavily usage limits
@@ -529,7 +539,8 @@ class ResearchService:
except Exception as e:
logger.warning(f"Error checking Tavily limits: {e}")
finally:
db_val.close()
if db_val:
db_val.close()
# Execute Tavily search
await task_manager.update_progress(task_id, "🤖 Executing Tavily AI search...")

View File

@@ -135,11 +135,14 @@ class TavilyResearchProvider(BaseProvider):
def track_tavily_usage(self, user_id: str, cost: float, search_depth: str):
"""Track Tavily API usage after successful call."""
from services.database import get_db
from services.database import get_session_for_user
from services.subscription import PricingService
from sqlalchemy import text
db = next(get_db())
db = get_session_for_user(user_id)
if not db:
logger.warning(f"[Tavily] Could not get DB session for user {user_id}, skipping usage tracking")
return
try:
pricing_service = PricingService(db)
current_period = pricing_service.get_current_billing_period(user_id)

View File

@@ -92,6 +92,7 @@ class BlogSEORecommendationApplier:
None,
schema,
user_id, # Pass user_id for subscription checking
max_tokens=8192,
)
if not result or result.get("error"):

View File

@@ -233,7 +233,7 @@ def create_blog_post(
# BACK TO BASICS MODE: Try simplest possible structure FIRST
# Since posting worked before Ricos/SEO, let's test with absolute minimum
BACK_TO_BASICS_MODE = True # Set to True to test with simplest structure
BACK_TO_BASICS_MODE = False # Disabled: full Ricos conversion now produces valid output
wix_logger.reset()
wix_logger.log_operation_start("Blog Post Creation", title=title[:50] if title else None, member_id=member_id[:20] if member_id else None)
@@ -257,8 +257,7 @@ def create_blog_post(
'text': (content[:500] if content else "This is a post from ALwrity.").strip(),
'decorations': []
}
}],
'paragraphData': {}
}]
}]
}

View File

@@ -256,17 +256,16 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
quote_content = ' '.join(quote_lines)
text_nodes = parse_markdown_inline(quote_content)
# CRITICAL: TEXT nodes must be wrapped in PARAGRAPH nodes within BLOCKQUOTE
# Wix API: omit empty data objects, don't include them as {}
paragraph_node = {
'id': str(uuid.uuid4()),
'type': 'PARAGRAPH',
'nodes': text_nodes,
'paragraphData': {}
}
blockquote_node = {
'id': node_id,
'type': 'BLOCKQUOTE',
'nodes': [paragraph_node],
'blockquoteData': {}
}
nodes.append(blockquote_node)
@@ -332,7 +331,6 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
'id': str(uuid.uuid4()),
'type': 'PARAGRAPH',
'nodes': text_nodes,
'paragraphData': {}
}
list_item_node = {
'id': item_node_id,
@@ -345,7 +343,6 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
'id': node_id,
'type': 'BULLETED_LIST',
'nodes': list_node_items,
'bulletedListData': {}
}
nodes.append(bulleted_list_node)
@@ -373,7 +370,6 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
'id': str(uuid.uuid4()),
'type': 'PARAGRAPH',
'nodes': text_nodes,
'paragraphData': {}
}
list_item_node = {
'id': item_node_id,
@@ -386,7 +382,6 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
'id': node_id,
'type': 'ORDERED_LIST',
'nodes': list_node_items,
'orderedListData': {}
}
nodes.append(ordered_list_node)
@@ -442,7 +437,6 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
'id': node_id,
'type': 'PARAGRAPH',
'nodes': text_nodes,
'paragraphData': {}
}
nodes.append(paragraph_node)
@@ -461,7 +455,6 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
'decorations': []
}
}],
'paragraphData': {}
}
nodes.append(fallback_paragraph)

View File

@@ -20,13 +20,14 @@ class SemanticHarvesterService:
"last_harvest_time": None
}
async def harvest_website(self, website_url: str, limit: int = 100) -> List[Dict[str, Any]]:
async def harvest_website(self, website_url: str, limit: int = 100, user_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""
Deep crawl a website using Exa AI.
Args:
website_url: The root URL to crawl.
limit: Maximum number of pages to retrieve.
user_id: Optional user ID for usage tracking and preflight checks.
Returns:
List of pages with content and metadata.
@@ -59,6 +60,30 @@ class SemanticHarvesterService:
logger.warning("[SemanticHarvester] Exa service disabled. Returning placeholder data.")
return self._get_placeholder_data(website_url)
# Preflight subscription check if user_id provided
if user_id:
try:
from services.database import get_session_for_user
from services.subscription import PricingService
from models.subscription_models import APIProvider
db = get_session_for_user(user_id)
if db:
try:
pricing_service = PricingService(db)
can_proceed, message, usage_info = pricing_service.check_usage_limits(
user_id=user_id,
provider=APIProvider.EXA,
tokens_requested=0,
actual_provider_name="exa",
)
if not can_proceed:
logger.warning(f"[SemanticHarvester] Exa blocked for user {user_id}: {message}")
return []
finally:
db.close()
except Exception as e:
logger.warning(f"[SemanticHarvester] Preflight check failed: {e}")
# Use Exa to search for all pages in this domain
search_response = self.exa_service.exa.search_and_contents(
query=f"site:{website_url}",
@@ -82,6 +107,38 @@ class SemanticHarvesterService:
})
logger.info(f"[SemanticHarvester] Successfully harvested {len(results)} pages from {website_url}")
# Track Exa usage if user_id provided
if user_id and results:
try:
from services.database import get_session_for_user
from services.subscription import PricingService
from sqlalchemy import text
db = get_session_for_user(user_id)
if db:
try:
pricing_service = PricingService(db)
current_period = pricing_service.get_current_billing_period(user_id)
cost = 0.005 # Exa search cost estimate
update_query = text("""
UPDATE usage_summaries
SET exa_calls = COALESCE(exa_calls, 0) + 1,
exa_cost = COALESCE(exa_cost, 0) + :cost,
total_calls = COALESCE(total_calls, 0) + 1,
total_cost = COALESCE(total_cost, 0) + :cost
WHERE user_id = :user_id AND billing_period = :period
""")
db.execute(update_query, {
'cost': cost, 'user_id': user_id, 'period': current_period,
})
db.commit()
logger.info(f"[SemanticHarvester] Tracked Exa usage: user={user_id}, cost=${cost}")
finally:
db.close()
except Exception as track_err:
logger.warning(f"[SemanticHarvester] Failed to track Exa usage: {track_err}")
return results
except Exception as e:

View File

@@ -133,9 +133,9 @@ def edit_image(
raise
except Exception as e:
logger.error(f"[Image Editing] ❌ Unexpected error during pre-flight validation: {e}")
# In podcast-only mode, allow the operation to continue on validation errors
if os.getenv("ALWRITY_ENABLED_FEATURES") == "podcast":
logger.warning(f"[Image Editing] ⚠️ Validation error in podcast mode - allowing operation to continue")
# In feature-limited mode, allow the operation to continue on validation errors
if os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() not in ("", "all"):
logger.warning(f"[Image Editing] ⚠️ Validation error in feature-limited mode - allowing operation to continue")
else:
raise HTTPException(status_code=500, detail=f"Image editing validation failed: {str(e)}")
finally:

View File

@@ -45,6 +45,7 @@ def llm_text_gen(
preferred_hf_models: Optional[List[str]] = None,
preferred_provider: Optional[str] = None,
flow_type: Optional[str] = None,
max_tokens: Optional[int] = None,
) -> str:
"""
Generate text using Language Model (LLM) based on the provided prompt.
@@ -75,7 +76,8 @@ def llm_text_gen(
gpt_provider = "google" # Default to Google Gemini
model = "gemini-2.0-flash-001"
temperature = 0.7
max_tokens = 4000
if max_tokens is None:
max_tokens = 4000
top_p = 0.9
n = 1
fp = 16
@@ -371,16 +373,27 @@ def llm_text_gen(
system_prompt=system_instructions
)
elif gpt_provider == "wavespeed":
from services.llm_providers.wavespeed_provider import wavespeed_text_response
llm_start = time.time()
response_text = wavespeed_text_response(
prompt=prompt,
model=model or "openai/gpt-oss-120b",
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
system_prompt=system_instructions
)
if json_struct:
from services.llm_providers.wavespeed_provider import wavespeed_structured_json_response
response_text = wavespeed_structured_json_response(
prompt=prompt,
schema=json_struct,
model=model or "openai/gpt-oss-120b",
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_instructions
)
else:
from services.llm_providers.wavespeed_provider import wavespeed_text_response
response_text = wavespeed_text_response(
prompt=prompt,
model=model or "openai/gpt-oss-120b",
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
system_prompt=system_instructions
)
llm_ms = (time.time() - llm_start) * 1000
logger.warning(f"[llm_text_gen][{flow_tag}] LLM API call took {llm_ms:.0f}ms for user {user_id} (wavespeed)")
else:

View File

@@ -179,6 +179,43 @@ def get_wavespeed_api_key() -> str:
return api_key
def _retry_with_increased_tokens(
client: "OpenAI",
messages: List[Dict[str, str]],
model: str,
fallback_models: Optional[List[str]],
temperature: float,
max_tokens: int,
) -> Optional[str]:
"""Retry the API call with increased max_tokens when JSON parsing fails due to truncation."""
max_tokens = min(max_tokens, 16384)
last_error = None
for candidate_model in _fallback_model_sequence(model, fallback_models):
try:
response = client.chat.completions.create(
model=candidate_model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
text = response.choices[0].message.content
text = text.strip() if text else ""
if text.startswith("```json"):
text = text[7:]
if text.startswith("```"):
text = text[3:]
if text.endswith("```"):
text = text[:-3]
return text.strip()
except NotFoundError as nf_err:
last_error = nf_err
continue
if last_error:
logger.error(f"All fallback models failed on retry with increased tokens: {last_error}")
return None
@retry(
retry=retry_if_exception(_should_retry_wavespeed_error),
wait=wait_random_exponential(min=1, max=60),
@@ -446,24 +483,69 @@ def wavespeed_structured_json_response(
raise last_error or Exception("WaveSpeed structured generation failed: all fallback models failed")
response_text = response.choices[0].message.content
response_text = response_text.strip() if response_text else ""
# If response_format returned empty content, retry without it
if not response_text:
logger.warning("WaveSpeed structured call returned empty content with response_format, retrying without it...")
response = None
last_error = None
for candidate_model in _fallback_model_sequence(model, fallback_models):
try:
response = client.chat.completions.create(
model=candidate_model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens
)
break
except NotFoundError as nf_err:
last_error = nf_err
continue
if response is not None:
response_text = response.choices[0].message.content
response_text = response_text.strip() if response_text else ""
# Clean up response text if needed
response_text = response_text.strip()
if response_text.startswith("```json"):
response_text = response_text[7:]
if response_text.startswith("```"):
response_text = response_text[3:]
if response_text.endswith("```"):
response_text = response_text[:-3]
response_text = response_text.strip()
try:
parsed_json = json.loads(response_text)
logger.info("✅ WaveSpeed structured JSON response parsed successfully")
return parsed_json
parsed_json = json.loads(response_text) if response_text else None
if parsed_json is not None:
logger.info("✅ WaveSpeed structured JSON response parsed successfully")
return parsed_json
except json.JSONDecodeError as json_err:
logger.error(f"❌ JSON parsing failed: {json_err}")
logger.error(f"Raw response: {response_text}")
# Retry once with increased max_tokens — likely a truncation issue
if max_tokens < 16384:
logger.warning(f"Retrying with increased max_tokens ({max_tokens}{max_tokens * 2}) due to JSON parse failure")
response_text = _retry_with_increased_tokens(
client=client,
messages=messages,
model=model,
fallback_models=fallback_models,
temperature=temperature,
max_tokens=max_tokens * 2,
)
if response_text:
try:
parsed_json = json.loads(response_text)
if parsed_json is not None:
logger.info("✅ WaveSpeed structured JSON parsed successfully after max_tokens increase")
return parsed_json
except json.JSONDecodeError:
logger.error("❌ JSON parsing failed even after max_tokens increase")
# Try to extract JSON from the response using regex
logger.error(f"Raw response: {response_text}")
# Try to extract JSON from the response using regex
if response_text:
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
if json_match:
try:
@@ -472,8 +554,8 @@ def wavespeed_structured_json_response(
return extracted_json
except json.JSONDecodeError:
pass
return {"error": "Failed to parse JSON response", "raw_response": response_text}
return {"error": "Failed to parse JSON response", "raw_response": response_text}
except Exception as e:
logger.error(f"❌ WaveSpeed API call failed: {e}")
@@ -501,14 +583,24 @@ def wavespeed_structured_json_response(
if response is None:
raise last_error or e
response_text = response.choices[0].message.content
# ... (same parsing logic would apply, simplified here for brevity)
response_text = response_text.strip() if response_text else ""
# Parse JSON with robust cleaning
if response_text.startswith("```json"):
response_text = response_text[7:]
if response_text.startswith("```"):
response_text = response_text[3:]
if response_text.endswith("```"):
response_text = response_text[:-3]
response_text = response_text.strip()
try:
return json.loads(response_text)
except:
# Regex fallback
return json.loads(response_text) if response_text else {"error": "Empty response"}
except json.JSONDecodeError:
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
if json_match:
return json.loads(json_match.group())
try:
return json.loads(json_match.group())
except json.JSONDecodeError:
pass
return {"error": "Failed to parse JSON response", "raw_response": response_text}
raise e

View File

@@ -19,11 +19,11 @@ from services.database import get_db_session
from models.onboarding import OnboardingSession, WebsiteAnalysis, ResearchPreferences
from models.persona_models import WritingPersona, PlatformPersona, PersonaAnalysisResult
def _get_podcast_mode():
"""Check if running in podcast-only mode to skip heavy initialization."""
def _is_feature_limited_mode():
"""Check if running in feature-limited mode to skip heavy initialization."""
import os
env_val = os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower()
return env_val == "podcast"
return env_val not in ("", "all")
class PersonaAnalysisService:
"""Service for analyzing onboarding data and generating writing personas using Gemini AI."""
@@ -40,9 +40,9 @@ class PersonaAnalysisService:
def __init__(self):
"""Initialize the persona analysis service (only once)."""
if not self._initialized:
# Skip heavy initialization in podcast-only mode
if _get_podcast_mode():
logger.debug("PersonaAnalysisService: Skipping heavy init in podcast mode")
# Skip heavy initialization in feature-limited mode
if _is_feature_limited_mode():
logger.debug(f"PersonaAnalysisService: Skipping heavy init in feature-limited mode")
self._initialized = True
return
@@ -55,8 +55,8 @@ class PersonaAnalysisService:
return
# Check again in case mode changed
if _get_podcast_mode():
logger.debug("PersonaAnalysisService: Skipping heavy init in podcast mode")
if _is_feature_limited_mode():
logger.debug("PersonaAnalysisService: Skipping heavy init in feature-limited mode")
self._heavy_init_done = True
return
@@ -89,9 +89,9 @@ class PersonaAnalysisService:
# Ensure heavy services are initialized
self._ensure_heavy_init()
# Check if heavy init failed (podcast mode)
# Check if heavy init failed (feature-limited mode)
if not getattr(self, '_heavy_init_done', False):
return {"error": "Persona service unavailable in podcast-only mode"}
return {"error": "Persona service unavailable in feature-limited mode"}
try:
logger.info(f"Generating persona for user {user_id}")

View File

@@ -296,6 +296,33 @@ class ResearchEngine:
target_audience = request.target_audience or "General"
research_prompt = strategy.build_research_prompt(topic, industry, target_audience, config)
# Preflight subscription check
try:
db = self._db_session
if not db:
from services.database import get_db_session
db = get_db_session()
if db:
from services.subscription import PricingService
from models.subscription_models import APIProvider
pricing_service = PricingService(db)
can_proceed, message, usage_info = pricing_service.check_usage_limits(
user_id=user_id,
provider=APIProvider.EXA,
tokens_requested=0,
actual_provider_name="exa",
)
if not can_proceed:
raise HTTPException(status_code=429, detail={
'error': message, 'message': message,
'provider': 'exa', 'usage_info': usage_info or {}
})
logger.info(f"[ResearchEngine] Exa preflight check passed for user {user_id}")
except HTTPException:
raise
except Exception as e:
logger.warning(f"[ResearchEngine] Exa preflight check failed: {e}")
# Execute Exa search
try:
@@ -341,6 +368,33 @@ class ResearchEngine:
target_audience = request.target_audience or "General"
research_prompt = strategy.build_research_prompt(topic, industry, target_audience, config)
# Preflight subscription check
try:
db = self._db_session
if not db:
from services.database import get_db_session
db = get_db_session()
if db:
from services.subscription import PricingService
from models.subscription_models import APIProvider
pricing_service = PricingService(db)
can_proceed, message, usage_info = pricing_service.check_usage_limits(
user_id=user_id,
provider=APIProvider.TAVILY,
tokens_requested=0,
actual_provider_name="tavily",
)
if not can_proceed:
raise HTTPException(status_code=429, detail={
'error': message, 'message': message,
'provider': 'tavily', 'usage_info': usage_info or {}
})
logger.info(f"[ResearchEngine] Tavily preflight check passed for user {user_id}")
except HTTPException:
raise
except Exception as e:
logger.warning(f"[ResearchEngine] Tavily preflight check failed: {e}")
# Execute Tavily search
try:

View File

@@ -83,6 +83,30 @@ class DeepCrawlService:
tavily_results.append(res)
logger.info(f"Found {len(tavily_urls)} URLs from Tavily")
# Track Tavily usage
try:
from services.subscription import PricingService
from sqlalchemy import text
pricing_service = PricingService(db)
current_period = pricing_service.get_current_billing_period(user_id)
cost = 0.005 # Tavily crawl cost estimate
update_query = text("""
UPDATE usage_summaries
SET tavily_calls = COALESCE(tavily_calls, 0) + 1,
tavily_cost = COALESCE(tavily_cost, 0) + :cost,
total_calls = COALESCE(total_calls, 0) + 1,
total_cost = COALESCE(total_cost, 0) + :cost
WHERE user_id = :user_id AND billing_period = :period
""")
db.execute(update_query, {
'cost': cost, 'user_id': user_id, 'period': current_period,
})
db.commit()
logger.info(f"[DeepCrawl] Tracked Tavily crawl usage: user={user_id}, cost=${cost}")
except Exception as track_err:
logger.warning(f"[DeepCrawl] Failed to track Tavily usage: {track_err}")
except Exception as e:
logger.warning(f"Tavily crawl failed: {e}")

View File

@@ -49,9 +49,11 @@ except Exception as _patch_err:
# Now safe to import pytrends
try:
from pytrends.request import TrendReq as _TrendReq
from pytrends.exceptions import TooManyRequestsError as _TooManyRequestsError
PYTrends_AVAILABLE = True
except ImportError:
PYTrends_AVAILABLE = False
_TooManyRequestsError = None
logger.warning("pytrends not installed. Google Trends features will be unavailable.")
# Patch 2: pytrends related_topics() and related_queries() use keyword[0]
@@ -139,6 +141,8 @@ class GoogleTrendsService:
Uses TrendReq with no retries (fail-fast) to avoid hitting CAPTCHA on blocks.
429 retry handling (1s, 2s, 4s backoff). Random user-agent is set
per instance to reduce fingerprinting.
Rate limiter is shared across all instances to enforce global rate limiting.
"""
USER_AGENTS = [
@@ -150,15 +154,28 @@ class GoogleTrendsService:
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.0.0",
]
# Class-level shared resources (shared across all instances)
_shared_rate_limiter = None
_shared_cache = None
_cache_ttl = timedelta(hours=24)
_last_429_time = 0 # Timestamp of last 429 error (Unix epoch)
_429_cooldown_period = 1800 # 30 minutes cooldown after 429
def __init__(self):
if not PYTrends_AVAILABLE:
raise RuntimeError("pytrends library is required. Install with: pip install pytrends")
self.rate_limiter = RateLimiter(max_calls=1, period=1.0)
self.cache: Dict[str, Any] = {}
self.cache_ttl = timedelta(hours=24)
# Initialize shared rate limiter at class level (lazy init)
if self.__class__._shared_rate_limiter is None:
self.__class__._shared_rate_limiter = RateLimiter(max_calls=1, period=3.0) # 1 call per 3 seconds
if self.__class__._shared_cache is None:
self.__class__._shared_cache = {}
logger.info("GoogleTrendsService initialized (pytrends 4.9.2, fail-fast, 2s delays)")
self.rate_limiter = self.__class__._shared_rate_limiter
self.cache = self.__class__._shared_cache
self.cache_ttl = self._cache_ttl
logger.info("GoogleTrendsService initialized (pytrends 4.9.2, shared rate limiter, 3s period, shared cache, 30min 429 cooldown)")
# -----------------------------------------------------------------------
# Public API
@@ -173,7 +190,7 @@ class GoogleTrendsService:
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""
Comprehensive trends analysis.
Comprehensive trends analysis with retry logic for 429 errors.
Args:
keywords: List of keywords to analyze (1-5)
@@ -193,11 +210,97 @@ class GoogleTrendsService:
keywords = keywords[:5]
cache_key = self._build_cache_key(keywords, timeframe, geo)
# Check if we're in a 429 cooldown period
now = time.time()
if now - self.__class__._last_429_time < self.__class__._429_cooldown_period:
remaining_cooldown = int(self.__class__._429_cooldown_period - (now - self.__class__._last_429_time))
logger.warning(
f"[Trends] In 429 cooldown period. {remaining_cooldown}s remaining. "
f"Returning cached data if available."
)
cached_data = self._get_from_cache(cache_key, ignore_ttl=True) # Use stale cache
if cached_data:
logger.info(f"[Trends] Returning stale cached data for {keywords} during cooldown")
return {**cached_data, "cached": True, "cooldown_active": True}
return self._create_fallback_response(
keywords, timeframe, geo, gprop,
f"Rate limited by Google. Cooldown active for {remaining_cooldown}s. Try again later."
)
# Check fresh cache
cached_data = self._get_from_cache(cache_key)
if cached_data:
logger.info(f"Returning cached trends data for: {keywords}")
return {**cached_data, "cached": True}
# Retry logic for 429 errors
max_retries = 3
retry_delays = [30, 60, 120] # Longer delays: 30s, 60s, 120s
for attempt in range(max_retries + 1):
try:
return await self._do_analyze_trends(
keywords, timeframe, geo, gprop, cache_key, attempt, max_retries
)
except Exception as e:
# Check if this is a 429 error (pytrends raises TooManyRequestsError)
is_429 = False
if _TooManyRequestsError and isinstance(e, _TooManyRequestsError):
is_429 = True
else:
error_str = str(e).lower()
is_429 = "429" in error_str or "rate limit" in error_str or "too many requests" in error_str
if is_429:
# Update the last 429 time for cooldown
self.__class__._last_429_time = time.time()
if attempt < max_retries:
delay = retry_delays[attempt]
logger.warning(
f"[Trends] 429 rate limit hit (attempt {attempt + 1}/{max_retries + 1}), "
f"retrying in {delay}s..."
)
await asyncio.sleep(delay)
continue
else:
# Out of retries - enter cooldown
logger.error(
f"[Trends] 429 rate limit persisted after {max_retries + 1} attempts. "
f"Entering {self.__class__._429_cooldown_period}s cooldown period."
)
# Try to return stale cache
stale_cache = self._get_from_cache(cache_key, ignore_ttl=True)
if stale_cache:
logger.info(f"[Trends] Returning stale cache after 429 exhaustion for {keywords}")
result = {**stale_cache}
result["cached"] = True
result["cooldown_active"] = True
return result
return self._create_fallback_response(
keywords, timeframe, geo, gprop,
f"Google is rate limiting requests. Cooldown active for {self.__class__._429_cooldown_period}s. Try again later."
)
else:
# Non-429 error
logger.error(f"Google Trends analysis failed after {attempt + 1} attempts: {e}")
return self._create_fallback_response(keywords, timeframe, geo, gprop, str(e))
# Should not reach here, but just in case
return self._create_fallback_response(keywords, timeframe, geo, gprop, "Max retries exceeded")
async def _do_analyze_trends(
self,
keywords: List[str],
timeframe: str,
geo: str,
gprop: str,
cache_key: str,
attempt: int,
max_retries: int,
) -> Dict[str, Any]:
"""Internal method to perform the actual trends analysis."""
await self.rate_limiter.acquire()
total_start = time.monotonic()
@@ -207,95 +310,63 @@ class GoogleTrendsService:
related_topics: Dict[str, List[Dict[str, Any]]] = {"top": [], "rising": []}
related_queries: Dict[str, List[Dict[str, Any]]] = {"top": [], "rising": []}
try:
logger.info(f"[Trends] ===== START analyze_trends ===== keywords={keywords} timeframe={timeframe} geo={geo}")
logger.info(
f"[Trends] ===== START analyze_trends (attempt {attempt + 1}/{max_retries + 1}) ===== "
f"keywords={keywords} timeframe={timeframe} geo={geo}"
)
# Initialize TrendReq with gprop (youtube for video/podcast relevance)
init_start = time.monotonic()
pytrends = await asyncio.to_thread(
self._create_pytrends,
keywords,
timeframe,
geo,
gprop,
)
init_ms = int((time.monotonic() - init_start) * 1000)
logger.info(f"[Trends] TrendReq init + build_payload took {init_ms}ms")
# Initialize TrendReq with gprop (youtube for video/podcast relevance)
init_start = time.monotonic()
pytrends = await asyncio.to_thread(
self._create_pytrends,
keywords,
timeframe,
geo,
gprop,
)
init_ms = int((time.monotonic() - init_start) * 1000)
logger.info(f"[Trends] TrendReq init + build_payload took {init_ms}ms")
# --- Interest Over Time ---
iot_start = time.monotonic()
interest_over_time = await asyncio.to_thread(
lambda: self._fetch_interest_over_time(pytrends)
)
iot_ms = int((time.monotonic() - iot_start) * 1000)
logger.info(f"[Trends] interest_over_time took {iot_ms}ms, returned {len(interest_over_time)} points")
# --- Interest Over Time ONLY (skip others to avoid 429) ---
await self.rate_limiter.acquire() # Rate limit check BEFORE each request
iot_start = time.monotonic()
interest_over_time = await asyncio.to_thread(
lambda: self._fetch_interest_over_time(pytrends)
)
iot_ms = int((time.monotonic() - iot_start) * 1000)
logger.info(f"[Trends] interest_over_time took {iot_ms}ms, returned {len(interest_over_time)} points")
await asyncio.sleep(2)
# Skip other requests to avoid 429 - only fetch interest_over_time for now
logger.info(f"[Trends] Skipping other requests to avoid 429 (interest_by_region, related_topics, related_queries)")
# --- Interest By Region ---
ibr_start = time.monotonic()
interest_by_region = await asyncio.to_thread(
lambda: self._fetch_interest_by_region(pytrends)
)
ibr_ms = int((time.monotonic() - ibr_start) * 1000)
logger.info(f"[Trends] interest_by_region took {ibr_ms}ms, returned {len(interest_by_region)} regions")
total_ms = int((time.monotonic() - total_start) * 1000)
logger.info(
f"[Trends] ===== DONE analyze_trends ===== total={total_ms}ms "
f"iot={len(interest_over_time)} ibr={len(interest_by_region)} "
f"rt_top={rt_top} rq_top={rq_top}"
)
await asyncio.sleep(2)
result = {
"interest_over_time": interest_over_time,
"interest_by_region": interest_by_region,
"related_topics": related_topics,
"related_queries": related_queries,
"timeframe": timeframe,
"geo": geo,
"keywords": keywords,
"source": "web" if gprop == "" else "podcast" if gprop == "youtube" else gprop,
"timestamp": datetime.utcnow().isoformat(),
"cached": False,
}
# --- Related Topics ---
rt_start = time.monotonic()
related_topics = await asyncio.to_thread(
lambda: self._fetch_related_topics(pytrends)
)
rt_ms = int((time.monotonic() - rt_start) * 1000)
rt_top = len(related_topics.get("top", []))
rt_rising = len(related_topics.get("rising", []))
logger.info(f"[Trends] related_topics took {rt_ms}ms, top={rt_top} rising={rt_rising}")
self._save_to_cache(cache_key, result)
await asyncio.sleep(2)
logger.info(
f"Google Trends data fetched successfully: "
f"{len(interest_over_time)} time points, {len(interest_by_region)} regions"
)
# --- Related Queries ---
rq_start = time.monotonic()
related_queries = await asyncio.to_thread(
lambda: self._fetch_related_queries(pytrends)
)
rq_ms = int((time.monotonic() - rq_start) * 1000)
rq_top = len(related_queries.get("top", []))
rq_rising = len(related_queries.get("rising", []))
logger.info(f"[Trends] related_queries took {rq_ms}ms, top={rq_top} rising={rq_rising}")
total_ms = int((time.monotonic() - total_start) * 1000)
logger.info(
f"[Trends] ===== DONE analyze_trends ===== total={total_ms}ms "
f"iot={len(interest_over_time)} ibr={len(interest_by_region)} "
f"rt_top={rt_top} rq_top={rq_top}"
)
result = {
"interest_over_time": interest_over_time,
"interest_by_region": interest_by_region,
"related_topics": related_topics,
"related_queries": related_queries,
"timeframe": timeframe,
"geo": geo,
"keywords": keywords,
"source": "web" if gprop == "" else "podcast" if gprop == "youtube" else gprop,
"timestamp": datetime.utcnow().isoformat(),
"cached": False,
}
self._save_to_cache(cache_key, result)
logger.info(
f"Google Trends data fetched successfully: "
f"{len(interest_over_time)} time points, {len(interest_by_region)} regions"
)
return result
except Exception as e:
logger.error(f"Google Trends analysis failed: {e}")
return self._create_fallback_response(keywords, timeframe, geo, gprop, str(e))
return result
# -----------------------------------------------------------------------
# TrendReq factory
@@ -346,6 +417,12 @@ class GoogleTrendsService:
return result
except Exception as e:
elapsed = int((time.monotonic() - start) * 1000)
# Re-raise 429 errors so retry logic can handle them
if _TooManyRequestsError and isinstance(e, _TooManyRequestsError):
raise
error_str = str(e).lower()
if "429" in error_str or "rate limit" in error_str or "too many requests" in error_str:
raise
logger.error(f"[Trends] interest_over_time failed in {elapsed}ms: {e}")
return []
@@ -363,6 +440,12 @@ class GoogleTrendsService:
return result
except Exception as e:
elapsed = int((time.monotonic() - start) * 1000)
# Re-raise 429 errors so retry logic can handle them
if _TooManyRequestsError and isinstance(e, _TooManyRequestsError):
raise
error_str = str(e).lower()
if "429" in error_str or "rate limit" in error_str or "too many requests" in error_str:
raise
logger.error(f"[Trends] interest_by_region failed in {elapsed}ms: {e}")
return []
@@ -409,6 +492,12 @@ class GoogleTrendsService:
return result
except Exception as e:
elapsed = int((time.monotonic() - start) * 1000)
# Re-raise 429 errors so retry logic can handle them
if _TooManyRequestsError and isinstance(e, _TooManyRequestsError):
raise
error_str = str(e).lower()
if "429" in error_str or "rate limit" in error_str or "too many requests" in error_str:
raise
logger.error(f"[Trends] related_topics failed in {elapsed}ms: {e}")
return result
@@ -452,6 +541,12 @@ class GoogleTrendsService:
return result
except Exception as e:
elapsed = int((time.monotonic() - start) * 1000)
# Re-raise 429 errors so retry logic can handle them
if _TooManyRequestsError and isinstance(e, _TooManyRequestsError):
raise
error_str = str(e).lower()
if "429" in error_str or "rate limit" in error_str or "too many requests" in error_str:
raise
logger.error(f"[Trends] related_queries failed in {elapsed}ms: {e}")
return result
@@ -503,14 +598,18 @@ class GoogleTrendsService:
keywords_str = ":".join(sorted(keywords))
return f"google_trends:{keywords_str}:{timeframe}:{geo}"
def _get_from_cache(self, cache_key: str) -> Optional[Dict[str, Any]]:
def _get_from_cache(self, cache_key: str, ignore_ttl: bool = False) -> Optional[Dict[str, Any]]:
"""Get cached data. If ignore_ttl=True, return stale data too (for 429 cooldown)."""
if cache_key not in self.cache:
return None
cached_entry = self.cache[cache_key]
cached_time = datetime.fromisoformat(cached_entry.get("timestamp", ""))
if datetime.utcnow() - cached_time > self.cache_ttl:
del self.cache[cache_key]
return None
if not ignore_ttl:
cached_time = datetime.fromisoformat(cached_entry.get("timestamp", ""))
if datetime.utcnow() - cached_time > self.cache_ttl:
del self.cache[cache_key]
return None
result = {**cached_entry}
result.pop("cached", None)
return result

View File

@@ -157,10 +157,10 @@ def _check_production_api_key_loading(
_record_check(checks, "production_api_key_loading", True, "skipped in local deploy mode")
return
# Also skip in podcast-only mode (no production API keys needed)
# Skip when in feature-limited mode (no production API keys needed)
enabled_features = os.getenv("ALWRITY_ENABLED_FEATURES", "all").strip().lower()
if enabled_features == "podcast":
_record_check(checks, "production_api_key_loading", True, "skipped in podcast-only mode")
if enabled_features and enabled_features not in ("", "all"):
_record_check(checks, "production_api_key_loading", True, f"skipped in feature-limited mode: {enabled_features}")
return
test_tenant_id = os.getenv("ALWRITY_STARTUP_TEST_TENANT_ID", "").strip()

View File

@@ -12,7 +12,7 @@ from loguru import logger
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError
from models.subscription_models import APIProvider, UsageAlert
from models.subscription_models import APIProvider, UsageAlert, UserSubscription
class SubscriptionErrorType(Enum):
USAGE_LIMIT_EXCEEDED = "usage_limit_exceeded"
@@ -248,6 +248,18 @@ class SubscriptionExceptionHandler:
return
try:
# Get billing period from subscription, fallback to calendar month
billing_period = datetime.now().strftime("%Y-%m") # default
try:
subscription = self.db.query(UserSubscription).filter(
UserSubscription.user_id == error.user_id,
UserSubscription.is_active == True
).first()
if subscription and subscription.current_period_start:
billing_period = subscription.current_period_start.strftime("%Y-%m")
except:
pass # Use default calendar period
alert = UsageAlert(
user_id=error.user_id,
alert_type="system_error",
@@ -256,7 +268,7 @@ class SubscriptionExceptionHandler:
title=f"System Error: {error.error_type.value}",
message=error.message,
severity=error.severity.value,
billing_period=datetime.now().strftime("%Y-%m")
billing_period=billing_period
)
self.db.add(alert)

View File

@@ -157,39 +157,38 @@ class LimitValidator:
user_tier = limits.get('tier', 'free') if limits else 'free'
# Get current usage for this billing period with error handling
# Use targeted expiry instead of expire_all() to avoid nuking the entire session cache
# Use subscription period, not calendar month
current_period = self.pricing_service.get_current_billing_period(user_id)
# Only expire specific objects that might have changed after renewal
# (subscription was already checked above; plan was expired above)
# The usage record is the main object we need fresh, and we query it directly below
if subscription:
self.db.expire(subscription)
# Use raw SQL query first to bypass ORM cache, fallback to ORM if SQL fails
usage = None
try:
current_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Only expire specific objects that might have changed after renewal
# (subscription was already checked above; plan was expired above)
# The usage record is the main object we need fresh, and we query it directly below
if subscription:
self.db.expire(subscription)
# Use raw SQL query first to bypass ORM cache, fallback to ORM if SQL fails
usage = None
try:
from sqlalchemy import text
sql_query = text("SELECT * FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1")
result = self.db.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
if result:
# Map result to UsageSummary object
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if usage:
self.db.refresh(usage) # Ensure fresh data
except Exception as sql_error:
logger.debug(f"[Subscription Check] Raw SQL query failed, using ORM: {sql_error}")
# Fallback to ORM query
from sqlalchemy import text
sql_query = text("SELECT * FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1")
result = self.db.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
if result:
# Map result to UsageSummary object
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if usage:
self.db.refresh(usage) # Ensure fresh data
except Exception as sql_error:
logger.debug(f"[Subscription Check] Raw SQL query failed, using ORM: {sql_error}")
# Fallback to ORM query
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if usage:
self.db.refresh(usage) # Ensure fresh data
if not usage:
# First usage this period, create summary
@@ -448,7 +447,7 @@ class LimitValidator:
logger.info(f"[Pre-flight Check] 📋 Validating {len(operations)} operation(s) before making any API calls")
# Get current usage and limits once
current_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
current_period = self.pricing_service.get_current_billing_period(user_id)
logger.info(f"[Pre-flight Check] 📅 Billing Period: {current_period} (for user {user_id})")

View File

@@ -67,15 +67,56 @@ class PricingService:
self.db.rollback()
return True
def get_current_billing_period(self, user_id: str) -> Optional[str]:
"""Return current billing period key (YYYY-MM) after ensuring subscription is current."""
def get_current_billing_period(self, user_id: str) -> str:
"""Return current billing period key (YYYY-MM) based on subscription, not calendar.
Maintains backward compatibility with existing calendar-month data."""
subscription = self.db.query(UserSubscription).filter(
UserSubscription.user_id == user_id,
UserSubscription.is_active == True
).first()
# Ensure subscription is current (advance if auto_renew)
self._ensure_subscription_current(subscription)
# Continue to use YYYY-MM for summaries
# Use subscription's billing period, NOT calendar month
if subscription and subscription.current_period_start:
sub_period = subscription.current_period_start.strftime("%Y-%m")
# Check if usage data exists for this subscription period
from models.subscription_models import UsageSummary
usage_exists = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == sub_period
).first()
if usage_exists:
return sub_period
# If no data for subscription period, check for calendar month data
# This handles backward compatibility for existing users
calendar_period = datetime.now().strftime("%Y-%m")
if calendar_period != sub_period:
calendar_usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == calendar_period
).first()
if calendar_usage:
logger.info(f"Using calendar period {calendar_period} for backward compatibility (subscription period {sub_period} has no data)")
return calendar_period
return sub_period
# Fallback: Check if user has any usage summary and use that period
from models.subscription_models import UsageSummary
latest_summary = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id
).order_by(UsageSummary.billing_period.desc()).first()
if latest_summary:
logger.info(f"Using latest billing period from UsageSummary: {latest_summary.billing_period}")
return latest_summary.billing_period
# Last fallback to calendar month for free tier / no data
return datetime.now().strftime("%Y-%m")
@classmethod
@@ -830,6 +871,7 @@ class PricingService:
'serper_calls': plan.serper_calls_limit,
'metaphor_calls': plan.metaphor_calls_limit,
'firecrawl_calls': plan.firecrawl_calls_limit,
'exa_calls': getattr(plan, 'exa_calls_limit', 0), # Exa research API
'stability_calls': plan.stability_calls_limit,
'video_calls': getattr(plan, 'video_calls_limit', 0), # Support missing column
'image_edit_calls': getattr(plan, 'image_edit_calls_limit', 0), # Support missing column

View File

@@ -8,7 +8,7 @@ from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError
from models.subscription_models import UserSubscription, SubscriptionPlan, SubscriptionTier, BillingCycle, UsageStatus, FraudWarning, ProcessedStripeEvent
from services.subscription.pricing_service import PricingService
from datetime import datetime
from datetime import datetime, timedelta
REQUIRED_STRIPE_PLAN_KEYS = {
(SubscriptionTier.BASIC.value, BillingCycle.MONTHLY.value),
@@ -421,10 +421,6 @@ class StripeService:
try:
sub = stripe.Subscription.retrieve(subscription_id)
price_id = sub['items']['data'][0]['price']['id']
# Map price_id to internal plan_id
# Note: You need a way to map Stripe Price IDs to your Plan IDs.
# For now, we'll assume the metadata or a lookup.
# Ideally, store price_id in SubscriptionPlan table or config.
# Update DB
self._update_user_subscription(
@@ -434,6 +430,24 @@ class StripeService:
status="active",
price_id=price_id
)
# Clear PricingService cache so next status check returns updated limits
try:
from services.subscription import PricingService
PricingService.clear_user_cache(user_id)
except Exception as cache_err:
logger.warning(f"Failed to clear user cache after checkout for user {user_id}: {cache_err}")
try:
from api.subscription.cache import clear_dashboard_cache
clear_dashboard_cache(user_id)
logger.info(f"Cleared dashboard cache for user {user_id} after checkout")
except Exception as cache_err:
logger.warning(f"Failed to clear cache after checkout for user {user_id}: {cache_err}")
# Expire all SQLAlchemy objects to force fresh reads
self.db.expire_all()
logger.info(f"Expired all SQLAlchemy objects for user {user_id} after checkout")
except Exception as e:
logger.error(f"Error processing checkout subscription: {e}")
@@ -457,11 +471,28 @@ class StripeService:
logger.info(f"Payment succeeded for user {subscription.user_id}")
subscription.status = UsageStatus.ACTIVE
subscription.is_active = True
# Update period end based on invoice lines period
subscription.auto_renew = True
# Update period start/end based on invoice lines period
if invoice.get('lines'):
period_start = invoice['lines']['data'][0]['period']['start']
period_end = invoice['lines']['data'][0]['period']['end']
subscription.current_period_start = datetime.fromtimestamp(period_start)
subscription.current_period_end = datetime.fromtimestamp(period_end)
self.db.commit()
# Clear PricingService cache so next status check returns updated limits
try:
from services.subscription import PricingService
PricingService.clear_user_cache(subscription.user_id)
logger.info(f"Cleared subscription cache for user {subscription.user_id} after payment success")
except Exception as cache_err:
logger.warning(f"Failed to clear user cache after payment success for user {subscription.user_id}: {cache_err}")
try:
from api.subscription.cache import clear_dashboard_cache
clear_dashboard_cache(subscription.user_id)
except Exception as dash_cache_err:
logger.warning(f"Failed to clear dashboard cache after payment success for user {subscription.user_id}: {dash_cache_err}")
self.db.expire_all()
async def _handle_invoice_payment_failed(self, invoice: Dict[str, Any]):
subscription_id = invoice.get("subscription")
@@ -497,6 +528,12 @@ class StripeService:
if status in ["active", "trialing"]:
subscription.status = UsageStatus.ACTIVE
subscription.is_active = True
subscription.auto_renew = True
# Update period boundaries from Stripe event
current_period = subscription_obj.get("current_period", {})
if current_period:
subscription.current_period_start = datetime.fromtimestamp(current_period.get("start", 0))
subscription.current_period_end = datetime.fromtimestamp(current_period.get("end", 0))
elif status in ["past_due", "unpaid", "incomplete", "incomplete_expired"]:
subscription.status = UsageStatus.PAST_DUE
subscription.is_active = False
@@ -506,6 +543,20 @@ class StripeService:
subscription.auto_renew = False
self.db.commit()
# Clear PricingService cache so next status check returns updated limits
try:
from services.subscription import PricingService
PricingService.clear_user_cache(subscription.user_id)
logger.info(f"Cleared subscription cache for user {subscription.user_id} after subscription update")
except Exception as cache_err:
logger.warning(f"Failed to clear user cache after subscription update for user {subscription.user_id}: {cache_err}")
try:
from api.subscription.cache import clear_dashboard_cache
clear_dashboard_cache(subscription.user_id)
except Exception as dash_cache_err:
logger.warning(f"Failed to clear dashboard cache after subscription update for user {subscription.user_id}: {dash_cache_err}")
self.db.expire_all()
async def _handle_subscription_deleted(self, subscription_obj: Dict[str, Any]):
"""
@@ -610,6 +661,11 @@ class StripeService:
)
now = datetime.utcnow()
# Calculate billing period end based on cycle
if billing_cycle == BillingCycle.YEARLY:
period_end = now + timedelta(days=365)
else:
period_end = now + timedelta(days=30)
if not subscription:
subscription = UserSubscription(
@@ -617,7 +673,7 @@ class StripeService:
plan_id=plan.id,
billing_cycle=billing_cycle,
current_period_start=now,
current_period_end=now,
current_period_end=period_end,
status=UsageStatus.ACTIVE if status == "active" else UsageStatus.SUSPENDED,
is_active=status == "active",
auto_renew=True,
@@ -627,6 +683,11 @@ class StripeService:
subscription.plan_id = plan.id
subscription.billing_cycle = billing_cycle
subscription.is_active = status == "active"
subscription.status = UsageStatus.ACTIVE if status == "active" else UsageStatus.SUSPENDED
# Reset billing period on upgrade/plan change
subscription.current_period_start = now
subscription.current_period_end = period_end
subscription.auto_renew = True
subscription.stripe_customer_id = stripe_customer_id
subscription.stripe_subscription_id = stripe_subscription_id

View File

@@ -0,0 +1,21 @@
"""
Usage tracking modules package.
Split from the monolithic usage_tracking_service.py for better maintainability.
"""
from .historical_usage import get_all_historical_usage, get_current_period_usage, get_usage_for_period
from .usage_stats import get_user_usage_stats
from .usage_trends import get_usage_trends
from .limits_enforcement import enforce_usage_limits
from .alerts import check_usage_alerts, create_usage_alert
__all__ = [
'get_all_historical_usage',
'get_current_period_usage',
'get_usage_for_period',
'get_user_usage_stats',
'get_usage_trends',
'enforce_usage_limits',
'check_usage_alerts',
'create_usage_alert',
]

View File

@@ -0,0 +1,101 @@
"""
Usage alert functions.
Extracted from usage_tracking_service.py for better maintainability.
"""
from typing import Dict, Any
from sqlalchemy.orm import Session
from loguru import logger
from models.subscription_models import UsageAlert, UsageSummary, APIProvider, UsageStatus
def check_usage_alerts(user_id: str, provider: APIProvider,
billing_period: str, db: Session, pricing_service):
"""Check if usage alerts should be sent."""
# Get current usage
period_keys = {'billing_period': billing_period, 'lookup_periods': [billing_period]}
summary = db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
).first()
if not summary:
return
# Get user limits
limits = pricing_service.get_user_limits(user_id)
if not limits:
return
# Check for alert thresholds (80%, 90%, 100%)
thresholds = [80, 90, 100]
for threshold in thresholds:
# Check if alert already sent for this threshold
existing_alert = db.query(UsageAlert).filter(
UsageAlert.user_id == user_id,
UsageAlert.billing_period == billing_period,
UsageAlert.threshold_percentage == threshold,
UsageAlert.provider == provider,
UsageAlert.is_sent == True
).first()
if existing_alert:
continue
# Check if threshold is reached
provider_name = provider.value
current_calls = getattr(summary, f"{provider_name}_calls", 0)
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
if call_limit > 0:
usage_percentage = (current_calls / call_limit) * 100
if usage_percentage >= threshold:
create_usage_alert(
user_id=user_id,
provider=provider,
threshold=threshold,
current_usage=current_calls,
limit=call_limit,
billing_period=billing_period,
db=db
)
def create_usage_alert(user_id: str, provider: APIProvider,
threshold: int, current_usage: int, limit: int,
billing_period: str, db: Session):
"""Create a usage alert."""
# Determine alert type and severity
if threshold >= 100:
alert_type = "limit_reached"
severity = "error"
title = f"API Limit Reached - {provider.value.title()}"
message = f"You have reached your {provider.value} API limit of {limit:,} calls for this billing period."
elif threshold >= 90:
alert_type = "usage_warning"
severity = "warning"
title = f"API Usage Warning - {provider.value.title()}"
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
else:
alert_type = "usage_warning"
severity = "info"
title = f"API Usage Notice - {provider.value.title()}"
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
alert = UsageAlert(
user_id=user_id,
alert_type=alert_type,
threshold_percentage=threshold,
provider=provider,
title=title,
message=message,
severity=severity,
billing_period=billing_period
)
db.add(alert)
logger.info(f"Created usage alert for {user_id}: {title}")

View File

@@ -0,0 +1,250 @@
"""
Historical usage aggregation functions.
Extracted from usage_tracking_service.py for better maintainability.
"""
from typing import Dict, Any
from sqlalchemy.orm import Session
from loguru import logger
from datetime import datetime
from models.subscription_models import UsageSummary, UsageStatus
# Shared provider mapping: DB column → frontend key
PROVIDER_MAPPING = {
'gemini_calls': 'gemini',
'openai_calls': 'openai',
'anthropic_calls': 'anthropic',
'mistral_calls': 'huggingface', # HuggingFace stored as mistral
'wavespeed_calls': 'wavespeed',
'exa_calls': 'exa',
'tavily_calls': 'tavily',
'serper_calls': 'serper',
'firecrawl_calls': 'firecrawl',
'metaphor_calls': 'metaphor',
'stability_calls': 'stability',
'video_calls': 'video',
'image_edit_calls': 'image_edit',
'audio_calls': 'audio',
}
def _build_provider_breakdown(summaries: list, mapping: dict) -> dict:
"""Build provider_breakdown dict from a list of UsageSummary records."""
breakdown = {}
for db_col, frontend_key in mapping.items():
total = sum(getattr(s, db_col, 0) or 0 for s in summaries)
breakdown[frontend_key] = {'calls': total, 'cost': 0, 'tokens': 0}
return breakdown
def _build_usage_percentages(provider_breakdown: dict, limits: dict) -> dict:
"""Build usage_percentages dict from provider_breakdown and per-period limits."""
pcts = {}
if not limits or not limits.get('limits'):
return pcts
limit_map = {
'gemini_calls': ('gemini', 'gemini_calls'),
'huggingface_calls': ('huggingface', 'mistral_calls'),
'stability_calls': ('stability', 'stability_calls'),
'video_calls': ('video', 'video_calls'),
'audio_calls': ('audio', 'audio_calls'),
'image_edit_calls': ('image_edit', 'image_edit_calls'),
'wavespeed_calls': ('wavespeed', 'wavespeed_calls'),
'tavily_calls': ('tavily', 'tavily_calls'),
'serper_calls': ('serper', 'serper_calls'),
'firecrawl_calls': ('firecrawl', 'firecrawl_calls'),
'metaphor_calls': ('metaphor', 'metaphor_calls'),
'exa_calls': ('exa', 'exa_calls'),
}
for pct_key, (bk_key, limit_key) in limit_map.items():
used = provider_breakdown.get(bk_key, {}).get('calls', 0)
limit_val = limits.get('limits', {}).get(limit_key, 0) or 0
if limit_val > 0:
pcts[pct_key] = (used / limit_val) * 100
# Cost percentage
total_cost = provider_breakdown.get('total_cost', 0)
cost_limit = limits.get('limits', {}).get('monthly_cost', 0) or 0
if cost_limit > 0:
pcts['cost'] = (total_cost / cost_limit) * 100
return pcts
def _summaries_usage_status(summaries: list) -> str:
"""Derive overall usage_status from a list of summaries."""
status = 'active'
for s in summaries:
try:
st = s.usage_status.value
except Exception:
st = str(s.usage_status)
if st == 'limit_reached':
return 'limit_reached'
if st == 'warning' and status != 'limit_reached':
status = 'warning'
return status
def _empty_usage_response(billing_period: str, limits: dict) -> Dict[str, Any]:
"""Return a zeroed UsageStats-shaped response."""
return {
'billing_period': billing_period,
'usage_status': 'active',
'total_calls': 0,
'total_tokens': 0,
'total_cost': 0.0,
'avg_response_time': 0.0,
'error_rate': 0.0,
'limits': limits,
'provider_breakdown': {},
'usage_percentages': {},
'historical_breakdown': [],
'last_updated': datetime.now().isoformat()
}
def get_all_historical_usage(user_id: str, db: Session, pricing_service) -> Dict[str, Any]:
"""Get ALL historical usage data aggregated across all billing periods."""
all_summaries = db.query(UsageSummary).filter(
UsageSummary.user_id == user_id
).order_by(UsageSummary.billing_period.desc()).all()
limits = pricing_service.get_user_limits(user_id)
if not all_summaries:
return _empty_usage_response('all', limits)
# Aggregate
total_calls = sum(s.total_calls or 0 for s in all_summaries)
total_tokens = sum(s.total_tokens or 0 for s in all_summaries)
total_cost = sum(float(s.total_cost or 0) for s in all_summaries)
total_weighted_time = sum((s.avg_response_time or 0) * (s.total_calls or 0) for s in all_summaries)
avg_response_time = total_weighted_time / total_calls if total_calls > 0 else 0.0
total_errors = sum((s.total_calls or 0) * (s.error_rate or 0) / 100 for s in all_summaries)
error_rate = (total_errors / total_calls * 100) if total_calls > 0 else 0.0
provider_breakdown = _build_provider_breakdown(all_summaries, PROVIDER_MAPPING)
# Historical breakdown per period
historical_breakdown = []
for s in all_summaries:
try:
status_val = s.usage_status.value
except Exception:
status_val = str(s.usage_status)
historical_breakdown.append({
'billing_period': s.billing_period,
'total_calls': s.total_calls or 0,
'total_tokens': s.total_tokens or 0,
'total_cost': float(s.total_cost or 0),
'usage_status': status_val,
'updated_at': s.updated_at.isoformat() if s.updated_at else None
})
return {
'billing_period': 'all',
'usage_status': _summaries_usage_status(all_summaries),
'total_calls': total_calls,
'total_tokens': total_tokens,
'total_cost': round(total_cost, 2),
'avg_response_time': round(avg_response_time, 2),
'error_rate': round(error_rate, 2),
'limits': limits,
'provider_breakdown': provider_breakdown,
'usage_percentages': {}, # misleading for all-time vs per-period limits
'historical_breakdown': historical_breakdown,
'last_updated': datetime.now().isoformat()
}
def get_current_period_usage(user_id: str, db: Session, pricing_service) -> Dict[str, Any]:
"""Get current billing period usage data with correct per-period limit percentages.
Returns a UsageStats-shaped dict with provider_breakdown and usage_percentages
computed against the plan's per-period limits.
"""
current_period = pricing_service.get_current_billing_period(user_id)
limits = pricing_service.get_user_limits(user_id)
summary = db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
result = _empty_usage_response(current_period, limits)
result['usage_percentages'] = _build_usage_percentages({}, limits)
return result
provider_breakdown = _build_provider_breakdown([summary], PROVIDER_MAPPING)
usage_percentages = _build_usage_percentages(provider_breakdown, limits)
try:
status_val = summary.usage_status.value
except Exception:
status_val = str(summary.usage_status)
return {
'billing_period': current_period,
'usage_status': status_val,
'total_calls': summary.total_calls or 0,
'total_tokens': summary.total_tokens or 0,
'total_cost': round(float(summary.total_cost or 0), 2),
'avg_response_time': summary.avg_response_time or 0.0,
'error_rate': summary.error_rate or 0.0,
'limits': limits,
'provider_breakdown': provider_breakdown,
'usage_percentages': usage_percentages,
'historical_breakdown': [],
'last_updated': datetime.now().isoformat()
}
def get_usage_for_period(user_id: str, billing_period: str, db: Session, pricing_service) -> Dict[str, Any]:
"""Get usage data for a specific billing period.
Returns a UsageStats-shaped dict with that period's provider_breakdown
and usage_percentages computed against plan limits.
"""
limits = pricing_service.get_user_limits(user_id)
summary = db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == billing_period
).first()
if not summary:
result = _empty_usage_response(billing_period, limits)
result['usage_percentages'] = _build_usage_percentages({}, limits)
return result
provider_breakdown = _build_provider_breakdown([summary], PROVIDER_MAPPING)
usage_percentages = _build_usage_percentages(provider_breakdown, limits)
try:
status_val = summary.usage_status.value
except Exception:
status_val = str(summary.usage_status)
return {
'billing_period': billing_period,
'usage_status': status_val,
'total_calls': summary.total_calls or 0,
'total_tokens': summary.total_tokens or 0,
'total_cost': round(float(summary.total_cost or 0), 2),
'avg_response_time': summary.avg_response_time or 0.0,
'error_rate': summary.error_rate or 0.0,
'limits': limits,
'provider_breakdown': provider_breakdown,
'usage_percentages': usage_percentages,
'historical_breakdown': [],
'last_updated': datetime.now().isoformat()
}

View File

@@ -0,0 +1,38 @@
"""
Usage limit enforcement functions.
Extracted from usage_tracking_service.py for better maintainability.
"""
from typing import Tuple, Dict, Any
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from loguru import logger
from models.subscription_models import APIProvider
from services.subscription.pricing_service import PricingService
def enforce_usage_limits(user_id: str, provider: APIProvider,
tokens_requested: int, db: Session,
pricing_service: PricingService) -> Tuple[bool, str, Dict[str, Any]]:
"""Enforce usage limits before making an API call."""
# Check short-lived cache first (30s)
cache_key = f"{user_id}:{provider.value}"
now = datetime.utcnow()
# This would need access to self._enforce_cache
# For now, keeping the structure
result = pricing_service.check_usage_limits(
user_id=user_id,
provider=provider,
tokens_requested=tokens_requested
)
# Cache the result
# self._enforce_cache[cache_key] = {
# 'result': result,
# 'expires_at': now + timedelta(seconds=30)
# }
return tuple(result)

View File

@@ -0,0 +1,29 @@
"""
Usage statistics functions.
Extracted from usage_tracking_service.py for better maintainability.
"""
from typing import Dict, Any
from sqlalchemy.orm import Session
from loguru import logger
from datetime import datetime
from models.subscription_models import UsageSummary, UsageStatus, APIProvider
from services.subscription.usage_tracking_modules.historical_usage import get_all_historical_usage, get_usage_for_period
def get_user_usage_stats(user_id: str, billing_period: str, db: Session, pricing_service) -> Dict[str, Any]:
"""Get comprehensive usage statistics for a user.
When no billing_period is specified, returns ALL historical usage data.
When a specific period is given, returns only that period's data."""
if not user_id:
logger.error("get_user_usage_stats called without user_id")
raise ValueError("user_id is required")
# If no billing_period requested, return ALL historical data
if not billing_period:
return get_all_historical_usage(user_id, db, pricing_service)
# Return data for the specific billing period
return get_usage_for_period(user_id, billing_period, db, pricing_service)

View File

@@ -0,0 +1,18 @@
"""
Usage trends functions.
Extracted from usage_tracking_service.py for better maintainability.
"""
from typing import Dict, Any
from sqlalchemy.orm import Session
from loguru import logger
def get_usage_trends(user_id: str, months: int, db: Session) -> Dict[str, Any]:
"""Get usage trends over time with self-healing from logs."""
from services.subscription.usage_tracking_helpers import build_billing_periods, query_usage_summaries, self_heal_summaries_from_logs, build_usage_trends_response
periods = build_billing_periods(months)
summary_dict = query_usage_summaries(db, user_id, periods)
self_heal_summaries_from_logs(db, user_id, periods, summary_dict)
return build_usage_trends_response(periods, summary_dict)

View File

@@ -1,41 +1,60 @@
"""
Usage Tracking Service
Comprehensive tracking of API usage, costs, and subscription limits.
Usage Tracking Service - Refactored into modular components.
This file now serves as a facade that delegates to specialized modules
in the usage_tracking_modules package.
Modules:
- historical_usage: Functions for aggregating historical usage data
- usage_stats: Functions for getting user usage statistics
- usage_trends: Functions for usage trend analysis
- limit_enforcement: Functions for enforcing usage limits
- alerts: Functions for usage alerts
"""
# Ensure Optional is available in global scope for dynamic imports
from typing import Optional
import asyncio
from typing import Dict, Any, List, Tuple
from datetime import datetime, timedelta
from typing import Dict, Any, Tuple, Optional
from sqlalchemy.orm import Session
from sqlalchemy import desc
from sqlalchemy import text
from loguru import logger
import json
from api.subscription.cache import clear_dashboard_cache
from datetime import datetime, timedelta
import time
from models.subscription_models import (
APIUsageLog, UsageSummary, APIProvider, UsageAlert,
UserSubscription, UsageStatus
APIProvider, UsageStatus, UserSubscription,
UsageSummary, APIUsageLog, UsageAlert
)
from .pricing_service import PricingService
from .provider_detection import detect_actual_provider
from .usage_tracking_helpers import (
build_billing_periods,
build_default_usage_percentages,
build_empty_usage_response,
from services.subscription.pricing_service import PricingService
from services.subscription.provider_detection import detect_actual_provider
from services.subscription.usage_tracking_helpers import (
build_provider_breakdown,
build_usage_trends_response,
build_default_usage_percentages,
calculate_final_total_cost,
maybe_persist_reconciled_costs,
build_usage_trends_response,
build_billing_periods,
query_usage_summaries,
reset_usage_summary_counters,
self_heal_summaries_from_logs,
reset_usage_summary_counters,
)
# Import clear_dashboard_cache lazily to avoid circular import
def _clear_dashboard_cache_for_user(user_id: str):
from api.subscription.cache import clear_dashboard_cache as _clear
return _clear(user_id)
from .usage_tracking_modules import (
get_all_historical_usage,
get_current_period_usage,
get_usage_for_period,
get_user_usage_stats,
get_usage_trends,
enforce_usage_limits,
check_usage_alerts,
create_usage_alert,
)
class UsageTrackingService:
"""Service for tracking API usage and managing subscription limits."""
"""Service for tracking API usage and managing billing information."""
def __init__(self, db: Session):
self.db = db
@@ -43,13 +62,14 @@ class UsageTrackingService:
# TTL cache (30s) for enforcement results to cut DB chatter
# key: f"{user_id}:{provider}", value: { 'result': (bool,str,dict), 'expires_at': datetime }
self._enforce_cache: Dict[str, Dict[str, Any]] = {}
def _get_authoritative_billing_period_keys(self, user_id: str, billing_period: Optional[str] = None) -> Dict[str, Any]:
"""Return authoritative billing period lookup keys. Always uses calendar month for consistency."""
"""Return authoritative billing period lookup keys. Always uses subscription period for consistency.
Maintains backward compatibility with existing calendar-month data."""
subscription = self.db.query(UserSubscription).filter(
UserSubscription.user_id == user_id
).first()
# If caller explicitly requested a billing period, use it
if billing_period:
return {
@@ -58,26 +78,125 @@ class UsageTrackingService:
"period_start": subscription.current_period_start if subscription else None,
"period_end": subscription.current_period_end if subscription else None,
}
# ALWAYS use current calendar month for billing period to ensure consistency
# This prevents data loss when subscription spans month boundaries
current_period = datetime.now().strftime("%Y-%m")
# Get subscription period if available
subscription_period = None
if subscription and subscription.current_period_start:
subscription_period = subscription.current_period_start.strftime("%Y-%m")
# Get calendar period
calendar_period = datetime.now().strftime("%Y-%m")
# Check which period has usage data
from models.subscription_models import UsageSummary
if subscription_period:
# Check if data exists for subscription period
sub_data = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == subscription_period
).first()
if sub_data:
# Use subscription period (has data)
return {
"billing_period": subscription_period,
"lookup_periods": [subscription_period],
"period_start": subscription.current_period_start,
"period_end": subscription.current_period_end,
}
# No data for subscription period, check calendar period (backward compatibility)
if calendar_period != subscription_period:
cal_data = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == calendar_period
).first()
if cal_data:
logger.info(f"Using calendar period {calendar_period} for backward compatibility (subscription period {subscription_period} has no data)")
return {
"billing_period": calendar_period,
"lookup_periods": [calendar_period],
"period_start": None,
"period_end": None,
}
# No data in either period, use subscription period
return {
"billing_period": subscription_period,
"lookup_periods": [subscription_period],
"period_start": subscription.current_period_start,
"period_end": subscription.current_period_end,
}
# No subscription, check for any existing data
latest_summary = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id
).order_by(UsageSummary.billing_period.desc()).first()
if latest_summary:
logger.info(f"Using latest billing period from UsageSummary: {latest_summary.billing_period} for user {user_id}")
return {
"billing_period": latest_summary.billing_period,
"lookup_periods": [latest_summary.billing_period],
"period_start": None,
"period_end": None,
}
# Last fallback to calendar month for free tier / no subscription
return {
"billing_period": current_period,
"lookup_periods": [current_period],
"period_start": subscription.current_period_start if subscription else None,
"period_end": subscription.current_period_end if subscription else None,
"billing_period": calendar_period,
"lookup_periods": [calendar_period],
"period_start": None,
"period_end": None,
}
# Delegate to modular functions
def get_user_usage_stats(self, user_id: str, billing_period: str = None) -> Dict[str, Any]:
"""Get comprehensive usage statistics for a user."""
return get_user_usage_stats(user_id, billing_period, self.db, self.pricing_service)
def _get_all_historical_usage(self, user_id: str) -> Dict[str, Any]:
"""Get ALL historical usage data aggregated across all billing periods."""
return get_all_historical_usage(user_id, self.db, self.pricing_service)
def get_current_period_usage(self, user_id: str) -> Dict[str, Any]:
"""Get current billing period usage with correct per-period limit percentages."""
return get_current_period_usage(user_id, self.db, self.pricing_service)
def get_usage_for_period(self, user_id: str, billing_period: str) -> Dict[str, Any]:
"""Get usage for a specific billing period."""
return get_usage_for_period(user_id, billing_period, self.db, self.pricing_service)
def get_usage_trends(self, user_id: str, months: int = 6) -> Dict[str, Any]:
"""Get usage trends over time with self-healing from logs."""
return get_usage_trends(user_id, months, self.db)
async def enforce_usage_limits(self, user_id: str, provider: APIProvider,
tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]:
"""Enforce usage limits before making an API call."""
return enforce_usage_limits(user_id, provider, tokens_requested, self.db, self.pricing_service)
async def _check_usage_alerts(self, user_id: str, provider: APIProvider, billing_period: str):
"""Check if usage alerts should be sent."""
check_usage_alerts(user_id, provider, billing_period, self.db, self.pricing_service)
async def _create_usage_alert(self, user_id: str, provider: APIProvider,
threshold: int, current_usage: int, limit: int,
billing_period: str):
"""Create a usage alert."""
create_usage_alert(user_id, provider, threshold, current_usage, limit, billing_period, self.db)
# Keep the track_api_usage method here as it's the core functionality
async def track_api_usage(self, user_id: str, provider: APIProvider,
endpoint: str, method: str, model_used: str = None,
tokens_input: int = 0, tokens_output: int = 0,
response_time: float = 0.0, status_code: int = 200,
request_size: int = None, response_size: int = None,
user_agent: str = None, ip_address: str = None,
error_message: str = None, retry_count: int = 0,
**kwargs) -> Dict[str, Any]:
endpoint: str, method: str, model_used: str = None,
tokens_input: int = 0, tokens_output: int = 0,
response_time: float = 0.0, status_code: int = 200,
request_size: int = None, response_size: int = None,
user_agent: str = None, ip_address: str = None,
error_message: str = None, retry_count: int = 0,
**kwargs) -> Dict[str, Any]:
"""Track an API usage event and update billing information."""
try:
@@ -165,394 +284,81 @@ class UsageTrackingService:
# Invalidate dashboard cache so header stats update immediately
try:
clear_dashboard_cache(user_id)
_clear_dashboard_cache_for_user(user_id)
except Exception as cache_err:
logger.debug(f"Could not clear dashboard cache: {cache_err}")
logger.info(f"Tracked API usage: {user_id} -> {provider.value} -> ${cost_data['cost_total']:.6f}")
logger.warning(f"Failed to clear dashboard cache: {cache_err}")
return {
'usage_logged': True,
'cost': cost_data['cost_total'],
'tokens_used': (tokens_input or 0) + (tokens_output or 0),
'billing_period': billing_period
"success": True,
"cost": cost_data['cost_total'],
"tokens": (tokens_input or 0) + (tokens_output or 0),
"billing_period": billing_period
}
except Exception as e:
logger.error(f"Error tracking API usage: {str(e)}")
logger.error(f"Failed to track API usage: {e}")
self.db.rollback()
return {
'usage_logged': False,
'error': str(e)
"success": False,
"error": str(e)
}
async def _update_usage_summary(self, user_id: str, provider: APIProvider,
tokens_used: int, cost: float, billing_period: str,
response_time: float, is_error: bool):
"""Update the usage summary for a user."""
tokens_used: int, cost: float,
billing_period: str,
response_time: float = 0.0,
is_error: bool = False):
"""Update or create usage summary for the billing period."""
# Get or create usage summary
period_keys = self._get_authoritative_billing_period_keys(user_id, billing_period)
# Get or create summary
summary = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
UsageSummary.billing_period == billing_period
).first()
if not summary:
logger.info(f"[UsageTracking] Creating new UsageSummary for user={user_id}, period={period_keys['billing_period']}")
summary = UsageSummary(
user_id=user_id,
billing_period=period_keys["billing_period"]
billing_period=billing_period,
usage_status=UsageStatus.ACTIVE,
total_calls=0,
total_tokens=0,
total_cost=0.0
)
self.db.add(summary)
else:
logger.debug(f"[UsageTracking] Found existing UsageSummary for user={user_id}, period={summary.billing_period}, calls={summary.total_calls}")
# Update provider-specific counters
# Update counts
summary.total_calls = (summary.total_calls or 0) + 1
summary.total_tokens = (summary.total_tokens or 0) + tokens_used
summary.total_cost = (summary.total_cost or 0.0) + cost
# Update provider-specific counts
provider_name = provider.value
current_calls = getattr(summary, f"{provider_name}_calls", 0)
current_calls = getattr(summary, f"{provider_name}_calls", 0) or 0
setattr(summary, f"{provider_name}_calls", current_calls + 1)
# Update token usage for LLM providers
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL, APIProvider.WAVESPEED]:
current_tokens = getattr(summary, f"{provider_name}_tokens", 0)
setattr(summary, f"{provider_name}_tokens", current_tokens + tokens_used)
# Update provider-specific tokens
tokens_attr = f"{provider_name}_tokens"
if hasattr(summary, tokens_attr):
current_tokens = getattr(summary, tokens_attr, 0) or 0
setattr(summary, tokens_attr, current_tokens + tokens_used)
# Update cost
current_cost = getattr(summary, f"{provider_name}_cost", 0.0)
setattr(summary, f"{provider_name}_cost", current_cost + cost)
# Update provider-specific cost
cost_attr = f"{provider_name}_cost"
if hasattr(summary, cost_attr):
current_cost = getattr(summary, cost_attr, 0.0) or 0.0
setattr(summary, cost_attr, current_cost + cost)
# Update totals
summary.total_calls += 1
summary.total_tokens += tokens_used
summary.total_cost += cost
# Update response time (rolling average)
if response_time > 0:
current_avg = summary.avg_response_time or 0.0
current_calls = summary.total_calls or 1
summary.avg_response_time = ((current_avg * (current_calls - 1)) + response_time) / current_calls
# Update performance metrics
if summary.total_calls > 0:
# Update average response time
total_response_time = summary.avg_response_time * (summary.total_calls - 1) + response_time
summary.avg_response_time = total_response_time / summary.total_calls
# Update error rate
if is_error:
error_count = int(summary.error_rate * (summary.total_calls - 1) / 100) + 1
summary.error_rate = (error_count / summary.total_calls) * 100
else:
error_count = int(summary.error_rate * (summary.total_calls - 1) / 100)
summary.error_rate = (error_count / summary.total_calls) * 100
# Update usage status based on limits
await self._update_usage_status(summary)
# Update error rate
if is_error:
summary.error_count = (summary.error_count or 0) + 1
total_calls = summary.total_calls or 1
summary.error_rate = (summary.error_count / total_calls) * 100
summary.updated_at = datetime.utcnow()
async def _update_usage_status(self, summary: UsageSummary):
"""Update usage status based on subscription limits."""
limits = self.pricing_service.get_user_limits(summary.user_id)
if not limits:
return
# Check various limits and determine status
max_usage_percentage = 0.0
# Check cost limit
cost_limit = limits['limits'].get('monthly_cost', 0)
if cost_limit > 0:
cost_usage_pct = (summary.total_cost / cost_limit) * 100
max_usage_percentage = max(max_usage_percentage, cost_usage_pct)
# Check call limits for each provider
for provider in APIProvider:
provider_name = provider.value
current_calls = getattr(summary, f"{provider_name}_calls", 0)
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
if call_limit > 0:
call_usage_pct = (current_calls / call_limit) * 100
max_usage_percentage = max(max_usage_percentage, call_usage_pct)
# Update status based on highest usage percentage
if max_usage_percentage >= 100:
summary.usage_status = UsageStatus.LIMIT_REACHED
elif max_usage_percentage >= 80:
summary.usage_status = UsageStatus.WARNING
else:
summary.usage_status = UsageStatus.ACTIVE
async def _check_usage_alerts(self, user_id: str, provider: APIProvider, billing_period: str):
"""Check if usage alerts should be sent."""
# Get current usage
period_keys = self._get_authoritative_billing_period_keys(user_id, billing_period)
summary = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
).first()
if not summary:
return
# Get user limits
limits = self.pricing_service.get_user_limits(user_id)
if not limits:
return
# Check for alert thresholds (80%, 90%, 100%)
thresholds = [80, 90, 100]
for threshold in thresholds:
# Check if alert already sent for this threshold
existing_alert = self.db.query(UsageAlert).filter(
UsageAlert.user_id == user_id,
UsageAlert.billing_period == billing_period,
UsageAlert.threshold_percentage == threshold,
UsageAlert.provider == provider,
UsageAlert.is_sent == True
).first()
if existing_alert:
continue
# Check if threshold is reached
provider_name = provider.value
current_calls = getattr(summary, f"{provider_name}_calls", 0)
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
if call_limit > 0:
usage_percentage = (current_calls / call_limit) * 100
if usage_percentage >= threshold:
await self._create_usage_alert(
user_id=user_id,
provider=provider,
threshold=threshold,
current_usage=current_calls,
limit=call_limit,
billing_period=billing_period
)
async def _create_usage_alert(self, user_id: str, provider: APIProvider,
threshold: int, current_usage: int, limit: int,
billing_period: str):
"""Create a usage alert."""
# Determine alert type and severity
if threshold >= 100:
alert_type = "limit_reached"
severity = "error"
title = f"API Limit Reached - {provider.value.title()}"
message = f"You have reached your {provider.value} API limit of {limit:,} calls for this billing period."
elif threshold >= 90:
alert_type = "usage_warning"
severity = "warning"
title = f"API Usage Warning - {provider.value.title()}"
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
else:
alert_type = "usage_warning"
severity = "info"
title = f"API Usage Notice - {provider.value.title()}"
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
alert = UsageAlert(
user_id=user_id,
alert_type=alert_type,
threshold_percentage=threshold,
provider=provider,
title=title,
message=message,
severity=severity,
billing_period=billing_period
)
self.db.add(alert)
logger.info(f"Created usage alert for {user_id}: {title}")
def get_user_usage_stats(self, user_id: str, billing_period: str = None) -> Dict[str, Any]:
"""Get comprehensive usage statistics for a user."""
if not user_id:
logger.error("get_user_usage_stats called without user_id")
raise ValueError("user_id is required")
requested_billing_period = billing_period
period_keys = self._get_authoritative_billing_period_keys(user_id, requested_billing_period)
billing_period = period_keys["billing_period"]
logger.debug(f"[get_user_usage_stats] user={user_id}, billing_period={billing_period}, lookup_periods={period_keys['lookup_periods']}")
# Get usage summary
summary = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
).first()
if summary:
logger.debug(f"[get_user_usage_stats] Found summary: period={summary.billing_period}, calls={summary.total_calls}, cost={summary.total_cost}")
else:
logger.debug(f"[get_user_usage_stats] No summary found for user={user_id}, period={billing_period}")
# Get user limits
limits = self.pricing_service.get_user_limits(user_id)
# Get recent alerts
alerts = self.db.query(UsageAlert).filter(
UsageAlert.user_id == user_id,
UsageAlert.billing_period == billing_period,
UsageAlert.is_read == False
).order_by(UsageAlert.created_at.desc()).limit(10).all()
if not summary:
# If no summary exists for current period, we should initialize it
# This handles the "start of month" case where a user logs in but hasn't made calls yet
if not requested_billing_period:
logger.info(f"Initializing empty UsageSummary for user {user_id} in period {billing_period}")
summary = UsageSummary(
user_id=user_id,
billing_period=billing_period,
usage_status=UsageStatus.ACTIVE,
total_calls=0,
total_tokens=0,
total_cost=0.0
)
try:
self.db.add(summary)
self.db.commit()
self.db.refresh(summary)
except Exception as e:
logger.error(f"Failed to initialize summary: {e}")
self.db.rollback()
# Fallback to zero-struct return if DB write fails
pass
if not summary: # Still no summary after attempt
return build_empty_usage_response(
billing_period=billing_period,
limits=limits,
providers=APIProvider,
)
# Provider breakdown - calculate costs first, then use for percentages
# Only include Gemini and HuggingFace (HuggingFace is stored under MISTRAL enum)
provider_breakdown, resolved_costs, core_counts = build_provider_breakdown(
db=self.db,
user_id=user_id,
billing_period=billing_period,
summary=summary,
)
summary_total_cost = summary.total_cost or 0.0
calculated_total_cost, final_total_cost = calculate_final_total_cost(
summary_total_cost=summary_total_cost,
resolved_costs=resolved_costs,
)
maybe_persist_reconciled_costs(
db=self.db,
summary=summary,
summary_total_cost=summary_total_cost,
calculated_total_cost=calculated_total_cost,
final_total_cost=final_total_cost,
resolved_costs=resolved_costs,
)
# Calculate usage percentages - only for Gemini and HuggingFace
# Use the calculated costs for accurate percentages
usage_percentages = build_default_usage_percentages(APIProvider)
if limits:
# Gemini
gemini_call_limit = limits['limits'].get("gemini_calls", 0) or 0
if gemini_call_limit > 0:
usage_percentages['gemini_calls'] = (core_counts['gemini_calls'] / gemini_call_limit) * 100
# HuggingFace (stored as mistral in database)
mistral_call_limit = limits['limits'].get("mistral_calls", 0) or 0
if mistral_call_limit > 0:
usage_percentages['mistral_calls'] = (core_counts['mistral_calls'] / mistral_call_limit) * 100
# Cost usage percentage - use final_total_cost (calculated from logs if needed)
cost_limit = limits['limits'].get('monthly_cost', 0) or 0
if cost_limit > 0:
usage_percentages['cost'] = (final_total_cost / cost_limit) * 100
return {
'billing_period': billing_period,
'usage_status': summary.usage_status.value if hasattr(summary.usage_status, 'value') else str(summary.usage_status),
'total_calls': summary.total_calls or 0,
'total_tokens': summary.total_tokens or 0,
'total_cost': final_total_cost,
'avg_response_time': summary.avg_response_time or 0.0,
'error_rate': summary.error_rate or 0.0,
'limits': limits,
'provider_breakdown': provider_breakdown,
'alerts': [
{
'id': alert.id,
'type': alert.alert_type,
'title': alert.title,
'message': alert.message,
'severity': alert.severity,
'created_at': alert.created_at.isoformat()
}
for alert in alerts
],
'usage_percentages': usage_percentages,
'last_updated': summary.updated_at.isoformat()
}
def get_usage_trends(self, user_id: str, months: int = 6) -> Dict[str, Any]:
"""Get usage trends over time with self-healing from logs."""
periods = build_billing_periods(months)
summary_dict = query_usage_summaries(self.db, user_id, periods)
self_heal_summaries_from_logs(self.db, user_id, periods, summary_dict)
return build_usage_trends_response(periods, summary_dict)
async def enforce_usage_limits(self, user_id: str, provider: APIProvider,
tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]:
"""Enforce usage limits before making an API call."""
# Check short-lived cache first (30s)
cache_key = f"{user_id}:{provider.value}"
now = datetime.utcnow()
cached = self._enforce_cache.get(cache_key)
if cached and cached.get('expires_at') and cached['expires_at'] > now:
return tuple(cached['result']) # type: ignore
result = self.pricing_service.check_usage_limits(
user_id=user_id,
provider=provider,
tokens_requested=tokens_requested
)
self._enforce_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
async def reset_current_billing_period(self, user_id: str) -> Dict[str, Any]:
"""Reset usage status and counters for the current billing period (after plan renewal/change)."""
period_keys = self._get_authoritative_billing_period_keys(user_id)
billing_period = period_keys["billing_period"]
summary = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
).first()
if not summary:
return {"reset": False, "reason": "no_summary"}
try:
reset_usage_summary_counters(summary)
self.db.commit()
# Invalidate dashboard cache so header stats update after reset
try:
clear_dashboard_cache(user_id)
except Exception as cache_err:
logger.debug(f"Could not clear dashboard cache: {cache_err}")
logger.info(f"Reset usage counters for user {user_id} in billing period {billing_period} after renewal")
return {"reset": True, "counters_reset": True}
except Exception as e:
self.db.rollback()
logger.error(f"Error resetting usage status: {e}")
return {"reset": False, "error": str(e)}

View File

@@ -2,9 +2,8 @@ import os
import asyncio
from typing import Any, Dict, List
from dataclasses import dataclass
import requests
import httpx
from loguru import logger
import time
import random
from services.llm_providers.main_text_generation import llm_text_gen
@@ -61,30 +60,26 @@ class WritingAssistantService:
logger.info(f"Writing assistant API call #{self.daily_api_calls}/{self.daily_limit} today")
return True
async def suggest(self, text: str, max_results: int = 1) -> List[WritingSuggestion]:
async def suggest(self, text: str, user_id: str | None = None) -> List[WritingSuggestion]:
if not text or len(text.strip()) < 6:
return []
# COST OPTIMIZATION: Use cached/static suggestions for common patterns
# This reduces API calls by 90%+ while maintaining usefulness
cached_suggestion = self._get_cached_suggestion(text)
if cached_suggestion:
return [cached_suggestion]
# COST CONTROL: Check daily usage limits
if not self._check_daily_limit():
logger.warning("Daily API limit reached for writing assistant")
return []
# Only make expensive API calls for unique, substantial content
if len(text.strip()) < 50: # Skip API calls for very short text
if len(text.strip()) < 50:
return []
# 1) Find relevant sources via Exa (reduced results for cost)
# 1) Find relevant sources via Exa
sources = await self._search_sources(text)
# 2) Generate continuation suggestion via Gemini
suggestion_text, confidence = await self._generate_continuation(text, sources)
# 2) Generate continuation suggestion via LLM grounded in sources
suggestion_text, confidence = await self._generate_continuation(text, sources, user_id=user_id)
if not suggestion_text:
return []
@@ -110,12 +105,12 @@ class WritingAssistantService:
}
try:
resp = requests.post(
"https://api.exa.ai/search",
headers={"x-api-key": self.exa_api_key, "Content-Type": "application/json"},
json=payload,
timeout=self.http_timeout_seconds,
)
async with httpx.AsyncClient(timeout=self.http_timeout_seconds) as client:
resp = await client.post(
"https://api.exa.ai/search",
headers={"x-api-key": self.exa_api_key, "Content-Type": "application/json"},
json=payload,
)
if resp.status_code != 200:
raise Exception(f"Exa error {resp.status_code}: {resp.text}")
data = resp.json()
@@ -140,8 +135,7 @@ class WritingAssistantService:
logger.error(f"WritingAssistant _search_sources error: {e}")
raise
async def _generate_continuation(self, text: str, sources: List[Dict[str, Any]]) -> tuple[str, float]:
# Build compact sources context block
async def _generate_continuation(self, text: str, sources: List[Dict[str, Any]], user_id: str | None = None) -> tuple[str, float]:
source_blocks: List[str] = []
for i, s in enumerate(sources[:5]):
excerpt = (s.get("text", "") or "")
@@ -149,16 +143,14 @@ class WritingAssistantService:
source_blocks.append(
f"Source {i+1}: {s.get('title','') or 'Source'}\nURL: {s.get('url','')}\nExcerpt: {excerpt}"
)
sources_text = "\n\n".join(source_blocks) if source_blocks else "(No sources)"
sources_text = "\n\n".join(source_blocks)
# Provider-agnostic behavior: short continuation with one inline citation hint
system_prompt = (
"You are an assistive writing continuation bot. "
"Only produce 1-2 SHORT sentences. Do not repeat or paraphrase the user's stub. "
"Match tone and topic. Prefer concrete, current facts from the provided sources. "
"Include exactly one brief citation hint in parentheses with an author (or 'Source') and URL in square brackets, e.g., ((Doe, 2021)[https://example.com])."
)
user_prompt = (
f"User text to continue (do not repeat):\n{text}\n\n"
f"Relevant sources to inform your continuation:\n{sources_text}\n\n"
@@ -166,13 +158,13 @@ class WritingAssistantService:
)
try:
# Inter-call jitter to reduce burst rate limits
time.sleep(random.uniform(0.05, 0.15))
await asyncio.sleep(random.uniform(0.05, 0.15))
ai_resp = llm_text_gen(
prompt=user_prompt,
json_struct=None,
system_prompt=system_prompt,
user_id=user_id,
)
if isinstance(ai_resp, dict) and ai_resp.get("text"):
suggestion = (ai_resp.get("text", "") or "").strip()
@@ -180,12 +172,10 @@ class WritingAssistantService:
suggestion = (str(ai_resp or "")).strip()
if not suggestion:
raise Exception("Assistive writer returned empty suggestion")
# naive confidence from number of sources present
confidence = 0.7 if sources else 0.5
confidence = 0.7
return suggestion, confidence
except Exception as e:
logger.error(f"WritingAssistant _generate_continuation error: {e}")
# Propagate to ensure frontend does not show stale/generic content
raise