fix: WYSIWYG editor, content generation, and writing assistant bug fixes
- Fix text selection menu not showing: wire contentRef via inputRef on multiline TextField - Fix blog title not truncating: add min-w-0 for flex item overflow - Fix outline generation 500: escape curly braces in f-string prompt template - Fix content generation 'NoneType not callable': replace SessionLocal() with get_session_for_user(), add db param to MediumBlogGenerator, fix signature mismatch in database_task_manager - Fix writing assistant suggest 500: add auth + user_id to API endpoint and service, replace sync requests with httpx.AsyncClient - Fix hallucination detector 404: explicitly include router in main.py and app.py - Fix missing error_data in task failure responses - Hide CopilotKit web inspector button - Remove hardcoded fallback suggestions from SmartTypingAssist - Fix stale closure refs in SmartTypingAssist handleTypingChange - Add two-column editor layout, stats bar, section hover menu - Various subscription, billing, and research module improvements
This commit is contained in:
157
backend/add_method.py
Normal file
157
backend/add_method.py
Normal file
@@ -0,0 +1,157 @@
|
||||
#!/usr/bin/env python
|
||||
# Add _get_all_historical_usage method to usage_tracking_service.py
|
||||
|
||||
with open('services/subscription/usage_tracking_service.py', 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Find where to insert (before get_usage_trends)
|
||||
insert_idx = None
|
||||
for i, line in enumerate(lines):
|
||||
if ' def get_usage_trends(' in line:
|
||||
insert_idx = i
|
||||
break
|
||||
|
||||
if insert_idx is None:
|
||||
print("Error: Could not find insertion point")
|
||||
exit(1)
|
||||
|
||||
print(f"Inserting at line {insert_idx + 1}")
|
||||
|
||||
# Method to insert
|
||||
new_method = ''' def _get_all_historical_usage(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get ALL historical usage data aggregated across all billing periods."""
|
||||
|
||||
# Get all usage summaries for the user
|
||||
all_summaries = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id
|
||||
).order_by(UsageSummary.billing_period.desc()).all()
|
||||
|
||||
if not all_summaries:
|
||||
return {
|
||||
'billing_period': 'all',
|
||||
'usage_status': 'active',
|
||||
'total_calls': 0,
|
||||
'total_tokens': 0,
|
||||
'total_cost': 0.0,
|
||||
'avg_response_time': 0.0,
|
||||
'error_rate': 0.0,
|
||||
'limits': self.pricing_service.get_user_limits(user_id),
|
||||
'provider_breakdown': {},
|
||||
'usage_percentages': {},
|
||||
'historical_breakdown': [],
|
||||
'last_updated': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Aggregate all data from UsageSummary
|
||||
total_calls = sum(s.total_calls or 0 for s in all_summaries)
|
||||
total_tokens = sum(s.total_tokens or 0 for s in all_summaries)
|
||||
total_cost = sum(float(s.total_cost or 0) for s in all_summaries)
|
||||
|
||||
# Calculate weighted average response time
|
||||
total_weighted_time = sum((s.avg_response_time or 0) * (s.total_calls or 0) for s in all_summaries)
|
||||
avg_response_time = total_weighted_time / total_calls if total_calls > 0 else 0.0
|
||||
|
||||
# Calculate overall error rate
|
||||
total_errors = sum((s.total_calls or 0) * (s.error_rate or 0) / 100 for s in all_summaries)
|
||||
error_rate = (total_errors / total_calls * 100) if total_calls > 0 else 0.0
|
||||
|
||||
# Get user limits
|
||||
limits = self.pricing_service.get_user_limits(user_id)
|
||||
|
||||
# Map database columns to frontend keys
|
||||
provider_mapping = {
|
||||
'gemini_calls': 'gemini',
|
||||
'openai_calls': 'openai',
|
||||
'anthropic_calls': 'anthropic',
|
||||
'mistral_calls': 'huggingface',
|
||||
'wavespeed_calls': 'wavespeed',
|
||||
'exa_calls': 'exa',
|
||||
'video_calls': 'video',
|
||||
'image_edit_calls': 'image_edit',
|
||||
'audio_calls': 'audio',
|
||||
}
|
||||
|
||||
# Build provider_breakdown for frontend
|
||||
provider_breakdown = {}
|
||||
for db_col, frontend_key in provider_mapping.items():
|
||||
total_provider_calls = sum(getattr(s, db_col, 0) or 0 for s in all_summaries)
|
||||
provider_breakdown[frontend_key] = {
|
||||
'calls': total_provider_calls,
|
||||
'cost': 0,
|
||||
'tokens': 0
|
||||
}
|
||||
|
||||
# Calculate usage_percentages based on limits
|
||||
usage_percentages = {}
|
||||
if limits and limits.get('limits'):
|
||||
# Gemini calls percentage
|
||||
gemini_calls = provider_breakdown.get('gemini', {}).get('calls', 0)
|
||||
gemini_limit = limits.get('limits', {}).get('gemini_calls', 0) or 0
|
||||
if gemini_limit > 0:
|
||||
usage_percentages['gemini_calls'] = (gemini_calls / gemini_limit) * 100
|
||||
|
||||
# HuggingFace calls percentage (from mistral_calls)
|
||||
huggingface_calls = provider_breakdown.get('huggingface', {}).get('calls', 0)
|
||||
huggingface_limit = limits.get('limits', {}).get('mistral_calls', 0) or 0
|
||||
if huggingface_limit > 0:
|
||||
usage_percentages['huggingface_calls'] = (huggingface_calls / huggingface_limit) * 100
|
||||
|
||||
# Cost percentage
|
||||
cost_limit = limits.get('limits', {}).get('monthly_cost', 0) or 0
|
||||
if cost_limit > 0:
|
||||
usage_percentages['cost'] = (total_cost / cost_limit) * 100
|
||||
|
||||
# Build historical breakdown
|
||||
historical_breakdown = []
|
||||
for s in all_summaries:
|
||||
try:
|
||||
status_val = s.usage_status.value
|
||||
except:
|
||||
status_val = str(s.usage_status)
|
||||
historical_breakdown.append({
|
||||
'billing_period': s.billing_period,
|
||||
'total_calls': s.total_calls or 0,
|
||||
'total_tokens': s.total_tokens or 0,
|
||||
'total_cost': float(s.total_cost or 0),
|
||||
'usage_status': status_val,
|
||||
'updated_at': s.updated_at.isoformat() if s.updated_at else None
|
||||
})
|
||||
|
||||
# Determine overall status
|
||||
usage_status = 'active'
|
||||
for s in all_summaries:
|
||||
try:
|
||||
status = s.usage_status.value
|
||||
except:
|
||||
status = str(s.usage_status)
|
||||
if status == 'limit_reached':
|
||||
usage_status = 'limit_reached'
|
||||
break
|
||||
elif status == 'warning' and usage_status != 'limit_reached':
|
||||
usage_status = 'warning'
|
||||
|
||||
return {
|
||||
'billing_period': 'all',
|
||||
'usage_status': usage_status,
|
||||
'total_calls': total_calls,
|
||||
'total_tokens': total_tokens,
|
||||
'total_cost': round(total_cost, 2),
|
||||
'avg_response_time': round(avg_response_time, 2),
|
||||
'error_rate': round(error_rate, 2),
|
||||
'limits': limits,
|
||||
'provider_breakdown': provider_breakdown,
|
||||
'usage_percentages': usage_percentages,
|
||||
'historical_breakdown': historical_breakdown,
|
||||
'last_updated': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
'''
|
||||
|
||||
# Insert the new method
|
||||
new_lines = lines[:insert_idx] + [new_method] + lines[insert_idx:]
|
||||
|
||||
# Write back
|
||||
with open('services/subscription/usage_tracking_service.py', 'w', encoding='utf-8') as f:
|
||||
f.writelines(new_lines)
|
||||
|
||||
print("Successfully added _get_all_historical_usage method")
|
||||
@@ -5,8 +5,8 @@ Modular utilities for ALwrity backend startup and configuration.
|
||||
|
||||
import os
|
||||
|
||||
# Check podcast mode early to skip heavy imports
|
||||
_is_podcast = os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() == "podcast"
|
||||
# Check feature mode early to skip heavy imports
|
||||
_is_full_mode = os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() in ("", "all")
|
||||
|
||||
from .dependency_manager import DependencyManager
|
||||
from .environment_setup import EnvironmentSetup
|
||||
@@ -26,41 +26,25 @@ from .feature_runtime import (
|
||||
)
|
||||
|
||||
# Lazy load OnboardingManager - it triggers heavy imports (aiohttp, etc.)
|
||||
if not _is_podcast:
|
||||
if _is_full_mode:
|
||||
from .onboarding_manager import OnboardingManager
|
||||
__all__ = [
|
||||
'DependencyManager',
|
||||
'EnvironmentSetup',
|
||||
'DatabaseSetup',
|
||||
'ProductionOptimizer',
|
||||
'HealthChecker',
|
||||
'RateLimiter',
|
||||
'FrontendServing',
|
||||
'RouterManager',
|
||||
'OnboardingManager',
|
||||
'get_active_profiles',
|
||||
'get_enabled_groups',
|
||||
'get_enabled_optional_services',
|
||||
'get_enabled_routers',
|
||||
'get_enabled_startup_hooks',
|
||||
'is_enabled'
|
||||
]
|
||||
else:
|
||||
OnboardingManager = None
|
||||
__all__ = [
|
||||
'DependencyManager',
|
||||
'EnvironmentSetup',
|
||||
'DatabaseSetup',
|
||||
'ProductionOptimizer',
|
||||
'HealthChecker',
|
||||
'RateLimiter',
|
||||
'FrontendServing',
|
||||
'RouterManager',
|
||||
'OnboardingManager',
|
||||
'get_active_profiles',
|
||||
'get_enabled_groups',
|
||||
'get_enabled_optional_services',
|
||||
'get_enabled_routers',
|
||||
'get_enabled_startup_hooks',
|
||||
'is_enabled'
|
||||
]
|
||||
|
||||
__all__ = [
|
||||
'DependencyManager',
|
||||
'EnvironmentSetup',
|
||||
'DatabaseSetup',
|
||||
'ProductionOptimizer',
|
||||
'HealthChecker',
|
||||
'RateLimiter',
|
||||
'FrontendServing',
|
||||
'RouterManager',
|
||||
'OnboardingManager',
|
||||
'get_active_profiles',
|
||||
'get_enabled_groups',
|
||||
'get_enabled_optional_services',
|
||||
'get_enabled_routers',
|
||||
'get_enabled_startup_hooks',
|
||||
'is_enabled'
|
||||
]
|
||||
|
||||
@@ -51,6 +51,13 @@ FEATURE_GROUPS: Dict[str, FeatureGroup] = {
|
||||
"api.content_planning.strategy_copilot:router",
|
||||
),
|
||||
),
|
||||
"blog_writer": FeatureGroup(
|
||||
features=("blog_writer",),
|
||||
routers=(
|
||||
"api.blog_writer.router:router",
|
||||
"api.blog_writer.seo_analysis:router",
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -59,5 +66,6 @@ PROFILE_GROUP_MAP: Dict[str, Tuple[str, ...]] = {
|
||||
"core": ("core",),
|
||||
"podcast": ("core", "podcast"),
|
||||
"youtube": ("core", "youtube"),
|
||||
"blog_writer": ("core", "blog_writer"),
|
||||
"planning": ("core", "content_planning"),
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ from loguru import logger
|
||||
|
||||
CORE_ROUTER_REGISTRY = [
|
||||
{"name": "component_logic", "module": "api.component_logic", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "subscription", "module": "api.subscription", "attr": "router", "features": {"all", "core", "podcast", "blog-writer", "youtube"}},
|
||||
{"name": "subscription", "module": "api.subscription", "attr": "router", "features": {"all", "core", "podcast", "blog_writer", "youtube"}},
|
||||
{"name": "step3_research", "module": "api.onboarding_utils.step3_routes", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "step4_assets", "module": "api.onboarding_utils.step4_asset_routes", "attr": "router", "features": {"all", "core", "podcast"}},
|
||||
{"name": "step4_persona", "module": "api.onboarding_utils.step4_persona_routes_optimized", "attr": "router", "features": {"all", "core"}},
|
||||
@@ -29,31 +29,31 @@ CORE_ROUTER_REGISTRY = [
|
||||
{"name": "linkedin_image", "module": "api.linkedin_image_generation", "attr": "router", "features": {"all", "core", "linkedin"}},
|
||||
{"name": "brainstorm", "module": "api.brainstorm", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "hallucination_detector", "module": "api.hallucination_detector", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "writing_assistant", "module": "api.writing_assistant", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "content_planning", "module": "api.content_planning.api.router", "attr": "router", "features": {"all", "core", "content-planning"}},
|
||||
{"name": "user_data", "module": "api.user_data", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "user_environment", "module": "api.user_environment", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "strategy_copilot", "module": "api.content_planning.strategy_copilot", "attr": "router", "features": {"all", "core", "content-planning"}},
|
||||
{"name": "error_logging", "module": "routers.error_logging", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "frontend_env_manager", "module": "routers.frontend_env_manager", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "writing_assistant", "module": "api.writing_assistant", "attr": "router", "features": {"all", "core", "blog_writer"}},
|
||||
{"name": "content_planning", "module": "api.content_planning.api.router", "attr": "router", "features": {"all", "core", "content_planning"}},
|
||||
{"name": "user_data", "module": "api.user_data", "attr": "router", "features": {"all", "core", "blog_writer"}},
|
||||
{"name": "user_environment", "module": "api.user_environment", "attr": "router", "features": {"all", "core", "blog_writer"}},
|
||||
{"name": "strategy_copilot", "module": "api.content_planning.strategy_copilot", "attr": "router", "features": {"all", "core", "content_planning"}},
|
||||
{"name": "error_logging", "module": "routers.error_logging", "attr": "router", "features": {"all", "core", "blog_writer"}},
|
||||
{"name": "frontend_env_manager", "module": "routers.frontend_env_manager", "attr": "router", "features": {"all", "core", "blog_writer"}},
|
||||
{"name": "platform_analytics", "module": "routers.platform_analytics", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "bing_insights", "module": "routers.bing_insights", "attr": "router", "features": {"all", "core", "seo"}},
|
||||
{"name": "background_jobs", "module": "routers.background_jobs", "attr": "router", "features": {"all", "core"}},
|
||||
]
|
||||
|
||||
OPTIONAL_ROUTER_REGISTRY = [
|
||||
{"name": "blog_writer", "module": "api.blog_writer.router", "attr": "router", "features": {"all", "blog-writer"}},
|
||||
{"name": "story_writer", "module": "api.story_writer.router", "attr": "router", "features": {"all", "story-writer"}},
|
||||
{"name": "blog_writer", "module": "api.blog_writer.router", "attr": "router", "features": {"all", "blog_writer"}},
|
||||
{"name": "story_writer", "module": "api.story_writer.router", "attr": "router", "features": {"all", "story_writer"}},
|
||||
{"name": "wix", "module": "api.wix_routes", "attr": "router", "features": {"all"}},
|
||||
{"name": "blog_seo_analysis", "module": "api.blog_writer.seo_analysis", "attr": "router", "features": {"all", "blog-writer"}},
|
||||
{"name": "blog_seo_analysis", "module": "api.blog_writer.seo_analysis", "attr": "router", "features": {"all", "blog_writer"}},
|
||||
{"name": "persona", "module": "api.persona_routes", "attr": "router", "features": {"all", "persona"}},
|
||||
{"name": "video_studio", "module": "api.video_studio.router", "attr": "router", "features": {"all", "video-studio"}},
|
||||
{"name": "stability", "module": "routers.stability", "attr": "router", "features": {"all", "image-studio"}},
|
||||
{"name": "stability_advanced", "module": "routers.stability_advanced", "attr": "router", "features": {"all", "image-studio"}},
|
||||
{"name": "stability_admin", "module": "routers.stability_admin", "attr": "router", "features": {"all", "image-studio"}},
|
||||
{"name": "images", "module": "api.images", "attr": "router", "features": {"all", "image-studio"}},
|
||||
{"name": "image_studio", "module": "routers.image_studio", "attr": "router", "features": {"all", "image-studio"}},
|
||||
{"name": "product_marketing", "module": "routers.product_marketing", "attr": "router", "features": {"all", "product-marketing"}},
|
||||
{"name": "video_studio", "module": "api.video_studio.router", "attr": "router", "features": {"all", "video_studio"}},
|
||||
{"name": "stability", "module": "routers.stability", "attr": "router", "features": {"all", "image_studio"}},
|
||||
{"name": "stability_advanced", "module": "routers.stability_advanced", "attr": "router", "features": {"all", "image_studio"}},
|
||||
{"name": "stability_admin", "module": "routers.stability_admin", "attr": "router", "features": {"all", "image_studio"}},
|
||||
{"name": "images", "module": "api.images", "attr": "router", "features": {"all", "image_studio"}},
|
||||
{"name": "image_studio", "module": "routers.image_studio", "attr": "router", "features": {"all", "image_studio"}},
|
||||
{"name": "product_marketing", "module": "routers.product_marketing", "attr": "router", "features": {"all", "product_marketing"}},
|
||||
{"name": "campaign_creator", "module": "routers.campaign_creator", "attr": "router", "features": {"all"}},
|
||||
{"name": "content_assets", "module": "api.content_assets.router", "attr": "router", "features": {"all"}},
|
||||
{"name": "podcast", "module": "api.podcast.router", "attr": "router", "features": {"all", "podcast"}},
|
||||
|
||||
@@ -7,12 +7,11 @@ The onboarding endpoints are re-exported from a stable module
|
||||
|
||||
import os
|
||||
|
||||
# Check podcast mode early
|
||||
_is_podcast = os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() == "podcast"
|
||||
|
||||
# In podcast mode, don't import heavy onboarding endpoints
|
||||
# In feature-only modes, don't import heavy onboarding endpoints
|
||||
# They trigger heavy dependencies (exa_py, etc.)
|
||||
if _is_podcast:
|
||||
_is_full_mode = os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() in ("", "all")
|
||||
|
||||
if not _is_full_mode:
|
||||
__all__ = []
|
||||
else:
|
||||
from .onboarding_endpoints import (
|
||||
|
||||
@@ -1195,3 +1195,68 @@ async def generate_introductions(
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate introductions: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ---------------------------
|
||||
# Save Complete Blog Asset
|
||||
# ---------------------------
|
||||
|
||||
|
||||
class SaveCompleteBlogAssetRequest(BaseModel):
|
||||
title: str
|
||||
content: str
|
||||
seo_title: Optional[str] = None
|
||||
meta_description: Optional[str] = None
|
||||
focus_keyword: Optional[str] = None
|
||||
tags: List[str] = Field(default_factory=list)
|
||||
categories: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
@router.post("/save-complete-asset")
|
||||
async def save_complete_blog_asset(
|
||||
request: SaveCompleteBlogAssetRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> Dict[str, Any]:
|
||||
"""Save the complete blog content as a single asset in the asset library."""
|
||||
try:
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get('id', ''))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
full_content = f"# {request.title}\n\n{request.content}"
|
||||
|
||||
asset_id = save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=full_content,
|
||||
source_module="blog_writer",
|
||||
title=f"Published Blog: {request.title[:60]}",
|
||||
description=request.meta_description or f"Complete published blog post: {request.title}",
|
||||
prompt=f"SEO Title: {request.seo_title or request.title}\nFocus Keyword: {request.focus_keyword or ''}",
|
||||
tags=["blog", "published"] + [t for t in (request.tags or []) if t],
|
||||
asset_metadata={
|
||||
"status": "published",
|
||||
"focus_keyword": request.focus_keyword,
|
||||
"categories": request.categories,
|
||||
"word_count": len(full_content.split()),
|
||||
},
|
||||
subdirectory="published",
|
||||
file_extension=".md"
|
||||
)
|
||||
|
||||
if asset_id:
|
||||
logger.info(f"✅ Complete blog asset saved to library: ID={asset_id}")
|
||||
return {"success": True, "asset_id": asset_id}
|
||||
else:
|
||||
logger.warning("save_and_track_text_content returned None for published blog")
|
||||
return {"success": False, "error": "Failed to save blog asset"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save complete blog asset: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -13,7 +13,7 @@ from typing import Any, Dict, List
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from services.database import SessionLocal, get_session_for_user
|
||||
from services.database import get_session_for_user
|
||||
|
||||
from models.blog_models import (
|
||||
BlogResearchRequest,
|
||||
@@ -264,7 +264,7 @@ class TaskManager:
|
||||
raise ValueError("Global target words exceed 1000; medium generation not allowed")
|
||||
|
||||
# Create a sync session for asset saving
|
||||
db_session = SessionLocal()
|
||||
db_session = get_session_for_user(user_id)
|
||||
try:
|
||||
result: MediumBlogGenerateResult = await self.service.generate_medium_blog_with_progress(
|
||||
request,
|
||||
@@ -326,6 +326,7 @@ class TaskManager:
|
||||
await self.update_progress(task_id, f"❌ Medium generation failed: {str(e)}")
|
||||
self.task_storage[task_id]["status"] = "failed"
|
||||
self.task_storage[task_id]["error"] = str(e)
|
||||
self.task_storage[task_id]["error_data"] = {"error_message": str(e), "error_type": type(e).__name__}
|
||||
|
||||
|
||||
# Global task manager instance
|
||||
|
||||
@@ -202,6 +202,26 @@ Listener CTA: {request.analysis.get('listener_cta', 'N/A')}
|
||||
interests = ", ".join(audience_dna.get("interests", []))
|
||||
target_audience = f"Expertise: {audience_dna.get('expertise_level', '')}. Interests: {interests}."
|
||||
|
||||
# Preflight subscription check for Exa
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.EXA,
|
||||
tokens_requested=0,
|
||||
actual_provider_name="exa",
|
||||
)
|
||||
if not can_proceed:
|
||||
raise HTTPException(status_code=429, detail={
|
||||
'error': message, 'message': message,
|
||||
'provider': 'exa', 'usage_info': usage_info or {}
|
||||
})
|
||||
logger.info(f"[Podcast Research] Preflight check passed for user {user_id}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"[Podcast Research] Preflight check failed: {e}")
|
||||
|
||||
try:
|
||||
# 1. RUN EXA SEARCH
|
||||
logger.warning(f"[Podcast Research] Calling Exa search with topic: {request.topic[:100]}...")
|
||||
|
||||
@@ -9,10 +9,13 @@ from typing import Dict, Any, List, Optional
|
||||
from pydantic import BaseModel
|
||||
from loguru import logger
|
||||
from types import SimpleNamespace
|
||||
from sqlalchemy import text
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from api.story_writer.utils.auth import require_authenticated_user
|
||||
from services.research.tavily_service import TavilyService
|
||||
from services.blog_writer.research.exa_provider import ExaResearchProvider
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
|
||||
router = APIRouter(prefix="/research", tags=["Podcast Category Research"])
|
||||
|
||||
@@ -29,6 +32,75 @@ EXA_CATEGORY_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def _preflight_check(user_id: str, provider: APIProvider, provider_name: str):
|
||||
"""Check subscription limits before making a research API call."""
|
||||
from services.database import get_session_for_user
|
||||
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
return
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
tokens_requested=0,
|
||||
actual_provider_name=provider_name,
|
||||
)
|
||||
if not can_proceed:
|
||||
raise HTTPException(status_code=429, detail={
|
||||
'error': message, 'message': message,
|
||||
'provider': provider_name, 'usage_info': usage_info or {}
|
||||
})
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"[CategoryResearch] Preflight check failed for {provider_name}: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _track_research_usage(user_id: str, provider_name: str, cost: float, calls_column: str, cost_column: str):
|
||||
"""Track research API usage after successful call."""
|
||||
from services.database import get_session_for_user
|
||||
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
logger.warning(f"[CategoryResearch] Could not get DB session for user {user_id}")
|
||||
return
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
current_period = pricing_service.get_current_billing_period(user_id)
|
||||
|
||||
update_query = text(f"""
|
||||
UPDATE usage_summaries
|
||||
SET {calls_column} = COALESCE({calls_column}, 0) + 1,
|
||||
{cost_column} = COALESCE({cost_column}, 0) + :cost,
|
||||
total_calls = COALESCE(total_calls, 0) + 1,
|
||||
total_cost = COALESCE(total_cost, 0) + :cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db.execute(update_query, {
|
||||
'cost': cost,
|
||||
'user_id': user_id,
|
||||
'period': current_period,
|
||||
})
|
||||
db.commit()
|
||||
logger.info(f"[CategoryResearch] Tracked {provider_name} usage: user={user_id}, cost=${cost}")
|
||||
|
||||
# Clear dashboard cache so header stats update immediately
|
||||
try:
|
||||
from api.subscription.cache import clear_dashboard_cache
|
||||
clear_dashboard_cache(user_id)
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"[CategoryResearch] Failed to clear dashboard cache: {cache_err}")
|
||||
except Exception as e:
|
||||
logger.error(f"[CategoryResearch] Failed to track {provider_name} usage: {e}")
|
||||
db.rollback()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
class CategoryResearchRequest(BaseModel):
|
||||
category: str
|
||||
keyword: Optional[str] = None
|
||||
@@ -80,9 +152,12 @@ def _normalize_exa_results(results: List[Dict], query: str) -> List[CategoryTopi
|
||||
return topics
|
||||
|
||||
|
||||
async def _search_tavily(category: str, keyword: str, max_results: int) -> CategoryResearchResponse:
|
||||
async def _search_tavily(category: str, keyword: str, max_results: int, user_id: str) -> CategoryResearchResponse:
|
||||
logger.info(f"[CategoryResearch] Using Tavily for category={category}, keyword={keyword}")
|
||||
|
||||
|
||||
# Preflight subscription check
|
||||
_preflight_check(user_id, APIProvider.TAVILY, "tavily")
|
||||
|
||||
try:
|
||||
tavily = TavilyService()
|
||||
result = await tavily.search(
|
||||
@@ -102,6 +177,10 @@ async def _search_tavily(category: str, keyword: str, max_results: int) -> Categ
|
||||
topics = _normalize_tavily_results(result.get("results", []))
|
||||
logger.info(f"[CategoryResearch] Tavily found {len(topics)} topics")
|
||||
|
||||
# Track usage
|
||||
cost = 0.001 # basic search = 1 credit
|
||||
_track_research_usage(user_id, "tavily", cost, "tavily_calls", "tavily_cost")
|
||||
|
||||
return CategoryResearchResponse(
|
||||
success=True,
|
||||
category=category,
|
||||
@@ -117,7 +196,7 @@ async def _search_tavily(category: str, keyword: str, max_results: int) -> Categ
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
async def _search_exa(category: str, keyword: str, max_results: int, website_url: Optional[str] = None) -> CategoryResearchResponse:
|
||||
async def _search_exa(category: str, keyword: str, max_results: int, user_id: str, website_url: Optional[str] = None) -> CategoryResearchResponse:
|
||||
exa_category = EXA_CATEGORY_MAP.get(category, category)
|
||||
|
||||
logger.info(f"[CategoryResearch] Exa: category={category}, exa_category={exa_category}, keyword={keyword}, website_url={website_url}")
|
||||
@@ -133,6 +212,9 @@ async def _search_exa(category: str, keyword: str, max_results: int, website_url
|
||||
from exa_py import Exa
|
||||
exa = Exa(exa_api_key)
|
||||
logger.info(f"[CategoryResearch] Exa client initialized")
|
||||
|
||||
# Preflight subscription check
|
||||
_preflight_check(user_id, APIProvider.EXA, "exa")
|
||||
|
||||
# Build search parameters
|
||||
search_params = {
|
||||
@@ -189,6 +271,10 @@ async def _search_exa(category: str, keyword: str, max_results: int, website_url
|
||||
|
||||
logger.info(f"[CategoryResearch] Exa found {len(topics)} topics")
|
||||
|
||||
# Track usage
|
||||
cost = 0.005 # Default Exa cost for 1-25 results
|
||||
_track_research_usage(user_id, "exa", cost, "exa_calls", "exa_cost")
|
||||
|
||||
return CategoryResearchResponse(
|
||||
success=True,
|
||||
category=category,
|
||||
@@ -218,6 +304,7 @@ async def research_by_category(
|
||||
- news, finance: Uses Tavily
|
||||
- research-paper, personal-site: Uses Exa
|
||||
"""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
category = request.category.lower()
|
||||
valid_categories = list(CATEGORY_PROVIDER_MAP.keys())
|
||||
|
||||
@@ -241,9 +328,9 @@ async def research_by_category(
|
||||
|
||||
try:
|
||||
if provider == "tavily":
|
||||
return await _search_tavily(category, keyword, max_results)
|
||||
return await _search_tavily(category, keyword, max_results, user_id)
|
||||
elif provider == "exa":
|
||||
return await _search_exa(category, keyword, max_results, website_url)
|
||||
return await _search_exa(category, keyword, max_results, user_id, website_url)
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Unknown provider")
|
||||
except Exception as e:
|
||||
|
||||
@@ -4,6 +4,7 @@ Podcast Trends Handler
|
||||
Endpoints for fetching Google Trends data relevant to podcast topics.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -13,6 +14,25 @@ from middleware.auth_middleware import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/trends", tags=["Podcast Trends"])
|
||||
|
||||
# Module-level shared instance (singleton pattern)
|
||||
_trends_service_instance = None
|
||||
_trends_service_lock = None
|
||||
|
||||
|
||||
def get_trends_service():
|
||||
"""Get or create shared GoogleTrendsService instance."""
|
||||
global _trends_service_instance, _trends_service_lock
|
||||
if _trends_service_instance is None:
|
||||
try:
|
||||
from services.research.trends import GoogleTrendsService
|
||||
_trends_service_instance = GoogleTrendsService()
|
||||
_trends_service_lock = asyncio.Lock()
|
||||
logger.info("[Podcast Trends] Created shared GoogleTrendsService instance")
|
||||
except (ImportError, RuntimeError) as e:
|
||||
logger.error(f"[Podcast Trends] Failed to create GoogleTrendsService: {e}")
|
||||
raise
|
||||
return _trends_service_instance
|
||||
|
||||
|
||||
class PodcastTrendsRequest(BaseModel):
|
||||
keywords: List[str] = Field(..., min_length=1, max_length=5, description="1-5 keywords to analyze")
|
||||
@@ -38,7 +58,7 @@ async def get_podcast_trends(
|
||||
raise HTTPException(status_code=401, detail="User ID not found")
|
||||
|
||||
try:
|
||||
from services.research.trends import GoogleTrendsService
|
||||
service = get_trends_service()
|
||||
except (ImportError, RuntimeError) as e:
|
||||
logger.error(f"[Podcast Trends] GoogleTrendsService unavailable: {e}")
|
||||
raise HTTPException(
|
||||
@@ -47,11 +67,10 @@ async def get_podcast_trends(
|
||||
)
|
||||
|
||||
try:
|
||||
service = GoogleTrendsService()
|
||||
# Map 'source' to 'gprop' - 'podcast' uses YouTube for video/podcast relevance
|
||||
gprop_map = {"": "", "web": "", "podcast": "youtube", "news": "news", "images": "images", "shopping": "froogle"}
|
||||
gprop = gprop_map.get(request.source, "")
|
||||
|
||||
|
||||
result = await service.analyze_trends(
|
||||
keywords=request.keywords,
|
||||
timeframe=request.timeframe,
|
||||
@@ -73,7 +92,15 @@ async def get_podcast_trends(
|
||||
# Return error if: has error OR no data (meaning blocked/empty)
|
||||
if has_error and not has_data:
|
||||
error_msg = result.get("error", "")
|
||||
cooldown_active = result.get("cooldown_active", False)
|
||||
logger.warning(f"[Trends] No data or error: {error_msg[:100]}")
|
||||
# Provide helpful message during cooldown
|
||||
if cooldown_active:
|
||||
return PodcastTrendsResponse(
|
||||
success=False,
|
||||
data=result,
|
||||
error="Google is rate limiting requests. Try using 'Get Trending Topics' instead, or wait 30 minutes."
|
||||
)
|
||||
return PodcastTrendsResponse(success=False, data=result, error=error_msg or "No trends data available. Google may be blocking requests.")
|
||||
|
||||
# Even if no error but empty data - return error
|
||||
|
||||
@@ -12,7 +12,7 @@ import sqlite3
|
||||
from services.database import get_db
|
||||
from services.subscription import UsageTrackingService, PricingService
|
||||
from services.subscription.schema_utils import ensure_subscription_plan_columns, ensure_usage_summaries_columns
|
||||
from models.subscription_models import UsageAlert
|
||||
from models.subscription_models import UsageAlert, UserSubscription
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from ..dependencies import verify_user_access
|
||||
from ..cache import get_cached_dashboard, set_cached_dashboard
|
||||
@@ -27,7 +27,9 @@ async def get_dashboard_data(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get comprehensive dashboard data for usage monitoring."""
|
||||
"""Get comprehensive dashboard data for usage monitoring.
|
||||
Returns all-time total + current period usage by default.
|
||||
When billing_period is specified, returns that period's data only."""
|
||||
|
||||
verify_user_access(user_id, current_user)
|
||||
|
||||
@@ -35,17 +37,23 @@ async def get_dashboard_data(
|
||||
ensure_subscription_plan_columns(db)
|
||||
ensure_usage_summaries_columns(db)
|
||||
|
||||
# Check cache first (skip if billing_period is specified)
|
||||
if not billing_period:
|
||||
cached_data = get_cached_dashboard(user_id)
|
||||
if cached_data:
|
||||
return cached_data
|
||||
# Check cache first (only for default view, skip when a specific period is requested)
|
||||
cached_data = get_cached_dashboard(user_id)
|
||||
if cached_data and not billing_period:
|
||||
return cached_data
|
||||
|
||||
usage_service = UsageTrackingService(db)
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
# Get current usage stats (for the requested period)
|
||||
current_usage = usage_service.get_user_usage_stats(user_id, billing_period)
|
||||
# When a specific billing_period is requested, show only that period's data
|
||||
# Otherwise show all-time total + current period usage
|
||||
if billing_period:
|
||||
period_usage = usage_service.get_usage_for_period(user_id, billing_period)
|
||||
total_usage = period_usage
|
||||
current_period_usage = period_usage
|
||||
else:
|
||||
total_usage = usage_service.get_user_usage_stats(user_id, None)
|
||||
current_period_usage = usage_service.get_current_period_usage(user_id)
|
||||
|
||||
# Get usage trends (last 6 months)
|
||||
trends = usage_service.get_usage_trends(user_id, 6)
|
||||
@@ -76,13 +84,44 @@ async def get_dashboard_data(
|
||||
]
|
||||
|
||||
# Calculate cost projections (only relevant for current month)
|
||||
current_cost = current_usage.get('total_cost', 0)
|
||||
current_cost = total_usage.get('total_cost', 0)
|
||||
days_in_period = 30
|
||||
current_day = datetime.now().day
|
||||
|
||||
# Only project costs if viewing current month
|
||||
is_current_month = not billing_period or billing_period == datetime.now().strftime("%Y-%m")
|
||||
if is_current_month:
|
||||
# Determine if viewing current period based on subscription, not calendar
|
||||
subscription = db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id,
|
||||
UserSubscription.is_active == True
|
||||
).first()
|
||||
|
||||
# Use subscription's billing period or fallback to calendar
|
||||
if subscription and subscription.current_period_start:
|
||||
sub_period = subscription.current_period_start.strftime("%Y-%m")
|
||||
calendar_period = datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Check if we have data for subscription period or calendar period
|
||||
from models.subscription_models import UsageSummary
|
||||
sub_data_exists = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == sub_period
|
||||
).first()
|
||||
|
||||
# Determine which period to use for "current"
|
||||
if sub_data_exists:
|
||||
effective_period = sub_period
|
||||
else:
|
||||
# Check calendar period for backward compatibility
|
||||
cal_data_exists = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == calendar_period
|
||||
).first()
|
||||
effective_period = calendar_period if cal_data_exists else sub_period
|
||||
|
||||
is_current_period = not billing_period or billing_period == effective_period
|
||||
else:
|
||||
is_current_period = not billing_period or billing_period == datetime.now().strftime("%Y-%m")
|
||||
|
||||
if is_current_period:
|
||||
projected_cost = (current_cost / current_day) * days_in_period if current_day > 0 else 0
|
||||
else:
|
||||
projected_cost = current_cost # For past months, projected is actual
|
||||
@@ -90,7 +129,8 @@ async def get_dashboard_data(
|
||||
response_payload = {
|
||||
"success": True,
|
||||
"data": {
|
||||
"current_usage": current_usage,
|
||||
"total_usage": total_usage,
|
||||
"current_period_usage": current_period_usage,
|
||||
"trends": trends,
|
||||
"limits": limits,
|
||||
"alerts": alerts_data,
|
||||
@@ -100,9 +140,9 @@ async def get_dashboard_data(
|
||||
"projected_usage_percentage": (projected_cost / max(limits.get('limits', {}).get('monthly_cost', 1), 1)) * 100 if limits else 0
|
||||
},
|
||||
"summary": {
|
||||
"total_api_calls_this_month": current_usage.get('total_calls', 0),
|
||||
"total_cost_this_month": current_usage.get('total_cost', 0),
|
||||
"usage_status": current_usage.get('usage_status', 'active'),
|
||||
"total_api_calls_this_month": total_usage.get('total_calls', 0),
|
||||
"total_cost_this_month": total_usage.get('total_cost', 0),
|
||||
"usage_status": total_usage.get('usage_status', 'active'),
|
||||
"unread_alerts": len(alerts_data)
|
||||
}
|
||||
}
|
||||
@@ -131,7 +171,13 @@ async def get_dashboard_data(
|
||||
usage_service = UsageTrackingService(db)
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
current_usage = usage_service.get_user_usage_stats(user_id)
|
||||
if billing_period:
|
||||
period_usage = usage_service.get_usage_for_period(user_id, billing_period)
|
||||
total_usage = period_usage
|
||||
current_period_usage = period_usage
|
||||
else:
|
||||
total_usage = usage_service.get_user_usage_stats(user_id, None)
|
||||
current_period_usage = usage_service.get_current_period_usage(user_id)
|
||||
trends = usage_service.get_usage_trends(user_id, 6)
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
|
||||
@@ -152,7 +198,7 @@ async def get_dashboard_data(
|
||||
for alert in alerts
|
||||
]
|
||||
|
||||
current_cost = current_usage.get('total_cost', 0)
|
||||
current_cost = total_usage.get('total_cost', 0)
|
||||
days_in_period = 30
|
||||
current_day = datetime.now().day
|
||||
projected_cost = (current_cost / current_day) * days_in_period if current_day > 0 else 0
|
||||
@@ -160,7 +206,8 @@ async def get_dashboard_data(
|
||||
response_payload = {
|
||||
"success": True,
|
||||
"data": {
|
||||
"current_usage": current_usage,
|
||||
"total_usage": total_usage,
|
||||
"current_period_usage": current_period_usage,
|
||||
"trends": trends,
|
||||
"limits": limits,
|
||||
"alerts": alerts_data,
|
||||
@@ -170,16 +217,17 @@ async def get_dashboard_data(
|
||||
"projected_usage_percentage": (projected_cost / max(limits.get('limits', {}).get('monthly_cost', 1), 1)) * 100 if limits else 0
|
||||
},
|
||||
"summary": {
|
||||
"total_api_calls_this_month": current_usage.get('total_calls', 0),
|
||||
"total_cost_this_month": current_usage.get('total_cost', 0),
|
||||
"usage_status": current_usage.get('usage_status', 'active'),
|
||||
"total_api_calls_this_month": total_usage.get('total_calls', 0),
|
||||
"total_cost_this_month": total_usage.get('total_cost', 0),
|
||||
"usage_status": total_usage.get('usage_status', 'active'),
|
||||
"unread_alerts": len(alerts_data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Cache the response after successful retry
|
||||
set_cached_dashboard(user_id, response_payload)
|
||||
# Cache the response after successful retry (only for default view)
|
||||
if not billing_period:
|
||||
set_cached_dashboard(user_id, response_payload)
|
||||
return response_payload
|
||||
except Exception as retry_err:
|
||||
logger.error(f"Schema fix and retry failed: {retry_err}")
|
||||
@@ -187,7 +235,8 @@ async def get_dashboard_data(
|
||||
"success": False,
|
||||
"error": str(retry_err),
|
||||
"data": {
|
||||
"current_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}},
|
||||
"total_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}},
|
||||
"current_period_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}, "usage_percentages": {}},
|
||||
"trends": [],
|
||||
"limits": {"limits": {"monthly_cost": 0}},
|
||||
"alerts": [],
|
||||
@@ -201,7 +250,8 @@ async def get_dashboard_data(
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"data": {
|
||||
"current_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}},
|
||||
"total_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}},
|
||||
"current_period_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}, "usage_percentages": {}},
|
||||
"trends": [],
|
||||
"limits": {"limits": {"monthly_cost": 0}},
|
||||
"alerts": [],
|
||||
|
||||
@@ -14,13 +14,21 @@ def format_plan_limits(plan: SubscriptionPlan) -> Dict[str, Any]:
|
||||
"""
|
||||
Format subscription plan limits for API response.
|
||||
|
||||
Includes _zero_means metadata per field to disambiguate:
|
||||
- 'disabled': 0 means the feature is not available (Free tier)
|
||||
- 'unlimited': 0 means unlimited usage (Enterprise tier)
|
||||
- 'limited': >0 means numerical limit applies
|
||||
|
||||
Args:
|
||||
plan: SubscriptionPlan model instance
|
||||
|
||||
Returns:
|
||||
Dictionary with formatted limits
|
||||
Dictionary with formatted limits and _zero_means metadata
|
||||
"""
|
||||
return {
|
||||
tier = plan.tier.value if hasattr(plan.tier, 'value') else str(plan.tier)
|
||||
is_enterprise = tier == 'enterprise'
|
||||
|
||||
limit_fields = {
|
||||
"ai_text_generation_calls": getattr(plan, 'ai_text_generation_calls_limit', None) or 0,
|
||||
"gemini_calls": plan.gemini_calls_limit,
|
||||
"openai_calls": plan.openai_calls_limit,
|
||||
@@ -35,11 +43,43 @@ def format_plan_limits(plan: SubscriptionPlan) -> Dict[str, Any]:
|
||||
"image_edit_calls": getattr(plan, 'image_edit_calls_limit', 0) or 0,
|
||||
"audio_calls": getattr(plan, 'audio_calls_limit', 0) or 0,
|
||||
"exa_calls": getattr(plan, 'exa_calls_limit', 0) or 0,
|
||||
"wavespeed_calls": getattr(plan, 'wavespeed_calls_limit', 0) or 0,
|
||||
"gemini_tokens": plan.gemini_tokens_limit,
|
||||
"openai_tokens": plan.openai_tokens_limit,
|
||||
"anthropic_tokens": plan.anthropic_tokens_limit,
|
||||
"mistral_tokens": plan.mistral_tokens_limit,
|
||||
"monthly_cost": plan.monthly_cost_limit
|
||||
"monthly_cost": plan.monthly_cost_limit,
|
||||
}
|
||||
|
||||
# Build _zero_means metadata: indicates whether 0 means 'disabled' or 'unlimited'
|
||||
zero_means = {}
|
||||
for field, value in limit_fields.items():
|
||||
if field == "monthly_cost":
|
||||
zero_means[field] = "disabled"
|
||||
elif is_enterprise:
|
||||
# Enterprise: 0 means unlimited for all call/token fields
|
||||
zero_means[field] = "unlimited"
|
||||
else:
|
||||
# Free/Basic/Pro: determine per-field
|
||||
# Fields that are 0=disabled on Free tier but 0=unlimited on Basic/Pro
|
||||
call_and_token_fields = {
|
||||
"gemini_calls", "openai_calls", "anthropic_calls", "mistral_calls",
|
||||
"tavily_calls", "serper_calls", "metaphor_calls", "firecrawl_calls",
|
||||
"stability_calls", "video_calls", "image_edit_calls", "audio_calls",
|
||||
"exa_calls", "wavespeed_calls", "ai_text_generation_calls",
|
||||
"gemini_tokens", "openai_tokens", "anthropic_tokens", "mistral_tokens",
|
||||
}
|
||||
if field in call_and_token_fields:
|
||||
if value == 0:
|
||||
zero_means[field] = "disabled" if tier == "free" else "unlimited"
|
||||
else:
|
||||
zero_means[field] = "limited"
|
||||
else:
|
||||
zero_means[field] = "limited" if value > 0 else "disabled"
|
||||
|
||||
return {
|
||||
**limit_fields,
|
||||
"_zero_means": zero_means,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Any, Dict
|
||||
from loguru import logger
|
||||
|
||||
from services.writing_assistant import WritingAssistantService
|
||||
from middleware.auth_middleware import get_current_user
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/writing-assistant", tags=["writing-assistant"])
|
||||
@@ -11,7 +12,6 @@ router = APIRouter(prefix="/api/writing-assistant", tags=["writing-assistant"])
|
||||
|
||||
class SuggestRequest(BaseModel):
|
||||
text: str
|
||||
max_results: int | None = 1
|
||||
|
||||
|
||||
class SourceModel(BaseModel):
|
||||
@@ -38,9 +38,10 @@ assistant_service = WritingAssistantService()
|
||||
|
||||
|
||||
@router.post("/suggest", response_model=SuggestResponse)
|
||||
async def suggest_endpoint(req: SuggestRequest) -> SuggestResponse:
|
||||
async def suggest_endpoint(req: SuggestRequest, current_user: Dict[str, Any] = Depends(get_current_user)) -> SuggestResponse:
|
||||
try:
|
||||
suggestions = await assistant_service.suggest(req.text, req.max_results or 1)
|
||||
user_id = current_user.get("id")
|
||||
suggestions = await assistant_service.suggest(req.text, user_id=user_id)
|
||||
return SuggestResponse(
|
||||
success=True,
|
||||
suggestions=[
|
||||
|
||||
289
backend/app.py
289
backend/app.py
@@ -27,11 +27,11 @@ load_dotenv(backend_dir / '.env', override=False)
|
||||
load_dotenv(project_root / '.env', override=False)
|
||||
load_dotenv(override=False)
|
||||
|
||||
# Set LOG_LEVEL early to WARNING to suppress DEBUG persona logs in podcast mode
|
||||
# Set LOG_LEVEL early to WARNING in feature-only modes to suppress DEBUG persona logs
|
||||
import os
|
||||
if os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() == "podcast":
|
||||
if os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() not in ("", "all"):
|
||||
os.environ["LOG_LEVEL"] = "WARNING"
|
||||
|
||||
|
||||
print(f"[app.py] Starting... ALWRITY_ENABLED_FEATURES={os.getenv('ALWRITY_ENABLED_FEATURES')}", flush=True)
|
||||
|
||||
|
||||
@@ -43,22 +43,21 @@ def get_enabled_features() -> set:
|
||||
return {f.strip() for f in env_value.split(",") if f.strip()}
|
||||
|
||||
|
||||
def _is_full_mode() -> bool:
|
||||
"""Check if running in full mode (all features enabled)."""
|
||||
enabled = get_enabled_features()
|
||||
return "all" in enabled
|
||||
|
||||
|
||||
def _is_feature_enabled(feature: str) -> bool:
|
||||
"""Check if a specific feature is enabled (including in 'all' mode)."""
|
||||
enabled = get_enabled_features()
|
||||
return feature in enabled or "all" in enabled
|
||||
|
||||
|
||||
# Print env var IMMEDIATELY at module start
|
||||
print(f"[app.py] ALWRITY_ENABLED_FEATURES at start: {os.getenv('ALWRITY_ENABLED_FEATURES')}", flush=True)
|
||||
|
||||
def is_podcast_only_demo_mode() -> bool:
|
||||
"""Check if podcast-only mode is enabled."""
|
||||
import os
|
||||
env_val = os.getenv("ALWRITY_ENABLED_FEATURES", "all")
|
||||
enabled = get_enabled_features()
|
||||
result = "podcast" in enabled and "all" not in enabled
|
||||
# Removed debug print - too verbose during startup
|
||||
return result
|
||||
|
||||
|
||||
# Podcast-only check BEFORE heavy imports
|
||||
PODCAST_ONLY_DEMO_MODE = is_podcast_only_demo_mode()
|
||||
|
||||
|
||||
# Import onboarding models (after env is loaded, before heavy imports)
|
||||
from models.onboarding import APIKey, WebsiteAnalysis, ResearchPreferences, PersonaData, CompetitorAnalysis
|
||||
@@ -90,28 +89,18 @@ _log_memory_usage()
|
||||
logger.info("app.py: Early memory checkpoint after env load")
|
||||
|
||||
|
||||
# Import modular utilities (skip OnboardingManager import in podcast-only mode)
|
||||
# Import modular utilities (skip OnboardingManager import in feature-only modes)
|
||||
from alwrity_utils import HealthChecker, RateLimiter, FrontendServing, RouterManager
|
||||
if not is_podcast_only_demo_mode():
|
||||
if _is_full_mode():
|
||||
from alwrity_utils import OnboardingManager
|
||||
|
||||
# Skip monitoring middleware in podcast-only mode to save memory
|
||||
if not is_podcast_only_demo_mode():
|
||||
# Skip monitoring middleware in feature-only modes to save memory
|
||||
if _is_full_mode():
|
||||
from services.subscription import monitoring_middleware
|
||||
else:
|
||||
monitoring_middleware = None
|
||||
|
||||
|
||||
def should_include_non_podcast_features() -> bool:
|
||||
"""Check if non-podcast features should be included."""
|
||||
enabled = get_enabled_features()
|
||||
return "all" in enabled or "core" in enabled
|
||||
|
||||
|
||||
# Legacy constant for backwards compatibility
|
||||
PODCAST_ONLY_DEMO_MODE = is_podcast_only_demo_mode()
|
||||
|
||||
|
||||
# Set up clean logging for end users
|
||||
from logging_config import setup_clean_logging
|
||||
setup_clean_logging()
|
||||
@@ -119,27 +108,27 @@ setup_clean_logging()
|
||||
# Import middleware
|
||||
from middleware.auth_middleware import get_current_user
|
||||
|
||||
# Import component logic endpoints (skip in podcast-only mode - uses seo_analyzer)
|
||||
# Import component logic endpoints (skip in feature-only modes - uses seo_analyzer)
|
||||
component_logic_router = None
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
if _is_full_mode():
|
||||
from api.component_logic import router as component_logic_router
|
||||
|
||||
# Import subscription API endpoints
|
||||
from api.subscription import router as subscription_router
|
||||
|
||||
# Import Step 3 onboarding routes (skip in podcast-only mode)
|
||||
# Import Step 3 onboarding routes (skip in feature-only modes)
|
||||
step3_routes = None
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
if _is_full_mode():
|
||||
from api.onboarding_utils.step3_routes import router as step3_routes
|
||||
|
||||
# Import SEO tools router (skip in podcast-only mode - uses seo_analyzer)
|
||||
# Import SEO tools router (skip in feature-only modes - uses seo_analyzer)
|
||||
seo_tools_router = None
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
if _is_full_mode():
|
||||
from routers.seo_tools import router as seo_tools_router
|
||||
|
||||
# Skip Facebook Writer, LinkedIn, and other non-podcast routes in podcast-only mode
|
||||
# Skip Facebook Writer, LinkedIn, and other non-essential routes in feature-only modes
|
||||
# Also skip other heavy services that trigger PersonaAnalysisService initialization
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
if _is_full_mode():
|
||||
from api.facebook_writer.routers import facebook_router
|
||||
from routers.linkedin import router as linkedin_router
|
||||
from api.linkedin_image_generation import router as linkedin_image_router
|
||||
@@ -150,7 +139,7 @@ if not PODCAST_ONLY_DEMO_MODE:
|
||||
from routers.product_marketing import router as product_marketing_router
|
||||
from routers.campaign_creator import router as campaign_creator_router
|
||||
else:
|
||||
# In podcast-only mode, only load essential podcast assets router
|
||||
# In feature-only modes, only load essential assets router
|
||||
from api.assets_serving import router as assets_serving_router
|
||||
brainstorm_router = None
|
||||
images_router = None
|
||||
@@ -158,31 +147,31 @@ else:
|
||||
product_marketing_router = None
|
||||
campaign_creator_router = None
|
||||
|
||||
# Import hallucination detector router (skip in podcast-only mode - triggers heavy ML)
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
# Import hallucination detector router (skip in feature-only modes - triggers heavy ML)
|
||||
if _is_full_mode():
|
||||
from api.hallucination_detector import router as hallucination_detector_router
|
||||
from api.writing_assistant import router as writing_assistant_router
|
||||
else:
|
||||
hallucination_detector_router = None
|
||||
writing_assistant_router = None
|
||||
|
||||
# Import research configuration router (skip in podcast-only mode)
|
||||
if not is_podcast_only_demo_mode():
|
||||
# Import research configuration router (skip in feature-only modes)
|
||||
if _is_full_mode():
|
||||
from api.research_config import router as research_config_router
|
||||
else:
|
||||
research_config_router = None
|
||||
|
||||
# Import user data endpoints
|
||||
# Import content planning endpoints (skip in podcast-only mode)
|
||||
if not is_podcast_only_demo_mode():
|
||||
# Import content planning endpoints (skip in feature-only modes)
|
||||
if _is_full_mode():
|
||||
from api.content_planning.api.router import router as content_planning_router
|
||||
from api.content_planning.strategy_copilot import router as strategy_copilot_router
|
||||
else:
|
||||
content_planning_router = None
|
||||
strategy_copilot_router = None
|
||||
|
||||
# Import user data endpoints (skip in podcast-only mode to save memory)
|
||||
if not is_podcast_only_demo_mode():
|
||||
# Import user data endpoints (skip in feature-only modes to save memory)
|
||||
if _is_full_mode():
|
||||
from api.user_data import router as user_data_router
|
||||
else:
|
||||
user_data_router = None
|
||||
@@ -197,14 +186,14 @@ from services.startup_health import (
|
||||
|
||||
# Trigger reload for monitoring fix
|
||||
|
||||
# Import OAuth token monitoring routes (skip in podcast-only mode)
|
||||
if not is_podcast_only_demo_mode():
|
||||
# Import OAuth token monitoring routes (skip in feature-only modes)
|
||||
if _is_full_mode():
|
||||
from api.oauth_token_monitoring_routes import router as oauth_token_monitoring_router
|
||||
else:
|
||||
oauth_token_monitoring_router = None
|
||||
|
||||
# Import SEO Dashboard endpoints (skip in podcast-only mode to save memory)
|
||||
if not is_podcast_only_demo_mode():
|
||||
# Import SEO Dashboard endpoints (skip in feature-only modes to save memory)
|
||||
if _is_full_mode():
|
||||
from api.seo_dashboard import (
|
||||
get_seo_dashboard_data,
|
||||
get_seo_health_score,
|
||||
@@ -318,8 +307,8 @@ router_manager = RouterManager(app)
|
||||
router_group_status: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
onboarding_manager = None
|
||||
# Only create OnboardingManager if NOT in podcast-only mode
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
# Only create OnboardingManager in full mode
|
||||
if _is_full_mode():
|
||||
from alwrity_utils import OnboardingManager
|
||||
onboarding_manager = OnboardingManager(app)
|
||||
|
||||
@@ -346,7 +335,8 @@ app.middleware("http")(api_key_injection_middleware)
|
||||
async def health():
|
||||
"""Health check endpoint."""
|
||||
health_data = health_checker.basic_health_check()
|
||||
health_data["podcast_only_demo_mode"] = PODCAST_ONLY_DEMO_MODE
|
||||
health_data["feature_mode"] = "single" if not _is_full_mode() else "full"
|
||||
health_data["enabled_features"] = list(get_enabled_features())
|
||||
return health_data
|
||||
|
||||
@app.get("/health/database")
|
||||
@@ -363,7 +353,8 @@ async def comprehensive_health():
|
||||
async def readiness(current_user: dict = Depends(get_current_user)):
|
||||
"""Readiness check that validates tenant DB resolution/session under auth context."""
|
||||
return {
|
||||
"podcast_only_demo_mode": PODCAST_ONLY_DEMO_MODE,
|
||||
"feature_mode": "single" if not _is_full_mode() else "full",
|
||||
"enabled_features": list(get_enabled_features()),
|
||||
"startup": get_startup_status(),
|
||||
"tenant": readiness_under_auth_context(current_user),
|
||||
}
|
||||
@@ -395,7 +386,8 @@ async def router_status():
|
||||
status = router_manager.get_router_status()
|
||||
status.update(
|
||||
{
|
||||
"podcast_only_demo_mode": PODCAST_ONLY_DEMO_MODE,
|
||||
"feature_mode": "single" if not _is_full_mode() else "full",
|
||||
"enabled_features": list(get_enabled_features()),
|
||||
"router_groups": router_group_status,
|
||||
}
|
||||
)
|
||||
@@ -410,53 +402,19 @@ async def feature_profile_status():
|
||||
@app.get("/api/onboarding/status")
|
||||
async def onboarding_status():
|
||||
"""Get onboarding manager status (or demo-mode disabled state)."""
|
||||
if PODCAST_ONLY_DEMO_MODE:
|
||||
if not _is_full_mode():
|
||||
return {
|
||||
"enabled": False,
|
||||
"status": "disabled",
|
||||
"message": "Onboarding is disabled for podcast-only demo mode.",
|
||||
"demo_mode": "podcast_only",
|
||||
"message": f"Onboarding is disabled in feature-only mode. Enabled features: {list(get_enabled_features())}",
|
||||
"feature_mode": "single",
|
||||
}
|
||||
return onboarding_manager.get_onboarding_status()
|
||||
|
||||
# Include routers using modular utilities
|
||||
if PODCAST_ONLY_DEMO_MODE:
|
||||
# In podcast-only mode, include only podcast-enabled routers from core registry
|
||||
from alwrity_utils.router_manager import CORE_ROUTER_REGISTRY
|
||||
podcast_routers = [r for r in CORE_ROUTER_REGISTRY if "podcast" in r.get("features", set())]
|
||||
logger.info(f"[PODCAST-ONLY] Found {len(podcast_routers)} podcast routers: {[r['name'] for r in podcast_routers]}")
|
||||
|
||||
# Try to include step4_assets for voice cloning (may fail if nltk not installed)
|
||||
step4_entry = next((r for r in CORE_ROUTER_REGISTRY if r.get("name") == "step4_assets"), None)
|
||||
if step4_entry:
|
||||
try:
|
||||
logger.info(f"[PODCAST-ONLY] Attempting to load step4_assets for voice cloning")
|
||||
router = router_manager._load_router_from_registry(step4_entry)
|
||||
router_manager.include_router_safely(router, step4_entry["name"], step4_entry.get("include_kwargs"))
|
||||
except ImportError as e:
|
||||
logger.warning(f"[PODCAST-ONLY] Skipping step4_assets (missing optional dependency): {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[PODCAST-ONLY] Failed to mount step4_assets: {e}")
|
||||
|
||||
# Load other podcast routers
|
||||
for entry in podcast_routers:
|
||||
if entry.get("name") == "step4_assets":
|
||||
continue # Already loaded above
|
||||
try:
|
||||
logger.info(f"[PODCAST-ONLY] Loading router: {entry['name']}")
|
||||
router = router_manager._load_router_from_registry(entry)
|
||||
router_manager.include_router_safely(router, entry["name"], entry.get("include_kwargs"))
|
||||
except Exception as e:
|
||||
logger.error(f"[PODCAST-ONLY] Failed to mount {entry.get('name', 'unknown')}: {e}")
|
||||
router_group_status["modular_core"] = {
|
||||
"mounted": True,
|
||||
"reason": "Podcast routers only in podcast-only mode",
|
||||
}
|
||||
router_group_status["modular_optional"] = {
|
||||
"mounted": False,
|
||||
"reason": "Skipped in podcast-only demo mode",
|
||||
}
|
||||
else:
|
||||
enabled_features = get_enabled_features()
|
||||
if "all" in enabled_features:
|
||||
# Full mode: load all core and optional routers
|
||||
router_group_status["modular_core"] = {
|
||||
"mounted": router_manager.include_core_routers(),
|
||||
"reason": "Full mode",
|
||||
@@ -465,6 +423,72 @@ else:
|
||||
"mounted": router_manager.include_optional_routers(),
|
||||
"reason": "Full mode",
|
||||
}
|
||||
else:
|
||||
# Feature-only mode: load only routers matching enabled features
|
||||
from alwrity_utils.router_manager import CORE_ROUTER_REGISTRY
|
||||
|
||||
# Filter core routers that match any enabled feature
|
||||
matching_core = [
|
||||
r for r in CORE_ROUTER_REGISTRY
|
||||
if r.get("features", set()) & enabled_features
|
||||
]
|
||||
logger.info(
|
||||
f"[FEATURE-MODE] Enabled features: {enabled_features}, "
|
||||
f"matching {len(matching_core)} core routers: {[r['name'] for r in matching_core]}"
|
||||
)
|
||||
|
||||
# Try to include step4_assets for voice cloning (may fail if nltk not installed)
|
||||
step4_entry = next((r for r in matching_core if r.get("name") == "step4_assets"), None)
|
||||
if step4_entry:
|
||||
try:
|
||||
logger.info(f"[FEATURE-MODE] Attempting to load step4_assets")
|
||||
router = router_manager._load_router_from_registry(step4_entry)
|
||||
router_manager.include_router_safely(router, step4_entry["name"], step4_entry.get("include_kwargs"))
|
||||
except ImportError as e:
|
||||
logger.warning(f"[FEATURE-MODE] Skipping step4_assets (missing optional dependency): {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[FEATURE-MODE] Failed to mount step4_assets: {e}")
|
||||
|
||||
# Load other matching core routers
|
||||
for entry in matching_core:
|
||||
if entry.get("name") == "step4_assets":
|
||||
continue # Already loaded above
|
||||
if entry.get("name") == "subscription":
|
||||
continue # Loaded separately below
|
||||
try:
|
||||
logger.info(f"[FEATURE-MODE] Loading router: {entry['name']}")
|
||||
router = router_manager._load_router_from_registry(entry)
|
||||
router_manager.include_router_safely(router, entry["name"], entry.get("include_kwargs"))
|
||||
except Exception as e:
|
||||
logger.error(f"[FEATURE-MODE] Failed to mount {entry.get('name', 'unknown')}: {e}")
|
||||
|
||||
router_group_status["modular_core"] = {
|
||||
"mounted": True,
|
||||
"reason": f"Feature-only mode: {enabled_features}",
|
||||
}
|
||||
|
||||
# Load optional routers matching enabled features
|
||||
from alwrity_utils.router_manager import OPTIONAL_ROUTER_REGISTRY
|
||||
matching_optional = [
|
||||
r for r in OPTIONAL_ROUTER_REGISTRY
|
||||
if r.get("features", set()) & enabled_features
|
||||
]
|
||||
for entry in matching_optional:
|
||||
try:
|
||||
logger.info(f"[FEATURE-MODE] Loading optional router: {entry['name']}")
|
||||
router = router_manager._load_router_from_registry(entry)
|
||||
router_manager.include_router_safely(router, entry["name"], entry.get("include_kwargs"))
|
||||
except Exception as e:
|
||||
logger.error(f"[FEATURE-MODE] Failed to mount optional {entry.get('name', 'unknown')}: {e}")
|
||||
|
||||
router_group_status["modular_optional"] = {
|
||||
"mounted": True,
|
||||
"reason": f"Feature-only mode: {enabled_features}",
|
||||
}
|
||||
|
||||
# Safety net: explicitly include hallucination detector (router_manager may skip silently)
|
||||
if hallucination_detector_router:
|
||||
router_manager.include_router_safely(hallucination_detector_router, "hallucination_detector")
|
||||
|
||||
# Log startup summary
|
||||
router_manager.log_startup_summary()
|
||||
@@ -480,8 +504,8 @@ router_group_status["assets_serving"] = {
|
||||
"reason": "Required for podcast media assets",
|
||||
}
|
||||
|
||||
# SEO Dashboard endpoints (skip in podcast-only mode)
|
||||
if not is_podcast_only_demo_mode():
|
||||
# SEO Dashboard endpoints (skip in feature-only modes)
|
||||
if _is_full_mode():
|
||||
@app.get("/api/seo-dashboard/data")
|
||||
async def seo_dashboard_data():
|
||||
"""Get complete SEO dashboard data."""
|
||||
@@ -619,7 +643,7 @@ if not is_podcast_only_demo_mode():
|
||||
return await analyze_urls_ai(request, current_user)
|
||||
|
||||
# Include platform analytics router
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
if _is_full_mode():
|
||||
from routers.platform_analytics import router as platform_analytics_router
|
||||
app.include_router(platform_analytics_router)
|
||||
# Include Bing Analytics Storage router to expose storage-backed endpoints
|
||||
@@ -644,25 +668,38 @@ if not PODCAST_ONLY_DEMO_MODE:
|
||||
else:
|
||||
router_group_status["platform_extensions"] = {
|
||||
"mounted": False,
|
||||
"reason": "Skipped in podcast-only demo mode",
|
||||
"reason": "Skipped in feature-only mode",
|
||||
}
|
||||
|
||||
# Include Podcast Maker router (always needed for podcast mode)
|
||||
from api.podcast.router import router as podcast_router
|
||||
logger.info(f"[PODCAST] Including podcast_router with prefixes: {podcast_router.routes}")
|
||||
app.include_router(podcast_router)
|
||||
router_group_status["podcast_maker"] = {
|
||||
"mounted": True,
|
||||
"reason": "Always mounted",
|
||||
}
|
||||
# Include Podcast Maker router (only when podcast feature is enabled)
|
||||
if _is_feature_enabled("podcast") and "all" not in get_enabled_features():
|
||||
from api.podcast.router import router as podcast_router
|
||||
logger.info(f"[ROUTER] Including podcast_router")
|
||||
app.include_router(podcast_router)
|
||||
router_group_status["podcast_maker"] = {
|
||||
"mounted": True,
|
||||
"reason": "Podcast feature enabled",
|
||||
}
|
||||
elif "all" in get_enabled_features():
|
||||
# In full mode, podcast is loaded via optional router registry
|
||||
router_group_status["podcast_maker"] = {
|
||||
"mounted": True,
|
||||
"reason": "Full mode (loaded via registry)",
|
||||
}
|
||||
else:
|
||||
router_group_status["podcast_maker"] = {
|
||||
"mounted": False,
|
||||
"reason": "Podcast feature not enabled",
|
||||
}
|
||||
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
if _is_full_mode():
|
||||
# Include YouTube Creator Studio router
|
||||
from api.youtube.router import router as youtube_router
|
||||
app.include_router(youtube_router, prefix="/api")
|
||||
|
||||
# Include research configuration router
|
||||
app.include_router(research_config_router, prefix="/api/research", tags=["research"])
|
||||
if research_config_router:
|
||||
app.include_router(research_config_router, prefix="/api/research", tags=["research"])
|
||||
|
||||
# Include Research Engine router (standalone AI research module)
|
||||
from api.research.router import router as research_engine_router
|
||||
@@ -688,7 +725,7 @@ if not PODCAST_ONLY_DEMO_MODE:
|
||||
else:
|
||||
router_group_status["advanced_workflows"] = {
|
||||
"mounted": False,
|
||||
"reason": "Skipped in podcast-only demo mode",
|
||||
"reason": "Skipped in feature-only mode",
|
||||
}
|
||||
|
||||
# Setup frontend serving using modular utilities
|
||||
@@ -715,20 +752,23 @@ async def startup_event():
|
||||
# Note: Pricing is initialized per-user in services/database.py:init_user_database()
|
||||
# which runs on first database access for each user. No global seeding needed at startup.
|
||||
|
||||
# Skip startup health checks in podcast-only mode to avoid unnecessary DB errors
|
||||
if not is_podcast_only_demo_mode():
|
||||
enabled_features = get_enabled_features()
|
||||
is_single_mode = "all" not in enabled_features
|
||||
|
||||
# Skip startup health checks in feature-only modes to avoid unnecessary DB errors
|
||||
if _is_full_mode():
|
||||
startup_report = run_startup_health_routine(app)
|
||||
if startup_report.get("status") != "healthy":
|
||||
logger.error(f"Startup readiness finished with failures: {startup_report.get('errors', [])}")
|
||||
else:
|
||||
logger.info("[Podcast] Skipping startup health routine (podcast-only mode)")
|
||||
logger.info(f"[FEATURE-MODE] Skipping startup health routine (features: {enabled_features})")
|
||||
|
||||
# Start task scheduler only if NOT in podcast-only mode
|
||||
if not is_podcast_only_demo_mode():
|
||||
# Start task scheduler only in full mode
|
||||
if _is_full_mode():
|
||||
from services.scheduler import get_scheduler
|
||||
await get_scheduler().start()
|
||||
else:
|
||||
logger.info("[Podcast] Skipping scheduler startup (podcast-only mode)")
|
||||
logger.info(f"[FEATURE-MODE] Skipping scheduler startup (features: {enabled_features})")
|
||||
|
||||
# Check Wix API key configuration
|
||||
wix_api_key = os.getenv('WIX_API_KEY')
|
||||
@@ -740,9 +780,12 @@ async def startup_event():
|
||||
elapsed = time.time() - startup_start
|
||||
logger.info(f"ALwrity backend started successfully in {elapsed:.1f}s")
|
||||
|
||||
# Critical router mount assertions for podcast-only demo mode
|
||||
# Critical router mount assertions for feature-only modes
|
||||
_assert_router_mounted("subscription")
|
||||
_assert_router_mounted("podcast")
|
||||
if _is_feature_enabled("podcast"):
|
||||
_assert_router_mounted("podcast")
|
||||
if _is_feature_enabled("blog_writer"):
|
||||
_assert_router_mounted("blog_writer")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during startup: {e}")
|
||||
# Don't raise - let the server start anyway
|
||||
@@ -757,6 +800,7 @@ def _assert_router_mounted(router_name: str) -> None:
|
||||
router_path_indicators = {
|
||||
"subscription": ["/api/subscription/plans", "/api/subscription/preflight"],
|
||||
"podcast": ["/api/podcast/projects", "/api/podcast/"],
|
||||
"blog_writer": ["/api/blog/health", "/api/blog/research/start"],
|
||||
}
|
||||
|
||||
expected_paths = router_path_indicators.get(router_name, [])
|
||||
@@ -767,10 +811,9 @@ def _assert_router_mounted(router_name: str) -> None:
|
||||
else:
|
||||
error_msg = f"❌ CRITICAL: Router '{router_name}' is NOT mounted! Expected paths: {expected_paths}"
|
||||
logger.error(error_msg)
|
||||
if PODCAST_ONLY_DEMO_MODE:
|
||||
# In demo mode, podcast router MUST be mounted
|
||||
if router_name == "podcast":
|
||||
raise RuntimeError(error_msg)
|
||||
# In feature-only mode, only fail if the feature is expected
|
||||
if not _is_full_mode() and _is_feature_enabled(router_name):
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
# Shutdown event
|
||||
@app.on_event("shutdown")
|
||||
|
||||
@@ -252,6 +252,8 @@ router_manager.include_core_routers()
|
||||
# Safety net: keep subscription routes available even if core inclusion flow changes
|
||||
# in special modes (e.g., demo mode). De-dup is handled by RouterManager.
|
||||
router_manager.include_router_safely(subscription_router, "subscription")
|
||||
# Include hallucination detector explicitly (router_manager may skip silently on import failure)
|
||||
router_manager.include_router_safely(hallucination_detector_router, "hallucination_detector")
|
||||
router_manager.include_optional_routers()
|
||||
|
||||
# SEO Dashboard endpoints
|
||||
|
||||
@@ -11,17 +11,30 @@ echo "📦 Checking ALWRITY_ENABLED_FEATURES..."
|
||||
ENABLED_FEATURES="${ALWRITY_ENABLED_FEATURES:-all}"
|
||||
echo "DEBUG: ENABLED_FEATURES='$ENABLED_FEATURES'"
|
||||
|
||||
if [[ "$ENABLED_FEATURES" == "podcast" ]]; then
|
||||
echo "🔊 Podcast-only mode: Installing lean requirements..."
|
||||
python -m pip install --no-cache-dir -r requirements-podcast.txt --only-binary :all: --retries 10 --timeout 120
|
||||
else
|
||||
echo "📦 Full mode: Installing all requirements..."
|
||||
python -m pip install --no-cache-dir -r requirements.txt --only-binary :all: --retries 10 --timeout 120
|
||||
# Download spaCy/NLTK models for full mode
|
||||
echo "🧠 Installing spaCy and NLTK models..."
|
||||
python -m spacy download en_core_web_sm
|
||||
python -m nltk.downloader punkt_tab stopwords averaged_perceptron_tagger
|
||||
fi
|
||||
case "$ENABLED_FEATURES" in
|
||||
all)
|
||||
echo "📦 Full mode: Installing all requirements..."
|
||||
python -m pip install --no-cache-dir -r requirements.txt --only-binary :all: --retries 10 --timeout 120
|
||||
# Download spaCy/NLTK models for full mode
|
||||
echo "🧠 Installing spaCy and NLTK models..."
|
||||
python -m spacy download en_core_web_sm
|
||||
python -m nltk.downloader punkt_tab stopwords averaged_perceptron_tagger
|
||||
;;
|
||||
podcast)
|
||||
echo "🔊 Podcast-only mode: Installing lean requirements..."
|
||||
python -m pip install --no-cache-dir -r requirements-podcast.txt --only-binary :all: --retries 10 --timeout 120
|
||||
;;
|
||||
*)
|
||||
echo "🎯 Feature-limited mode ($ENABLED_FEATURES): Installing requirements..."
|
||||
req_file="requirements-${ENABLED_FEATURES}.txt"
|
||||
if [[ -f "$req_file" ]]; then
|
||||
python -m pip install --no-cache-dir -r "$req_file" --only-binary :all: --retries 10 --timeout 120
|
||||
else
|
||||
echo "⚠️ No feature-specific requirements file found ($req_file), installing full requirements..."
|
||||
python -m pip install --no-cache-dir -r requirements.txt --only-binary :all: --retries 10 --timeout 120
|
||||
fi
|
||||
;;
|
||||
esac
|
||||
|
||||
# 3. Clean up unnecessary build artifacts
|
||||
find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,6 +9,7 @@ import json
|
||||
from typing import Dict, Any, List
|
||||
from loguru import logger
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.blog_models import (
|
||||
MediumBlogGenerateRequest,
|
||||
@@ -26,7 +27,7 @@ class MediumBlogGenerator:
|
||||
def __init__(self):
|
||||
self.cache = persistent_content_cache
|
||||
|
||||
async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str, user_id: str) -> MediumBlogGenerateResult:
|
||||
async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str, user_id: str, db: Session = None) -> MediumBlogGenerateResult:
|
||||
"""Use Gemini structured JSON to generate a medium-length blog in one call.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -499,7 +499,7 @@ class DatabaseTaskManager:
|
||||
)
|
||||
blog_writer_logger.log_error(e, "outline_generation_task", context={"task_id": task_id})
|
||||
|
||||
async def _run_medium_generation_task(self, task_id: str, request: MediumBlogGenerateRequest):
|
||||
async def _run_medium_generation_task(self, task_id: str, request: MediumBlogGenerateRequest, user_id: str):
|
||||
"""Background task to generate a medium blog using a single structured JSON call."""
|
||||
try:
|
||||
await self.update_progress(task_id, "📦 Packaging outline and metadata...", 0)
|
||||
@@ -512,7 +512,7 @@ class DatabaseTaskManager:
|
||||
result: MediumBlogGenerateResult = await self.service.generate_medium_blog_with_progress(
|
||||
request,
|
||||
task_id,
|
||||
user_id=request.user_id if hasattr(request, 'user_id') else (await self.get_task_status(task_id))['user_id'],
|
||||
user_id,
|
||||
db=self.db
|
||||
)
|
||||
|
||||
|
||||
@@ -70,22 +70,22 @@ STRATEGIC REQUIREMENTS:
|
||||
- Ensure engaging, actionable content throughout
|
||||
|
||||
Return JSON format:
|
||||
{
|
||||
{{
|
||||
"title_options": [
|
||||
"Title option 1",
|
||||
"Title option 2",
|
||||
"Title option 3"
|
||||
],
|
||||
"outline": [
|
||||
{
|
||||
{{
|
||||
"heading": "Section heading with primary keyword",
|
||||
"subheadings": ["Subheading 1", "Subheading 2", "Subheading 3"],
|
||||
"key_points": ["Key point 1", "Key point 2", "Key point 3"],
|
||||
"target_words": 300,
|
||||
"keywords": ["primary keyword", "secondary keyword"]
|
||||
}
|
||||
}}
|
||||
]
|
||||
}"""
|
||||
}}"""
|
||||
|
||||
def get_outline_schema(self) -> Dict[str, Any]:
|
||||
"""Get the structured JSON schema for outline generation."""
|
||||
|
||||
@@ -5,8 +5,8 @@ Enhances individual outline sections for better engagement and value.
|
||||
"""
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from models.blog_models import BlogOutlineSection
|
||||
import json
|
||||
|
||||
|
||||
class SectionEnhancer:
|
||||
@@ -73,14 +73,45 @@ class SectionEnhancer:
|
||||
"required": ["heading", "subheadings", "key_points", "target_words", "keywords"]
|
||||
}
|
||||
|
||||
enhanced_data = llm_text_gen(
|
||||
raw = llm_text_gen(
|
||||
prompt=enhancement_prompt,
|
||||
json_struct=enhancement_schema,
|
||||
system_prompt=None,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if isinstance(enhanced_data, dict) and 'error' not in enhanced_data:
|
||||
# Parse JSON from LLM response (works with both string and dict return types)
|
||||
import re
|
||||
if isinstance(raw, str):
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith('```json'):
|
||||
cleaned = cleaned[7:]
|
||||
if cleaned.startswith('```'):
|
||||
cleaned = cleaned[3:]
|
||||
if cleaned.endswith('```'):
|
||||
cleaned = cleaned[:-3]
|
||||
cleaned = cleaned.strip()
|
||||
try:
|
||||
enhanced_data = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
|
||||
if json_match:
|
||||
try:
|
||||
enhanced_data = json.loads(json_match.group(0))
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Section enhancement returned invalid JSON: {e}")
|
||||
return section
|
||||
else:
|
||||
logger.warning(f"Section enhancement returned non-JSON string: {cleaned[:200]}")
|
||||
return section
|
||||
elif isinstance(raw, dict):
|
||||
enhanced_data = raw
|
||||
else:
|
||||
logger.warning(f"Unexpected LLM response type: {type(raw)}")
|
||||
return section
|
||||
|
||||
if 'error' in enhanced_data:
|
||||
logger.warning(f"AI section enhancement failed: {enhanced_data.get('error', 'Unknown error')}")
|
||||
else:
|
||||
return BlogOutlineSection(
|
||||
id=section.id,
|
||||
heading=enhanced_data.get('heading', section.heading),
|
||||
|
||||
@@ -6,6 +6,7 @@ Extracts competitor insights and market intelligence from research content.
|
||||
|
||||
from typing import Dict, Any
|
||||
from loguru import logger
|
||||
import json
|
||||
|
||||
|
||||
class CompetitorAnalyzer:
|
||||
@@ -22,7 +23,7 @@ class CompetitorAnalyzer:
|
||||
Extract and analyze:
|
||||
1. Top competitors mentioned (companies, brands, platforms)
|
||||
2. Content gaps (what competitors are missing)
|
||||
3. Market opportunities (untapped areas)
|
||||
3. Opportunities (untapped areas)
|
||||
4. Competitive advantages (what makes content unique)
|
||||
5. Market positioning insights
|
||||
6. Industry leaders and their strategies
|
||||
@@ -55,18 +56,38 @@ class CompetitorAnalyzer:
|
||||
"required": ["top_competitors", "content_gaps", "opportunities", "competitive_advantages", "market_positioning", "industry_leaders", "analysis_notes"]
|
||||
}
|
||||
|
||||
competitor_analysis = llm_text_gen(
|
||||
raw = llm_text_gen(
|
||||
prompt=competitor_prompt,
|
||||
json_struct=competitor_schema,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if isinstance(competitor_analysis, dict) and 'error' not in competitor_analysis:
|
||||
logger.info("✅ AI competitor analysis completed successfully")
|
||||
return competitor_analysis
|
||||
# Parse JSON from LLM response (works with both string and dict return types)
|
||||
import re
|
||||
if isinstance(raw, str):
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith('```json'):
|
||||
cleaned = cleaned[7:]
|
||||
if cleaned.startswith('```'):
|
||||
cleaned = cleaned[3:]
|
||||
if cleaned.endswith('```'):
|
||||
cleaned = cleaned[:-3]
|
||||
cleaned = cleaned.strip()
|
||||
try:
|
||||
competitor_analysis = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
|
||||
if json_match:
|
||||
competitor_analysis = json.loads(json_match.group(0))
|
||||
else:
|
||||
raise ValueError(f"Competitor analysis returned non-JSON string: {cleaned[:200]}")
|
||||
elif isinstance(raw, dict):
|
||||
competitor_analysis = raw
|
||||
else:
|
||||
# Fail gracefully - no fallback data
|
||||
error_msg = competitor_analysis.get('error', 'Unknown error') if isinstance(competitor_analysis, dict) else str(competitor_analysis)
|
||||
logger.error(f"AI competitor analysis failed: {error_msg}")
|
||||
raise ValueError(f"Competitor analysis failed: {error_msg}")
|
||||
raise ValueError(f"Unexpected LLM response type: {type(raw)}")
|
||||
|
||||
if 'error' in competitor_analysis:
|
||||
raise ValueError(f"Competitor analysis failed: {competitor_analysis.get('error', 'Unknown error')}")
|
||||
|
||||
logger.info("✅ AI competitor analysis completed successfully")
|
||||
return competitor_analysis
|
||||
|
||||
|
||||
@@ -63,18 +63,41 @@ class ContentAngleGenerator:
|
||||
"required": ["content_angles"]
|
||||
}
|
||||
|
||||
angles_result = llm_text_gen(
|
||||
raw = llm_text_gen(
|
||||
prompt=angles_prompt,
|
||||
json_struct=angles_schema,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if isinstance(angles_result, dict) and 'content_angles' in angles_result:
|
||||
logger.info("✅ AI content angles generation completed successfully")
|
||||
return angles_result['content_angles'][:7]
|
||||
# Parse JSON from LLM response (works with both string and dict return types)
|
||||
import json, re
|
||||
if isinstance(raw, str):
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith('```json'):
|
||||
cleaned = cleaned[7:]
|
||||
if cleaned.startswith('```'):
|
||||
cleaned = cleaned[3:]
|
||||
if cleaned.endswith('```'):
|
||||
cleaned = cleaned[:-3]
|
||||
cleaned = cleaned.strip()
|
||||
try:
|
||||
angles_result = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
|
||||
if json_match:
|
||||
angles_result = json.loads(json_match.group(0))
|
||||
else:
|
||||
raise ValueError(f"Content angles returned non-JSON string: {cleaned[:200]}")
|
||||
elif isinstance(raw, dict):
|
||||
angles_result = raw
|
||||
else:
|
||||
# Fail gracefully - no fallback data
|
||||
error_msg = angles_result.get('error', 'Unknown error') if isinstance(angles_result, dict) else str(angles_result)
|
||||
logger.error(f"AI content angles generation failed: {error_msg}")
|
||||
raise ValueError(f"Content angles generation failed: {error_msg}")
|
||||
raise ValueError(f"Unexpected LLM response type: {type(raw)}")
|
||||
|
||||
if 'error' in angles_result:
|
||||
raise ValueError(f"Content angles generation failed: {angles_result.get('error', 'Unknown error')}")
|
||||
|
||||
if 'content_angles' not in angles_result:
|
||||
raise ValueError(f"Content angles missing from response")
|
||||
|
||||
logger.info("✅ AI content angles generation completed successfully")
|
||||
return angles_result['content_angles'][:7]
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ Extracts and analyzes keywords from research content using structured AI respons
|
||||
|
||||
from typing import Dict, Any, List
|
||||
from loguru import logger
|
||||
import json
|
||||
|
||||
|
||||
class KeywordAnalyzer:
|
||||
@@ -62,18 +63,38 @@ class KeywordAnalyzer:
|
||||
"required": ["primary", "secondary", "long_tail", "search_intent", "difficulty", "content_gaps", "semantic_keywords", "trending_terms", "analysis_insights"]
|
||||
}
|
||||
|
||||
keyword_analysis = llm_text_gen(
|
||||
raw = llm_text_gen(
|
||||
prompt=keyword_prompt,
|
||||
json_struct=keyword_schema,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if isinstance(keyword_analysis, dict) and 'error' not in keyword_analysis:
|
||||
logger.info("✅ AI keyword analysis completed successfully")
|
||||
return keyword_analysis
|
||||
# Parse JSON from LLM response (works with both string and dict return types)
|
||||
import re
|
||||
if isinstance(raw, str):
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith('```json'):
|
||||
cleaned = cleaned[7:]
|
||||
if cleaned.startswith('```'):
|
||||
cleaned = cleaned[3:]
|
||||
if cleaned.endswith('```'):
|
||||
cleaned = cleaned[:-3]
|
||||
cleaned = cleaned.strip()
|
||||
try:
|
||||
keyword_analysis = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
|
||||
if json_match:
|
||||
keyword_analysis = json.loads(json_match.group(0))
|
||||
else:
|
||||
raise ValueError(f"Keyword analysis returned non-JSON string: {cleaned[:200]}")
|
||||
elif isinstance(raw, dict):
|
||||
keyword_analysis = raw
|
||||
else:
|
||||
# Fail gracefully - no fallback data
|
||||
error_msg = keyword_analysis.get('error', 'Unknown error') if isinstance(keyword_analysis, dict) else str(keyword_analysis)
|
||||
logger.error(f"AI keyword analysis failed: {error_msg}")
|
||||
raise ValueError(f"Keyword analysis failed: {error_msg}")
|
||||
raise ValueError(f"Unexpected LLM response type: {type(raw)}")
|
||||
|
||||
if 'error' in keyword_analysis:
|
||||
raise ValueError(f"Keyword analysis failed: {keyword_analysis.get('error', 'Unknown error')}")
|
||||
|
||||
logger.info("✅ AI keyword analysis completed successfully")
|
||||
return keyword_analysis
|
||||
|
||||
|
||||
@@ -111,19 +111,22 @@ class ResearchService:
|
||||
# Exa research workflow
|
||||
from .exa_provider import ExaResearchProvider
|
||||
from services.subscription.preflight_validator import validate_exa_research_operations
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
import os
|
||||
import time
|
||||
|
||||
# Pre-flight validation
|
||||
db_val = next(get_db())
|
||||
# Pre-flight validation (use get_session_for_user since get_db is a FastAPI dependency)
|
||||
db_val = get_session_for_user(user_id)
|
||||
if not db_val:
|
||||
raise HTTPException(status_code=503, detail="Database temporarily unavailable. Please try again.")
|
||||
try:
|
||||
pricing_service = PricingService(db_val)
|
||||
gpt_provider = os.getenv("GPT_PROVIDER", "google")
|
||||
validate_exa_research_operations(pricing_service, user_id, gpt_provider)
|
||||
finally:
|
||||
db_val.close()
|
||||
if db_val:
|
||||
db_val.close()
|
||||
|
||||
# Execute Exa search
|
||||
api_start_time = time.time()
|
||||
@@ -162,13 +165,15 @@ class ResearchService:
|
||||
elif config.provider == ResearchProvider.TAVILY:
|
||||
# Tavily research workflow
|
||||
from .tavily_provider import TavilyResearchProvider
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
import os
|
||||
import time
|
||||
|
||||
# Pre-flight validation (similar to Exa)
|
||||
db_val = next(get_db())
|
||||
# Pre-flight validation (use get_session_for_user since get_db is a FastAPI dependency)
|
||||
db_val = get_session_for_user(user_id)
|
||||
if not db_val:
|
||||
raise HTTPException(status_code=503, detail="Database temporarily unavailable. Please try again.")
|
||||
try:
|
||||
pricing_service = PricingService(db_val)
|
||||
# Check Tavily usage limits
|
||||
@@ -429,14 +434,16 @@ class ResearchService:
|
||||
# Exa research workflow
|
||||
from .exa_provider import ExaResearchProvider
|
||||
from services.subscription.preflight_validator import validate_exa_research_operations
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
import os
|
||||
|
||||
await task_manager.update_progress(task_id, "🌐 Connecting to Exa neural search...")
|
||||
|
||||
# Pre-flight validation
|
||||
db_val = next(get_db())
|
||||
# Pre-flight validation (use get_session_for_user since get_db is a FastAPI dependency)
|
||||
db_val = get_session_for_user(user_id)
|
||||
if not db_val:
|
||||
raise HTTPException(status_code=503, detail="Database temporarily unavailable. Please try again.")
|
||||
try:
|
||||
pricing_service = PricingService(db_val)
|
||||
gpt_provider = os.getenv("GPT_PROVIDER", "google")
|
||||
@@ -446,7 +453,8 @@ class ResearchService:
|
||||
await task_manager.update_progress(task_id, f"❌ Subscription limit exceeded: {http_error.detail.get('message', str(http_error.detail)) if isinstance(http_error.detail, dict) else str(http_error.detail)}")
|
||||
raise
|
||||
finally:
|
||||
db_val.close()
|
||||
if db_val:
|
||||
db_val.close()
|
||||
|
||||
# Execute Exa search
|
||||
await task_manager.update_progress(task_id, "🤖 Executing Exa neural search...")
|
||||
@@ -485,14 +493,16 @@ class ResearchService:
|
||||
elif config.provider == ResearchProvider.TAVILY:
|
||||
# Tavily research workflow
|
||||
from .tavily_provider import TavilyResearchProvider
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
import os
|
||||
|
||||
await task_manager.update_progress(task_id, "🌐 Connecting to Tavily AI search...")
|
||||
|
||||
# Pre-flight validation
|
||||
db_val = next(get_db())
|
||||
# Pre-flight validation (use get_session_for_user since get_db is a FastAPI dependency)
|
||||
db_val = get_session_for_user(user_id)
|
||||
if not db_val:
|
||||
raise HTTPException(status_code=503, detail="Database temporarily unavailable. Please try again.")
|
||||
try:
|
||||
pricing_service = PricingService(db_val)
|
||||
# Check Tavily usage limits
|
||||
@@ -529,7 +539,8 @@ class ResearchService:
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking Tavily limits: {e}")
|
||||
finally:
|
||||
db_val.close()
|
||||
if db_val:
|
||||
db_val.close()
|
||||
|
||||
# Execute Tavily search
|
||||
await task_manager.update_progress(task_id, "🤖 Executing Tavily AI search...")
|
||||
|
||||
@@ -135,11 +135,14 @@ class TavilyResearchProvider(BaseProvider):
|
||||
|
||||
def track_tavily_usage(self, user_id: str, cost: float, search_depth: str):
|
||||
"""Track Tavily API usage after successful call."""
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from sqlalchemy import text
|
||||
|
||||
db = next(get_db())
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
logger.warning(f"[Tavily] Could not get DB session for user {user_id}, skipping usage tracking")
|
||||
return
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
current_period = pricing_service.get_current_billing_period(user_id)
|
||||
|
||||
@@ -92,6 +92,7 @@ class BlogSEORecommendationApplier:
|
||||
None,
|
||||
schema,
|
||||
user_id, # Pass user_id for subscription checking
|
||||
max_tokens=8192,
|
||||
)
|
||||
|
||||
if not result or result.get("error"):
|
||||
|
||||
@@ -233,7 +233,7 @@ def create_blog_post(
|
||||
|
||||
# BACK TO BASICS MODE: Try simplest possible structure FIRST
|
||||
# Since posting worked before Ricos/SEO, let's test with absolute minimum
|
||||
BACK_TO_BASICS_MODE = True # Set to True to test with simplest structure
|
||||
BACK_TO_BASICS_MODE = False # Disabled: full Ricos conversion now produces valid output
|
||||
|
||||
wix_logger.reset()
|
||||
wix_logger.log_operation_start("Blog Post Creation", title=title[:50] if title else None, member_id=member_id[:20] if member_id else None)
|
||||
@@ -257,8 +257,7 @@ def create_blog_post(
|
||||
'text': (content[:500] if content else "This is a post from ALwrity.").strip(),
|
||||
'decorations': []
|
||||
}
|
||||
}],
|
||||
'paragraphData': {}
|
||||
}]
|
||||
}]
|
||||
}
|
||||
|
||||
|
||||
@@ -256,17 +256,16 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
quote_content = ' '.join(quote_lines)
|
||||
text_nodes = parse_markdown_inline(quote_content)
|
||||
# CRITICAL: TEXT nodes must be wrapped in PARAGRAPH nodes within BLOCKQUOTE
|
||||
# Wix API: omit empty data objects, don't include them as {}
|
||||
paragraph_node = {
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'PARAGRAPH',
|
||||
'nodes': text_nodes,
|
||||
'paragraphData': {}
|
||||
}
|
||||
blockquote_node = {
|
||||
'id': node_id,
|
||||
'type': 'BLOCKQUOTE',
|
||||
'nodes': [paragraph_node],
|
||||
'blockquoteData': {}
|
||||
}
|
||||
nodes.append(blockquote_node)
|
||||
|
||||
@@ -332,7 +331,6 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'PARAGRAPH',
|
||||
'nodes': text_nodes,
|
||||
'paragraphData': {}
|
||||
}
|
||||
list_item_node = {
|
||||
'id': item_node_id,
|
||||
@@ -345,7 +343,6 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'id': node_id,
|
||||
'type': 'BULLETED_LIST',
|
||||
'nodes': list_node_items,
|
||||
'bulletedListData': {}
|
||||
}
|
||||
nodes.append(bulleted_list_node)
|
||||
|
||||
@@ -373,7 +370,6 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'PARAGRAPH',
|
||||
'nodes': text_nodes,
|
||||
'paragraphData': {}
|
||||
}
|
||||
list_item_node = {
|
||||
'id': item_node_id,
|
||||
@@ -386,7 +382,6 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'id': node_id,
|
||||
'type': 'ORDERED_LIST',
|
||||
'nodes': list_node_items,
|
||||
'orderedListData': {}
|
||||
}
|
||||
nodes.append(ordered_list_node)
|
||||
|
||||
@@ -442,7 +437,6 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'id': node_id,
|
||||
'type': 'PARAGRAPH',
|
||||
'nodes': text_nodes,
|
||||
'paragraphData': {}
|
||||
}
|
||||
nodes.append(paragraph_node)
|
||||
|
||||
@@ -461,7 +455,6 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'decorations': []
|
||||
}
|
||||
}],
|
||||
'paragraphData': {}
|
||||
}
|
||||
nodes.append(fallback_paragraph)
|
||||
|
||||
|
||||
@@ -20,13 +20,14 @@ class SemanticHarvesterService:
|
||||
"last_harvest_time": None
|
||||
}
|
||||
|
||||
async def harvest_website(self, website_url: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
async def harvest_website(self, website_url: str, limit: int = 100, user_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Deep crawl a website using Exa AI.
|
||||
|
||||
Args:
|
||||
website_url: The root URL to crawl.
|
||||
limit: Maximum number of pages to retrieve.
|
||||
user_id: Optional user ID for usage tracking and preflight checks.
|
||||
|
||||
Returns:
|
||||
List of pages with content and metadata.
|
||||
@@ -59,6 +60,30 @@ class SemanticHarvesterService:
|
||||
logger.warning("[SemanticHarvester] Exa service disabled. Returning placeholder data.")
|
||||
return self._get_placeholder_data(website_url)
|
||||
|
||||
# Preflight subscription check if user_id provided
|
||||
if user_id:
|
||||
try:
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
db = get_session_for_user(user_id)
|
||||
if db:
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.EXA,
|
||||
tokens_requested=0,
|
||||
actual_provider_name="exa",
|
||||
)
|
||||
if not can_proceed:
|
||||
logger.warning(f"[SemanticHarvester] Exa blocked for user {user_id}: {message}")
|
||||
return []
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"[SemanticHarvester] Preflight check failed: {e}")
|
||||
|
||||
# Use Exa to search for all pages in this domain
|
||||
search_response = self.exa_service.exa.search_and_contents(
|
||||
query=f"site:{website_url}",
|
||||
@@ -82,6 +107,38 @@ class SemanticHarvesterService:
|
||||
})
|
||||
|
||||
logger.info(f"[SemanticHarvester] Successfully harvested {len(results)} pages from {website_url}")
|
||||
|
||||
# Track Exa usage if user_id provided
|
||||
if user_id and results:
|
||||
try:
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from sqlalchemy import text
|
||||
db = get_session_for_user(user_id)
|
||||
if db:
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
current_period = pricing_service.get_current_billing_period(user_id)
|
||||
cost = 0.005 # Exa search cost estimate
|
||||
|
||||
update_query = text("""
|
||||
UPDATE usage_summaries
|
||||
SET exa_calls = COALESCE(exa_calls, 0) + 1,
|
||||
exa_cost = COALESCE(exa_cost, 0) + :cost,
|
||||
total_calls = COALESCE(total_calls, 0) + 1,
|
||||
total_cost = COALESCE(total_cost, 0) + :cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db.execute(update_query, {
|
||||
'cost': cost, 'user_id': user_id, 'period': current_period,
|
||||
})
|
||||
db.commit()
|
||||
logger.info(f"[SemanticHarvester] Tracked Exa usage: user={user_id}, cost=${cost}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as track_err:
|
||||
logger.warning(f"[SemanticHarvester] Failed to track Exa usage: {track_err}")
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -133,9 +133,9 @@ def edit_image(
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Image Editing] ❌ Unexpected error during pre-flight validation: {e}")
|
||||
# In podcast-only mode, allow the operation to continue on validation errors
|
||||
if os.getenv("ALWRITY_ENABLED_FEATURES") == "podcast":
|
||||
logger.warning(f"[Image Editing] ⚠️ Validation error in podcast mode - allowing operation to continue")
|
||||
# In feature-limited mode, allow the operation to continue on validation errors
|
||||
if os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() not in ("", "all"):
|
||||
logger.warning(f"[Image Editing] ⚠️ Validation error in feature-limited mode - allowing operation to continue")
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail=f"Image editing validation failed: {str(e)}")
|
||||
finally:
|
||||
|
||||
@@ -45,6 +45,7 @@ def llm_text_gen(
|
||||
preferred_hf_models: Optional[List[str]] = None,
|
||||
preferred_provider: Optional[str] = None,
|
||||
flow_type: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate text using Language Model (LLM) based on the provided prompt.
|
||||
@@ -75,7 +76,8 @@ def llm_text_gen(
|
||||
gpt_provider = "google" # Default to Google Gemini
|
||||
model = "gemini-2.0-flash-001"
|
||||
temperature = 0.7
|
||||
max_tokens = 4000
|
||||
if max_tokens is None:
|
||||
max_tokens = 4000
|
||||
top_p = 0.9
|
||||
n = 1
|
||||
fp = 16
|
||||
@@ -371,16 +373,27 @@ def llm_text_gen(
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
elif gpt_provider == "wavespeed":
|
||||
from services.llm_providers.wavespeed_provider import wavespeed_text_response
|
||||
llm_start = time.time()
|
||||
response_text = wavespeed_text_response(
|
||||
prompt=prompt,
|
||||
model=model or "openai/gpt-oss-120b",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
if json_struct:
|
||||
from services.llm_providers.wavespeed_provider import wavespeed_structured_json_response
|
||||
response_text = wavespeed_structured_json_response(
|
||||
prompt=prompt,
|
||||
schema=json_struct,
|
||||
model=model or "openai/gpt-oss-120b",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
else:
|
||||
from services.llm_providers.wavespeed_provider import wavespeed_text_response
|
||||
response_text = wavespeed_text_response(
|
||||
prompt=prompt,
|
||||
model=model or "openai/gpt-oss-120b",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
llm_ms = (time.time() - llm_start) * 1000
|
||||
logger.warning(f"[llm_text_gen][{flow_tag}] LLM API call took {llm_ms:.0f}ms for user {user_id} (wavespeed)")
|
||||
else:
|
||||
|
||||
@@ -179,6 +179,43 @@ def get_wavespeed_api_key() -> str:
|
||||
|
||||
return api_key
|
||||
|
||||
|
||||
def _retry_with_increased_tokens(
|
||||
client: "OpenAI",
|
||||
messages: List[Dict[str, str]],
|
||||
model: str,
|
||||
fallback_models: Optional[List[str]],
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
) -> Optional[str]:
|
||||
"""Retry the API call with increased max_tokens when JSON parsing fails due to truncation."""
|
||||
max_tokens = min(max_tokens, 16384)
|
||||
last_error = None
|
||||
for candidate_model in _fallback_model_sequence(model, fallback_models):
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=candidate_model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
text = response.choices[0].message.content
|
||||
text = text.strip() if text else ""
|
||||
if text.startswith("```json"):
|
||||
text = text[7:]
|
||||
if text.startswith("```"):
|
||||
text = text[3:]
|
||||
if text.endswith("```"):
|
||||
text = text[:-3]
|
||||
return text.strip()
|
||||
except NotFoundError as nf_err:
|
||||
last_error = nf_err
|
||||
continue
|
||||
if last_error:
|
||||
logger.error(f"All fallback models failed on retry with increased tokens: {last_error}")
|
||||
return None
|
||||
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception(_should_retry_wavespeed_error),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
@@ -446,24 +483,69 @@ def wavespeed_structured_json_response(
|
||||
raise last_error or Exception("WaveSpeed structured generation failed: all fallback models failed")
|
||||
|
||||
response_text = response.choices[0].message.content
|
||||
response_text = response_text.strip() if response_text else ""
|
||||
|
||||
# If response_format returned empty content, retry without it
|
||||
if not response_text:
|
||||
logger.warning("WaveSpeed structured call returned empty content with response_format, retrying without it...")
|
||||
response = None
|
||||
last_error = None
|
||||
for candidate_model in _fallback_model_sequence(model, fallback_models):
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=candidate_model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
break
|
||||
except NotFoundError as nf_err:
|
||||
last_error = nf_err
|
||||
continue
|
||||
if response is not None:
|
||||
response_text = response.choices[0].message.content
|
||||
response_text = response_text.strip() if response_text else ""
|
||||
|
||||
# Clean up response text if needed
|
||||
response_text = response_text.strip()
|
||||
if response_text.startswith("```json"):
|
||||
response_text = response_text[7:]
|
||||
if response_text.startswith("```"):
|
||||
response_text = response_text[3:]
|
||||
if response_text.endswith("```"):
|
||||
response_text = response_text[:-3]
|
||||
response_text = response_text.strip()
|
||||
|
||||
try:
|
||||
parsed_json = json.loads(response_text)
|
||||
logger.info("✅ WaveSpeed structured JSON response parsed successfully")
|
||||
return parsed_json
|
||||
parsed_json = json.loads(response_text) if response_text else None
|
||||
if parsed_json is not None:
|
||||
logger.info("✅ WaveSpeed structured JSON response parsed successfully")
|
||||
return parsed_json
|
||||
except json.JSONDecodeError as json_err:
|
||||
logger.error(f"❌ JSON parsing failed: {json_err}")
|
||||
logger.error(f"Raw response: {response_text}")
|
||||
# Retry once with increased max_tokens — likely a truncation issue
|
||||
if max_tokens < 16384:
|
||||
logger.warning(f"Retrying with increased max_tokens ({max_tokens} → {max_tokens * 2}) due to JSON parse failure")
|
||||
response_text = _retry_with_increased_tokens(
|
||||
client=client,
|
||||
messages=messages,
|
||||
model=model,
|
||||
fallback_models=fallback_models,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens * 2,
|
||||
)
|
||||
if response_text:
|
||||
try:
|
||||
parsed_json = json.loads(response_text)
|
||||
if parsed_json is not None:
|
||||
logger.info("✅ WaveSpeed structured JSON parsed successfully after max_tokens increase")
|
||||
return parsed_json
|
||||
except json.JSONDecodeError:
|
||||
logger.error("❌ JSON parsing failed even after max_tokens increase")
|
||||
|
||||
# Try to extract JSON from the response using regex
|
||||
logger.error(f"Raw response: {response_text}")
|
||||
|
||||
# Try to extract JSON from the response using regex
|
||||
if response_text:
|
||||
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
||||
if json_match:
|
||||
try:
|
||||
@@ -472,8 +554,8 @@ def wavespeed_structured_json_response(
|
||||
return extracted_json
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return {"error": "Failed to parse JSON response", "raw_response": response_text}
|
||||
|
||||
return {"error": "Failed to parse JSON response", "raw_response": response_text}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ WaveSpeed API call failed: {e}")
|
||||
@@ -501,14 +583,24 @@ def wavespeed_structured_json_response(
|
||||
if response is None:
|
||||
raise last_error or e
|
||||
response_text = response.choices[0].message.content
|
||||
# ... (same parsing logic would apply, simplified here for brevity)
|
||||
response_text = response_text.strip() if response_text else ""
|
||||
# Parse JSON with robust cleaning
|
||||
if response_text.startswith("```json"):
|
||||
response_text = response_text[7:]
|
||||
if response_text.startswith("```"):
|
||||
response_text = response_text[3:]
|
||||
if response_text.endswith("```"):
|
||||
response_text = response_text[:-3]
|
||||
response_text = response_text.strip()
|
||||
try:
|
||||
return json.loads(response_text)
|
||||
except:
|
||||
# Regex fallback
|
||||
return json.loads(response_text) if response_text else {"error": "Empty response"}
|
||||
except json.JSONDecodeError:
|
||||
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
||||
if json_match:
|
||||
return json.loads(json_match.group())
|
||||
try:
|
||||
return json.loads(json_match.group())
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return {"error": "Failed to parse JSON response", "raw_response": response_text}
|
||||
raise e
|
||||
|
||||
|
||||
@@ -19,11 +19,11 @@ from services.database import get_db_session
|
||||
from models.onboarding import OnboardingSession, WebsiteAnalysis, ResearchPreferences
|
||||
from models.persona_models import WritingPersona, PlatformPersona, PersonaAnalysisResult
|
||||
|
||||
def _get_podcast_mode():
|
||||
"""Check if running in podcast-only mode to skip heavy initialization."""
|
||||
def _is_feature_limited_mode():
|
||||
"""Check if running in feature-limited mode to skip heavy initialization."""
|
||||
import os
|
||||
env_val = os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower()
|
||||
return env_val == "podcast"
|
||||
return env_val not in ("", "all")
|
||||
|
||||
class PersonaAnalysisService:
|
||||
"""Service for analyzing onboarding data and generating writing personas using Gemini AI."""
|
||||
@@ -40,9 +40,9 @@ class PersonaAnalysisService:
|
||||
def __init__(self):
|
||||
"""Initialize the persona analysis service (only once)."""
|
||||
if not self._initialized:
|
||||
# Skip heavy initialization in podcast-only mode
|
||||
if _get_podcast_mode():
|
||||
logger.debug("PersonaAnalysisService: Skipping heavy init in podcast mode")
|
||||
# Skip heavy initialization in feature-limited mode
|
||||
if _is_feature_limited_mode():
|
||||
logger.debug(f"PersonaAnalysisService: Skipping heavy init in feature-limited mode")
|
||||
self._initialized = True
|
||||
return
|
||||
|
||||
@@ -55,8 +55,8 @@ class PersonaAnalysisService:
|
||||
return
|
||||
|
||||
# Check again in case mode changed
|
||||
if _get_podcast_mode():
|
||||
logger.debug("PersonaAnalysisService: Skipping heavy init in podcast mode")
|
||||
if _is_feature_limited_mode():
|
||||
logger.debug("PersonaAnalysisService: Skipping heavy init in feature-limited mode")
|
||||
self._heavy_init_done = True
|
||||
return
|
||||
|
||||
@@ -89,9 +89,9 @@ class PersonaAnalysisService:
|
||||
# Ensure heavy services are initialized
|
||||
self._ensure_heavy_init()
|
||||
|
||||
# Check if heavy init failed (podcast mode)
|
||||
# Check if heavy init failed (feature-limited mode)
|
||||
if not getattr(self, '_heavy_init_done', False):
|
||||
return {"error": "Persona service unavailable in podcast-only mode"}
|
||||
return {"error": "Persona service unavailable in feature-limited mode"}
|
||||
|
||||
try:
|
||||
logger.info(f"Generating persona for user {user_id}")
|
||||
|
||||
@@ -296,6 +296,33 @@ class ResearchEngine:
|
||||
target_audience = request.target_audience or "General"
|
||||
|
||||
research_prompt = strategy.build_research_prompt(topic, industry, target_audience, config)
|
||||
|
||||
# Preflight subscription check
|
||||
try:
|
||||
db = self._db_session
|
||||
if not db:
|
||||
from services.database import get_db_session
|
||||
db = get_db_session()
|
||||
if db:
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.EXA,
|
||||
tokens_requested=0,
|
||||
actual_provider_name="exa",
|
||||
)
|
||||
if not can_proceed:
|
||||
raise HTTPException(status_code=429, detail={
|
||||
'error': message, 'message': message,
|
||||
'provider': 'exa', 'usage_info': usage_info or {}
|
||||
})
|
||||
logger.info(f"[ResearchEngine] Exa preflight check passed for user {user_id}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"[ResearchEngine] Exa preflight check failed: {e}")
|
||||
|
||||
# Execute Exa search
|
||||
try:
|
||||
@@ -341,6 +368,33 @@ class ResearchEngine:
|
||||
target_audience = request.target_audience or "General"
|
||||
|
||||
research_prompt = strategy.build_research_prompt(topic, industry, target_audience, config)
|
||||
|
||||
# Preflight subscription check
|
||||
try:
|
||||
db = self._db_session
|
||||
if not db:
|
||||
from services.database import get_db_session
|
||||
db = get_db_session()
|
||||
if db:
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.TAVILY,
|
||||
tokens_requested=0,
|
||||
actual_provider_name="tavily",
|
||||
)
|
||||
if not can_proceed:
|
||||
raise HTTPException(status_code=429, detail={
|
||||
'error': message, 'message': message,
|
||||
'provider': 'tavily', 'usage_info': usage_info or {}
|
||||
})
|
||||
logger.info(f"[ResearchEngine] Tavily preflight check passed for user {user_id}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"[ResearchEngine] Tavily preflight check failed: {e}")
|
||||
|
||||
# Execute Tavily search
|
||||
try:
|
||||
|
||||
@@ -83,6 +83,30 @@ class DeepCrawlService:
|
||||
tavily_results.append(res)
|
||||
|
||||
logger.info(f"Found {len(tavily_urls)} URLs from Tavily")
|
||||
|
||||
# Track Tavily usage
|
||||
try:
|
||||
from services.subscription import PricingService
|
||||
from sqlalchemy import text
|
||||
pricing_service = PricingService(db)
|
||||
current_period = pricing_service.get_current_billing_period(user_id)
|
||||
cost = 0.005 # Tavily crawl cost estimate
|
||||
|
||||
update_query = text("""
|
||||
UPDATE usage_summaries
|
||||
SET tavily_calls = COALESCE(tavily_calls, 0) + 1,
|
||||
tavily_cost = COALESCE(tavily_cost, 0) + :cost,
|
||||
total_calls = COALESCE(total_calls, 0) + 1,
|
||||
total_cost = COALESCE(total_cost, 0) + :cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db.execute(update_query, {
|
||||
'cost': cost, 'user_id': user_id, 'period': current_period,
|
||||
})
|
||||
db.commit()
|
||||
logger.info(f"[DeepCrawl] Tracked Tavily crawl usage: user={user_id}, cost=${cost}")
|
||||
except Exception as track_err:
|
||||
logger.warning(f"[DeepCrawl] Failed to track Tavily usage: {track_err}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Tavily crawl failed: {e}")
|
||||
|
||||
|
||||
@@ -49,9 +49,11 @@ except Exception as _patch_err:
|
||||
# Now safe to import pytrends
|
||||
try:
|
||||
from pytrends.request import TrendReq as _TrendReq
|
||||
from pytrends.exceptions import TooManyRequestsError as _TooManyRequestsError
|
||||
PYTrends_AVAILABLE = True
|
||||
except ImportError:
|
||||
PYTrends_AVAILABLE = False
|
||||
_TooManyRequestsError = None
|
||||
logger.warning("pytrends not installed. Google Trends features will be unavailable.")
|
||||
|
||||
# Patch 2: pytrends related_topics() and related_queries() use keyword[0]
|
||||
@@ -139,6 +141,8 @@ class GoogleTrendsService:
|
||||
Uses TrendReq with no retries (fail-fast) to avoid hitting CAPTCHA on blocks.
|
||||
429 retry handling (1s, 2s, 4s backoff). Random user-agent is set
|
||||
per instance to reduce fingerprinting.
|
||||
|
||||
Rate limiter is shared across all instances to enforce global rate limiting.
|
||||
"""
|
||||
|
||||
USER_AGENTS = [
|
||||
@@ -150,15 +154,28 @@ class GoogleTrendsService:
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.0.0",
|
||||
]
|
||||
|
||||
# Class-level shared resources (shared across all instances)
|
||||
_shared_rate_limiter = None
|
||||
_shared_cache = None
|
||||
_cache_ttl = timedelta(hours=24)
|
||||
_last_429_time = 0 # Timestamp of last 429 error (Unix epoch)
|
||||
_429_cooldown_period = 1800 # 30 minutes cooldown after 429
|
||||
|
||||
def __init__(self):
|
||||
if not PYTrends_AVAILABLE:
|
||||
raise RuntimeError("pytrends library is required. Install with: pip install pytrends")
|
||||
|
||||
self.rate_limiter = RateLimiter(max_calls=1, period=1.0)
|
||||
self.cache: Dict[str, Any] = {}
|
||||
self.cache_ttl = timedelta(hours=24)
|
||||
# Initialize shared rate limiter at class level (lazy init)
|
||||
if self.__class__._shared_rate_limiter is None:
|
||||
self.__class__._shared_rate_limiter = RateLimiter(max_calls=1, period=3.0) # 1 call per 3 seconds
|
||||
if self.__class__._shared_cache is None:
|
||||
self.__class__._shared_cache = {}
|
||||
|
||||
logger.info("GoogleTrendsService initialized (pytrends 4.9.2, fail-fast, 2s delays)")
|
||||
self.rate_limiter = self.__class__._shared_rate_limiter
|
||||
self.cache = self.__class__._shared_cache
|
||||
self.cache_ttl = self._cache_ttl
|
||||
|
||||
logger.info("GoogleTrendsService initialized (pytrends 4.9.2, shared rate limiter, 3s period, shared cache, 30min 429 cooldown)")
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Public API
|
||||
@@ -173,7 +190,7 @@ class GoogleTrendsService:
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Comprehensive trends analysis.
|
||||
Comprehensive trends analysis with retry logic for 429 errors.
|
||||
|
||||
Args:
|
||||
keywords: List of keywords to analyze (1-5)
|
||||
@@ -193,11 +210,97 @@ class GoogleTrendsService:
|
||||
keywords = keywords[:5]
|
||||
|
||||
cache_key = self._build_cache_key(keywords, timeframe, geo)
|
||||
|
||||
# Check if we're in a 429 cooldown period
|
||||
now = time.time()
|
||||
if now - self.__class__._last_429_time < self.__class__._429_cooldown_period:
|
||||
remaining_cooldown = int(self.__class__._429_cooldown_period - (now - self.__class__._last_429_time))
|
||||
logger.warning(
|
||||
f"[Trends] In 429 cooldown period. {remaining_cooldown}s remaining. "
|
||||
f"Returning cached data if available."
|
||||
)
|
||||
cached_data = self._get_from_cache(cache_key, ignore_ttl=True) # Use stale cache
|
||||
if cached_data:
|
||||
logger.info(f"[Trends] Returning stale cached data for {keywords} during cooldown")
|
||||
return {**cached_data, "cached": True, "cooldown_active": True}
|
||||
return self._create_fallback_response(
|
||||
keywords, timeframe, geo, gprop,
|
||||
f"Rate limited by Google. Cooldown active for {remaining_cooldown}s. Try again later."
|
||||
)
|
||||
|
||||
# Check fresh cache
|
||||
cached_data = self._get_from_cache(cache_key)
|
||||
if cached_data:
|
||||
logger.info(f"Returning cached trends data for: {keywords}")
|
||||
return {**cached_data, "cached": True}
|
||||
|
||||
# Retry logic for 429 errors
|
||||
max_retries = 3
|
||||
retry_delays = [30, 60, 120] # Longer delays: 30s, 60s, 120s
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
return await self._do_analyze_trends(
|
||||
keywords, timeframe, geo, gprop, cache_key, attempt, max_retries
|
||||
)
|
||||
except Exception as e:
|
||||
# Check if this is a 429 error (pytrends raises TooManyRequestsError)
|
||||
is_429 = False
|
||||
if _TooManyRequestsError and isinstance(e, _TooManyRequestsError):
|
||||
is_429 = True
|
||||
else:
|
||||
error_str = str(e).lower()
|
||||
is_429 = "429" in error_str or "rate limit" in error_str or "too many requests" in error_str
|
||||
|
||||
if is_429:
|
||||
# Update the last 429 time for cooldown
|
||||
self.__class__._last_429_time = time.time()
|
||||
|
||||
if attempt < max_retries:
|
||||
delay = retry_delays[attempt]
|
||||
logger.warning(
|
||||
f"[Trends] 429 rate limit hit (attempt {attempt + 1}/{max_retries + 1}), "
|
||||
f"retrying in {delay}s..."
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
else:
|
||||
# Out of retries - enter cooldown
|
||||
logger.error(
|
||||
f"[Trends] 429 rate limit persisted after {max_retries + 1} attempts. "
|
||||
f"Entering {self.__class__._429_cooldown_period}s cooldown period."
|
||||
)
|
||||
# Try to return stale cache
|
||||
stale_cache = self._get_from_cache(cache_key, ignore_ttl=True)
|
||||
if stale_cache:
|
||||
logger.info(f"[Trends] Returning stale cache after 429 exhaustion for {keywords}")
|
||||
result = {**stale_cache}
|
||||
result["cached"] = True
|
||||
result["cooldown_active"] = True
|
||||
return result
|
||||
return self._create_fallback_response(
|
||||
keywords, timeframe, geo, gprop,
|
||||
f"Google is rate limiting requests. Cooldown active for {self.__class__._429_cooldown_period}s. Try again later."
|
||||
)
|
||||
else:
|
||||
# Non-429 error
|
||||
logger.error(f"Google Trends analysis failed after {attempt + 1} attempts: {e}")
|
||||
return self._create_fallback_response(keywords, timeframe, geo, gprop, str(e))
|
||||
|
||||
# Should not reach here, but just in case
|
||||
return self._create_fallback_response(keywords, timeframe, geo, gprop, "Max retries exceeded")
|
||||
|
||||
async def _do_analyze_trends(
|
||||
self,
|
||||
keywords: List[str],
|
||||
timeframe: str,
|
||||
geo: str,
|
||||
gprop: str,
|
||||
cache_key: str,
|
||||
attempt: int,
|
||||
max_retries: int,
|
||||
) -> Dict[str, Any]:
|
||||
"""Internal method to perform the actual trends analysis."""
|
||||
await self.rate_limiter.acquire()
|
||||
|
||||
total_start = time.monotonic()
|
||||
@@ -207,95 +310,63 @@ class GoogleTrendsService:
|
||||
related_topics: Dict[str, List[Dict[str, Any]]] = {"top": [], "rising": []}
|
||||
related_queries: Dict[str, List[Dict[str, Any]]] = {"top": [], "rising": []}
|
||||
|
||||
try:
|
||||
logger.info(f"[Trends] ===== START analyze_trends ===== keywords={keywords} timeframe={timeframe} geo={geo}")
|
||||
logger.info(
|
||||
f"[Trends] ===== START analyze_trends (attempt {attempt + 1}/{max_retries + 1}) ===== "
|
||||
f"keywords={keywords} timeframe={timeframe} geo={geo}"
|
||||
)
|
||||
|
||||
# Initialize TrendReq with gprop (youtube for video/podcast relevance)
|
||||
init_start = time.monotonic()
|
||||
pytrends = await asyncio.to_thread(
|
||||
self._create_pytrends,
|
||||
keywords,
|
||||
timeframe,
|
||||
geo,
|
||||
gprop,
|
||||
)
|
||||
init_ms = int((time.monotonic() - init_start) * 1000)
|
||||
logger.info(f"[Trends] TrendReq init + build_payload took {init_ms}ms")
|
||||
# Initialize TrendReq with gprop (youtube for video/podcast relevance)
|
||||
init_start = time.monotonic()
|
||||
pytrends = await asyncio.to_thread(
|
||||
self._create_pytrends,
|
||||
keywords,
|
||||
timeframe,
|
||||
geo,
|
||||
gprop,
|
||||
)
|
||||
init_ms = int((time.monotonic() - init_start) * 1000)
|
||||
logger.info(f"[Trends] TrendReq init + build_payload took {init_ms}ms")
|
||||
|
||||
# --- Interest Over Time ---
|
||||
iot_start = time.monotonic()
|
||||
interest_over_time = await asyncio.to_thread(
|
||||
lambda: self._fetch_interest_over_time(pytrends)
|
||||
)
|
||||
iot_ms = int((time.monotonic() - iot_start) * 1000)
|
||||
logger.info(f"[Trends] interest_over_time took {iot_ms}ms, returned {len(interest_over_time)} points")
|
||||
# --- Interest Over Time ONLY (skip others to avoid 429) ---
|
||||
await self.rate_limiter.acquire() # Rate limit check BEFORE each request
|
||||
iot_start = time.monotonic()
|
||||
interest_over_time = await asyncio.to_thread(
|
||||
lambda: self._fetch_interest_over_time(pytrends)
|
||||
)
|
||||
iot_ms = int((time.monotonic() - iot_start) * 1000)
|
||||
logger.info(f"[Trends] interest_over_time took {iot_ms}ms, returned {len(interest_over_time)} points")
|
||||
|
||||
await asyncio.sleep(2)
|
||||
# Skip other requests to avoid 429 - only fetch interest_over_time for now
|
||||
logger.info(f"[Trends] Skipping other requests to avoid 429 (interest_by_region, related_topics, related_queries)")
|
||||
|
||||
# --- Interest By Region ---
|
||||
ibr_start = time.monotonic()
|
||||
interest_by_region = await asyncio.to_thread(
|
||||
lambda: self._fetch_interest_by_region(pytrends)
|
||||
)
|
||||
ibr_ms = int((time.monotonic() - ibr_start) * 1000)
|
||||
logger.info(f"[Trends] interest_by_region took {ibr_ms}ms, returned {len(interest_by_region)} regions")
|
||||
total_ms = int((time.monotonic() - total_start) * 1000)
|
||||
logger.info(
|
||||
f"[Trends] ===== DONE analyze_trends ===== total={total_ms}ms "
|
||||
f"iot={len(interest_over_time)} ibr={len(interest_by_region)} "
|
||||
f"rt_top={rt_top} rq_top={rq_top}"
|
||||
)
|
||||
|
||||
await asyncio.sleep(2)
|
||||
result = {
|
||||
"interest_over_time": interest_over_time,
|
||||
"interest_by_region": interest_by_region,
|
||||
"related_topics": related_topics,
|
||||
"related_queries": related_queries,
|
||||
"timeframe": timeframe,
|
||||
"geo": geo,
|
||||
"keywords": keywords,
|
||||
"source": "web" if gprop == "" else "podcast" if gprop == "youtube" else gprop,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"cached": False,
|
||||
}
|
||||
|
||||
# --- Related Topics ---
|
||||
rt_start = time.monotonic()
|
||||
related_topics = await asyncio.to_thread(
|
||||
lambda: self._fetch_related_topics(pytrends)
|
||||
)
|
||||
rt_ms = int((time.monotonic() - rt_start) * 1000)
|
||||
rt_top = len(related_topics.get("top", []))
|
||||
rt_rising = len(related_topics.get("rising", []))
|
||||
logger.info(f"[Trends] related_topics took {rt_ms}ms, top={rt_top} rising={rt_rising}")
|
||||
self._save_to_cache(cache_key, result)
|
||||
|
||||
await asyncio.sleep(2)
|
||||
logger.info(
|
||||
f"Google Trends data fetched successfully: "
|
||||
f"{len(interest_over_time)} time points, {len(interest_by_region)} regions"
|
||||
)
|
||||
|
||||
# --- Related Queries ---
|
||||
rq_start = time.monotonic()
|
||||
related_queries = await asyncio.to_thread(
|
||||
lambda: self._fetch_related_queries(pytrends)
|
||||
)
|
||||
rq_ms = int((time.monotonic() - rq_start) * 1000)
|
||||
rq_top = len(related_queries.get("top", []))
|
||||
rq_rising = len(related_queries.get("rising", []))
|
||||
logger.info(f"[Trends] related_queries took {rq_ms}ms, top={rq_top} rising={rq_rising}")
|
||||
|
||||
total_ms = int((time.monotonic() - total_start) * 1000)
|
||||
logger.info(
|
||||
f"[Trends] ===== DONE analyze_trends ===== total={total_ms}ms "
|
||||
f"iot={len(interest_over_time)} ibr={len(interest_by_region)} "
|
||||
f"rt_top={rt_top} rq_top={rq_top}"
|
||||
)
|
||||
|
||||
result = {
|
||||
"interest_over_time": interest_over_time,
|
||||
"interest_by_region": interest_by_region,
|
||||
"related_topics": related_topics,
|
||||
"related_queries": related_queries,
|
||||
"timeframe": timeframe,
|
||||
"geo": geo,
|
||||
"keywords": keywords,
|
||||
"source": "web" if gprop == "" else "podcast" if gprop == "youtube" else gprop,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"cached": False,
|
||||
}
|
||||
|
||||
self._save_to_cache(cache_key, result)
|
||||
|
||||
logger.info(
|
||||
f"Google Trends data fetched successfully: "
|
||||
f"{len(interest_over_time)} time points, {len(interest_by_region)} regions"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Google Trends analysis failed: {e}")
|
||||
return self._create_fallback_response(keywords, timeframe, geo, gprop, str(e))
|
||||
return result
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# TrendReq factory
|
||||
@@ -346,6 +417,12 @@ class GoogleTrendsService:
|
||||
return result
|
||||
except Exception as e:
|
||||
elapsed = int((time.monotonic() - start) * 1000)
|
||||
# Re-raise 429 errors so retry logic can handle them
|
||||
if _TooManyRequestsError and isinstance(e, _TooManyRequestsError):
|
||||
raise
|
||||
error_str = str(e).lower()
|
||||
if "429" in error_str or "rate limit" in error_str or "too many requests" in error_str:
|
||||
raise
|
||||
logger.error(f"[Trends] interest_over_time failed in {elapsed}ms: {e}")
|
||||
return []
|
||||
|
||||
@@ -363,6 +440,12 @@ class GoogleTrendsService:
|
||||
return result
|
||||
except Exception as e:
|
||||
elapsed = int((time.monotonic() - start) * 1000)
|
||||
# Re-raise 429 errors so retry logic can handle them
|
||||
if _TooManyRequestsError and isinstance(e, _TooManyRequestsError):
|
||||
raise
|
||||
error_str = str(e).lower()
|
||||
if "429" in error_str or "rate limit" in error_str or "too many requests" in error_str:
|
||||
raise
|
||||
logger.error(f"[Trends] interest_by_region failed in {elapsed}ms: {e}")
|
||||
return []
|
||||
|
||||
@@ -409,6 +492,12 @@ class GoogleTrendsService:
|
||||
return result
|
||||
except Exception as e:
|
||||
elapsed = int((time.monotonic() - start) * 1000)
|
||||
# Re-raise 429 errors so retry logic can handle them
|
||||
if _TooManyRequestsError and isinstance(e, _TooManyRequestsError):
|
||||
raise
|
||||
error_str = str(e).lower()
|
||||
if "429" in error_str or "rate limit" in error_str or "too many requests" in error_str:
|
||||
raise
|
||||
logger.error(f"[Trends] related_topics failed in {elapsed}ms: {e}")
|
||||
return result
|
||||
|
||||
@@ -452,6 +541,12 @@ class GoogleTrendsService:
|
||||
return result
|
||||
except Exception as e:
|
||||
elapsed = int((time.monotonic() - start) * 1000)
|
||||
# Re-raise 429 errors so retry logic can handle them
|
||||
if _TooManyRequestsError and isinstance(e, _TooManyRequestsError):
|
||||
raise
|
||||
error_str = str(e).lower()
|
||||
if "429" in error_str or "rate limit" in error_str or "too many requests" in error_str:
|
||||
raise
|
||||
logger.error(f"[Trends] related_queries failed in {elapsed}ms: {e}")
|
||||
return result
|
||||
|
||||
@@ -503,14 +598,18 @@ class GoogleTrendsService:
|
||||
keywords_str = ":".join(sorted(keywords))
|
||||
return f"google_trends:{keywords_str}:{timeframe}:{geo}"
|
||||
|
||||
def _get_from_cache(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
||||
def _get_from_cache(self, cache_key: str, ignore_ttl: bool = False) -> Optional[Dict[str, Any]]:
|
||||
"""Get cached data. If ignore_ttl=True, return stale data too (for 429 cooldown)."""
|
||||
if cache_key not in self.cache:
|
||||
return None
|
||||
cached_entry = self.cache[cache_key]
|
||||
cached_time = datetime.fromisoformat(cached_entry.get("timestamp", ""))
|
||||
if datetime.utcnow() - cached_time > self.cache_ttl:
|
||||
del self.cache[cache_key]
|
||||
return None
|
||||
|
||||
if not ignore_ttl:
|
||||
cached_time = datetime.fromisoformat(cached_entry.get("timestamp", ""))
|
||||
if datetime.utcnow() - cached_time > self.cache_ttl:
|
||||
del self.cache[cache_key]
|
||||
return None
|
||||
|
||||
result = {**cached_entry}
|
||||
result.pop("cached", None)
|
||||
return result
|
||||
|
||||
@@ -157,10 +157,10 @@ def _check_production_api_key_loading(
|
||||
_record_check(checks, "production_api_key_loading", True, "skipped in local deploy mode")
|
||||
return
|
||||
|
||||
# Also skip in podcast-only mode (no production API keys needed)
|
||||
# Skip when in feature-limited mode (no production API keys needed)
|
||||
enabled_features = os.getenv("ALWRITY_ENABLED_FEATURES", "all").strip().lower()
|
||||
if enabled_features == "podcast":
|
||||
_record_check(checks, "production_api_key_loading", True, "skipped in podcast-only mode")
|
||||
if enabled_features and enabled_features not in ("", "all"):
|
||||
_record_check(checks, "production_api_key_loading", True, f"skipped in feature-limited mode: {enabled_features}")
|
||||
return
|
||||
|
||||
test_tenant_id = os.getenv("ALWRITY_STARTUP_TEST_TENANT_ID", "").strip()
|
||||
|
||||
@@ -12,7 +12,7 @@ from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from models.subscription_models import APIProvider, UsageAlert
|
||||
from models.subscription_models import APIProvider, UsageAlert, UserSubscription
|
||||
|
||||
class SubscriptionErrorType(Enum):
|
||||
USAGE_LIMIT_EXCEEDED = "usage_limit_exceeded"
|
||||
@@ -248,6 +248,18 @@ class SubscriptionExceptionHandler:
|
||||
return
|
||||
|
||||
try:
|
||||
# Get billing period from subscription, fallback to calendar month
|
||||
billing_period = datetime.now().strftime("%Y-%m") # default
|
||||
try:
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == error.user_id,
|
||||
UserSubscription.is_active == True
|
||||
).first()
|
||||
if subscription and subscription.current_period_start:
|
||||
billing_period = subscription.current_period_start.strftime("%Y-%m")
|
||||
except:
|
||||
pass # Use default calendar period
|
||||
|
||||
alert = UsageAlert(
|
||||
user_id=error.user_id,
|
||||
alert_type="system_error",
|
||||
@@ -256,7 +268,7 @@ class SubscriptionExceptionHandler:
|
||||
title=f"System Error: {error.error_type.value}",
|
||||
message=error.message,
|
||||
severity=error.severity.value,
|
||||
billing_period=datetime.now().strftime("%Y-%m")
|
||||
billing_period=billing_period
|
||||
)
|
||||
|
||||
self.db.add(alert)
|
||||
|
||||
@@ -157,39 +157,38 @@ class LimitValidator:
|
||||
user_tier = limits.get('tier', 'free') if limits else 'free'
|
||||
|
||||
# Get current usage for this billing period with error handling
|
||||
# Use targeted expiry instead of expire_all() to avoid nuking the entire session cache
|
||||
# Use subscription period, not calendar month
|
||||
current_period = self.pricing_service.get_current_billing_period(user_id)
|
||||
|
||||
# Only expire specific objects that might have changed after renewal
|
||||
# (subscription was already checked above; plan was expired above)
|
||||
# The usage record is the main object we need fresh, and we query it directly below
|
||||
if subscription:
|
||||
self.db.expire(subscription)
|
||||
|
||||
# Use raw SQL query first to bypass ORM cache, fallback to ORM if SQL fails
|
||||
usage = None
|
||||
try:
|
||||
current_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Only expire specific objects that might have changed after renewal
|
||||
# (subscription was already checked above; plan was expired above)
|
||||
# The usage record is the main object we need fresh, and we query it directly below
|
||||
if subscription:
|
||||
self.db.expire(subscription)
|
||||
|
||||
# Use raw SQL query first to bypass ORM cache, fallback to ORM if SQL fails
|
||||
usage = None
|
||||
try:
|
||||
from sqlalchemy import text
|
||||
sql_query = text("SELECT * FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1")
|
||||
result = self.db.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
|
||||
if result:
|
||||
# Map result to UsageSummary object
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if usage:
|
||||
self.db.refresh(usage) # Ensure fresh data
|
||||
except Exception as sql_error:
|
||||
logger.debug(f"[Subscription Check] Raw SQL query failed, using ORM: {sql_error}")
|
||||
# Fallback to ORM query
|
||||
from sqlalchemy import text
|
||||
sql_query = text("SELECT * FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1")
|
||||
result = self.db.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
|
||||
if result:
|
||||
# Map result to UsageSummary object
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if usage:
|
||||
self.db.refresh(usage) # Ensure fresh data
|
||||
except Exception as sql_error:
|
||||
logger.debug(f"[Subscription Check] Raw SQL query failed, using ORM: {sql_error}")
|
||||
# Fallback to ORM query
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if usage:
|
||||
self.db.refresh(usage) # Ensure fresh data
|
||||
|
||||
if not usage:
|
||||
# First usage this period, create summary
|
||||
@@ -448,7 +447,7 @@ class LimitValidator:
|
||||
logger.info(f"[Pre-flight Check] 📋 Validating {len(operations)} operation(s) before making any API calls")
|
||||
|
||||
# Get current usage and limits once
|
||||
current_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
current_period = self.pricing_service.get_current_billing_period(user_id)
|
||||
|
||||
logger.info(f"[Pre-flight Check] 📅 Billing Period: {current_period} (for user {user_id})")
|
||||
|
||||
|
||||
@@ -67,15 +67,56 @@ class PricingService:
|
||||
self.db.rollback()
|
||||
return True
|
||||
|
||||
def get_current_billing_period(self, user_id: str) -> Optional[str]:
|
||||
"""Return current billing period key (YYYY-MM) after ensuring subscription is current."""
|
||||
def get_current_billing_period(self, user_id: str) -> str:
|
||||
"""Return current billing period key (YYYY-MM) based on subscription, not calendar.
|
||||
Maintains backward compatibility with existing calendar-month data."""
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id,
|
||||
UserSubscription.is_active == True
|
||||
).first()
|
||||
|
||||
# Ensure subscription is current (advance if auto_renew)
|
||||
self._ensure_subscription_current(subscription)
|
||||
# Continue to use YYYY-MM for summaries
|
||||
|
||||
# Use subscription's billing period, NOT calendar month
|
||||
if subscription and subscription.current_period_start:
|
||||
sub_period = subscription.current_period_start.strftime("%Y-%m")
|
||||
|
||||
# Check if usage data exists for this subscription period
|
||||
from models.subscription_models import UsageSummary
|
||||
usage_exists = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == sub_period
|
||||
).first()
|
||||
|
||||
if usage_exists:
|
||||
return sub_period
|
||||
|
||||
# If no data for subscription period, check for calendar month data
|
||||
# This handles backward compatibility for existing users
|
||||
calendar_period = datetime.now().strftime("%Y-%m")
|
||||
if calendar_period != sub_period:
|
||||
calendar_usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == calendar_period
|
||||
).first()
|
||||
if calendar_usage:
|
||||
logger.info(f"Using calendar period {calendar_period} for backward compatibility (subscription period {sub_period} has no data)")
|
||||
return calendar_period
|
||||
|
||||
return sub_period
|
||||
|
||||
# Fallback: Check if user has any usage summary and use that period
|
||||
from models.subscription_models import UsageSummary
|
||||
latest_summary = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id
|
||||
).order_by(UsageSummary.billing_period.desc()).first()
|
||||
|
||||
if latest_summary:
|
||||
logger.info(f"Using latest billing period from UsageSummary: {latest_summary.billing_period}")
|
||||
return latest_summary.billing_period
|
||||
|
||||
# Last fallback to calendar month for free tier / no data
|
||||
return datetime.now().strftime("%Y-%m")
|
||||
|
||||
@classmethod
|
||||
@@ -830,6 +871,7 @@ class PricingService:
|
||||
'serper_calls': plan.serper_calls_limit,
|
||||
'metaphor_calls': plan.metaphor_calls_limit,
|
||||
'firecrawl_calls': plan.firecrawl_calls_limit,
|
||||
'exa_calls': getattr(plan, 'exa_calls_limit', 0), # Exa research API
|
||||
'stability_calls': plan.stability_calls_limit,
|
||||
'video_calls': getattr(plan, 'video_calls_limit', 0), # Support missing column
|
||||
'image_edit_calls': getattr(plan, 'image_edit_calls_limit', 0), # Support missing column
|
||||
|
||||
@@ -8,7 +8,7 @@ from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from models.subscription_models import UserSubscription, SubscriptionPlan, SubscriptionTier, BillingCycle, UsageStatus, FraudWarning, ProcessedStripeEvent
|
||||
from services.subscription.pricing_service import PricingService
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
REQUIRED_STRIPE_PLAN_KEYS = {
|
||||
(SubscriptionTier.BASIC.value, BillingCycle.MONTHLY.value),
|
||||
@@ -421,10 +421,6 @@ class StripeService:
|
||||
try:
|
||||
sub = stripe.Subscription.retrieve(subscription_id)
|
||||
price_id = sub['items']['data'][0]['price']['id']
|
||||
# Map price_id to internal plan_id
|
||||
# Note: You need a way to map Stripe Price IDs to your Plan IDs.
|
||||
# For now, we'll assume the metadata or a lookup.
|
||||
# Ideally, store price_id in SubscriptionPlan table or config.
|
||||
|
||||
# Update DB
|
||||
self._update_user_subscription(
|
||||
@@ -434,6 +430,24 @@ class StripeService:
|
||||
status="active",
|
||||
price_id=price_id
|
||||
)
|
||||
|
||||
# Clear PricingService cache so next status check returns updated limits
|
||||
try:
|
||||
from services.subscription import PricingService
|
||||
PricingService.clear_user_cache(user_id)
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"Failed to clear user cache after checkout for user {user_id}: {cache_err}")
|
||||
try:
|
||||
from api.subscription.cache import clear_dashboard_cache
|
||||
clear_dashboard_cache(user_id)
|
||||
logger.info(f"Cleared dashboard cache for user {user_id} after checkout")
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"Failed to clear cache after checkout for user {user_id}: {cache_err}")
|
||||
|
||||
# Expire all SQLAlchemy objects to force fresh reads
|
||||
self.db.expire_all()
|
||||
logger.info(f"Expired all SQLAlchemy objects for user {user_id} after checkout")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing checkout subscription: {e}")
|
||||
|
||||
@@ -457,11 +471,28 @@ class StripeService:
|
||||
logger.info(f"Payment succeeded for user {subscription.user_id}")
|
||||
subscription.status = UsageStatus.ACTIVE
|
||||
subscription.is_active = True
|
||||
# Update period end based on invoice lines period
|
||||
subscription.auto_renew = True
|
||||
# Update period start/end based on invoice lines period
|
||||
if invoice.get('lines'):
|
||||
period_start = invoice['lines']['data'][0]['period']['start']
|
||||
period_end = invoice['lines']['data'][0]['period']['end']
|
||||
subscription.current_period_start = datetime.fromtimestamp(period_start)
|
||||
subscription.current_period_end = datetime.fromtimestamp(period_end)
|
||||
self.db.commit()
|
||||
|
||||
# Clear PricingService cache so next status check returns updated limits
|
||||
try:
|
||||
from services.subscription import PricingService
|
||||
PricingService.clear_user_cache(subscription.user_id)
|
||||
logger.info(f"Cleared subscription cache for user {subscription.user_id} after payment success")
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"Failed to clear user cache after payment success for user {subscription.user_id}: {cache_err}")
|
||||
try:
|
||||
from api.subscription.cache import clear_dashboard_cache
|
||||
clear_dashboard_cache(subscription.user_id)
|
||||
except Exception as dash_cache_err:
|
||||
logger.warning(f"Failed to clear dashboard cache after payment success for user {subscription.user_id}: {dash_cache_err}")
|
||||
self.db.expire_all()
|
||||
|
||||
async def _handle_invoice_payment_failed(self, invoice: Dict[str, Any]):
|
||||
subscription_id = invoice.get("subscription")
|
||||
@@ -497,6 +528,12 @@ class StripeService:
|
||||
if status in ["active", "trialing"]:
|
||||
subscription.status = UsageStatus.ACTIVE
|
||||
subscription.is_active = True
|
||||
subscription.auto_renew = True
|
||||
# Update period boundaries from Stripe event
|
||||
current_period = subscription_obj.get("current_period", {})
|
||||
if current_period:
|
||||
subscription.current_period_start = datetime.fromtimestamp(current_period.get("start", 0))
|
||||
subscription.current_period_end = datetime.fromtimestamp(current_period.get("end", 0))
|
||||
elif status in ["past_due", "unpaid", "incomplete", "incomplete_expired"]:
|
||||
subscription.status = UsageStatus.PAST_DUE
|
||||
subscription.is_active = False
|
||||
@@ -506,6 +543,20 @@ class StripeService:
|
||||
subscription.auto_renew = False
|
||||
|
||||
self.db.commit()
|
||||
|
||||
# Clear PricingService cache so next status check returns updated limits
|
||||
try:
|
||||
from services.subscription import PricingService
|
||||
PricingService.clear_user_cache(subscription.user_id)
|
||||
logger.info(f"Cleared subscription cache for user {subscription.user_id} after subscription update")
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"Failed to clear user cache after subscription update for user {subscription.user_id}: {cache_err}")
|
||||
try:
|
||||
from api.subscription.cache import clear_dashboard_cache
|
||||
clear_dashboard_cache(subscription.user_id)
|
||||
except Exception as dash_cache_err:
|
||||
logger.warning(f"Failed to clear dashboard cache after subscription update for user {subscription.user_id}: {dash_cache_err}")
|
||||
self.db.expire_all()
|
||||
|
||||
async def _handle_subscription_deleted(self, subscription_obj: Dict[str, Any]):
|
||||
"""
|
||||
@@ -610,6 +661,11 @@ class StripeService:
|
||||
)
|
||||
|
||||
now = datetime.utcnow()
|
||||
# Calculate billing period end based on cycle
|
||||
if billing_cycle == BillingCycle.YEARLY:
|
||||
period_end = now + timedelta(days=365)
|
||||
else:
|
||||
period_end = now + timedelta(days=30)
|
||||
|
||||
if not subscription:
|
||||
subscription = UserSubscription(
|
||||
@@ -617,7 +673,7 @@ class StripeService:
|
||||
plan_id=plan.id,
|
||||
billing_cycle=billing_cycle,
|
||||
current_period_start=now,
|
||||
current_period_end=now,
|
||||
current_period_end=period_end,
|
||||
status=UsageStatus.ACTIVE if status == "active" else UsageStatus.SUSPENDED,
|
||||
is_active=status == "active",
|
||||
auto_renew=True,
|
||||
@@ -627,6 +683,11 @@ class StripeService:
|
||||
subscription.plan_id = plan.id
|
||||
subscription.billing_cycle = billing_cycle
|
||||
subscription.is_active = status == "active"
|
||||
subscription.status = UsageStatus.ACTIVE if status == "active" else UsageStatus.SUSPENDED
|
||||
# Reset billing period on upgrade/plan change
|
||||
subscription.current_period_start = now
|
||||
subscription.current_period_end = period_end
|
||||
subscription.auto_renew = True
|
||||
|
||||
subscription.stripe_customer_id = stripe_customer_id
|
||||
subscription.stripe_subscription_id = stripe_subscription_id
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
Usage tracking modules package.
|
||||
Split from the monolithic usage_tracking_service.py for better maintainability.
|
||||
"""
|
||||
|
||||
from .historical_usage import get_all_historical_usage, get_current_period_usage, get_usage_for_period
|
||||
from .usage_stats import get_user_usage_stats
|
||||
from .usage_trends import get_usage_trends
|
||||
from .limits_enforcement import enforce_usage_limits
|
||||
from .alerts import check_usage_alerts, create_usage_alert
|
||||
|
||||
__all__ = [
|
||||
'get_all_historical_usage',
|
||||
'get_current_period_usage',
|
||||
'get_usage_for_period',
|
||||
'get_user_usage_stats',
|
||||
'get_usage_trends',
|
||||
'enforce_usage_limits',
|
||||
'check_usage_alerts',
|
||||
'create_usage_alert',
|
||||
]
|
||||
101
backend/services/subscription/usage_tracking_modules/alerts.py
Normal file
101
backend/services/subscription/usage_tracking_modules/alerts.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Usage alert functions.
|
||||
Extracted from usage_tracking_service.py for better maintainability.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from loguru import logger
|
||||
|
||||
from models.subscription_models import UsageAlert, UsageSummary, APIProvider, UsageStatus
|
||||
|
||||
|
||||
def check_usage_alerts(user_id: str, provider: APIProvider,
|
||||
billing_period: str, db: Session, pricing_service):
|
||||
"""Check if usage alerts should be sent."""
|
||||
# Get current usage
|
||||
period_keys = {'billing_period': billing_period, 'lookup_periods': [billing_period]}
|
||||
summary = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
return
|
||||
|
||||
# Get user limits
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
if not limits:
|
||||
return
|
||||
|
||||
# Check for alert thresholds (80%, 90%, 100%)
|
||||
thresholds = [80, 90, 100]
|
||||
|
||||
for threshold in thresholds:
|
||||
# Check if alert already sent for this threshold
|
||||
existing_alert = db.query(UsageAlert).filter(
|
||||
UsageAlert.user_id == user_id,
|
||||
UsageAlert.billing_period == billing_period,
|
||||
UsageAlert.threshold_percentage == threshold,
|
||||
UsageAlert.provider == provider,
|
||||
UsageAlert.is_sent == True
|
||||
).first()
|
||||
|
||||
if existing_alert:
|
||||
continue
|
||||
|
||||
# Check if threshold is reached
|
||||
provider_name = provider.value
|
||||
current_calls = getattr(summary, f"{provider_name}_calls", 0)
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
|
||||
|
||||
if call_limit > 0:
|
||||
usage_percentage = (current_calls / call_limit) * 100
|
||||
|
||||
if usage_percentage >= threshold:
|
||||
create_usage_alert(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
threshold=threshold,
|
||||
current_usage=current_calls,
|
||||
limit=call_limit,
|
||||
billing_period=billing_period,
|
||||
db=db
|
||||
)
|
||||
|
||||
|
||||
def create_usage_alert(user_id: str, provider: APIProvider,
|
||||
threshold: int, current_usage: int, limit: int,
|
||||
billing_period: str, db: Session):
|
||||
"""Create a usage alert."""
|
||||
|
||||
# Determine alert type and severity
|
||||
if threshold >= 100:
|
||||
alert_type = "limit_reached"
|
||||
severity = "error"
|
||||
title = f"API Limit Reached - {provider.value.title()}"
|
||||
message = f"You have reached your {provider.value} API limit of {limit:,} calls for this billing period."
|
||||
elif threshold >= 90:
|
||||
alert_type = "usage_warning"
|
||||
severity = "warning"
|
||||
title = f"API Usage Warning - {provider.value.title()}"
|
||||
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
|
||||
else:
|
||||
alert_type = "usage_warning"
|
||||
severity = "info"
|
||||
title = f"API Usage Notice - {provider.value.title()}"
|
||||
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
|
||||
|
||||
alert = UsageAlert(
|
||||
user_id=user_id,
|
||||
alert_type=alert_type,
|
||||
threshold_percentage=threshold,
|
||||
provider=provider,
|
||||
title=title,
|
||||
message=message,
|
||||
severity=severity,
|
||||
billing_period=billing_period
|
||||
)
|
||||
|
||||
db.add(alert)
|
||||
logger.info(f"Created usage alert for {user_id}: {title}")
|
||||
@@ -0,0 +1,250 @@
|
||||
"""
|
||||
Historical usage aggregation functions.
|
||||
Extracted from usage_tracking_service.py for better maintainability.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from loguru import logger
|
||||
from datetime import datetime
|
||||
|
||||
from models.subscription_models import UsageSummary, UsageStatus
|
||||
|
||||
|
||||
# Shared provider mapping: DB column → frontend key
|
||||
PROVIDER_MAPPING = {
|
||||
'gemini_calls': 'gemini',
|
||||
'openai_calls': 'openai',
|
||||
'anthropic_calls': 'anthropic',
|
||||
'mistral_calls': 'huggingface', # HuggingFace stored as mistral
|
||||
'wavespeed_calls': 'wavespeed',
|
||||
'exa_calls': 'exa',
|
||||
'tavily_calls': 'tavily',
|
||||
'serper_calls': 'serper',
|
||||
'firecrawl_calls': 'firecrawl',
|
||||
'metaphor_calls': 'metaphor',
|
||||
'stability_calls': 'stability',
|
||||
'video_calls': 'video',
|
||||
'image_edit_calls': 'image_edit',
|
||||
'audio_calls': 'audio',
|
||||
}
|
||||
|
||||
|
||||
def _build_provider_breakdown(summaries: list, mapping: dict) -> dict:
|
||||
"""Build provider_breakdown dict from a list of UsageSummary records."""
|
||||
breakdown = {}
|
||||
for db_col, frontend_key in mapping.items():
|
||||
total = sum(getattr(s, db_col, 0) or 0 for s in summaries)
|
||||
breakdown[frontend_key] = {'calls': total, 'cost': 0, 'tokens': 0}
|
||||
return breakdown
|
||||
|
||||
|
||||
def _build_usage_percentages(provider_breakdown: dict, limits: dict) -> dict:
|
||||
"""Build usage_percentages dict from provider_breakdown and per-period limits."""
|
||||
pcts = {}
|
||||
if not limits or not limits.get('limits'):
|
||||
return pcts
|
||||
|
||||
limit_map = {
|
||||
'gemini_calls': ('gemini', 'gemini_calls'),
|
||||
'huggingface_calls': ('huggingface', 'mistral_calls'),
|
||||
'stability_calls': ('stability', 'stability_calls'),
|
||||
'video_calls': ('video', 'video_calls'),
|
||||
'audio_calls': ('audio', 'audio_calls'),
|
||||
'image_edit_calls': ('image_edit', 'image_edit_calls'),
|
||||
'wavespeed_calls': ('wavespeed', 'wavespeed_calls'),
|
||||
'tavily_calls': ('tavily', 'tavily_calls'),
|
||||
'serper_calls': ('serper', 'serper_calls'),
|
||||
'firecrawl_calls': ('firecrawl', 'firecrawl_calls'),
|
||||
'metaphor_calls': ('metaphor', 'metaphor_calls'),
|
||||
'exa_calls': ('exa', 'exa_calls'),
|
||||
}
|
||||
|
||||
for pct_key, (bk_key, limit_key) in limit_map.items():
|
||||
used = provider_breakdown.get(bk_key, {}).get('calls', 0)
|
||||
limit_val = limits.get('limits', {}).get(limit_key, 0) or 0
|
||||
if limit_val > 0:
|
||||
pcts[pct_key] = (used / limit_val) * 100
|
||||
|
||||
# Cost percentage
|
||||
total_cost = provider_breakdown.get('total_cost', 0)
|
||||
cost_limit = limits.get('limits', {}).get('monthly_cost', 0) or 0
|
||||
if cost_limit > 0:
|
||||
pcts['cost'] = (total_cost / cost_limit) * 100
|
||||
|
||||
return pcts
|
||||
|
||||
|
||||
def _summaries_usage_status(summaries: list) -> str:
|
||||
"""Derive overall usage_status from a list of summaries."""
|
||||
status = 'active'
|
||||
for s in summaries:
|
||||
try:
|
||||
st = s.usage_status.value
|
||||
except Exception:
|
||||
st = str(s.usage_status)
|
||||
if st == 'limit_reached':
|
||||
return 'limit_reached'
|
||||
if st == 'warning' and status != 'limit_reached':
|
||||
status = 'warning'
|
||||
return status
|
||||
|
||||
|
||||
def _empty_usage_response(billing_period: str, limits: dict) -> Dict[str, Any]:
|
||||
"""Return a zeroed UsageStats-shaped response."""
|
||||
return {
|
||||
'billing_period': billing_period,
|
||||
'usage_status': 'active',
|
||||
'total_calls': 0,
|
||||
'total_tokens': 0,
|
||||
'total_cost': 0.0,
|
||||
'avg_response_time': 0.0,
|
||||
'error_rate': 0.0,
|
||||
'limits': limits,
|
||||
'provider_breakdown': {},
|
||||
'usage_percentages': {},
|
||||
'historical_breakdown': [],
|
||||
'last_updated': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
def get_all_historical_usage(user_id: str, db: Session, pricing_service) -> Dict[str, Any]:
|
||||
"""Get ALL historical usage data aggregated across all billing periods."""
|
||||
all_summaries = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id
|
||||
).order_by(UsageSummary.billing_period.desc()).all()
|
||||
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
|
||||
if not all_summaries:
|
||||
return _empty_usage_response('all', limits)
|
||||
|
||||
# Aggregate
|
||||
total_calls = sum(s.total_calls or 0 for s in all_summaries)
|
||||
total_tokens = sum(s.total_tokens or 0 for s in all_summaries)
|
||||
total_cost = sum(float(s.total_cost or 0) for s in all_summaries)
|
||||
|
||||
total_weighted_time = sum((s.avg_response_time or 0) * (s.total_calls or 0) for s in all_summaries)
|
||||
avg_response_time = total_weighted_time / total_calls if total_calls > 0 else 0.0
|
||||
|
||||
total_errors = sum((s.total_calls or 0) * (s.error_rate or 0) / 100 for s in all_summaries)
|
||||
error_rate = (total_errors / total_calls * 100) if total_calls > 0 else 0.0
|
||||
|
||||
provider_breakdown = _build_provider_breakdown(all_summaries, PROVIDER_MAPPING)
|
||||
|
||||
# Historical breakdown per period
|
||||
historical_breakdown = []
|
||||
for s in all_summaries:
|
||||
try:
|
||||
status_val = s.usage_status.value
|
||||
except Exception:
|
||||
status_val = str(s.usage_status)
|
||||
historical_breakdown.append({
|
||||
'billing_period': s.billing_period,
|
||||
'total_calls': s.total_calls or 0,
|
||||
'total_tokens': s.total_tokens or 0,
|
||||
'total_cost': float(s.total_cost or 0),
|
||||
'usage_status': status_val,
|
||||
'updated_at': s.updated_at.isoformat() if s.updated_at else None
|
||||
})
|
||||
|
||||
return {
|
||||
'billing_period': 'all',
|
||||
'usage_status': _summaries_usage_status(all_summaries),
|
||||
'total_calls': total_calls,
|
||||
'total_tokens': total_tokens,
|
||||
'total_cost': round(total_cost, 2),
|
||||
'avg_response_time': round(avg_response_time, 2),
|
||||
'error_rate': round(error_rate, 2),
|
||||
'limits': limits,
|
||||
'provider_breakdown': provider_breakdown,
|
||||
'usage_percentages': {}, # misleading for all-time vs per-period limits
|
||||
'historical_breakdown': historical_breakdown,
|
||||
'last_updated': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
def get_current_period_usage(user_id: str, db: Session, pricing_service) -> Dict[str, Any]:
|
||||
"""Get current billing period usage data with correct per-period limit percentages.
|
||||
|
||||
Returns a UsageStats-shaped dict with provider_breakdown and usage_percentages
|
||||
computed against the plan's per-period limits.
|
||||
"""
|
||||
current_period = pricing_service.get_current_billing_period(user_id)
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
|
||||
summary = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
result = _empty_usage_response(current_period, limits)
|
||||
result['usage_percentages'] = _build_usage_percentages({}, limits)
|
||||
return result
|
||||
|
||||
provider_breakdown = _build_provider_breakdown([summary], PROVIDER_MAPPING)
|
||||
|
||||
usage_percentages = _build_usage_percentages(provider_breakdown, limits)
|
||||
|
||||
try:
|
||||
status_val = summary.usage_status.value
|
||||
except Exception:
|
||||
status_val = str(summary.usage_status)
|
||||
|
||||
return {
|
||||
'billing_period': current_period,
|
||||
'usage_status': status_val,
|
||||
'total_calls': summary.total_calls or 0,
|
||||
'total_tokens': summary.total_tokens or 0,
|
||||
'total_cost': round(float(summary.total_cost or 0), 2),
|
||||
'avg_response_time': summary.avg_response_time or 0.0,
|
||||
'error_rate': summary.error_rate or 0.0,
|
||||
'limits': limits,
|
||||
'provider_breakdown': provider_breakdown,
|
||||
'usage_percentages': usage_percentages,
|
||||
'historical_breakdown': [],
|
||||
'last_updated': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
def get_usage_for_period(user_id: str, billing_period: str, db: Session, pricing_service) -> Dict[str, Any]:
|
||||
"""Get usage data for a specific billing period.
|
||||
|
||||
Returns a UsageStats-shaped dict with that period's provider_breakdown
|
||||
and usage_percentages computed against plan limits.
|
||||
"""
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
|
||||
summary = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == billing_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
result = _empty_usage_response(billing_period, limits)
|
||||
result['usage_percentages'] = _build_usage_percentages({}, limits)
|
||||
return result
|
||||
|
||||
provider_breakdown = _build_provider_breakdown([summary], PROVIDER_MAPPING)
|
||||
usage_percentages = _build_usage_percentages(provider_breakdown, limits)
|
||||
|
||||
try:
|
||||
status_val = summary.usage_status.value
|
||||
except Exception:
|
||||
status_val = str(summary.usage_status)
|
||||
|
||||
return {
|
||||
'billing_period': billing_period,
|
||||
'usage_status': status_val,
|
||||
'total_calls': summary.total_calls or 0,
|
||||
'total_tokens': summary.total_tokens or 0,
|
||||
'total_cost': round(float(summary.total_cost or 0), 2),
|
||||
'avg_response_time': summary.avg_response_time or 0.0,
|
||||
'error_rate': summary.error_rate or 0.0,
|
||||
'limits': limits,
|
||||
'provider_breakdown': provider_breakdown,
|
||||
'usage_percentages': usage_percentages,
|
||||
'historical_breakdown': [],
|
||||
'last_updated': datetime.now().isoformat()
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
Usage limit enforcement functions.
|
||||
Extracted from usage_tracking_service.py for better maintainability.
|
||||
"""
|
||||
|
||||
from typing import Tuple, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
from loguru import logger
|
||||
|
||||
from models.subscription_models import APIProvider
|
||||
from services.subscription.pricing_service import PricingService
|
||||
|
||||
|
||||
def enforce_usage_limits(user_id: str, provider: APIProvider,
|
||||
tokens_requested: int, db: Session,
|
||||
pricing_service: PricingService) -> Tuple[bool, str, Dict[str, Any]]:
|
||||
"""Enforce usage limits before making an API call."""
|
||||
# Check short-lived cache first (30s)
|
||||
cache_key = f"{user_id}:{provider.value}"
|
||||
now = datetime.utcnow()
|
||||
|
||||
# This would need access to self._enforce_cache
|
||||
# For now, keeping the structure
|
||||
|
||||
result = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
tokens_requested=tokens_requested
|
||||
)
|
||||
|
||||
# Cache the result
|
||||
# self._enforce_cache[cache_key] = {
|
||||
# 'result': result,
|
||||
# 'expires_at': now + timedelta(seconds=30)
|
||||
# }
|
||||
|
||||
return tuple(result)
|
||||
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Usage statistics functions.
|
||||
Extracted from usage_tracking_service.py for better maintainability.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from loguru import logger
|
||||
from datetime import datetime
|
||||
|
||||
from models.subscription_models import UsageSummary, UsageStatus, APIProvider
|
||||
from services.subscription.usage_tracking_modules.historical_usage import get_all_historical_usage, get_usage_for_period
|
||||
|
||||
|
||||
def get_user_usage_stats(user_id: str, billing_period: str, db: Session, pricing_service) -> Dict[str, Any]:
|
||||
"""Get comprehensive usage statistics for a user.
|
||||
When no billing_period is specified, returns ALL historical usage data.
|
||||
When a specific period is given, returns only that period's data."""
|
||||
|
||||
if not user_id:
|
||||
logger.error("get_user_usage_stats called without user_id")
|
||||
raise ValueError("user_id is required")
|
||||
|
||||
# If no billing_period requested, return ALL historical data
|
||||
if not billing_period:
|
||||
return get_all_historical_usage(user_id, db, pricing_service)
|
||||
|
||||
# Return data for the specific billing period
|
||||
return get_usage_for_period(user_id, billing_period, db, pricing_service)
|
||||
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
Usage trends functions.
|
||||
Extracted from usage_tracking_service.py for better maintainability.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def get_usage_trends(user_id: str, months: int, db: Session) -> Dict[str, Any]:
|
||||
"""Get usage trends over time with self-healing from logs."""
|
||||
from services.subscription.usage_tracking_helpers import build_billing_periods, query_usage_summaries, self_heal_summaries_from_logs, build_usage_trends_response
|
||||
|
||||
periods = build_billing_periods(months)
|
||||
summary_dict = query_usage_summaries(db, user_id, periods)
|
||||
self_heal_summaries_from_logs(db, user_id, periods, summary_dict)
|
||||
return build_usage_trends_response(periods, summary_dict)
|
||||
@@ -1,41 +1,60 @@
|
||||
"""
|
||||
Usage Tracking Service
|
||||
Comprehensive tracking of API usage, costs, and subscription limits.
|
||||
Usage Tracking Service - Refactored into modular components.
|
||||
|
||||
This file now serves as a facade that delegates to specialized modules
|
||||
in the usage_tracking_modules package.
|
||||
|
||||
Modules:
|
||||
- historical_usage: Functions for aggregating historical usage data
|
||||
- usage_stats: Functions for getting user usage statistics
|
||||
- usage_trends: Functions for usage trend analysis
|
||||
- limit_enforcement: Functions for enforcing usage limits
|
||||
- alerts: Functions for usage alerts
|
||||
"""
|
||||
|
||||
# Ensure Optional is available in global scope for dynamic imports
|
||||
from typing import Optional
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, Tuple, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import text
|
||||
from loguru import logger
|
||||
import json
|
||||
from api.subscription.cache import clear_dashboard_cache
|
||||
from datetime import datetime, timedelta
|
||||
import time
|
||||
|
||||
from models.subscription_models import (
|
||||
APIUsageLog, UsageSummary, APIProvider, UsageAlert,
|
||||
UserSubscription, UsageStatus
|
||||
APIProvider, UsageStatus, UserSubscription,
|
||||
UsageSummary, APIUsageLog, UsageAlert
|
||||
)
|
||||
from .pricing_service import PricingService
|
||||
from .provider_detection import detect_actual_provider
|
||||
from .usage_tracking_helpers import (
|
||||
build_billing_periods,
|
||||
build_default_usage_percentages,
|
||||
build_empty_usage_response,
|
||||
from services.subscription.pricing_service import PricingService
|
||||
from services.subscription.provider_detection import detect_actual_provider
|
||||
from services.subscription.usage_tracking_helpers import (
|
||||
build_provider_breakdown,
|
||||
build_usage_trends_response,
|
||||
build_default_usage_percentages,
|
||||
calculate_final_total_cost,
|
||||
maybe_persist_reconciled_costs,
|
||||
build_usage_trends_response,
|
||||
build_billing_periods,
|
||||
query_usage_summaries,
|
||||
reset_usage_summary_counters,
|
||||
self_heal_summaries_from_logs,
|
||||
reset_usage_summary_counters,
|
||||
)
|
||||
# Import clear_dashboard_cache lazily to avoid circular import
|
||||
def _clear_dashboard_cache_for_user(user_id: str):
|
||||
from api.subscription.cache import clear_dashboard_cache as _clear
|
||||
return _clear(user_id)
|
||||
|
||||
from .usage_tracking_modules import (
|
||||
get_all_historical_usage,
|
||||
get_current_period_usage,
|
||||
get_usage_for_period,
|
||||
get_user_usage_stats,
|
||||
get_usage_trends,
|
||||
enforce_usage_limits,
|
||||
check_usage_alerts,
|
||||
create_usage_alert,
|
||||
)
|
||||
|
||||
|
||||
class UsageTrackingService:
|
||||
"""Service for tracking API usage and managing subscription limits."""
|
||||
"""Service for tracking API usage and managing billing information."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
@@ -43,13 +62,14 @@ class UsageTrackingService:
|
||||
# TTL cache (30s) for enforcement results to cut DB chatter
|
||||
# key: f"{user_id}:{provider}", value: { 'result': (bool,str,dict), 'expires_at': datetime }
|
||||
self._enforce_cache: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
def _get_authoritative_billing_period_keys(self, user_id: str, billing_period: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Return authoritative billing period lookup keys. Always uses calendar month for consistency."""
|
||||
"""Return authoritative billing period lookup keys. Always uses subscription period for consistency.
|
||||
Maintains backward compatibility with existing calendar-month data."""
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id
|
||||
).first()
|
||||
|
||||
|
||||
# If caller explicitly requested a billing period, use it
|
||||
if billing_period:
|
||||
return {
|
||||
@@ -58,26 +78,125 @@ class UsageTrackingService:
|
||||
"period_start": subscription.current_period_start if subscription else None,
|
||||
"period_end": subscription.current_period_end if subscription else None,
|
||||
}
|
||||
|
||||
# ALWAYS use current calendar month for billing period to ensure consistency
|
||||
# This prevents data loss when subscription spans month boundaries
|
||||
current_period = datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get subscription period if available
|
||||
subscription_period = None
|
||||
if subscription and subscription.current_period_start:
|
||||
subscription_period = subscription.current_period_start.strftime("%Y-%m")
|
||||
|
||||
# Get calendar period
|
||||
calendar_period = datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Check which period has usage data
|
||||
from models.subscription_models import UsageSummary
|
||||
|
||||
if subscription_period:
|
||||
# Check if data exists for subscription period
|
||||
sub_data = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == subscription_period
|
||||
).first()
|
||||
|
||||
if sub_data:
|
||||
# Use subscription period (has data)
|
||||
return {
|
||||
"billing_period": subscription_period,
|
||||
"lookup_periods": [subscription_period],
|
||||
"period_start": subscription.current_period_start,
|
||||
"period_end": subscription.current_period_end,
|
||||
}
|
||||
|
||||
# No data for subscription period, check calendar period (backward compatibility)
|
||||
if calendar_period != subscription_period:
|
||||
cal_data = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == calendar_period
|
||||
).first()
|
||||
|
||||
if cal_data:
|
||||
logger.info(f"Using calendar period {calendar_period} for backward compatibility (subscription period {subscription_period} has no data)")
|
||||
return {
|
||||
"billing_period": calendar_period,
|
||||
"lookup_periods": [calendar_period],
|
||||
"period_start": None,
|
||||
"period_end": None,
|
||||
}
|
||||
|
||||
# No data in either period, use subscription period
|
||||
return {
|
||||
"billing_period": subscription_period,
|
||||
"lookup_periods": [subscription_period],
|
||||
"period_start": subscription.current_period_start,
|
||||
"period_end": subscription.current_period_end,
|
||||
}
|
||||
|
||||
# No subscription, check for any existing data
|
||||
latest_summary = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id
|
||||
).order_by(UsageSummary.billing_period.desc()).first()
|
||||
|
||||
if latest_summary:
|
||||
logger.info(f"Using latest billing period from UsageSummary: {latest_summary.billing_period} for user {user_id}")
|
||||
return {
|
||||
"billing_period": latest_summary.billing_period,
|
||||
"lookup_periods": [latest_summary.billing_period],
|
||||
"period_start": None,
|
||||
"period_end": None,
|
||||
}
|
||||
|
||||
# Last fallback to calendar month for free tier / no subscription
|
||||
return {
|
||||
"billing_period": current_period,
|
||||
"lookup_periods": [current_period],
|
||||
"period_start": subscription.current_period_start if subscription else None,
|
||||
"period_end": subscription.current_period_end if subscription else None,
|
||||
"billing_period": calendar_period,
|
||||
"lookup_periods": [calendar_period],
|
||||
"period_start": None,
|
||||
"period_end": None,
|
||||
}
|
||||
|
||||
# Delegate to modular functions
|
||||
def get_user_usage_stats(self, user_id: str, billing_period: str = None) -> Dict[str, Any]:
|
||||
"""Get comprehensive usage statistics for a user."""
|
||||
return get_user_usage_stats(user_id, billing_period, self.db, self.pricing_service)
|
||||
|
||||
def _get_all_historical_usage(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get ALL historical usage data aggregated across all billing periods."""
|
||||
return get_all_historical_usage(user_id, self.db, self.pricing_service)
|
||||
|
||||
def get_current_period_usage(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get current billing period usage with correct per-period limit percentages."""
|
||||
return get_current_period_usage(user_id, self.db, self.pricing_service)
|
||||
|
||||
def get_usage_for_period(self, user_id: str, billing_period: str) -> Dict[str, Any]:
|
||||
"""Get usage for a specific billing period."""
|
||||
return get_usage_for_period(user_id, billing_period, self.db, self.pricing_service)
|
||||
|
||||
def get_usage_trends(self, user_id: str, months: int = 6) -> Dict[str, Any]:
|
||||
"""Get usage trends over time with self-healing from logs."""
|
||||
return get_usage_trends(user_id, months, self.db)
|
||||
|
||||
async def enforce_usage_limits(self, user_id: str, provider: APIProvider,
|
||||
tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]:
|
||||
"""Enforce usage limits before making an API call."""
|
||||
return enforce_usage_limits(user_id, provider, tokens_requested, self.db, self.pricing_service)
|
||||
|
||||
async def _check_usage_alerts(self, user_id: str, provider: APIProvider, billing_period: str):
|
||||
"""Check if usage alerts should be sent."""
|
||||
check_usage_alerts(user_id, provider, billing_period, self.db, self.pricing_service)
|
||||
|
||||
async def _create_usage_alert(self, user_id: str, provider: APIProvider,
|
||||
threshold: int, current_usage: int, limit: int,
|
||||
billing_period: str):
|
||||
"""Create a usage alert."""
|
||||
create_usage_alert(user_id, provider, threshold, current_usage, limit, billing_period, self.db)
|
||||
|
||||
# Keep the track_api_usage method here as it's the core functionality
|
||||
async def track_api_usage(self, user_id: str, provider: APIProvider,
|
||||
endpoint: str, method: str, model_used: str = None,
|
||||
tokens_input: int = 0, tokens_output: int = 0,
|
||||
response_time: float = 0.0, status_code: int = 200,
|
||||
request_size: int = None, response_size: int = None,
|
||||
user_agent: str = None, ip_address: str = None,
|
||||
error_message: str = None, retry_count: int = 0,
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
endpoint: str, method: str, model_used: str = None,
|
||||
tokens_input: int = 0, tokens_output: int = 0,
|
||||
response_time: float = 0.0, status_code: int = 200,
|
||||
request_size: int = None, response_size: int = None,
|
||||
user_agent: str = None, ip_address: str = None,
|
||||
error_message: str = None, retry_count: int = 0,
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
"""Track an API usage event and update billing information."""
|
||||
|
||||
try:
|
||||
@@ -165,394 +284,81 @@ class UsageTrackingService:
|
||||
|
||||
# Invalidate dashboard cache so header stats update immediately
|
||||
try:
|
||||
clear_dashboard_cache(user_id)
|
||||
_clear_dashboard_cache_for_user(user_id)
|
||||
except Exception as cache_err:
|
||||
logger.debug(f"Could not clear dashboard cache: {cache_err}")
|
||||
|
||||
logger.info(f"Tracked API usage: {user_id} -> {provider.value} -> ${cost_data['cost_total']:.6f}")
|
||||
logger.warning(f"Failed to clear dashboard cache: {cache_err}")
|
||||
|
||||
return {
|
||||
'usage_logged': True,
|
||||
'cost': cost_data['cost_total'],
|
||||
'tokens_used': (tokens_input or 0) + (tokens_output or 0),
|
||||
'billing_period': billing_period
|
||||
"success": True,
|
||||
"cost": cost_data['cost_total'],
|
||||
"tokens": (tokens_input or 0) + (tokens_output or 0),
|
||||
"billing_period": billing_period
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error tracking API usage: {str(e)}")
|
||||
logger.error(f"Failed to track API usage: {e}")
|
||||
self.db.rollback()
|
||||
return {
|
||||
'usage_logged': False,
|
||||
'error': str(e)
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def _update_usage_summary(self, user_id: str, provider: APIProvider,
|
||||
tokens_used: int, cost: float, billing_period: str,
|
||||
response_time: float, is_error: bool):
|
||||
"""Update the usage summary for a user."""
|
||||
tokens_used: int, cost: float,
|
||||
billing_period: str,
|
||||
response_time: float = 0.0,
|
||||
is_error: bool = False):
|
||||
"""Update or create usage summary for the billing period."""
|
||||
|
||||
# Get or create usage summary
|
||||
period_keys = self._get_authoritative_billing_period_keys(user_id, billing_period)
|
||||
# Get or create summary
|
||||
summary = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
|
||||
UsageSummary.billing_period == billing_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
logger.info(f"[UsageTracking] Creating new UsageSummary for user={user_id}, period={period_keys['billing_period']}")
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=period_keys["billing_period"]
|
||||
billing_period=billing_period,
|
||||
usage_status=UsageStatus.ACTIVE,
|
||||
total_calls=0,
|
||||
total_tokens=0,
|
||||
total_cost=0.0
|
||||
)
|
||||
self.db.add(summary)
|
||||
else:
|
||||
logger.debug(f"[UsageTracking] Found existing UsageSummary for user={user_id}, period={summary.billing_period}, calls={summary.total_calls}")
|
||||
|
||||
# Update provider-specific counters
|
||||
# Update counts
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.total_tokens = (summary.total_tokens or 0) + tokens_used
|
||||
summary.total_cost = (summary.total_cost or 0.0) + cost
|
||||
|
||||
# Update provider-specific counts
|
||||
provider_name = provider.value
|
||||
current_calls = getattr(summary, f"{provider_name}_calls", 0)
|
||||
current_calls = getattr(summary, f"{provider_name}_calls", 0) or 0
|
||||
setattr(summary, f"{provider_name}_calls", current_calls + 1)
|
||||
|
||||
# Update token usage for LLM providers
|
||||
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL, APIProvider.WAVESPEED]:
|
||||
current_tokens = getattr(summary, f"{provider_name}_tokens", 0)
|
||||
setattr(summary, f"{provider_name}_tokens", current_tokens + tokens_used)
|
||||
# Update provider-specific tokens
|
||||
tokens_attr = f"{provider_name}_tokens"
|
||||
if hasattr(summary, tokens_attr):
|
||||
current_tokens = getattr(summary, tokens_attr, 0) or 0
|
||||
setattr(summary, tokens_attr, current_tokens + tokens_used)
|
||||
|
||||
# Update cost
|
||||
current_cost = getattr(summary, f"{provider_name}_cost", 0.0)
|
||||
setattr(summary, f"{provider_name}_cost", current_cost + cost)
|
||||
# Update provider-specific cost
|
||||
cost_attr = f"{provider_name}_cost"
|
||||
if hasattr(summary, cost_attr):
|
||||
current_cost = getattr(summary, cost_attr, 0.0) or 0.0
|
||||
setattr(summary, cost_attr, current_cost + cost)
|
||||
|
||||
# Update totals
|
||||
summary.total_calls += 1
|
||||
summary.total_tokens += tokens_used
|
||||
summary.total_cost += cost
|
||||
# Update response time (rolling average)
|
||||
if response_time > 0:
|
||||
current_avg = summary.avg_response_time or 0.0
|
||||
current_calls = summary.total_calls or 1
|
||||
summary.avg_response_time = ((current_avg * (current_calls - 1)) + response_time) / current_calls
|
||||
|
||||
# Update performance metrics
|
||||
if summary.total_calls > 0:
|
||||
# Update average response time
|
||||
total_response_time = summary.avg_response_time * (summary.total_calls - 1) + response_time
|
||||
summary.avg_response_time = total_response_time / summary.total_calls
|
||||
|
||||
# Update error rate
|
||||
if is_error:
|
||||
error_count = int(summary.error_rate * (summary.total_calls - 1) / 100) + 1
|
||||
summary.error_rate = (error_count / summary.total_calls) * 100
|
||||
else:
|
||||
error_count = int(summary.error_rate * (summary.total_calls - 1) / 100)
|
||||
summary.error_rate = (error_count / summary.total_calls) * 100
|
||||
|
||||
# Update usage status based on limits
|
||||
await self._update_usage_status(summary)
|
||||
# Update error rate
|
||||
if is_error:
|
||||
summary.error_count = (summary.error_count or 0) + 1
|
||||
total_calls = summary.total_calls or 1
|
||||
summary.error_rate = (summary.error_count / total_calls) * 100
|
||||
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
async def _update_usage_status(self, summary: UsageSummary):
|
||||
"""Update usage status based on subscription limits."""
|
||||
|
||||
limits = self.pricing_service.get_user_limits(summary.user_id)
|
||||
if not limits:
|
||||
return
|
||||
|
||||
# Check various limits and determine status
|
||||
max_usage_percentage = 0.0
|
||||
|
||||
# Check cost limit
|
||||
cost_limit = limits['limits'].get('monthly_cost', 0)
|
||||
if cost_limit > 0:
|
||||
cost_usage_pct = (summary.total_cost / cost_limit) * 100
|
||||
max_usage_percentage = max(max_usage_percentage, cost_usage_pct)
|
||||
|
||||
# Check call limits for each provider
|
||||
for provider in APIProvider:
|
||||
provider_name = provider.value
|
||||
current_calls = getattr(summary, f"{provider_name}_calls", 0)
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
|
||||
|
||||
if call_limit > 0:
|
||||
call_usage_pct = (current_calls / call_limit) * 100
|
||||
max_usage_percentage = max(max_usage_percentage, call_usage_pct)
|
||||
|
||||
# Update status based on highest usage percentage
|
||||
if max_usage_percentage >= 100:
|
||||
summary.usage_status = UsageStatus.LIMIT_REACHED
|
||||
elif max_usage_percentage >= 80:
|
||||
summary.usage_status = UsageStatus.WARNING
|
||||
else:
|
||||
summary.usage_status = UsageStatus.ACTIVE
|
||||
|
||||
async def _check_usage_alerts(self, user_id: str, provider: APIProvider, billing_period: str):
|
||||
"""Check if usage alerts should be sent."""
|
||||
|
||||
# Get current usage
|
||||
period_keys = self._get_authoritative_billing_period_keys(user_id, billing_period)
|
||||
summary = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
return
|
||||
|
||||
# Get user limits
|
||||
limits = self.pricing_service.get_user_limits(user_id)
|
||||
if not limits:
|
||||
return
|
||||
|
||||
# Check for alert thresholds (80%, 90%, 100%)
|
||||
thresholds = [80, 90, 100]
|
||||
|
||||
for threshold in thresholds:
|
||||
# Check if alert already sent for this threshold
|
||||
existing_alert = self.db.query(UsageAlert).filter(
|
||||
UsageAlert.user_id == user_id,
|
||||
UsageAlert.billing_period == billing_period,
|
||||
UsageAlert.threshold_percentage == threshold,
|
||||
UsageAlert.provider == provider,
|
||||
UsageAlert.is_sent == True
|
||||
).first()
|
||||
|
||||
if existing_alert:
|
||||
continue
|
||||
|
||||
# Check if threshold is reached
|
||||
provider_name = provider.value
|
||||
current_calls = getattr(summary, f"{provider_name}_calls", 0)
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
|
||||
|
||||
if call_limit > 0:
|
||||
usage_percentage = (current_calls / call_limit) * 100
|
||||
|
||||
if usage_percentage >= threshold:
|
||||
await self._create_usage_alert(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
threshold=threshold,
|
||||
current_usage=current_calls,
|
||||
limit=call_limit,
|
||||
billing_period=billing_period
|
||||
)
|
||||
|
||||
async def _create_usage_alert(self, user_id: str, provider: APIProvider,
|
||||
threshold: int, current_usage: int, limit: int,
|
||||
billing_period: str):
|
||||
"""Create a usage alert."""
|
||||
|
||||
# Determine alert type and severity
|
||||
if threshold >= 100:
|
||||
alert_type = "limit_reached"
|
||||
severity = "error"
|
||||
title = f"API Limit Reached - {provider.value.title()}"
|
||||
message = f"You have reached your {provider.value} API limit of {limit:,} calls for this billing period."
|
||||
elif threshold >= 90:
|
||||
alert_type = "usage_warning"
|
||||
severity = "warning"
|
||||
title = f"API Usage Warning - {provider.value.title()}"
|
||||
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
|
||||
else:
|
||||
alert_type = "usage_warning"
|
||||
severity = "info"
|
||||
title = f"API Usage Notice - {provider.value.title()}"
|
||||
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
|
||||
|
||||
alert = UsageAlert(
|
||||
user_id=user_id,
|
||||
alert_type=alert_type,
|
||||
threshold_percentage=threshold,
|
||||
provider=provider,
|
||||
title=title,
|
||||
message=message,
|
||||
severity=severity,
|
||||
billing_period=billing_period
|
||||
)
|
||||
|
||||
self.db.add(alert)
|
||||
logger.info(f"Created usage alert for {user_id}: {title}")
|
||||
|
||||
def get_user_usage_stats(self, user_id: str, billing_period: str = None) -> Dict[str, Any]:
|
||||
"""Get comprehensive usage statistics for a user."""
|
||||
|
||||
if not user_id:
|
||||
logger.error("get_user_usage_stats called without user_id")
|
||||
raise ValueError("user_id is required")
|
||||
|
||||
requested_billing_period = billing_period
|
||||
period_keys = self._get_authoritative_billing_period_keys(user_id, requested_billing_period)
|
||||
billing_period = period_keys["billing_period"]
|
||||
|
||||
logger.debug(f"[get_user_usage_stats] user={user_id}, billing_period={billing_period}, lookup_periods={period_keys['lookup_periods']}")
|
||||
|
||||
# Get usage summary
|
||||
summary = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
|
||||
).first()
|
||||
|
||||
if summary:
|
||||
logger.debug(f"[get_user_usage_stats] Found summary: period={summary.billing_period}, calls={summary.total_calls}, cost={summary.total_cost}")
|
||||
else:
|
||||
logger.debug(f"[get_user_usage_stats] No summary found for user={user_id}, period={billing_period}")
|
||||
|
||||
# Get user limits
|
||||
limits = self.pricing_service.get_user_limits(user_id)
|
||||
|
||||
# Get recent alerts
|
||||
alerts = self.db.query(UsageAlert).filter(
|
||||
UsageAlert.user_id == user_id,
|
||||
UsageAlert.billing_period == billing_period,
|
||||
UsageAlert.is_read == False
|
||||
).order_by(UsageAlert.created_at.desc()).limit(10).all()
|
||||
|
||||
if not summary:
|
||||
# If no summary exists for current period, we should initialize it
|
||||
# This handles the "start of month" case where a user logs in but hasn't made calls yet
|
||||
if not requested_billing_period:
|
||||
logger.info(f"Initializing empty UsageSummary for user {user_id} in period {billing_period}")
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=billing_period,
|
||||
usage_status=UsageStatus.ACTIVE,
|
||||
total_calls=0,
|
||||
total_tokens=0,
|
||||
total_cost=0.0
|
||||
)
|
||||
try:
|
||||
self.db.add(summary)
|
||||
self.db.commit()
|
||||
self.db.refresh(summary)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize summary: {e}")
|
||||
self.db.rollback()
|
||||
# Fallback to zero-struct return if DB write fails
|
||||
pass
|
||||
|
||||
if not summary: # Still no summary after attempt
|
||||
return build_empty_usage_response(
|
||||
billing_period=billing_period,
|
||||
limits=limits,
|
||||
providers=APIProvider,
|
||||
)
|
||||
|
||||
# Provider breakdown - calculate costs first, then use for percentages
|
||||
# Only include Gemini and HuggingFace (HuggingFace is stored under MISTRAL enum)
|
||||
provider_breakdown, resolved_costs, core_counts = build_provider_breakdown(
|
||||
db=self.db,
|
||||
user_id=user_id,
|
||||
billing_period=billing_period,
|
||||
summary=summary,
|
||||
)
|
||||
|
||||
summary_total_cost = summary.total_cost or 0.0
|
||||
calculated_total_cost, final_total_cost = calculate_final_total_cost(
|
||||
summary_total_cost=summary_total_cost,
|
||||
resolved_costs=resolved_costs,
|
||||
)
|
||||
|
||||
maybe_persist_reconciled_costs(
|
||||
db=self.db,
|
||||
summary=summary,
|
||||
summary_total_cost=summary_total_cost,
|
||||
calculated_total_cost=calculated_total_cost,
|
||||
final_total_cost=final_total_cost,
|
||||
resolved_costs=resolved_costs,
|
||||
)
|
||||
|
||||
# Calculate usage percentages - only for Gemini and HuggingFace
|
||||
# Use the calculated costs for accurate percentages
|
||||
usage_percentages = build_default_usage_percentages(APIProvider)
|
||||
if limits:
|
||||
# Gemini
|
||||
gemini_call_limit = limits['limits'].get("gemini_calls", 0) or 0
|
||||
if gemini_call_limit > 0:
|
||||
usage_percentages['gemini_calls'] = (core_counts['gemini_calls'] / gemini_call_limit) * 100
|
||||
|
||||
# HuggingFace (stored as mistral in database)
|
||||
mistral_call_limit = limits['limits'].get("mistral_calls", 0) or 0
|
||||
if mistral_call_limit > 0:
|
||||
usage_percentages['mistral_calls'] = (core_counts['mistral_calls'] / mistral_call_limit) * 100
|
||||
|
||||
# Cost usage percentage - use final_total_cost (calculated from logs if needed)
|
||||
cost_limit = limits['limits'].get('monthly_cost', 0) or 0
|
||||
if cost_limit > 0:
|
||||
usage_percentages['cost'] = (final_total_cost / cost_limit) * 100
|
||||
|
||||
return {
|
||||
'billing_period': billing_period,
|
||||
'usage_status': summary.usage_status.value if hasattr(summary.usage_status, 'value') else str(summary.usage_status),
|
||||
'total_calls': summary.total_calls or 0,
|
||||
'total_tokens': summary.total_tokens or 0,
|
||||
'total_cost': final_total_cost,
|
||||
'avg_response_time': summary.avg_response_time or 0.0,
|
||||
'error_rate': summary.error_rate or 0.0,
|
||||
'limits': limits,
|
||||
'provider_breakdown': provider_breakdown,
|
||||
'alerts': [
|
||||
{
|
||||
'id': alert.id,
|
||||
'type': alert.alert_type,
|
||||
'title': alert.title,
|
||||
'message': alert.message,
|
||||
'severity': alert.severity,
|
||||
'created_at': alert.created_at.isoformat()
|
||||
}
|
||||
for alert in alerts
|
||||
],
|
||||
'usage_percentages': usage_percentages,
|
||||
'last_updated': summary.updated_at.isoformat()
|
||||
}
|
||||
|
||||
def get_usage_trends(self, user_id: str, months: int = 6) -> Dict[str, Any]:
|
||||
"""Get usage trends over time with self-healing from logs."""
|
||||
periods = build_billing_periods(months)
|
||||
summary_dict = query_usage_summaries(self.db, user_id, periods)
|
||||
self_heal_summaries_from_logs(self.db, user_id, periods, summary_dict)
|
||||
return build_usage_trends_response(periods, summary_dict)
|
||||
|
||||
async def enforce_usage_limits(self, user_id: str, provider: APIProvider,
|
||||
tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]:
|
||||
"""Enforce usage limits before making an API call."""
|
||||
# Check short-lived cache first (30s)
|
||||
cache_key = f"{user_id}:{provider.value}"
|
||||
now = datetime.utcnow()
|
||||
cached = self._enforce_cache.get(cache_key)
|
||||
if cached and cached.get('expires_at') and cached['expires_at'] > now:
|
||||
return tuple(cached['result']) # type: ignore
|
||||
|
||||
result = self.pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
tokens_requested=tokens_requested
|
||||
)
|
||||
self._enforce_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
|
||||
async def reset_current_billing_period(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Reset usage status and counters for the current billing period (after plan renewal/change)."""
|
||||
period_keys = self._get_authoritative_billing_period_keys(user_id)
|
||||
billing_period = period_keys["billing_period"]
|
||||
summary = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
return {"reset": False, "reason": "no_summary"}
|
||||
|
||||
try:
|
||||
reset_usage_summary_counters(summary)
|
||||
self.db.commit()
|
||||
|
||||
# Invalidate dashboard cache so header stats update after reset
|
||||
try:
|
||||
clear_dashboard_cache(user_id)
|
||||
except Exception as cache_err:
|
||||
logger.debug(f"Could not clear dashboard cache: {cache_err}")
|
||||
|
||||
logger.info(f"Reset usage counters for user {user_id} in billing period {billing_period} after renewal")
|
||||
return {"reset": True, "counters_reset": True}
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"Error resetting usage status: {e}")
|
||||
return {"reset": False, "error": str(e)}
|
||||
|
||||
@@ -2,9 +2,8 @@ import os
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
from dataclasses import dataclass
|
||||
import requests
|
||||
import httpx
|
||||
from loguru import logger
|
||||
import time
|
||||
import random
|
||||
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
@@ -61,30 +60,26 @@ class WritingAssistantService:
|
||||
logger.info(f"Writing assistant API call #{self.daily_api_calls}/{self.daily_limit} today")
|
||||
return True
|
||||
|
||||
async def suggest(self, text: str, max_results: int = 1) -> List[WritingSuggestion]:
|
||||
async def suggest(self, text: str, user_id: str | None = None) -> List[WritingSuggestion]:
|
||||
if not text or len(text.strip()) < 6:
|
||||
return []
|
||||
|
||||
# COST OPTIMIZATION: Use cached/static suggestions for common patterns
|
||||
# This reduces API calls by 90%+ while maintaining usefulness
|
||||
cached_suggestion = self._get_cached_suggestion(text)
|
||||
if cached_suggestion:
|
||||
return [cached_suggestion]
|
||||
|
||||
# COST CONTROL: Check daily usage limits
|
||||
if not self._check_daily_limit():
|
||||
logger.warning("Daily API limit reached for writing assistant")
|
||||
return []
|
||||
|
||||
# Only make expensive API calls for unique, substantial content
|
||||
if len(text.strip()) < 50: # Skip API calls for very short text
|
||||
if len(text.strip()) < 50:
|
||||
return []
|
||||
|
||||
# 1) Find relevant sources via Exa (reduced results for cost)
|
||||
# 1) Find relevant sources via Exa
|
||||
sources = await self._search_sources(text)
|
||||
|
||||
# 2) Generate continuation suggestion via Gemini
|
||||
suggestion_text, confidence = await self._generate_continuation(text, sources)
|
||||
# 2) Generate continuation suggestion via LLM grounded in sources
|
||||
suggestion_text, confidence = await self._generate_continuation(text, sources, user_id=user_id)
|
||||
|
||||
if not suggestion_text:
|
||||
return []
|
||||
@@ -110,12 +105,12 @@ class WritingAssistantService:
|
||||
}
|
||||
|
||||
try:
|
||||
resp = requests.post(
|
||||
"https://api.exa.ai/search",
|
||||
headers={"x-api-key": self.exa_api_key, "Content-Type": "application/json"},
|
||||
json=payload,
|
||||
timeout=self.http_timeout_seconds,
|
||||
)
|
||||
async with httpx.AsyncClient(timeout=self.http_timeout_seconds) as client:
|
||||
resp = await client.post(
|
||||
"https://api.exa.ai/search",
|
||||
headers={"x-api-key": self.exa_api_key, "Content-Type": "application/json"},
|
||||
json=payload,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
raise Exception(f"Exa error {resp.status_code}: {resp.text}")
|
||||
data = resp.json()
|
||||
@@ -140,8 +135,7 @@ class WritingAssistantService:
|
||||
logger.error(f"WritingAssistant _search_sources error: {e}")
|
||||
raise
|
||||
|
||||
async def _generate_continuation(self, text: str, sources: List[Dict[str, Any]]) -> tuple[str, float]:
|
||||
# Build compact sources context block
|
||||
async def _generate_continuation(self, text: str, sources: List[Dict[str, Any]], user_id: str | None = None) -> tuple[str, float]:
|
||||
source_blocks: List[str] = []
|
||||
for i, s in enumerate(sources[:5]):
|
||||
excerpt = (s.get("text", "") or "")
|
||||
@@ -149,16 +143,14 @@ class WritingAssistantService:
|
||||
source_blocks.append(
|
||||
f"Source {i+1}: {s.get('title','') or 'Source'}\nURL: {s.get('url','')}\nExcerpt: {excerpt}"
|
||||
)
|
||||
sources_text = "\n\n".join(source_blocks) if source_blocks else "(No sources)"
|
||||
sources_text = "\n\n".join(source_blocks)
|
||||
|
||||
# Provider-agnostic behavior: short continuation with one inline citation hint
|
||||
system_prompt = (
|
||||
"You are an assistive writing continuation bot. "
|
||||
"Only produce 1-2 SHORT sentences. Do not repeat or paraphrase the user's stub. "
|
||||
"Match tone and topic. Prefer concrete, current facts from the provided sources. "
|
||||
"Include exactly one brief citation hint in parentheses with an author (or 'Source') and URL in square brackets, e.g., ((Doe, 2021)[https://example.com])."
|
||||
)
|
||||
|
||||
user_prompt = (
|
||||
f"User text to continue (do not repeat):\n{text}\n\n"
|
||||
f"Relevant sources to inform your continuation:\n{sources_text}\n\n"
|
||||
@@ -166,13 +158,13 @@ class WritingAssistantService:
|
||||
)
|
||||
|
||||
try:
|
||||
# Inter-call jitter to reduce burst rate limits
|
||||
time.sleep(random.uniform(0.05, 0.15))
|
||||
await asyncio.sleep(random.uniform(0.05, 0.15))
|
||||
|
||||
ai_resp = llm_text_gen(
|
||||
prompt=user_prompt,
|
||||
json_struct=None,
|
||||
system_prompt=system_prompt,
|
||||
user_id=user_id,
|
||||
)
|
||||
if isinstance(ai_resp, dict) and ai_resp.get("text"):
|
||||
suggestion = (ai_resp.get("text", "") or "").strip()
|
||||
@@ -180,12 +172,10 @@ class WritingAssistantService:
|
||||
suggestion = (str(ai_resp or "")).strip()
|
||||
if not suggestion:
|
||||
raise Exception("Assistive writer returned empty suggestion")
|
||||
# naive confidence from number of sources present
|
||||
confidence = 0.7 if sources else 0.5
|
||||
confidence = 0.7
|
||||
return suggestion, confidence
|
||||
except Exception as e:
|
||||
logger.error(f"WritingAssistant _generate_continuation error: {e}")
|
||||
# Propagate to ensure frontend does not show stale/generic content
|
||||
raise
|
||||
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ def should_bootstrap_linguistic_models() -> bool:
|
||||
}
|
||||
|
||||
# Check if any linguistic-required feature is enabled
|
||||
linguistic_features = {"content_planning", "facebook", "linkedin", "blog-writer", "persona"}
|
||||
linguistic_features = {"content_planning", "facebook", "linkedin", "blog_writer", "persona"}
|
||||
return bool(enabled_features & linguistic_features)
|
||||
|
||||
|
||||
@@ -287,12 +287,16 @@ from alwrity_utils import (
|
||||
def start_backend(enable_reload=False, production_mode=False):
|
||||
"""Start the backend server."""
|
||||
print("==> Starting ALwrity Backend...")
|
||||
podcast_only_demo_mode = os.getenv("ALWRITY_PODCAST_ONLY_DEMO_MODE", os.getenv("PODCAST_ONLY_DEMO_MODE", "false")).lower() in {"1", "true", "yes", "on"}
|
||||
# Check for legacy podcast-only demo mode env vars (backward compat)
|
||||
is_legacy_podcast_mode = os.getenv("ALWRITY_PODCAST_ONLY_DEMO_MODE", os.getenv("PODCAST_ONLY_DEMO_MODE", "false")).lower() in {"1", "true", "yes", "on"}
|
||||
enabled = get_enabled_features()
|
||||
is_feature_limited = "all" not in enabled
|
||||
|
||||
if podcast_only_demo_mode:
|
||||
print("\n" + "=" * 60)
|
||||
print("==> PODCAST-ONLY DEMO MODE ACTIVE")
|
||||
print(" Non-podcast router groups are intentionally skipped.")
|
||||
if is_legacy_podcast_mode or is_feature_limited:
|
||||
mode_label = "legacy podcast-only" if is_legacy_podcast_mode else f"feature-limited ({', '.join(sorted(enabled))})"
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"==> {mode_label.upper()} MODE ACTIVE")
|
||||
print(" Non-matching router groups are intentionally skipped.")
|
||||
print("=" * 60)
|
||||
|
||||
# Set host based on environment and mode
|
||||
@@ -385,12 +389,12 @@ def start_backend(enable_reload=False, production_mode=False):
|
||||
print(f"[DEBUG] Starting uvicorn with host={host} port={port}", flush=True)
|
||||
print("[DEBUG] >>> ABOUT TO CALL UVICORN.RUN() <<<", flush=True)
|
||||
|
||||
# Skip video preflight in podcast-only mode to save memory/time
|
||||
is_podcast = os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() == "podcast"
|
||||
print(f"[DEBUG] Podcast mode check: {is_podcast}", flush=True)
|
||||
# Skip video preflight in feature-limited mode to save memory/time
|
||||
is_feature_limited = os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() not in ("", "all")
|
||||
print(f"[DEBUG] Feature-limited mode check: {is_feature_limited}", flush=True)
|
||||
|
||||
if is_podcast:
|
||||
print("[DEBUG] Podcast mode - skipping video preflight", flush=True)
|
||||
if is_feature_limited:
|
||||
print("[DEBUG] Feature-limited mode - skipping video preflight", flush=True)
|
||||
else:
|
||||
# Log diagnostics and assert versions (fail fast if misconfigured)
|
||||
try:
|
||||
|
||||
83
backend/temp_method.py
Normal file
83
backend/temp_method.py
Normal file
@@ -0,0 +1,83 @@
|
||||
def _get_all_historical_usage(self, user_id: str) -> Dict[str, Any]:
|
||||
\ \\Get ALL historical usage data aggregated across all billing periods.\\\
|
||||
|
||||
# Get all usage summaries for the user
|
||||
all_summaries = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id
|
||||
).order_by(UsageSummary.billing_period.desc()).all()
|
||||
|
||||
if not all_summaries:
|
||||
return {
|
||||
\billing_period\: \all\,
|
||||
\usage_status\: \active\,
|
||||
\total_calls\: 0,
|
||||
\total_tokens\: 0,
|
||||
\total_cost\: 0.0,
|
||||
\avg_response_time\: 0.0,
|
||||
\error_rate\: 0.0,
|
||||
\limits\: self.pricing_service.get_user_limits(user_id),
|
||||
\provider_breakdown\: {},
|
||||
\usage_percentages\: {},
|
||||
\historical_breakdown\: [],
|
||||
\last_updated\: datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Aggregate all data
|
||||
total_calls = sum(s.total_calls or 0 for s in all_summaries)
|
||||
total_tokens = sum(s.total_tokens or 0 for s in all_summaries)
|
||||
total_cost = sum(float(s.total_cost or 0) for s in all_summaries)
|
||||
|
||||
# Calculate weighted average response time
|
||||
total_weighted_time = sum((s.avg_response_time or 0) * (s.total_calls or 0) for s in all_summaries)
|
||||
avg_response_time = total_weighted_time / total_calls if total_calls > 0 else 0.0
|
||||
|
||||
# Calculate overall error rate
|
||||
total_errors = sum((s.total_calls or 0) * (s.error_rate or 0) / 100 for s in all_summaries)
|
||||
error_rate = (total_errors / total_calls * 100) if total_calls > 0 else 0.0
|
||||
|
||||
# Get user limits
|
||||
limits = self.pricing_service.get_user_limits(user_id)
|
||||
|
||||
# Build historical breakdown
|
||||
historical_breakdown = []
|
||||
for s in all_summaries:
|
||||
try:
|
||||
status_val = s.usage_status.value
|
||||
except:
|
||||
status_val = str(s.usage_status)
|
||||
historical_breakdown.append({
|
||||
\billing_period\: s.billing_period,
|
||||
\total_calls\: s.total_calls or 0,
|
||||
\total_tokens\: s.total_tokens or 0,
|
||||
\total_cost\: float(s.total_cost or 0),
|
||||
\usage_status\: status_val,
|
||||
\updated_at\: s.updated_at.isoformat() if s.updated_at else None
|
||||
})
|
||||
|
||||
# Determine overall status
|
||||
usage_status = \active\
|
||||
for s in all_summaries:
|
||||
try:
|
||||
status = s.usage_status.value
|
||||
except:
|
||||
status = str(s.usage_status)
|
||||
if status == \limit_reached\:
|
||||
usage_status = \limit_reached\
|
||||
break
|
||||
elif status == \warning\ and usage_status != \limit_reached\:
|
||||
usage_status = \warning\
|
||||
|
||||
return {
|
||||
\billing_period\: \all\,
|
||||
\usage_status\: usage_status,
|
||||
\total_calls\: total_calls,
|
||||
\total_tokens\: total_tokens,
|
||||
\total_cost\: round(total_cost, 2),
|
||||
\avg_response_time\: round(avg_response_time, 2),
|
||||
\error_rate\: round(error_rate, 2),
|
||||
\limits\: limits,
|
||||
\provider_breakdown\: {},
|
||||
\usage_percentages\: {},
|
||||
\historical_breakdown\: historical_breakdown,
|
||||
\last_updated\: datetime.now().isoformat()
|
||||
}
|
||||
Reference in New Issue
Block a user