story writer backend migration complete, Blog writer SEO and story writer backend migration complete, Blog writer SEO and story writer frontend migration complete
This commit is contained in:
@@ -671,4 +671,122 @@ async def rewrite_status(task_id: str):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get rewrite status for {task_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/titles/generate-seo")
|
||||
async def generate_seo_titles(
|
||||
request: Dict[str, Any],
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate 5 SEO-optimized blog titles using research and outline data."""
|
||||
try:
|
||||
# Extract Clerk user ID (required)
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get('id', ''))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
# Import here to avoid circular dependencies
|
||||
from services.blog_writer.outline.seo_title_generator import SEOTitleGenerator
|
||||
from models.blog_models import BlogResearchResponse, BlogOutlineSection
|
||||
|
||||
# Parse request data
|
||||
research_data = request.get('research')
|
||||
outline_data = request.get('outline', [])
|
||||
primary_keywords = request.get('primary_keywords', [])
|
||||
secondary_keywords = request.get('secondary_keywords', [])
|
||||
content_angles = request.get('content_angles', [])
|
||||
search_intent = request.get('search_intent', 'informational')
|
||||
word_count = request.get('word_count', 1500)
|
||||
|
||||
if not research_data:
|
||||
raise HTTPException(status_code=400, detail="Research data is required")
|
||||
|
||||
# Convert to models
|
||||
research = BlogResearchResponse(**research_data)
|
||||
outline = [BlogOutlineSection(**section) for section in outline_data]
|
||||
|
||||
# Generate titles
|
||||
title_generator = SEOTitleGenerator()
|
||||
titles = await title_generator.generate_seo_titles(
|
||||
research=research,
|
||||
outline=outline,
|
||||
primary_keywords=primary_keywords,
|
||||
secondary_keywords=secondary_keywords,
|
||||
content_angles=content_angles,
|
||||
search_intent=search_intent,
|
||||
word_count=word_count,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"titles": titles
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate SEO titles: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/introductions/generate")
|
||||
async def generate_introductions(
|
||||
request: Dict[str, Any],
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate 3 varied blog introductions using research, outline, and content."""
|
||||
try:
|
||||
# Extract Clerk user ID (required)
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get('id', ''))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
# Import here to avoid circular dependencies
|
||||
from services.blog_writer.content.introduction_generator import IntroductionGenerator
|
||||
from models.blog_models import BlogResearchResponse, BlogOutlineSection
|
||||
|
||||
# Parse request data
|
||||
blog_title = request.get('blog_title', '')
|
||||
research_data = request.get('research')
|
||||
outline_data = request.get('outline', [])
|
||||
sections_content = request.get('sections_content', {})
|
||||
primary_keywords = request.get('primary_keywords', [])
|
||||
search_intent = request.get('search_intent', 'informational')
|
||||
|
||||
if not research_data:
|
||||
raise HTTPException(status_code=400, detail="Research data is required")
|
||||
if not blog_title:
|
||||
raise HTTPException(status_code=400, detail="Blog title is required")
|
||||
|
||||
# Convert to models
|
||||
research = BlogResearchResponse(**research_data)
|
||||
outline = [BlogOutlineSection(**section) for section in outline_data]
|
||||
|
||||
# Generate introductions
|
||||
intro_generator = IntroductionGenerator()
|
||||
introductions = await intro_generator.generate_introductions(
|
||||
blog_title=blog_title,
|
||||
research=research,
|
||||
outline=outline,
|
||||
sections_content=sections_content,
|
||||
primary_keywords=primary_keywords,
|
||||
search_intent=search_intent,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"introductions": introductions
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate introductions: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -1151,6 +1151,118 @@ async def retry_website_analysis(
|
||||
raise HTTPException(status_code=500, detail=f"Failed to retry website analysis: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/tasks-needing-intervention/{user_id}")
|
||||
async def get_tasks_needing_intervention(
|
||||
user_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get all tasks that need human intervention.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
List of tasks needing intervention with failure pattern details
|
||||
"""
|
||||
try:
|
||||
# Verify user access
|
||||
if str(current_user.get('id')) != user_id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
from services.scheduler.core.failure_detection_service import FailureDetectionService
|
||||
detection_service = FailureDetectionService(db)
|
||||
|
||||
tasks = detection_service.get_tasks_needing_intervention(user_id=user_id)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"tasks": tasks,
|
||||
"count": len(tasks)
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting tasks needing intervention: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get tasks needing intervention: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/tasks/{task_type}/{task_id}/manual-trigger")
|
||||
async def manual_trigger_task(
|
||||
task_type: str,
|
||||
task_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Manually trigger a task that is in cool-off or needs intervention.
|
||||
This bypasses the cool-off check and executes the task immediately.
|
||||
|
||||
Args:
|
||||
task_type: Task type (oauth_token_monitoring, website_analysis, gsc_insights, bing_insights)
|
||||
task_id: Task ID
|
||||
|
||||
Returns:
|
||||
Success status and execution result
|
||||
"""
|
||||
try:
|
||||
from services.scheduler.core.task_execution_handler import execute_task_async
|
||||
scheduler = get_scheduler()
|
||||
|
||||
# Load task based on type
|
||||
task = None
|
||||
if task_type == "oauth_token_monitoring":
|
||||
task = db.query(OAuthTokenMonitoringTask).filter(
|
||||
OAuthTokenMonitoringTask.id == task_id
|
||||
).first()
|
||||
elif task_type == "website_analysis":
|
||||
task = db.query(WebsiteAnalysisTask).filter(
|
||||
WebsiteAnalysisTask.id == task_id
|
||||
).first()
|
||||
elif task_type in ["gsc_insights", "bing_insights"]:
|
||||
task = db.query(PlatformInsightsTask).filter(
|
||||
PlatformInsightsTask.id == task_id
|
||||
).first()
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown task type: {task_type}")
|
||||
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
# Verify user access
|
||||
if str(current_user.get('id')) != task.user_id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
# Clear cool-off status and reset failure count
|
||||
task.status = "active"
|
||||
task.consecutive_failures = 0
|
||||
task.failure_pattern = None
|
||||
|
||||
# Execute task manually (bypasses cool-off check)
|
||||
# Task types are registered as: oauth_token_monitoring, website_analysis, gsc_insights, bing_insights
|
||||
await execute_task_async(scheduler, task_type, task, execution_source="manual")
|
||||
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Manually triggered task {task_id} ({task_type}) for user {task.user_id}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Task triggered successfully",
|
||||
"task": {
|
||||
"id": task.id,
|
||||
"status": task.status,
|
||||
"last_check": task.last_check.isoformat() if task.last_check else None
|
||||
}
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error manually triggering task {task_id}: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to trigger task: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/platform-insights/logs/{user_id}")
|
||||
async def get_platform_insights_logs(
|
||||
user_id: str,
|
||||
|
||||
9
backend/api/story_writer/__init__.py
Normal file
9
backend/api/story_writer/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Story Writer API
|
||||
|
||||
API endpoints for story generation functionality.
|
||||
"""
|
||||
|
||||
from .router import router
|
||||
|
||||
__all__ = ['router']
|
||||
70
backend/api/story_writer/cache_manager.py
Normal file
70
backend/api/story_writer/cache_manager.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
Cache Management System for Story Writer API
|
||||
|
||||
Handles story generation cache operations.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class CacheManager:
|
||||
"""Manages cache operations for story generation data."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the cache manager."""
|
||||
self.cache: Dict[str, Dict[str, Any]] = {}
|
||||
logger.info("[StoryWriter] CacheManager initialized")
|
||||
|
||||
def get_cache_key(self, request_data: Dict[str, Any]) -> str:
|
||||
"""Generate a cache key from request data."""
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
# Create a normalized version of the request for caching
|
||||
cache_data = {
|
||||
"persona": request_data.get("persona", ""),
|
||||
"story_setting": request_data.get("story_setting", ""),
|
||||
"character_input": request_data.get("character_input", ""),
|
||||
"plot_elements": request_data.get("plot_elements", ""),
|
||||
"writing_style": request_data.get("writing_style", ""),
|
||||
"story_tone": request_data.get("story_tone", ""),
|
||||
"narrative_pov": request_data.get("narrative_pov", ""),
|
||||
"audience_age_group": request_data.get("audience_age_group", ""),
|
||||
"content_rating": request_data.get("content_rating", ""),
|
||||
"ending_preference": request_data.get("ending_preference", ""),
|
||||
}
|
||||
|
||||
cache_str = json.dumps(cache_data, sort_keys=True)
|
||||
return hashlib.md5(cache_str.encode()).hexdigest()
|
||||
|
||||
def get_cached_result(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a cached result if available."""
|
||||
if cache_key in self.cache:
|
||||
logger.debug(f"[StoryWriter] Cache hit for key: {cache_key}")
|
||||
return self.cache[cache_key]
|
||||
logger.debug(f"[StoryWriter] Cache miss for key: {cache_key}")
|
||||
return None
|
||||
|
||||
def cache_result(self, cache_key: str, result: Dict[str, Any]):
|
||||
"""Cache a result."""
|
||||
self.cache[cache_key] = result
|
||||
logger.debug(f"[StoryWriter] Cached result for key: {cache_key}")
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear all cached results."""
|
||||
count = len(self.cache)
|
||||
self.cache.clear()
|
||||
logger.info(f"[StoryWriter] Cleared {count} cached entries")
|
||||
return {"status": "success", "message": f"Cleared {count} cached entries"}
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
return {
|
||||
"total_entries": len(self.cache),
|
||||
"cache_keys": list(self.cache.keys())
|
||||
}
|
||||
|
||||
|
||||
# Global cache manager instance
|
||||
cache_manager = CacheManager()
|
||||
1181
backend/api/story_writer/router.py
Normal file
1181
backend/api/story_writer/router.py
Normal file
File diff suppressed because it is too large
Load Diff
251
backend/api/story_writer/task_manager.py
Normal file
251
backend/api/story_writer/task_manager.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""
|
||||
Task Management System for Story Writer API
|
||||
|
||||
Handles background task execution, status tracking, and progress updates
|
||||
for story generation operations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class TaskManager:
|
||||
"""Manages background tasks for story generation."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the task manager."""
|
||||
self.task_storage: Dict[str, Dict[str, Any]] = {}
|
||||
logger.info("[StoryWriter] TaskManager initialized")
|
||||
|
||||
def cleanup_old_tasks(self):
|
||||
"""Remove tasks older than 1 hour to prevent memory leaks."""
|
||||
current_time = datetime.now()
|
||||
tasks_to_remove = []
|
||||
|
||||
for task_id, task_data in self.task_storage.items():
|
||||
created_at = task_data.get("created_at")
|
||||
if created_at and (current_time - created_at).total_seconds() > 3600: # 1 hour
|
||||
tasks_to_remove.append(task_id)
|
||||
|
||||
for task_id in tasks_to_remove:
|
||||
del self.task_storage[task_id]
|
||||
logger.debug(f"[StoryWriter] Cleaned up old task: {task_id}")
|
||||
|
||||
def create_task(self, task_type: str = "story_generation") -> str:
|
||||
"""Create a new task and return its ID."""
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
self.task_storage[task_id] = {
|
||||
"status": "pending",
|
||||
"created_at": datetime.now(),
|
||||
"result": None,
|
||||
"error": None,
|
||||
"progress_messages": [],
|
||||
"task_type": task_type,
|
||||
"progress": 0.0
|
||||
}
|
||||
|
||||
logger.info(f"[StoryWriter] Created task: {task_id} (type: {task_type})")
|
||||
return task_id
|
||||
|
||||
def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get the status of a task."""
|
||||
self.cleanup_old_tasks()
|
||||
|
||||
if task_id not in self.task_storage:
|
||||
logger.warning(f"[StoryWriter] Task not found: {task_id}")
|
||||
return None
|
||||
|
||||
task = self.task_storage[task_id]
|
||||
response = {
|
||||
"task_id": task_id,
|
||||
"status": task["status"],
|
||||
"progress": task.get("progress", 0.0),
|
||||
"message": task.get("progress_messages", [])[-1] if task.get("progress_messages") else None,
|
||||
"created_at": task["created_at"].isoformat() if task.get("created_at") else None,
|
||||
"updated_at": task.get("updated_at", task.get("created_at")).isoformat() if task.get("updated_at") or task.get("created_at") else None,
|
||||
}
|
||||
|
||||
if task["status"] == "completed" and task.get("result"):
|
||||
response["result"] = task["result"]
|
||||
|
||||
if task["status"] == "failed" and task.get("error"):
|
||||
response["error"] = task["error"]
|
||||
|
||||
return response
|
||||
|
||||
def update_task_status(
|
||||
self,
|
||||
task_id: str,
|
||||
status: str,
|
||||
progress: Optional[float] = None,
|
||||
message: Optional[str] = None,
|
||||
result: Optional[Dict[str, Any]] = None,
|
||||
error: Optional[str] = None
|
||||
):
|
||||
"""Update the status of a task."""
|
||||
if task_id not in self.task_storage:
|
||||
logger.warning(f"[StoryWriter] Cannot update non-existent task: {task_id}")
|
||||
return
|
||||
|
||||
task = self.task_storage[task_id]
|
||||
task["status"] = status
|
||||
task["updated_at"] = datetime.now()
|
||||
|
||||
if progress is not None:
|
||||
task["progress"] = progress
|
||||
|
||||
if message:
|
||||
if "progress_messages" not in task:
|
||||
task["progress_messages"] = []
|
||||
task["progress_messages"].append(message)
|
||||
logger.info(f"[StoryWriter] Task {task_id}: {message} (progress: {progress}%)")
|
||||
|
||||
if result is not None:
|
||||
task["result"] = result
|
||||
|
||||
if error is not None:
|
||||
task["error"] = error
|
||||
logger.error(f"[StoryWriter] Task {task_id} error: {error}")
|
||||
|
||||
async def execute_story_generation_task(
|
||||
self,
|
||||
task_id: str,
|
||||
request_data: Dict[str, Any],
|
||||
user_id: str
|
||||
):
|
||||
"""Execute story generation task asynchronously."""
|
||||
from services.story_writer.story_service import StoryWriterService
|
||||
|
||||
service = StoryWriterService()
|
||||
|
||||
try:
|
||||
self.update_task_status(task_id, "processing", progress=0.0, message="Starting story generation...")
|
||||
|
||||
# Step 1: Generate premise
|
||||
self.update_task_status(task_id, "processing", progress=10.0, message="Generating story premise...")
|
||||
premise = service.generate_premise(
|
||||
persona=request_data["persona"],
|
||||
story_setting=request_data["story_setting"],
|
||||
character_input=request_data["character_input"],
|
||||
plot_elements=request_data["plot_elements"],
|
||||
writing_style=request_data["writing_style"],
|
||||
story_tone=request_data["story_tone"],
|
||||
narrative_pov=request_data["narrative_pov"],
|
||||
audience_age_group=request_data["audience_age_group"],
|
||||
content_rating=request_data["content_rating"],
|
||||
ending_preference=request_data["ending_preference"],
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Step 2: Generate outline
|
||||
self.update_task_status(task_id, "processing", progress=30.0, message="Generating story outline...")
|
||||
outline = service.generate_outline(
|
||||
premise=premise,
|
||||
persona=request_data["persona"],
|
||||
story_setting=request_data["story_setting"],
|
||||
character_input=request_data["character_input"],
|
||||
plot_elements=request_data["plot_elements"],
|
||||
writing_style=request_data["writing_style"],
|
||||
story_tone=request_data["story_tone"],
|
||||
narrative_pov=request_data["narrative_pov"],
|
||||
audience_age_group=request_data["audience_age_group"],
|
||||
content_rating=request_data["content_rating"],
|
||||
ending_preference=request_data["ending_preference"],
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Step 3: Generate story start
|
||||
self.update_task_status(task_id, "processing", progress=50.0, message="Writing story beginning...")
|
||||
story_start = service.generate_story_start(
|
||||
premise=premise,
|
||||
outline=outline,
|
||||
persona=request_data["persona"],
|
||||
story_setting=request_data["story_setting"],
|
||||
character_input=request_data["character_input"],
|
||||
plot_elements=request_data["plot_elements"],
|
||||
writing_style=request_data["writing_style"],
|
||||
story_tone=request_data["story_tone"],
|
||||
narrative_pov=request_data["narrative_pov"],
|
||||
audience_age_group=request_data["audience_age_group"],
|
||||
content_rating=request_data["content_rating"],
|
||||
ending_preference=request_data["ending_preference"],
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Step 4: Continue story
|
||||
self.update_task_status(task_id, "processing", progress=70.0, message="Continuing story generation...")
|
||||
story_text = story_start
|
||||
max_iterations = request_data.get("max_iterations", 10)
|
||||
iteration = 0
|
||||
|
||||
while 'IAMDONE' not in story_text and iteration < max_iterations:
|
||||
iteration += 1
|
||||
progress = 70.0 + (iteration / max_iterations) * 25.0
|
||||
self.update_task_status(
|
||||
task_id,
|
||||
"processing",
|
||||
progress=min(progress, 95.0),
|
||||
message=f"Writing continuation {iteration}/{max_iterations}..."
|
||||
)
|
||||
|
||||
continuation = service.continue_story(
|
||||
premise=premise,
|
||||
outline=outline,
|
||||
story_text=story_text,
|
||||
persona=request_data["persona"],
|
||||
story_setting=request_data["story_setting"],
|
||||
character_input=request_data["character_input"],
|
||||
plot_elements=request_data["plot_elements"],
|
||||
writing_style=request_data["writing_style"],
|
||||
story_tone=request_data["story_tone"],
|
||||
narrative_pov=request_data["narrative_pov"],
|
||||
audience_age_group=request_data["audience_age_group"],
|
||||
content_rating=request_data["content_rating"],
|
||||
ending_preference=request_data["ending_preference"],
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if continuation:
|
||||
story_text += '\n\n' + continuation
|
||||
else:
|
||||
logger.warning(f"[StoryWriter] Empty continuation at iteration {iteration}")
|
||||
break
|
||||
|
||||
# Clean up and finalize
|
||||
final_story = story_text.replace('IAMDONE', '').strip()
|
||||
|
||||
result = {
|
||||
"premise": premise,
|
||||
"outline": outline,
|
||||
"story": final_story,
|
||||
"is_complete": 'IAMDONE' in story_text or iteration >= max_iterations,
|
||||
"iterations": iteration
|
||||
}
|
||||
|
||||
self.update_task_status(
|
||||
task_id,
|
||||
"completed",
|
||||
progress=100.0,
|
||||
message="Story generation completed!",
|
||||
result=result
|
||||
)
|
||||
|
||||
logger.info(f"[StoryWriter] Task {task_id} completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"[StoryWriter] Task {task_id} failed: {error_msg}")
|
||||
self.update_task_status(
|
||||
task_id,
|
||||
"failed",
|
||||
error=error_msg,
|
||||
message=f"Story generation failed: {error_msg}"
|
||||
)
|
||||
|
||||
|
||||
# Global task manager instance
|
||||
task_manager = TaskManager()
|
||||
@@ -5,6 +5,7 @@ Provides endpoints for subscription management and usage monitoring.
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc, func
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime, timedelta
|
||||
from loguru import logger
|
||||
@@ -12,12 +13,14 @@ from functools import lru_cache
|
||||
|
||||
from services.database import get_db
|
||||
from services.subscription import UsageTrackingService, PricingService
|
||||
from services.subscription.log_wrapping_service import LogWrappingService
|
||||
from services.subscription.schema_utils import ensure_subscription_plan_columns
|
||||
import sqlite3
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from models.subscription_models import (
|
||||
APIProvider, SubscriptionPlan, UserSubscription, UsageSummary,
|
||||
APIProviderPricing, UsageAlert, SubscriptionTier, BillingCycle, UsageStatus
|
||||
APIProviderPricing, UsageAlert, SubscriptionTier, BillingCycle, UsageStatus,
|
||||
APIUsageLog, SubscriptionRenewalHistory
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/subscription", tags=["subscription"])
|
||||
@@ -525,8 +528,67 @@ async def subscribe_to_plan(
|
||||
).first()
|
||||
|
||||
now = datetime.utcnow()
|
||||
|
||||
|
||||
# Track renewal history - capture BEFORE updating subscription
|
||||
previous_period_start = None
|
||||
previous_period_end = None
|
||||
previous_plan_name = None
|
||||
previous_plan_tier = None
|
||||
renewal_type = "new"
|
||||
renewal_count = 0
|
||||
|
||||
# Get usage snapshot BEFORE renewal (capture current state)
|
||||
usage_before_snapshot = None
|
||||
current_period = datetime.utcnow().strftime("%Y-%m")
|
||||
usage_before = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if usage_before:
|
||||
usage_before_snapshot = {
|
||||
"total_calls": usage_before.total_calls or 0,
|
||||
"total_tokens": usage_before.total_tokens or 0,
|
||||
"total_cost": float(usage_before.total_cost) if usage_before.total_cost else 0.0,
|
||||
"gemini_calls": usage_before.gemini_calls or 0,
|
||||
"mistral_calls": usage_before.mistral_calls or 0,
|
||||
"usage_status": usage_before.usage_status.value if hasattr(usage_before.usage_status, 'value') else str(usage_before.usage_status)
|
||||
}
|
||||
|
||||
if existing_subscription:
|
||||
# This is a renewal/update - capture previous subscription state BEFORE updating
|
||||
previous_period_start = existing_subscription.current_period_start
|
||||
previous_period_end = existing_subscription.current_period_end
|
||||
previous_plan = existing_subscription.plan
|
||||
previous_plan_name = previous_plan.name if previous_plan else None
|
||||
previous_plan_tier = previous_plan.tier.value if previous_plan else None
|
||||
|
||||
# Determine renewal type
|
||||
if previous_plan and previous_plan.id == plan_id:
|
||||
# Same plan - this is a renewal
|
||||
renewal_type = "renewal"
|
||||
elif previous_plan:
|
||||
# Different plan - check if upgrade or downgrade
|
||||
tier_order = {"free": 0, "basic": 1, "pro": 2, "enterprise": 3}
|
||||
previous_tier_order = tier_order.get(previous_plan_tier or "free", 0)
|
||||
new_tier_order = tier_order.get(plan.tier.value, 0)
|
||||
if new_tier_order > previous_tier_order:
|
||||
renewal_type = "upgrade"
|
||||
elif new_tier_order < previous_tier_order:
|
||||
renewal_type = "downgrade"
|
||||
else:
|
||||
renewal_type = "renewal" # Same tier, different plan name
|
||||
|
||||
# Get renewal count (how many times this user has renewed)
|
||||
last_renewal = db.query(SubscriptionRenewalHistory).filter(
|
||||
SubscriptionRenewalHistory.user_id == user_id
|
||||
).order_by(SubscriptionRenewalHistory.created_at.desc()).first()
|
||||
|
||||
if last_renewal:
|
||||
renewal_count = last_renewal.renewal_count + 1
|
||||
else:
|
||||
renewal_count = 1 # First renewal
|
||||
|
||||
# Update existing subscription
|
||||
existing_subscription.plan_id = plan_id
|
||||
existing_subscription.billing_cycle = BillingCycle(billing_cycle)
|
||||
@@ -552,7 +614,30 @@ async def subscribe_to_plan(
|
||||
auto_renew=True
|
||||
)
|
||||
db.add(subscription)
|
||||
|
||||
|
||||
db.commit()
|
||||
|
||||
# Create renewal history record AFTER subscription update (so we have the new period_end)
|
||||
renewal_history = SubscriptionRenewalHistory(
|
||||
user_id=user_id,
|
||||
plan_id=plan_id,
|
||||
plan_name=plan.name,
|
||||
plan_tier=plan.tier.value,
|
||||
previous_period_start=previous_period_start,
|
||||
previous_period_end=previous_period_end,
|
||||
new_period_start=now,
|
||||
new_period_end=subscription.current_period_end,
|
||||
billing_cycle=BillingCycle(billing_cycle),
|
||||
renewal_type=renewal_type,
|
||||
renewal_count=renewal_count,
|
||||
previous_plan_name=previous_plan_name,
|
||||
previous_plan_tier=previous_plan_tier,
|
||||
usage_before_renewal=usage_before_snapshot, # Usage snapshot captured BEFORE renewal
|
||||
payment_amount=plan.price_yearly if billing_cycle == 'yearly' else plan.price_monthly,
|
||||
payment_status="paid", # Assume paid for now (can be updated if payment processing is added)
|
||||
payment_date=now
|
||||
)
|
||||
db.add(renewal_history)
|
||||
db.commit()
|
||||
|
||||
# Get current usage BEFORE reset for logging
|
||||
@@ -883,4 +968,222 @@ async def get_dashboard_data(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting dashboard data: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get("/renewal-history/{user_id}")
|
||||
async def get_renewal_history(
|
||||
user_id: str,
|
||||
limit: int = Query(50, ge=1, le=100, description="Number of records to return"),
|
||||
offset: int = Query(0, ge=0, description="Pagination offset"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get subscription renewal history for a user.
|
||||
|
||||
Returns:
|
||||
- List of renewal history records
|
||||
- Total count for pagination
|
||||
"""
|
||||
try:
|
||||
# Verify user can only access their own data
|
||||
if current_user.get('id') != user_id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
# Get total count
|
||||
total_count = db.query(SubscriptionRenewalHistory).filter(
|
||||
SubscriptionRenewalHistory.user_id == user_id
|
||||
).count()
|
||||
|
||||
# Get paginated results, ordered by created_at descending (most recent first)
|
||||
renewals = db.query(SubscriptionRenewalHistory).filter(
|
||||
SubscriptionRenewalHistory.user_id == user_id
|
||||
).order_by(SubscriptionRenewalHistory.created_at.desc()).offset(offset).limit(limit).all()
|
||||
|
||||
# Format renewal history for response
|
||||
renewal_history = []
|
||||
for renewal in renewals:
|
||||
renewal_history.append({
|
||||
'id': renewal.id,
|
||||
'plan_name': renewal.plan_name,
|
||||
'plan_tier': renewal.plan_tier,
|
||||
'previous_period_start': renewal.previous_period_start.isoformat() if renewal.previous_period_start else None,
|
||||
'previous_period_end': renewal.previous_period_end.isoformat() if renewal.previous_period_end else None,
|
||||
'new_period_start': renewal.new_period_start.isoformat() if renewal.new_period_start else None,
|
||||
'new_period_end': renewal.new_period_end.isoformat() if renewal.new_period_end else None,
|
||||
'billing_cycle': renewal.billing_cycle.value if renewal.billing_cycle else None,
|
||||
'renewal_type': renewal.renewal_type,
|
||||
'renewal_count': renewal.renewal_count,
|
||||
'previous_plan_name': renewal.previous_plan_name,
|
||||
'previous_plan_tier': renewal.previous_plan_tier,
|
||||
'usage_before_renewal': renewal.usage_before_renewal,
|
||||
'payment_amount': float(renewal.payment_amount) if renewal.payment_amount else 0.0,
|
||||
'payment_status': renewal.payment_status,
|
||||
'payment_date': renewal.payment_date.isoformat() if renewal.payment_date else None,
|
||||
'created_at': renewal.created_at.isoformat() if renewal.created_at else None
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"renewals": renewal_history,
|
||||
"total_count": total_count,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"has_more": (offset + limit) < total_count
|
||||
}
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting renewal history: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get("/usage-logs")
|
||||
async def get_usage_logs(
|
||||
limit: int = Query(50, ge=1, le=5000, description="Number of logs to return"),
|
||||
offset: int = Query(0, ge=0, description="Pagination offset"),
|
||||
provider: Optional[str] = Query(None, description="Filter by provider"),
|
||||
status_code: Optional[int] = Query(None, description="Filter by HTTP status code"),
|
||||
billing_period: Optional[str] = Query(None, description="Filter by billing period (YYYY-MM)"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get API usage logs for the current user.
|
||||
|
||||
Query Params:
|
||||
- limit: Number of logs to return (1-500, default: 50)
|
||||
- offset: Pagination offset (default: 0)
|
||||
- provider: Filter by provider (e.g., "gemini", "openai", "huggingface")
|
||||
- status_code: Filter by HTTP status code (e.g., 200 for success, 400+ for errors)
|
||||
- billing_period: Filter by billing period (YYYY-MM format)
|
||||
|
||||
Returns:
|
||||
- List of usage logs with API call details
|
||||
- Total count for pagination
|
||||
"""
|
||||
try:
|
||||
# Get user_id from current_user
|
||||
user_id = str(current_user.get('id', '')) if current_user else None
|
||||
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User not authenticated")
|
||||
|
||||
# Build query
|
||||
query = db.query(APIUsageLog).filter(
|
||||
APIUsageLog.user_id == user_id
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if provider:
|
||||
provider_lower = provider.lower()
|
||||
# Handle special case: huggingface maps to MISTRAL enum in database
|
||||
if provider_lower == "huggingface":
|
||||
provider_enum = APIProvider.MISTRAL
|
||||
else:
|
||||
try:
|
||||
provider_enum = APIProvider(provider_lower)
|
||||
except ValueError:
|
||||
# Invalid provider, return empty results
|
||||
return {
|
||||
"logs": [],
|
||||
"total_count": 0,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"has_more": False
|
||||
}
|
||||
query = query.filter(APIUsageLog.provider == provider_enum)
|
||||
|
||||
if status_code is not None:
|
||||
query = query.filter(APIUsageLog.status_code == status_code)
|
||||
|
||||
if billing_period:
|
||||
query = query.filter(APIUsageLog.billing_period == billing_period)
|
||||
|
||||
# Check and wrap logs if necessary (before getting count)
|
||||
wrapping_service = LogWrappingService(db)
|
||||
wrap_result = wrapping_service.check_and_wrap_logs(user_id)
|
||||
if wrap_result.get('wrapped'):
|
||||
logger.info(f"[UsageLogs] Log wrapping completed for user {user_id}: {wrap_result.get('message')}")
|
||||
# Rebuild query after wrapping (in case filters changed)
|
||||
query = db.query(APIUsageLog).filter(
|
||||
APIUsageLog.user_id == user_id
|
||||
)
|
||||
# Reapply filters
|
||||
if provider:
|
||||
provider_lower = provider.lower()
|
||||
if provider_lower == "huggingface":
|
||||
provider_enum = APIProvider.MISTRAL
|
||||
else:
|
||||
try:
|
||||
provider_enum = APIProvider(provider_lower)
|
||||
except ValueError:
|
||||
return {
|
||||
"logs": [],
|
||||
"total_count": 0,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"has_more": False
|
||||
}
|
||||
query = query.filter(APIUsageLog.provider == provider_enum)
|
||||
if status_code is not None:
|
||||
query = query.filter(APIUsageLog.status_code == status_code)
|
||||
if billing_period:
|
||||
query = query.filter(APIUsageLog.billing_period == billing_period)
|
||||
|
||||
# Get total count
|
||||
total_count = query.count()
|
||||
|
||||
# Get paginated results, ordered by timestamp descending (most recent first)
|
||||
logs = query.order_by(desc(APIUsageLog.timestamp)).offset(offset).limit(limit).all()
|
||||
|
||||
# Format logs for response
|
||||
formatted_logs = []
|
||||
for log in logs:
|
||||
# Determine status based on status_code
|
||||
status = 'success' if 200 <= log.status_code < 300 else 'failed'
|
||||
|
||||
# Handle provider display name - ALL MISTRAL enum logs are actually HuggingFace
|
||||
# (HuggingFace always maps to MISTRAL enum in the database)
|
||||
provider_display = log.provider.value if log.provider else None
|
||||
if provider_display == "mistral":
|
||||
# All MISTRAL provider logs are HuggingFace calls
|
||||
provider_display = "huggingface"
|
||||
|
||||
formatted_logs.append({
|
||||
'id': log.id,
|
||||
'timestamp': log.timestamp.isoformat() if log.timestamp else None,
|
||||
'provider': provider_display,
|
||||
'model_used': log.model_used,
|
||||
'endpoint': log.endpoint,
|
||||
'method': log.method,
|
||||
'tokens_input': log.tokens_input or 0,
|
||||
'tokens_output': log.tokens_output or 0,
|
||||
'tokens_total': log.tokens_total or 0,
|
||||
'cost_input': float(log.cost_input) if log.cost_input else 0.0,
|
||||
'cost_output': float(log.cost_output) if log.cost_output else 0.0,
|
||||
'cost_total': float(log.cost_total) if log.cost_total else 0.0,
|
||||
'response_time': float(log.response_time) if log.response_time else 0.0,
|
||||
'status_code': log.status_code,
|
||||
'status': status,
|
||||
'error_message': log.error_message,
|
||||
'billing_period': log.billing_period,
|
||||
'retry_count': log.retry_count or 0,
|
||||
'is_aggregated': log.endpoint == "[AGGREGATED]" # Flag to indicate aggregated log
|
||||
})
|
||||
|
||||
return {
|
||||
"logs": formatted_logs,
|
||||
"total_count": total_count,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"has_more": (offset + limit) < total_count
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting usage logs: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get usage logs: {str(e)}")
|
||||
@@ -498,7 +498,15 @@ async def get_test_authorization_url(state: Optional[str] = None) -> Dict[str, s
|
||||
if not wix_service.client_id:
|
||||
logger.warning("TEST: Wix Client ID not configured, returning mock URL")
|
||||
return {
|
||||
"url": "https://www.wix.com/oauth/access?client_id=YOUR_CLIENT_ID&redirect_uri=http://localhost:3000/wix/callback&response_type=code&scope=BLOG.CREATE-DRAFT,BLOG.PUBLISH,MEDIA.MANAGE&code_challenge=test&code_challenge_method=S256",
|
||||
"url": (
|
||||
"https://www.wix.com/oauth/access?client_id=YOUR_CLIENT_ID"
|
||||
"&redirect_uri=http://localhost:3000/wix/callback"
|
||||
"&response_type=code&scope="
|
||||
"BLOG.CREATE-DRAFT,BLOG.PUBLISH-POST,BLOG.READ-CATEGORY,"
|
||||
"BLOG.CREATE-CATEGORY,BLOG.READ-TAG,BLOG.CREATE-TAG,"
|
||||
"MEDIA.SITE_MEDIA_FILES_IMPORT"
|
||||
"&code_challenge=test&code_challenge_method=S256"
|
||||
),
|
||||
"state": state or "test_state",
|
||||
"message": "WIX_CLIENT_ID not configured. Please set it in your .env file to get a real authorization URL."
|
||||
}
|
||||
@@ -573,9 +581,19 @@ async def test_publish_real(payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
- Derives member_id server-side (required by Wix for third-party apps)
|
||||
"""
|
||||
try:
|
||||
access_token = payload.get("access_token")
|
||||
if not access_token:
|
||||
# Normalize access_token from payload (could be string, dict, or other format)
|
||||
from services.integrations.wix.utils import normalize_token_string
|
||||
raw_access_token = payload.get("access_token")
|
||||
if not raw_access_token:
|
||||
raise HTTPException(status_code=400, detail="Missing access_token")
|
||||
|
||||
# Normalize token to string (handles dict with accessToken.value, int, etc.)
|
||||
access_token = normalize_token_string(raw_access_token)
|
||||
if not access_token:
|
||||
# Fallback: try to convert to string directly
|
||||
access_token = str(raw_access_token).strip()
|
||||
if not access_token or access_token == "None":
|
||||
raise HTTPException(status_code=400, detail="Invalid access_token format")
|
||||
|
||||
# Derive current member id from token (try local decode first, then API fallback)
|
||||
member_id = wix_service.extract_member_id_from_access_token(access_token)
|
||||
|
||||
Reference in New Issue
Block a user