Research component integration, Copilotkit implementation, SEO copilotkit implementation, Wix SEO metadata complete, Wix SEO metadata review

This commit is contained in:
ajaysi
2025-11-03 16:01:44 +05:30
parent de4328175d
commit e69107b07c
94 changed files with 9748 additions and 1565 deletions

View File

@@ -185,10 +185,20 @@ async def get_research_status(task_id: str) -> Dict[str, Any]:
# Outline Endpoints
@router.post("/outline/start")
async def start_outline_generation(request: BlogOutlineRequest) -> Dict[str, Any]:
async def start_outline_generation(
request: BlogOutlineRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""Start an outline generation operation and return a task ID for polling."""
try:
task_id = task_manager.start_outline_task(request)
# Extract Clerk user ID (required)
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id'))
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found in authentication token")
task_id = task_manager.start_outline_task(request, user_id)
return {"task_id": task_id, "status": "started"}
except Exception as e:
logger.error(f"Failed to start outline generation: {e}")
@@ -272,12 +282,22 @@ async def generate_section(request: BlogSectionRequest) -> BlogSectionResponse:
@router.post("/content/start")
async def start_content_generation(request: Dict[str, Any]) -> Dict[str, Any]:
async def start_content_generation(
request: Dict[str, Any],
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""Start full content generation and return a task id for polling.
Accepts a payload compatible with MediumBlogGenerateRequest to minimize duplication.
"""
try:
# Extract Clerk user ID (required)
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id'))
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found in authentication token")
# Map dict to MediumBlogGenerateRequest for reuse
from models.blog_models import MediumBlogGenerateRequest, MediumSectionOutline, PersonaInfo
sections = [MediumSectionOutline(**s) for s in request.get("sections", [])]
@@ -293,7 +313,7 @@ async def start_content_generation(request: Dict[str, Any]) -> Dict[str, Any]:
globalTargetWords=request.get("globalTargetWords", 1000),
researchKeywords=request.get("researchKeywords") or request.get("keywords"),
)
task_id = task_manager.start_content_generation_task(req)
task_id = task_manager.start_content_generation_task(req, user_id)
return {"task_id": task_id, "status": "started"}
except Exception as e:
logger.error(f"Failed to start content generation: {e}")
@@ -307,6 +327,51 @@ async def content_generation_status(task_id: str) -> Dict[str, Any]:
status = await task_manager.get_task_status(task_id)
if status is None:
raise HTTPException(status_code=404, detail="Task not found")
# If task failed with subscription error, return HTTP error so frontend interceptor can catch it
if status.get('status') == 'failed' and status.get('error_status') in [429, 402]:
error_data = status.get('error_data', {}) or {}
error_status = status.get('error_status', 429)
if not isinstance(error_data, dict):
logger.warning(f"Content generation task {task_id} error_data not dict: {error_data}")
error_data = {'error': str(error_data)}
# Determine provider and usage info
stored_error_message = status.get('error', error_data.get('error'))
provider = error_data.get('provider', 'unknown')
usage_info = error_data.get('usage_info')
if not usage_info:
usage_info = {
'provider': provider,
'message': stored_error_message,
'error_type': error_data.get('error_type', 'unknown')
}
# Include any known fields from error_data
for key in ['current_tokens', 'requested_tokens', 'limit', 'current_calls']:
if key in error_data:
usage_info[key] = error_data[key]
# Build error message for detail
error_msg = error_data.get('message', stored_error_message or 'Subscription limit exceeded')
# Log the subscription error with all context
logger.warning(f"Content generation task {task_id} failed with subscription error {error_status}: {error_msg}")
logger.warning(f" Provider: {provider}, Usage Info: {usage_info}")
# Use JSONResponse to ensure detail is returned as-is, not wrapped in an array
from fastapi.responses import JSONResponse
return JSONResponse(
status_code=error_status,
content={
'error': error_data.get('error', stored_error_message or 'Subscription limit exceeded'),
'message': error_msg,
'provider': provider,
'usage_info': usage_info
}
)
return status
except HTTPException:
raise
@@ -499,14 +564,24 @@ async def get_outline_cache_entries(limit: int = 20):
# ---------------------------
@router.post("/generate/medium/start")
async def start_medium_generation(request: MediumBlogGenerateRequest):
async def start_medium_generation(
request: MediumBlogGenerateRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Start medium-length blog generation (≤1000 words) and return a task id."""
try:
# Extract Clerk user ID (required)
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id'))
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found in authentication token")
# Simple server-side guard
if (request.globalTargetWords or 1000) > 1000:
raise HTTPException(status_code=400, detail="Global target words exceed 1000; use per-section generation")
task_id = task_manager.start_medium_generation_task(request)
task_id = task_manager.start_medium_generation_task(request, user_id)
return {"task_id": task_id, "status": "started"}
except HTTPException:
raise
@@ -522,6 +597,51 @@ async def medium_generation_status(task_id: str):
status = await task_manager.get_task_status(task_id)
if status is None:
raise HTTPException(status_code=404, detail="Task not found")
# If task failed with subscription error, return HTTP error so frontend interceptor can catch it
if status.get('status') == 'failed' and status.get('error_status') in [429, 402]:
error_data = status.get('error_data', {}) or {}
error_status = status.get('error_status', 429)
if not isinstance(error_data, dict):
logger.warning(f"Medium generation task {task_id} error_data not dict: {error_data}")
error_data = {'error': str(error_data)}
# Determine provider and usage info
stored_error_message = status.get('error', error_data.get('error'))
provider = error_data.get('provider', 'unknown')
usage_info = error_data.get('usage_info')
if not usage_info:
usage_info = {
'provider': provider,
'message': stored_error_message,
'error_type': error_data.get('error_type', 'unknown')
}
# Include any known fields from error_data
for key in ['current_tokens', 'requested_tokens', 'limit', 'current_calls']:
if key in error_data:
usage_info[key] = error_data[key]
# Build error message for detail
error_msg = error_data.get('message', stored_error_message or 'Subscription limit exceeded')
# Log the subscription error with all context
logger.warning(f"Medium generation task {task_id} failed with subscription error {error_status}: {error_msg}")
logger.warning(f" Provider: {provider}, Usage Info: {usage_info}")
# Use JSONResponse to ensure detail is returned as-is, not wrapped in an array
from fastapi.responses import JSONResponse
return JSONResponse(
status_code=error_status,
content={
'error': error_data.get('error', stored_error_message or 'Subscription limit exceeded'),
'message': error_msg,
'provider': provider,
'usage_info': usage_info
}
)
return status
except HTTPException:
raise

View File

@@ -5,7 +5,7 @@ Provides API endpoint for analyzing blog content SEO with parallel processing
and CopilotKit integration for real-time progress updates.
"""
from fastapi import APIRouter, HTTPException, BackgroundTasks
from fastapi import APIRouter, HTTPException, BackgroundTasks, Depends
from pydantic import BaseModel
from typing import Dict, Any, Optional
from loguru import logger
@@ -13,6 +13,7 @@ from datetime import datetime
from services.blog_writer.seo.blog_content_seo_analyzer import BlogContentSEOAnalyzer
from services.blog_writer.core.blog_writer_service import BlogWriterService
from middleware.auth_middleware import get_current_user
router = APIRouter(prefix="/api/blog-writer/seo", tags=["Blog SEO Analysis"])
@@ -56,7 +57,10 @@ blog_writer_service = BlogWriterService()
@router.post("/analyze", response_model=SEOAnalysisResponse)
async def analyze_blog_seo(request: SEOAnalysisRequest):
async def analyze_blog_seo(
request: SEOAnalysisRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
):
"""
Analyze blog content for SEO optimization
@@ -69,6 +73,7 @@ async def analyze_blog_seo(request: SEOAnalysisRequest):
Args:
request: SEOAnalysisRequest containing blog content and research data
current_user: Authenticated user from middleware
Returns:
SEOAnalysisResponse with comprehensive analysis results
@@ -76,6 +81,14 @@ async def analyze_blog_seo(request: SEOAnalysisRequest):
try:
logger.info(f"Starting SEO analysis for blog content")
# Extract Clerk user ID (required)
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
# Validate request
if not request.blog_content or not request.blog_content.strip():
raise HTTPException(status_code=400, detail="Blog content is required")
@@ -91,7 +104,8 @@ async def analyze_blog_seo(request: SEOAnalysisRequest):
analysis_results = await seo_analyzer.analyze_blog_content(
blog_content=request.blog_content,
research_data=request.research_data,
blog_title=request.blog_title
blog_title=request.blog_title,
user_id=user_id
)
# Check for errors
@@ -131,7 +145,10 @@ async def analyze_blog_seo(request: SEOAnalysisRequest):
@router.post("/analyze-with-progress")
async def analyze_blog_seo_with_progress(request: SEOAnalysisRequest):
async def analyze_blog_seo_with_progress(
request: SEOAnalysisRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
):
"""
Analyze blog content for SEO with real-time progress updates
@@ -140,6 +157,7 @@ async def analyze_blog_seo_with_progress(request: SEOAnalysisRequest):
Args:
request: SEOAnalysisRequest containing blog content and research data
current_user: Authenticated user from middleware
Returns:
Generator yielding progress updates and final results
@@ -147,6 +165,14 @@ async def analyze_blog_seo_with_progress(request: SEOAnalysisRequest):
try:
logger.info(f"Starting SEO analysis with progress for blog content")
# Extract Clerk user ID (required)
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
# Validate request
if not request.blog_content or not request.blog_content.strip():
raise HTTPException(status_code=400, detail="Blog content is required")
@@ -209,7 +235,9 @@ async def analyze_blog_seo_with_progress(request: SEOAnalysisRequest):
# Perform actual analysis
analysis_results = await seo_analyzer.analyze_blog_content(
blog_content=request.blog_content,
research_data=request.research_data
research_data=request.research_data,
blog_title=request.blog_title,
user_id=user_id
)
# Final result

View File

@@ -88,8 +88,12 @@ class TaskManager:
response["error"] = task["error"]
if "error_status" in task:
response["error_status"] = task["error_status"]
logger.info(f"[TaskManager] get_task_status for {task_id}: Including error_status={task['error_status']} in response")
if "error_data" in task:
response["error_data"] = task["error_data"]
logger.info(f"[TaskManager] get_task_status for {task_id}: Including error_data with keys: {list(task['error_data'].keys()) if isinstance(task['error_data'], dict) else 'not-dict'}")
else:
logger.warning(f"[TaskManager] get_task_status for {task_id}: Task failed but no error_data found. Task keys: {list(task.keys())}")
return response
@@ -127,29 +131,33 @@ class TaskManager:
asyncio.create_task(self._run_research_task(task_id, request, user_id))
return task_id
def start_outline_task(self, request: BlogOutlineRequest) -> str:
def start_outline_task(self, request: BlogOutlineRequest, user_id: str) -> str:
"""Start an outline generation operation and return a task ID."""
task_id = self.create_task("outline")
# Start the outline generation operation in the background
asyncio.create_task(self._run_outline_generation_task(task_id, request))
asyncio.create_task(self._run_outline_generation_task(task_id, request, user_id))
return task_id
def start_medium_generation_task(self, request: MediumBlogGenerateRequest) -> str:
def start_medium_generation_task(self, request: MediumBlogGenerateRequest, user_id: str) -> str:
"""Start a medium (≤1000 words) full-blog generation task."""
task_id = self.create_task("medium_generation")
asyncio.create_task(self._run_medium_generation_task(task_id, request))
asyncio.create_task(self._run_medium_generation_task(task_id, request, user_id))
return task_id
def start_content_generation_task(self, request: MediumBlogGenerateRequest) -> str:
def start_content_generation_task(self, request: MediumBlogGenerateRequest, user_id: str) -> str:
"""Start content generation (full blog via sections) with provider parity.
Internally reuses medium generator pipeline for now but tracked under
distinct task_type 'content_generation' and same polling contract.
Args:
request: Content generation request
user_id: User ID (required for subscription checks and usage tracking)
"""
task_id = self.create_task("content_generation")
asyncio.create_task(self._run_medium_generation_task(task_id, request))
asyncio.create_task(self._run_medium_generation_task(task_id, request, user_id))
return task_id
async def _run_research_task(self, task_id: str, request: BlogResearchRequest, user_id: str):
@@ -205,7 +213,7 @@ class TaskManager:
self.task_storage[task_id]["status"] = "failed"
self.task_storage[task_id]["error"] = "Research completed with unknown status"
async def _run_outline_generation_task(self, task_id: str, request: BlogOutlineRequest):
async def _run_outline_generation_task(self, task_id: str, request: BlogOutlineRequest, user_id: str):
"""Background task to run outline generation and update status with progress messages."""
try:
# Update status to running
@@ -215,21 +223,31 @@ class TaskManager:
# Send initial progress message
await self.update_progress(task_id, "🧩 Starting outline generation...")
# Run the actual outline generation with progress updates
result = await self.service.generate_outline_with_progress(request, task_id)
# Run the actual outline generation with progress updates (pass user_id for subscription checks)
result = await self.service.generate_outline_with_progress(request, task_id, user_id)
# Update status to completed
await self.update_progress(task_id, f"✅ Outline generated successfully! Created {len(result.outline)} sections with {len(result.title_options)} title options.")
self.task_storage[task_id]["status"] = "completed"
self.task_storage[task_id]["result"] = result.dict()
except HTTPException as http_error:
# Handle HTTPException (e.g., 429 subscription limit) - preserve error details for frontend
error_detail = http_error.detail
error_message = error_detail.get('message', str(error_detail)) if isinstance(error_detail, dict) else str(error_detail)
await self.update_progress(task_id, f"{error_message}")
self.task_storage[task_id]["status"] = "failed"
self.task_storage[task_id]["error"] = error_message
# Store HTTP error details for frontend modal
self.task_storage[task_id]["error_status"] = http_error.status_code
self.task_storage[task_id]["error_data"] = error_detail if isinstance(error_detail, dict) else {"error": str(error_detail)}
except Exception as e:
await self.update_progress(task_id, f"❌ Outline generation failed: {str(e)}")
# Update status to failed
self.task_storage[task_id]["status"] = "failed"
self.task_storage[task_id]["error"] = str(e)
async def _run_medium_generation_task(self, task_id: str, request: MediumBlogGenerateRequest):
async def _run_medium_generation_task(self, task_id: str, request: MediumBlogGenerateRequest, user_id: str):
"""Background task to generate a medium blog using a single structured JSON call."""
try:
self.task_storage[task_id]["status"] = "running"
@@ -245,6 +263,7 @@ class TaskManager:
result: MediumBlogGenerateResult = await self.service.generate_medium_blog_with_progress(
request,
task_id,
user_id
)
if not result or not getattr(result, "sections", None):
@@ -263,10 +282,38 @@ class TaskManager:
self.task_storage[task_id]["result"] = result.dict()
await self.update_progress(task_id, f"✅ Generated {len(result.sections)} sections successfully.")
except Exception as e:
await self.update_progress(task_id, f"❌ Medium generation failed: {str(e)}")
except HTTPException as http_error:
# Handle HTTPException (e.g., 429 subscription limit) - preserve error details for frontend
logger.info(f"[TaskManager] Caught HTTPException in medium generation task {task_id}: status={http_error.status_code}, detail={http_error.detail}")
error_detail = http_error.detail
error_message = error_detail.get('message', str(error_detail)) if isinstance(error_detail, dict) else str(error_detail)
await self.update_progress(task_id, f"{error_message}")
self.task_storage[task_id]["status"] = "failed"
self.task_storage[task_id]["error"] = str(e)
self.task_storage[task_id]["error"] = error_message
# Store HTTP error details for frontend modal
self.task_storage[task_id]["error_status"] = http_error.status_code
self.task_storage[task_id]["error_data"] = error_detail if isinstance(error_detail, dict) else {"error": str(error_detail)}
logger.info(f"[TaskManager] Stored error_status={http_error.status_code} and error_data keys: {list(error_detail.keys()) if isinstance(error_detail, dict) else 'not-dict'}")
except Exception as e:
# Check if this is an HTTPException that got wrapped (can happen in async tasks)
# HTTPException has status_code and detail attributes
logger.info(f"[TaskManager] Caught Exception in medium generation task {task_id}: type={type(e).__name__}, has_status_code={hasattr(e, 'status_code')}, has_detail={hasattr(e, 'detail')}")
if hasattr(e, 'status_code') and hasattr(e, 'detail'):
# This is an HTTPException that was caught as generic Exception
logger.info(f"[TaskManager] Detected HTTPException in Exception handler: status={e.status_code}, detail={e.detail}")
error_detail = e.detail
error_message = error_detail.get('message', str(error_detail)) if isinstance(error_detail, dict) else str(error_detail)
await self.update_progress(task_id, f"{error_message}")
self.task_storage[task_id]["status"] = "failed"
self.task_storage[task_id]["error"] = error_message
# Store HTTP error details for frontend modal
self.task_storage[task_id]["error_status"] = e.status_code
self.task_storage[task_id]["error_data"] = error_detail if isinstance(error_detail, dict) else {"error": str(error_detail)}
logger.info(f"[TaskManager] Stored error_status={e.status_code} and error_data keys: {list(error_detail.keys()) if isinstance(error_detail, dict) else 'not-dict'}")
else:
await self.update_progress(task_id, f"❌ Medium generation failed: {str(e)}")
self.task_storage[task_id]["status"] = "failed"
self.task_storage[task_id]["error"] = str(e)
# Global task manager instance

View File

@@ -12,6 +12,7 @@ from functools import lru_cache
from services.database import get_db
from services.subscription import UsageTrackingService, PricingService
from services.subscription.schema_utils import ensure_subscription_plan_columns
from middleware.auth_middleware import get_current_user
from models.subscription_models import (
APIProvider, SubscriptionPlan, UserSubscription, UsageSummary,
@@ -79,6 +80,8 @@ async def get_subscription_plans(
"""Get all available subscription plans."""
try:
# Ensure required columns exist (handles environments without migrations applied yet)
ensure_subscription_plan_columns(db)
plans = db.query(SubscriptionPlan).filter(
SubscriptionPlan.is_active == True
).order_by(SubscriptionPlan.price_monthly).all()
@@ -137,6 +140,7 @@ async def get_user_subscription(
raise HTTPException(status_code=403, detail="Access denied")
try:
ensure_subscription_plan_columns(db)
subscription = db.query(UserSubscription).filter(
UserSubscription.user_id == user_id,
UserSubscription.is_active == True
@@ -234,6 +238,7 @@ async def get_subscription_status(
raise HTTPException(status_code=403, detail="Access denied")
try:
ensure_subscription_plan_columns(db)
subscription = db.query(UserSubscription).filter(
UserSubscription.user_id == user_id,
UserSubscription.is_active == True
@@ -346,6 +351,7 @@ async def subscribe_to_plan(
raise HTTPException(status_code=403, detail="Access denied")
try:
ensure_subscription_plan_columns(db)
plan_id = subscription_data.get('plan_id')
billing_cycle = subscription_data.get('billing_cycle', 'monthly')
@@ -427,11 +433,16 @@ async def subscribe_to_plan(
logger.info(f" 📊 No usage summary found for period {current_period} (will be created on reset)")
# Clear subscription limits cache to force refresh on next check
# IMPORTANT: Do this BEFORE resetting usage to ensure cache is cleared first
try:
from services.subscription import PricingService
# Clear cache for this specific user (class-level cache shared across all instances)
cleared_count = PricingService.clear_user_cache(user_id)
logger.info(f" 🗑️ Cleared {cleared_count} subscription cache entries for user {user_id}")
# Also expire all SQLAlchemy objects to force fresh reads
db.expire_all()
logger.info(f" 🔄 Expired all SQLAlchemy objects to force fresh reads")
except Exception as cache_err:
logger.error(f" ❌ Failed to clear cache after subscribe: {cache_err}")
@@ -441,12 +452,22 @@ async def subscribe_to_plan(
usage_service = UsageTrackingService(db)
reset_result = await usage_service.reset_current_billing_period(user_id)
# Re-query usage summary from DB after reset to get fresh data
# Force commit to ensure reset is persisted
db.commit()
# Expire all SQLAlchemy objects to force fresh reads
db.expire_all()
# Re-query usage summary from DB after reset to get fresh data (fresh query)
usage_after = db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
# Refresh the usage object if found to ensure we have latest data
if usage_after:
db.refresh(usage_after)
if reset_result.get('reset'):
logger.info(f" ✅ Usage counters RESET successfully")
if usage_after:
@@ -635,6 +656,7 @@ async def get_dashboard_data(
"""Get comprehensive dashboard data for usage monitoring."""
try:
ensure_subscription_plan_columns(db)
# Serve from short TTL cache to avoid hammering DB on bursts
import time
now = time.time()

View File

@@ -535,15 +535,33 @@ async def test_publish_real(payload: Dict[str, Any]) -> Dict[str, Any]:
if not member_id:
raise HTTPException(status_code=400, detail="Unable to resolve member_id from token")
# Extract SEO metadata if provided
seo_metadata = payload.get("seo_metadata")
# Extract category/tag IDs or names
# Can be either:
# - IDs: List of UUID strings
# - Names: List of name strings (will be looked up/created)
category_ids = payload.get("category_ids") or payload.get("category_names")
tag_ids = payload.get("tag_ids") or payload.get("tag_names")
# If SEO metadata has categories/tags but they weren't explicitly provided, use them
if seo_metadata:
if not category_ids and seo_metadata.get("blog_categories"):
category_ids = seo_metadata.get("blog_categories")
if not tag_ids and seo_metadata.get("blog_tags"):
tag_ids = seo_metadata.get("blog_tags")
result = wix_service.create_blog_post(
access_token=access_token,
title=payload.get("title") or "Untitled",
content=payload.get("content") or "",
cover_image_url=payload.get("cover_image_url"),
category_ids=payload.get("category_ids") or None,
tag_ids=payload.get("tag_ids") or None,
category_ids=category_ids,
tag_ids=tag_ids,
publish=bool(payload.get("publish", True)),
member_id=member_id,
seo_metadata=seo_metadata,
)
return {

View File

@@ -11,6 +11,7 @@ import asyncio
from datetime import datetime
from services.subscription import monitoring_middleware
# Import modular utilities
from alwrity_utils import HealthChecker, RateLimiter, FrontendServing, RouterManager, OnboardingManager

View File

@@ -0,0 +1,17 @@
-- Add EXA to subscription plans
ALTER TABLE subscription_plans
ADD COLUMN exa_calls_limit INT DEFAULT 0;
-- Add EXA to usage summaries
ALTER TABLE usage_summaries
ADD COLUMN exa_calls INT DEFAULT 0;
ALTER TABLE usage_summaries
ADD COLUMN exa_cost FLOAT DEFAULT 0.0;
-- Update default limits for existing plans
UPDATE subscription_plans SET exa_calls_limit = 100 WHERE tier = 'free';
UPDATE subscription_plans SET exa_calls_limit = 500 WHERE tier = 'basic';
UPDATE subscription_plans SET exa_calls_limit = 2000 WHERE tier = 'pro';
UPDATE subscription_plans SET exa_calls_limit = 0 WHERE tier = 'enterprise';

View File

@@ -1,5 +1,6 @@
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
from enum import Enum
class PersonaInfo(BaseModel):
@@ -50,6 +51,51 @@ class GroundingMetadata(BaseModel):
web_search_queries: List[str] = []
class ResearchMode(str, Enum):
"""Research modes for different depth levels."""
BASIC = "basic"
COMPREHENSIVE = "comprehensive"
TARGETED = "targeted"
class SourceType(str, Enum):
"""Types of sources to include in research."""
WEB = "web"
ACADEMIC = "academic"
NEWS = "news"
INDUSTRY = "industry"
EXPERT = "expert"
class DateRange(str, Enum):
"""Date range filters for research."""
LAST_WEEK = "last_week"
LAST_MONTH = "last_month"
LAST_3_MONTHS = "last_3_months"
LAST_6_MONTHS = "last_6_months"
LAST_YEAR = "last_year"
ALL_TIME = "all_time"
class ResearchProvider(str, Enum):
"""Research provider options."""
GOOGLE = "google" # Gemini native grounding
EXA = "exa" # Exa neural search
class ResearchConfig(BaseModel):
"""Configuration for research execution."""
mode: ResearchMode = ResearchMode.BASIC
provider: ResearchProvider = ResearchProvider.GOOGLE
date_range: Optional[DateRange] = None
source_types: List[SourceType] = []
max_sources: int = 10
include_statistics: bool = True
include_expert_quotes: bool = True
include_competitors: bool = True
include_trends: bool = True
class BlogResearchRequest(BaseModel):
keywords: List[str]
topic: Optional[str] = None
@@ -58,6 +104,8 @@ class BlogResearchRequest(BaseModel):
tone: Optional[str] = None
word_count_target: Optional[int] = 1500
persona: Optional[PersonaInfo] = None
research_mode: Optional[ResearchMode] = ResearchMode.BASIC
config: Optional[ResearchConfig] = None
class BlogResearchResponse(BaseModel):

View File

@@ -34,6 +34,7 @@ class APIProvider(enum.Enum):
METAPHOR = "metaphor"
FIRECRAWL = "firecrawl"
STABILITY = "stability"
EXA = "exa"
class BillingCycle(enum.Enum):
MONTHLY = "monthly"
@@ -66,6 +67,7 @@ class SubscriptionPlan(Base):
metaphor_calls_limit = Column(Integer, default=0)
firecrawl_calls_limit = Column(Integer, default=0)
stability_calls_limit = Column(Integer, default=0) # Image generation
exa_calls_limit = Column(Integer, default=0) # Exa neural search
# Token Limits (for LLM providers)
gemini_tokens_limit = Column(Integer, default=0)
@@ -182,6 +184,7 @@ class UsageSummary(Base):
metaphor_calls = Column(Integer, default=0)
firecrawl_calls = Column(Integer, default=0)
stability_calls = Column(Integer, default=0)
exa_calls = Column(Integer, default=0)
# Token Usage
gemini_tokens = Column(Integer, default=0)
@@ -199,6 +202,7 @@ class UsageSummary(Base):
metaphor_cost = Column(Float, default=0.0)
firecrawl_cost = Column(Float, default=0.0)
stability_cost = Column(Float, default=0.0)
exa_cost = Column(Float, default=0.0)
# Totals
total_calls = Column(Integer, default=0)

View File

@@ -38,6 +38,7 @@ aiohttp>=3.9.0
# Data processing
pandas>=2.0.0
numpy>=1.24.0
markdown>=3.5.0
# SEO Analysis dependencies
advertools>=0.14.0

View File

@@ -3,7 +3,7 @@ Script to update Basic plan subscription limits for testing rate limits and rene
Updates:
- LLM Calls (all providers): 10 calls (was 500-1000)
- LLM Tokens (all providers): 2000 tokens (was 200k-1M)
- LLM Tokens (all providers): 5000 tokens (increased from 2000 to support research workflow)
- Images: 5 images (was 50)
This script updates the SubscriptionPlan table, which automatically applies to all users
@@ -69,11 +69,11 @@ def update_basic_plan_limits():
basic_plan.anthropic_calls_limit = 200
basic_plan.mistral_calls_limit = 500
# Update all LLM provider token limits to 2000
basic_plan.gemini_tokens_limit = 2000
basic_plan.openai_tokens_limit = 2000
basic_plan.anthropic_tokens_limit = 2000
basic_plan.mistral_tokens_limit = 2000
# Update all LLM provider token limits to 20000 (increased from 5000 for better stability)
basic_plan.gemini_tokens_limit = 20000
basic_plan.openai_tokens_limit = 20000
basic_plan.anthropic_tokens_limit = 20000
basic_plan.mistral_tokens_limit = 20000
# Update image generation limit to 5
basic_plan.stability_calls_limit = 5
@@ -83,7 +83,7 @@ def update_basic_plan_limits():
logger.info("\n📝 New Basic plan limits:")
logger.info(f" LLM Calls (all providers): 10")
logger.info(f" LLM Tokens (all providers): 2000")
logger.info(f" LLM Tokens (all providers): 20000 (increased from 5000)")
logger.info(f" Images: 5")
# Count and get affected users
@@ -118,7 +118,7 @@ def update_basic_plan_limits():
# New limits - use unified AI text generation limit if available
new_call_limit = getattr(basic_plan, 'ai_text_generation_calls_limit', None) or basic_plan.gemini_calls_limit
new_token_limit = basic_plan.gemini_tokens_limit # 2000
new_token_limit = basic_plan.gemini_tokens_limit # 5000 (increased from 2000)
new_image_limit = basic_plan.stability_calls_limit # 5
for sub in user_subscriptions:
@@ -253,7 +253,7 @@ if __name__ == "__main__":
logger.info("="*60)
logger.info("This will update Basic plan limits for testing rate limits:")
logger.info(" - LLM Calls: 10 (all providers)")
logger.info(" - LLM Tokens: 2000 (all providers)")
logger.info(" - LLM Tokens: 20000 (all providers, increased from 5000)")
logger.info(" - Images: 5")
logger.info("="*60)

View File

@@ -8,6 +8,7 @@ import time
import json
from typing import Dict, Any, List
from loguru import logger
from fastapi import HTTPException
from models.blog_models import (
MediumBlogGenerateRequest,
@@ -25,8 +26,20 @@ class MediumBlogGenerator:
def __init__(self):
self.cache = persistent_content_cache
async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str) -> MediumBlogGenerateResult:
"""Use Gemini structured JSON to generate a medium-length blog in one call."""
async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str, user_id: str) -> MediumBlogGenerateResult:
"""Use Gemini structured JSON to generate a medium-length blog in one call.
Args:
req: Medium blog generation request
task_id: Task ID for progress updates
user_id: User ID (required for subscription checks and usage tracking)
Raises:
ValueError: If user_id is not provided
"""
if not user_id:
raise ValueError("user_id is required for medium blog generation (subscription checks and usage tracking)")
import time
start = time.time()
@@ -156,7 +169,7 @@ class MediumBlogGenerator:
- Use language that resonates with {audience}
- Maintain consistent voice that reflects this persona's expertise
"""
prompt = (
f"Write blog content for the following sections. Each section should be {req.globalTargetWords or 1000} words total, distributed across all sections.\n\n"
f"Blog Title: {req.title}\n\n"
@@ -176,11 +189,20 @@ class MediumBlogGenerator:
f"Sections to write:\n{json.dumps(payload, ensure_ascii=False, indent=2)}"
)
ai_resp = llm_text_gen(
prompt=prompt,
json_struct=schema,
system_prompt=system,
)
try:
ai_resp = llm_text_gen(
prompt=prompt,
json_struct=schema,
system_prompt=system,
user_id=user_id
)
except HTTPException:
# Re-raise HTTPExceptions (e.g., 429 subscription limit) to preserve error details
raise
except Exception as llm_error:
# Wrap other errors
logger.error(f"AI generation failed: {llm_error}")
raise Exception(f"AI generation failed: {str(llm_error)}")
# Check for errors in AI response
if not ai_resp or ai_resp.get("error"):

View File

@@ -105,13 +105,20 @@ class BlogWriterService:
return await self.research_service.research_with_progress(request, task_id, user_id)
# Outline Methods
async def generate_outline(self, request: BlogOutlineRequest) -> BlogOutlineResponse:
"""Generate AI-powered outline from research data."""
return await self.outline_service.generate_outline(request)
async def generate_outline(self, request: BlogOutlineRequest, user_id: str) -> BlogOutlineResponse:
"""Generate AI-powered outline from research data.
Args:
request: Outline generation request with research data
user_id: User ID (required for subscription checks and usage tracking)
"""
if not user_id:
raise ValueError("user_id is required for outline generation (subscription checks and usage tracking)")
return await self.outline_service.generate_outline(request, user_id)
async def generate_outline_with_progress(self, request: BlogOutlineRequest, task_id: str) -> BlogOutlineResponse:
async def generate_outline_with_progress(self, request: BlogOutlineRequest, task_id: str, user_id: str) -> BlogOutlineResponse:
"""Generate outline with real-time progress updates."""
return await self.outline_service.generate_outline_with_progress(request, task_id)
return await self.outline_service.generate_outline_with_progress(request, task_id, user_id)
async def refine_outline(self, request: BlogOutlineRefineRequest) -> BlogOutlineResponse:
"""Refine outline with HITL operations."""
@@ -334,9 +341,17 @@ class BlogWriterService:
# TODO: Move to content module
return BlogPublishResponse(success=True, platform=request.platform, url="https://example.com/post")
async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str) -> MediumBlogGenerateResult:
"""Use Gemini structured JSON to generate a medium-length blog in one call."""
return await self.medium_blog_generator.generate_medium_blog_with_progress(req, task_id)
async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str, user_id: str) -> MediumBlogGenerateResult:
"""Use Gemini structured JSON to generate a medium-length blog in one call.
Args:
req: Medium blog generation request
task_id: Task ID for progress updates
user_id: User ID (required for subscription checks and usage tracking)
"""
if not user_id:
raise ValueError("user_id is required for medium blog generation (subscription checks and usage tracking)")
return await self.medium_blog_generator.generate_medium_blog_with_progress(req, task_id, user_id)
async def analyze_flow_basic(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""Analyze flow metrics for entire blog using single AI call (cost-effective)."""

View File

@@ -42,10 +42,20 @@ class OutlineGenerator:
self.response_processor = ResponseProcessor()
self.parallel_processor = ParallelProcessor(self.source_mapper, self.grounding_engine)
async def generate(self, request: BlogOutlineRequest) -> BlogOutlineResponse:
async def generate(self, request: BlogOutlineRequest, user_id: str) -> BlogOutlineResponse:
"""
Generate AI-powered outline using research results
Generate AI-powered outline using research results.
Args:
request: Outline generation request with research data
user_id: User ID (required for subscription checks and usage tracking)
Raises:
ValueError: If user_id is not provided
"""
if not user_id:
raise ValueError("user_id is required for outline generation (subscription checks and usage tracking)")
# Extract research insights
research = request.research
primary_keywords = research.keyword_analysis.get('primary', [])
@@ -68,15 +78,15 @@ class OutlineGenerator:
# Define schema with proper property ordering (critical for Gemini API)
outline_schema = self.prompt_builder.get_outline_schema()
# Generate outline using structured JSON response with retry logic
outline_data = await self.response_processor.generate_with_retry(outline_prompt, outline_schema)
# Generate outline using structured JSON response with retry logic (user_id required)
outline_data = await self.response_processor.generate_with_retry(outline_prompt, outline_schema, user_id)
# Convert to BlogOutlineSection objects
outline_sections = self.response_processor.convert_to_sections(outline_data, sources)
# Run parallel processing for speed optimization
# Run parallel processing for speed optimization (user_id required)
mapped_sections, grounding_insights = await self.parallel_processor.run_parallel_processing_async(
outline_sections, research
outline_sections, research, user_id
)
# Enhance sections with grounding insights
@@ -85,9 +95,9 @@ class OutlineGenerator:
mapped_sections, research.grounding_metadata, grounding_insights
)
# Optimize outline for better flow, SEO, and engagement
# Optimize outline for better flow, SEO, and engagement (user_id required)
logger.info("Optimizing outline for better flow and engagement...")
optimized_sections = await self.outline_optimizer.optimize(grounding_enhanced_sections, "comprehensive optimization")
optimized_sections = await self.outline_optimizer.optimize(grounding_enhanced_sections, "comprehensive optimization", user_id)
# Rebalance word counts for optimal distribution
target_words = request.word_count or 1500
@@ -118,10 +128,21 @@ class OutlineGenerator:
research_coverage=research_coverage
)
async def generate_with_progress(self, request: BlogOutlineRequest, task_id: str) -> BlogOutlineResponse:
async def generate_with_progress(self, request: BlogOutlineRequest, task_id: str, user_id: str) -> BlogOutlineResponse:
"""
Outline generation method with progress updates for real-time feedback.
Args:
request: Outline generation request with research data
task_id: Task ID for progress updates
user_id: User ID (required for subscription checks and usage tracking)
Raises:
ValueError: If user_id is not provided
"""
if not user_id:
raise ValueError("user_id is required for outline generation (subscription checks and usage tracking)")
from api.blog_writer.task_manager import task_manager
# Extract research insights
@@ -150,17 +171,17 @@ class OutlineGenerator:
await task_manager.update_progress(task_id, "🔄 Making AI request to generate structured outline...")
# Generate outline using structured JSON response with retry logic
outline_data = await self.response_processor.generate_with_retry(outline_prompt, outline_schema, task_id)
# Generate outline using structured JSON response with retry logic (user_id required for subscription checks)
outline_data = await self.response_processor.generate_with_retry(outline_prompt, outline_schema, user_id, task_id)
await task_manager.update_progress(task_id, "📝 Processing outline structure and validating sections...")
# Convert to BlogOutlineSection objects
outline_sections = self.response_processor.convert_to_sections(outline_data, sources)
# Run parallel processing for speed optimization
# Run parallel processing for speed optimization (user_id required for subscription checks)
mapped_sections, grounding_insights = await self.parallel_processor.run_parallel_processing(
outline_sections, research, task_id
outline_sections, research, user_id, task_id
)
# Enhance sections with grounding insights (depends on both previous tasks)
@@ -169,9 +190,9 @@ class OutlineGenerator:
mapped_sections, research.grounding_metadata, grounding_insights
)
# Optimize outline for better flow, SEO, and engagement
# Optimize outline for better flow, SEO, and engagement (user_id required for subscription checks)
await task_manager.update_progress(task_id, "🎯 Optimizing outline for better flow and engagement...")
optimized_sections = await self.outline_optimizer.optimize(grounding_enhanced_sections, "comprehensive optimization")
optimized_sections = await self.outline_optimizer.optimize(grounding_enhanced_sections, "comprehensive optimization", user_id)
# Rebalance word counts for optimal distribution
await task_manager.update_progress(task_id, "⚖️ Rebalancing word count distribution...")

View File

@@ -13,8 +13,23 @@ from models.blog_models import BlogOutlineSection
class OutlineOptimizer:
"""Optimizes outlines for better flow, SEO, and engagement."""
async def optimize(self, outline: List[BlogOutlineSection], focus: str = "general optimization") -> List[BlogOutlineSection]:
"""Optimize entire outline for better flow, SEO, and engagement."""
async def optimize(self, outline: List[BlogOutlineSection], focus: str, user_id: str) -> List[BlogOutlineSection]:
"""Optimize entire outline for better flow, SEO, and engagement.
Args:
outline: List of outline sections to optimize
focus: Optimization focus (e.g., "general optimization")
user_id: User ID (required for subscription checks and usage tracking)
Returns:
List of optimized outline sections
Raises:
ValueError: If user_id is not provided
"""
if not user_id:
raise ValueError("user_id is required for outline optimization (subscription checks and usage tracking)")
outline_text = "\n".join([f"{i+1}. {s.heading}" for i, s in enumerate(outline)])
optimization_prompt = f"""Optimize this blog outline for better flow, engagement, and SEO:
@@ -67,7 +82,8 @@ Return JSON format:
optimized_data = llm_text_gen(
prompt=optimization_prompt,
json_struct=optimization_schema,
system_prompt=None
system_prompt=None,
user_id=user_id
)
# Handle the new schema format with "outline" wrapper

View File

@@ -29,11 +29,21 @@ class OutlineService:
self.outline_optimizer = OutlineOptimizer()
self.section_enhancer = SectionEnhancer()
async def generate_outline(self, request: BlogOutlineRequest) -> BlogOutlineResponse:
async def generate_outline(self, request: BlogOutlineRequest, user_id: str) -> BlogOutlineResponse:
"""
Stage 2: Content Planning with AI-generated outline using research results
Uses Gemini with research data to create comprehensive, SEO-optimized outline
Stage 2: Content Planning with AI-generated outline using research results.
Uses Gemini with research data to create comprehensive, SEO-optimized outline.
Args:
request: Outline generation request with research data
user_id: User ID (required for subscription checks and usage tracking)
Raises:
ValueError: If user_id is not provided
"""
if not user_id:
raise ValueError("user_id is required for outline generation (subscription checks and usage tracking)")
# Extract cache parameters - use original user keywords for consistent caching
keywords = request.research.original_keywords or request.research.keyword_analysis.get('primary', [])
industry = getattr(request.persona, 'industry', 'general') if request.persona else 'general'
@@ -56,9 +66,9 @@ class OutlineService:
logger.info(f"Using cached outline for keywords: {keywords}")
return BlogOutlineResponse(**cached_result)
# Generate new outline if not cached
# Generate new outline if not cached (user_id required)
logger.info(f"Generating new outline for keywords: {keywords}")
result = await self.outline_generator.generate(request)
result = await self.outline_generator.generate(request, user_id)
# Cache the result
persistent_outline_cache.cache_outline(
@@ -73,7 +83,7 @@ class OutlineService:
return result
async def generate_outline_with_progress(self, request: BlogOutlineRequest, task_id: str) -> BlogOutlineResponse:
async def generate_outline_with_progress(self, request: BlogOutlineRequest, task_id: str, user_id: str) -> BlogOutlineResponse:
"""
Outline generation method with progress updates for real-time feedback.
"""
@@ -104,7 +114,7 @@ class OutlineService:
# Generate new outline if not cached
logger.info(f"Generating new outline for keywords: {keywords} (with progress updates)")
result = await self.outline_generator.generate_with_progress(request, task_id)
result = await self.outline_generator.generate_with_progress(request, task_id, user_id)
# Cache the result
persistent_outline_cache.cache_outline(

View File

@@ -17,18 +17,25 @@ class ParallelProcessor:
self.source_mapper = source_mapper
self.grounding_engine = grounding_engine
async def run_parallel_processing(self, outline_sections, research, task_id: str = None) -> Tuple[Any, Any]:
async def run_parallel_processing(self, outline_sections, research, user_id: str, task_id: str = None) -> Tuple[Any, Any]:
"""
Run source mapping and grounding insights extraction in parallel.
Args:
outline_sections: List of outline sections to process
research: Research data object
user_id: User ID (required for subscription checks and usage tracking)
task_id: Optional task ID for progress updates
Returns:
Tuple of (mapped_sections, grounding_insights)
Raises:
ValueError: If user_id is not provided
"""
if not user_id:
raise ValueError("user_id is required for parallel processing (subscription checks and usage tracking)")
if task_id:
from api.blog_writer.task_manager import task_manager
await task_manager.update_progress(task_id, "⚡ Running parallel processing for maximum speed...")
@@ -37,7 +44,7 @@ class ParallelProcessor:
# Run these tasks in parallel to save time
source_mapping_task = asyncio.create_task(
self._run_source_mapping(outline_sections, research, task_id)
self._run_source_mapping(outline_sections, research, task_id, user_id)
)
grounding_insights_task = asyncio.create_task(
@@ -52,22 +59,29 @@ class ParallelProcessor:
return mapped_sections, grounding_insights
async def run_parallel_processing_async(self, outline_sections, research) -> Tuple[Any, Any]:
async def run_parallel_processing_async(self, outline_sections, research, user_id: str) -> Tuple[Any, Any]:
"""
Run parallel processing without progress updates (for non-progress methods).
Args:
outline_sections: List of outline sections to process
research: Research data object
user_id: User ID (required for subscription checks and usage tracking)
Returns:
Tuple of (mapped_sections, grounding_insights)
Raises:
ValueError: If user_id is not provided
"""
if not user_id:
raise ValueError("user_id is required for parallel processing (subscription checks and usage tracking)")
logger.info("Running parallel processing for maximum speed...")
# Run these tasks in parallel to save time
source_mapping_task = asyncio.create_task(
self._run_source_mapping_async(outline_sections, research)
self._run_source_mapping_async(outline_sections, research, user_id)
)
grounding_insights_task = asyncio.create_task(
@@ -82,12 +96,12 @@ class ParallelProcessor:
return mapped_sections, grounding_insights
async def _run_source_mapping(self, outline_sections, research, task_id):
async def _run_source_mapping(self, outline_sections, research, task_id, user_id: str):
"""Run source mapping in parallel."""
if task_id:
from api.blog_writer.task_manager import task_manager
await task_manager.update_progress(task_id, "🔗 Applying intelligent source-to-section mapping...")
return self.source_mapper.map_sources_to_sections(outline_sections, research)
return self.source_mapper.map_sources_to_sections(outline_sections, research, user_id)
async def _run_grounding_insights_extraction(self, research, task_id):
"""Run grounding insights extraction in parallel."""
@@ -96,10 +110,10 @@ class ParallelProcessor:
await task_manager.update_progress(task_id, "🧠 Extracting grounding metadata insights...")
return self.grounding_engine.extract_contextual_insights(research.grounding_metadata)
async def _run_source_mapping_async(self, outline_sections, research):
async def _run_source_mapping_async(self, outline_sections, research, user_id: str):
"""Run source mapping in parallel (async version without progress updates)."""
logger.info("Applying intelligent source-to-section mapping...")
return self.source_mapper.map_sources_to_sections(outline_sections, research)
return self.source_mapper.map_sources_to_sections(outline_sections, research, user_id)
async def _run_grounding_insights_extraction_async(self, research):
"""Run grounding insights extraction in parallel (async version without progress updates)."""

View File

@@ -18,8 +18,21 @@ class ResponseProcessor:
"""Initialize the response processor."""
pass
async def generate_with_retry(self, prompt: str, schema: Dict[str, Any], task_id: str = None) -> Dict[str, Any]:
"""Generate outline with retry logic for API failures."""
async def generate_with_retry(self, prompt: str, schema: Dict[str, Any], user_id: str, task_id: str = None) -> Dict[str, Any]:
"""Generate outline with retry logic for API failures.
Args:
prompt: The prompt for outline generation
schema: JSON schema for structured response
user_id: User ID (required for subscription checks and usage tracking)
task_id: Optional task ID for progress updates
Raises:
ValueError: If user_id is not provided
"""
if not user_id:
raise ValueError("user_id is required for outline generation (subscription checks and usage tracking)")
from services.llm_providers.main_text_generation import llm_text_gen
from api.blog_writer.task_manager import task_manager
@@ -34,7 +47,8 @@ class ResponseProcessor:
outline_data = llm_text_gen(
prompt=prompt,
json_struct=schema,
system_prompt=None
system_prompt=None,
user_id=user_id
)
# Log response for debugging

View File

@@ -12,8 +12,23 @@ from models.blog_models import BlogOutlineSection
class SectionEnhancer:
"""Enhances individual outline sections using AI."""
async def enhance(self, section: BlogOutlineSection, focus: str = "general improvement") -> BlogOutlineSection:
"""Enhance a section using AI with research context."""
async def enhance(self, section: BlogOutlineSection, focus: str, user_id: str) -> BlogOutlineSection:
"""Enhance a section using AI with research context.
Args:
section: Outline section to enhance
focus: Enhancement focus (e.g., "general improvement")
user_id: User ID (required for subscription checks and usage tracking)
Returns:
Enhanced outline section
Raises:
ValueError: If user_id is not provided
"""
if not user_id:
raise ValueError("user_id is required for section enhancement (subscription checks and usage tracking)")
enhancement_prompt = f"""
Enhance the following blog section to make it more engaging, comprehensive, and valuable:
@@ -61,7 +76,8 @@ class SectionEnhancer:
enhanced_data = llm_text_gen(
prompt=enhancement_prompt,
json_struct=enhancement_schema,
system_prompt=None
system_prompt=None,
user_id=user_id
)
if isinstance(enhanced_data, dict) and 'error' not in enhanced_data:

View File

@@ -52,7 +52,8 @@ class SourceToSectionMapper:
def map_sources_to_sections(
self,
sections: List[BlogOutlineSection],
research_data: BlogResearchResponse
research_data: BlogResearchResponse,
user_id: str
) -> List[BlogOutlineSection]:
"""
Map research sources to outline sections using intelligent algorithms.
@@ -60,10 +61,17 @@ class SourceToSectionMapper:
Args:
sections: List of outline sections to map sources to
research_data: Research data containing sources and metadata
user_id: User ID (required for subscription checks and usage tracking)
Returns:
List of outline sections with intelligently mapped sources
Raises:
ValueError: If user_id is not provided
"""
if not user_id:
raise ValueError("user_id is required for source mapping (subscription checks and usage tracking)")
if not sections or not research_data.sources:
logger.warning("No sections or sources to map")
return sections
@@ -73,8 +81,8 @@ class SourceToSectionMapper:
# Step 1: Algorithmic mapping
mapping_results = self._algorithmic_source_mapping(sections, research_data)
# Step 2: AI validation and improvement (single prompt)
validated_mapping = self._ai_validate_mapping(mapping_results, research_data)
# Step 2: AI validation and improvement (single prompt, user_id required for subscription checks)
validated_mapping = self._ai_validate_mapping(mapping_results, research_data, user_id)
# Step 3: Apply validated mapping to sections
mapped_sections = self._apply_mapping_to_sections(sections, validated_mapping)
@@ -261,7 +269,8 @@ class SourceToSectionMapper:
def _ai_validate_mapping(
self,
mapping_results: Dict[str, List[Tuple[ResearchSource, float]]],
research_data: BlogResearchResponse
research_data: BlogResearchResponse,
user_id: str
) -> Dict[str, List[Tuple[ResearchSource, float]]]:
"""
Use AI to validate and improve the algorithmic mapping results.
@@ -269,18 +278,25 @@ class SourceToSectionMapper:
Args:
mapping_results: Algorithmic mapping results
research_data: Research data for context
user_id: User ID (required for subscription checks and usage tracking)
Returns:
AI-validated and improved mapping results
Raises:
ValueError: If user_id is not provided
"""
if not user_id:
raise ValueError("user_id is required for AI validation (subscription checks and usage tracking)")
try:
logger.info("Starting AI validation of source-to-section mapping...")
# Build AI validation prompt
validation_prompt = self._build_validation_prompt(mapping_results, research_data)
# Get AI validation response
validation_response = self._get_ai_validation_response(validation_prompt)
# Get AI validation response (user_id required for subscription checks)
validation_response = self._get_ai_validation_response(validation_prompt, user_id)
# Parse and apply AI validation results
validated_mapping = self._parse_validation_response(validation_response, mapping_results, research_data)
@@ -548,23 +564,31 @@ Analyze the mapping and provide your recommendations.
return prompt
def _get_ai_validation_response(self, prompt: str) -> str:
def _get_ai_validation_response(self, prompt: str, user_id: str) -> str:
"""
Get AI validation response using LLM provider.
Args:
prompt: Validation prompt
user_id: User ID (required for subscription checks and usage tracking)
Returns:
AI validation response
Raises:
ValueError: If user_id is not provided
"""
if not user_id:
raise ValueError("user_id is required for AI validation response (subscription checks and usage tracking)")
try:
from services.llm_providers.main_text_generation import llm_text_gen
response = llm_text_gen(
prompt=prompt,
json_struct=None,
system_prompt=None
system_prompt=None,
user_id=user_id
)
return response

View File

@@ -13,11 +13,17 @@ from .keyword_analyzer import KeywordAnalyzer
from .competitor_analyzer import CompetitorAnalyzer
from .content_angle_generator import ContentAngleGenerator
from .data_filter import ResearchDataFilter
from .base_provider import ResearchProvider as BaseResearchProvider
from .google_provider import GoogleResearchProvider
from .exa_provider import ExaResearchProvider
__all__ = [
'ResearchService',
'KeywordAnalyzer',
'CompetitorAnalyzer',
'ContentAngleGenerator',
'ResearchDataFilter'
'ResearchDataFilter',
'BaseResearchProvider',
'GoogleResearchProvider',
'ExaResearchProvider',
]

View File

@@ -0,0 +1,37 @@
"""
Base Research Provider Interface
Abstract base class for research provider implementations.
Ensures consistency across different research providers (Google, Exa, etc.)
"""
from abc import ABC, abstractmethod
from typing import Dict, Any
class ResearchProvider(ABC):
"""Abstract base class for research providers."""
@abstractmethod
async def search(
self,
prompt: str,
topic: str,
industry: str,
target_audience: str,
config: Any, # ResearchConfig
user_id: str
) -> Dict[str, Any]:
"""Execute research and return raw results."""
pass
@abstractmethod
def get_provider_enum(self):
"""Return APIProvider enum for subscription tracking."""
pass
@abstractmethod
def estimate_tokens(self) -> int:
"""Estimate token usage for pre-flight validation."""
pass

View File

@@ -0,0 +1,188 @@
"""
Exa Research Provider
Neural search implementation using Exa API for high-quality, citation-rich research.
"""
from exa_py import Exa
import os
from loguru import logger
from models.subscription_models import APIProvider
from .base_provider import ResearchProvider as BaseProvider
class ExaResearchProvider(BaseProvider):
"""Exa neural search provider."""
def __init__(self):
self.api_key = os.getenv("EXA_API_KEY")
if not self.api_key:
raise RuntimeError("EXA_API_KEY not configured")
self.exa = Exa(self.api_key)
logger.info("✅ Exa Research Provider initialized")
async def search(self, prompt, topic, industry, target_audience, config, user_id):
"""Execute Exa neural search and return standardized results."""
# Build Exa query
query = f"{topic} {industry} {target_audience}"
# Map source types to Exa categories
category = self._map_source_type_to_category(config.source_types)
logger.info(f"[Exa Research] Executing search: {query}")
# Execute Exa search
results = self.exa.search_and_contents(
query,
type="auto",
category=category,
num_results=min(config.max_sources, 25),
contents={
'text': {'max_characters': 1000},
'summary': {'query': f"Key insights about {topic}"},
'highlights': {
'num_sentences': 2,
'highlights_per_url': 3
}
}
)
# Transform to standardized format
sources = self._transform_sources(results.results)
content = self._aggregate_content(results.results)
search_type = getattr(results, 'resolvedSearchType', 'neural') if hasattr(results, 'resolvedSearchType') else 'neural'
# Get cost if available
cost = 0.005 # Default Exa cost for 1-25 results
if hasattr(results, 'costDollars'):
if hasattr(results.costDollars, 'total'):
cost = results.costDollars.total
logger.info(f"[Exa Research] Search completed: {len(sources)} sources, type: {search_type}")
return {
'sources': sources,
'content': content,
'search_type': search_type,
'provider': 'exa',
'search_queries': [query],
'cost': {'total': cost}
}
def get_provider_enum(self):
"""Return EXA provider enum for subscription tracking."""
return APIProvider.EXA
def estimate_tokens(self) -> int:
"""Estimate token usage for Exa (not token-based)."""
return 0 # Exa is per-search, not token-based
def _map_source_type_to_category(self, source_types):
"""Map SourceType enum to Exa category parameter."""
if not source_types:
return None
category_map = {
'research paper': 'research paper',
'news': 'news',
'web': 'personal site',
'industry': 'company',
'expert': 'linkedin profile'
}
for st in source_types:
if st.value in category_map:
return category_map[st.value]
return None
def _transform_sources(self, results):
"""Transform Exa results to ResearchSource format."""
sources = []
for idx, result in enumerate(results):
source_type = self._determine_source_type(result.url if hasattr(result, 'url') else '')
sources.append({
'title': result.title if hasattr(result, 'title') else '',
'url': result.url if hasattr(result, 'url') else '',
'excerpt': self._get_excerpt(result),
'credibility_score': 0.85, # Exa results are high quality
'published_at': result.publishedDate if hasattr(result, 'publishedDate') else None,
'index': idx,
'source_type': source_type,
'content': result.text if hasattr(result, 'text') else '',
'highlights': result.highlights if hasattr(result, 'highlights') else [],
'summary': result.summary if hasattr(result, 'summary') else ''
})
return sources
def _get_excerpt(self, result):
"""Extract excerpt from Exa result."""
if hasattr(result, 'text') and result.text:
return result.text[:500]
elif hasattr(result, 'summary') and result.summary:
return result.summary
return ''
def _determine_source_type(self, url):
"""Determine source type from URL."""
if not url:
return 'web'
url_lower = url.lower()
if 'arxiv.org' in url_lower or 'research' in url_lower:
return 'academic'
elif any(news in url_lower for news in ['cnn.com', 'bbc.com', 'reuters.com', 'theguardian.com']):
return 'news'
elif 'linkedin.com' in url_lower:
return 'expert'
else:
return 'web'
def _aggregate_content(self, results):
"""Aggregate content from Exa results for LLM analysis."""
content_parts = []
for idx, result in enumerate(results):
if hasattr(result, 'summary') and result.summary:
content_parts.append(f"Source {idx + 1}: {result.summary}")
elif hasattr(result, 'text') and result.text:
content_parts.append(f"Source {idx + 1}: {result.text[:1000]}")
return "\n\n".join(content_parts)
def track_exa_usage(self, user_id: str, cost: float):
"""Track Exa API usage after successful call."""
from services.database import get_db
from services.subscription import PricingService
from sqlalchemy import text
db = next(get_db())
try:
pricing_service = PricingService(db)
current_period = pricing_service.get_current_billing_period(user_id)
# Update exa_calls and exa_cost via SQL UPDATE
update_query = text("""
UPDATE usage_summaries
SET exa_calls = COALESCE(exa_calls, 0) + 1,
exa_cost = COALESCE(exa_cost, 0) + :cost,
total_calls = total_calls + 1,
total_cost = total_cost + :cost
WHERE user_id = :user_id AND billing_period = :period
""")
db.execute(update_query, {
'cost': cost,
'user_id': user_id,
'period': current_period
})
db.commit()
logger.info(f"[Exa] Tracked usage: user={user_id}, cost=${cost}")
except Exception as e:
logger.error(f"[Exa] Failed to track usage: {e}")
db.rollback()
finally:
db.close()

View File

@@ -0,0 +1,40 @@
"""
Google Research Provider
Wrapper for Gemini native Google Search grounding to match base provider interface.
"""
from services.llm_providers.gemini_grounded_provider import GeminiGroundedProvider
from models.subscription_models import APIProvider
from .base_provider import ResearchProvider as BaseProvider
from loguru import logger
class GoogleResearchProvider(BaseProvider):
"""Google research provider using Gemini native grounding."""
def __init__(self):
self.gemini = GeminiGroundedProvider()
async def search(self, prompt, topic, industry, target_audience, config, user_id):
"""Call Gemini grounding with pre-flight validation."""
logger.info(f"[Google Research] Executing search for topic: {topic}")
result = await self.gemini.generate_grounded_content(
prompt=prompt,
content_type="research",
max_tokens=2000,
user_id=user_id,
validate_subsequent_operations=True
)
return result
def get_provider_enum(self):
"""Return GEMINI provider enum for subscription tracking."""
return APIProvider.GEMINI
def estimate_tokens(self) -> int:
"""Estimate token usage for Google grounding."""
return 1200 # Conservative estimate

View File

@@ -16,6 +16,9 @@ from models.blog_models import (
GroundingChunk,
GroundingSupport,
Citation,
ResearchConfig,
ResearchMode,
ResearchProvider,
)
from services.blog_writer.logger_config import blog_writer_logger, log_function_call
from fastapi import HTTPException
@@ -24,6 +27,7 @@ from .keyword_analyzer import KeywordAnalyzer
from .competitor_analyzer import CompetitorAnalyzer
from .content_angle_generator import ContentAngleGenerator
from .data_filter import ResearchDataFilter
from .research_strategies import get_strategy_for_mode
class ResearchService:
@@ -44,7 +48,6 @@ class ResearchService:
Includes intelligent caching for exact keyword matches.
"""
try:
from services.llm_providers.gemini_grounded_provider import GeminiGroundedProvider
from services.cache.research_cache import research_cache
topic = request.topic or ", ".join(request.keywords)
@@ -79,62 +82,104 @@ class ResearchService:
# Cache miss - proceed with API call
logger.info(f"Cache miss - making API call for keywords: {request.keywords}")
blog_writer_logger.log_operation_start("gemini_api_call", api_name="gemini_grounded", operation="research")
gemini = GeminiGroundedProvider()
blog_writer_logger.log_operation_start("research_api_call", api_name="research", operation="research")
# Single comprehensive research prompt - Gemini handles Google Search automatically
research_prompt = f"""
Research the topic "{topic}" in the {industry} industry for {target_audience} audience. Provide a comprehensive analysis including:
1. Current trends and insights (2024-2025)
2. Key statistics and data points with sources
3. Industry expert opinions and quotes
4. Recent developments and news
5. Market analysis and forecasts
6. Best practices and case studies
7. Keyword analysis: primary, secondary, and long-tail opportunities
8. Competitor analysis: top players and content gaps
9. Content angle suggestions: 5 compelling angles for blog posts
Focus on factual, up-to-date information from credible sources.
Include specific data points, percentages, and recent developments.
Structure your response with clear sections for each analysis area.
"""
# Determine research mode and get appropriate strategy
research_mode = request.research_mode or ResearchMode.BASIC
config = request.config or ResearchConfig(mode=research_mode, provider=ResearchProvider.GOOGLE)
strategy = get_strategy_for_mode(research_mode)
# Single Gemini call with native Google Search grounding - no fallbacks
# Validation is handled inside generate_grounded_content when validate_subsequent_operations=True
import time
api_start_time = time.time()
gemini_result = await gemini.generate_grounded_content(
prompt=research_prompt,
content_type="research",
max_tokens=2000,
user_id=user_id,
validate_subsequent_operations=True # Validates Google Grounding + 3 LLM calls
)
api_duration_ms = (time.time() - api_start_time) * 1000
logger.info(f"Research: mode={research_mode.value}, provider={config.provider.value}")
# Log API call performance
blog_writer_logger.log_api_call(
"gemini_grounded",
"generate_grounded_content",
api_duration_ms,
token_usage=gemini_result.get("token_usage", {}),
content_length=len(gemini_result.get("content", ""))
)
# Build research prompt based on strategy
research_prompt = strategy.build_research_prompt(topic, industry, target_audience, config)
# Extract sources from grounding metadata
sources = self._extract_sources_from_grounding(gemini_result)
# Route to appropriate provider
if config.provider == ResearchProvider.EXA:
# Exa research workflow
from .exa_provider import ExaResearchProvider
from services.subscription.preflight_validator import validate_exa_research_operations
from services.database import get_db
from services.subscription import PricingService
import os
import time
# Pre-flight validation
db_val = next(get_db())
try:
pricing_service = PricingService(db_val)
gpt_provider = os.getenv("GPT_PROVIDER", "google")
validate_exa_research_operations(pricing_service, user_id, gpt_provider)
finally:
db_val.close()
# Execute Exa search
api_start_time = time.time()
try:
exa_provider = ExaResearchProvider()
raw_result = await exa_provider.search(
research_prompt, topic, industry, target_audience, config, user_id
)
api_duration_ms = (time.time() - api_start_time) * 1000
# Track usage
cost = raw_result.get('cost', {}).get('total', 0.005) if isinstance(raw_result.get('cost'), dict) else 0.005
exa_provider.track_exa_usage(user_id, cost)
# Log API call performance
blog_writer_logger.log_api_call(
"exa_search",
"search_and_contents",
api_duration_ms,
token_usage={},
content_length=len(raw_result.get('content', ''))
)
# Extract content for downstream analysis
content = raw_result.get('content', '')
sources = raw_result.get('sources', [])
search_widget = "" # Exa doesn't provide search widgets
search_queries = raw_result.get('search_queries', [])
grounding_metadata = None # Exa doesn't provide grounding metadata
except RuntimeError as e:
if "EXA_API_KEY not configured" in str(e):
logger.warning("Exa not configured, falling back to Google")
config.provider = ResearchProvider.GOOGLE
# Continue to Google flow below
raw_result = None
else:
raise
if config.provider != ResearchProvider.EXA:
# Google research (existing flow) or fallback from Exa
from .google_provider import GoogleResearchProvider
import time
api_start_time = time.time()
google_provider = GoogleResearchProvider()
gemini_result = await google_provider.search(
research_prompt, topic, industry, target_audience, config, user_id
)
api_duration_ms = (time.time() - api_start_time) * 1000
# Log API call performance
blog_writer_logger.log_api_call(
"gemini_grounded",
"generate_grounded_content",
api_duration_ms,
token_usage=gemini_result.get("token_usage", {}),
content_length=len(gemini_result.get("content", ""))
)
# Extract sources and content
sources = self._extract_sources_from_grounding(gemini_result)
content = gemini_result.get("content", "")
search_widget = gemini_result.get("search_widget", "") or ""
search_queries = gemini_result.get("search_queries", []) or []
grounding_metadata = self._extract_grounding_metadata(gemini_result)
# Extract grounding metadata for detailed UI display
grounding_metadata = self._extract_grounding_metadata(gemini_result)
# Extract search widget and queries for UI display
search_widget = gemini_result.get("search_widget", "") or ""
search_queries = gemini_result.get("search_queries", []) or []
# Parse the comprehensive response for different analysis components
content = gemini_result.get("content", "")
# Continue with common analysis (same for both providers)
keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords, user_id=user_id)
competitor_analysis = self.competitor_analyzer.analyze(content, user_id=user_id)
suggested_angles = self.content_angle_generator.generate(content, topic, industry, user_id=user_id)
@@ -261,7 +306,6 @@ class ResearchService:
Research method with progress updates for real-time feedback.
"""
try:
from services.llm_providers.gemini_grounded_provider import GeminiGroundedProvider
from services.cache.research_cache import research_cache
from services.cache.persistent_research_cache import persistent_research_cache
from api.blog_writer.task_manager import task_manager
@@ -293,66 +337,100 @@ class ResearchService:
logger.info(f"Returning cached research result for keywords: {request.keywords}")
return BlogResearchResponse(**cached_result)
# User ID validation (validation logic is now in Google Grounding provider)
# User ID validation
if not user_id:
await task_manager.update_progress(task_id, "❌ Error: User ID is required for research operation")
raise ValueError("user_id is required for research operation. Please provide Clerk user ID.")
# Cache miss - proceed with API call
await task_manager.update_progress(task_id, "🌐 Cache miss - connecting to Google Search grounding...")
logger.info(f"Cache miss - making API call for keywords: {request.keywords}")
gemini = GeminiGroundedProvider()
# Single comprehensive research prompt - Gemini handles Google Search automatically
research_prompt = f"""
Research the topic "{topic}" in the {industry} industry for {target_audience} audience. Provide a comprehensive analysis including:
1. Current trends and insights (2024-2025)
2. Key statistics and data points with sources
3. Industry expert opinions and quotes
4. Recent developments and news
5. Market analysis and forecasts
6. Best practices and case studies
7. Keyword analysis: primary, secondary, and long-tail opportunities
8. Competitor analysis: top players and content gaps
9. Content angle suggestions: 5 compelling angles for blog posts
Focus on factual, up-to-date information from credible sources.
Include specific data points, percentages, and recent developments.
Structure your response with clear sections for each analysis area.
"""
# Determine research mode and get appropriate strategy
research_mode = request.research_mode or ResearchMode.BASIC
config = request.config or ResearchConfig(mode=research_mode, provider=ResearchProvider.GOOGLE)
strategy = get_strategy_for_mode(research_mode)
await task_manager.update_progress(task_id, "🤖 Making AI request to Gemini with Google Search grounding...")
# Single Gemini call with native Google Search grounding - no fallbacks
# Validation is handled inside generate_grounded_content when validate_subsequent_operations=True
try:
gemini_result = await gemini.generate_grounded_content(
prompt=research_prompt,
content_type="research",
max_tokens=2000,
user_id=user_id,
validate_subsequent_operations=True # Validates Google Grounding + 3 LLM calls
)
except HTTPException as http_error:
# Re-raise HTTPException so it can be properly handled by task manager
logger.error(f"Subscription limit exceeded for research: {http_error.detail}")
await task_manager.update_progress(task_id, f"❌ Subscription limit exceeded: {http_error.detail.get('message', str(http_error.detail)) if isinstance(http_error.detail, dict) else str(http_error.detail)}")
raise # Re-raise HTTPException to preserve status code and error details
logger.info(f"Research: mode={research_mode.value}, provider={config.provider.value}")
await task_manager.update_progress(task_id, "📊 Processing research results and extracting insights...")
# Extract sources from grounding metadata
sources = self._extract_sources_from_grounding(gemini_result)
# Build research prompt based on strategy
research_prompt = strategy.build_research_prompt(topic, industry, target_audience, config)
# Extract grounding metadata for detailed UI display
grounding_metadata = self._extract_grounding_metadata(gemini_result)
# Extract search widget and queries for UI display
search_widget = gemini_result.get("search_widget", "") or ""
search_queries = gemini_result.get("search_queries", []) or []
# Route to appropriate provider
if config.provider == ResearchProvider.EXA:
# Exa research workflow
from .exa_provider import ExaResearchProvider
from services.subscription.preflight_validator import validate_exa_research_operations
from services.database import get_db
from services.subscription import PricingService
import os
await task_manager.update_progress(task_id, "🌐 Connecting to Exa neural search...")
# Pre-flight validation
db_val = next(get_db())
try:
pricing_service = PricingService(db_val)
gpt_provider = os.getenv("GPT_PROVIDER", "google")
validate_exa_research_operations(pricing_service, user_id, gpt_provider)
except HTTPException as http_error:
logger.error(f"Subscription limit exceeded for Exa research: {http_error.detail}")
await task_manager.update_progress(task_id, f"❌ Subscription limit exceeded: {http_error.detail.get('message', str(http_error.detail)) if isinstance(http_error.detail, dict) else str(http_error.detail)}")
raise
finally:
db_val.close()
# Execute Exa search
await task_manager.update_progress(task_id, "🤖 Executing Exa neural search...")
try:
exa_provider = ExaResearchProvider()
raw_result = await exa_provider.search(
research_prompt, topic, industry, target_audience, config, user_id
)
# Track usage
cost = raw_result.get('cost', {}).get('total', 0.005) if isinstance(raw_result.get('cost'), dict) else 0.005
exa_provider.track_exa_usage(user_id, cost)
# Extract content for downstream analysis
content = raw_result.get('content', '')
sources = raw_result.get('sources', [])
search_widget = "" # Exa doesn't provide search widgets
search_queries = raw_result.get('search_queries', [])
grounding_metadata = None # Exa doesn't provide grounding metadata
except RuntimeError as e:
if "EXA_API_KEY not configured" in str(e):
logger.warning("Exa not configured, falling back to Google")
await task_manager.update_progress(task_id, "⚠️ Exa not configured, falling back to Google Search")
config.provider = ResearchProvider.GOOGLE
# Continue to Google flow below
else:
raise
if config.provider != ResearchProvider.EXA:
# Google research (existing flow)
from .google_provider import GoogleResearchProvider
await task_manager.update_progress(task_id, "🌐 Connecting to Google Search grounding...")
google_provider = GoogleResearchProvider()
await task_manager.update_progress(task_id, "🤖 Making AI request to Gemini with Google Search grounding...")
try:
gemini_result = await google_provider.search(
research_prompt, topic, industry, target_audience, config, user_id
)
except HTTPException as http_error:
logger.error(f"Subscription limit exceeded for Google research: {http_error.detail}")
await task_manager.update_progress(task_id, f"❌ Subscription limit exceeded: {http_error.detail.get('message', str(http_error.detail)) if isinstance(http_error.detail, dict) else str(http_error.detail)}")
raise
await task_manager.update_progress(task_id, "📊 Processing research results and extracting insights...")
# Extract sources and content
sources = self._extract_sources_from_grounding(gemini_result)
content = gemini_result.get("content", "")
search_widget = gemini_result.get("search_widget", "") or ""
search_queries = gemini_result.get("search_queries", []) or []
grounding_metadata = self._extract_grounding_metadata(gemini_result)
# Continue with common analysis (same for both providers)
await task_manager.update_progress(task_id, "🔍 Analyzing keywords and content angles...")
# Parse the comprehensive response for different analysis components
content = gemini_result.get("content", "")
keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords, user_id=user_id)
competitor_analysis = self.competitor_analyzer.analyze(content, user_id=user_id)
suggested_angles = self.content_angle_generator.generate(content, topic, industry, user_id=user_id)

View File

@@ -0,0 +1,234 @@
"""
Research Strategy Pattern Implementation
Different strategies for executing research based on depth and focus.
"""
from abc import ABC, abstractmethod
from typing import Dict, Any
from loguru import logger
from models.blog_models import BlogResearchRequest, ResearchMode, ResearchConfig
from .keyword_analyzer import KeywordAnalyzer
from .competitor_analyzer import CompetitorAnalyzer
from .content_angle_generator import ContentAngleGenerator
class ResearchStrategy(ABC):
"""Base class for research strategies."""
def __init__(self):
self.keyword_analyzer = KeywordAnalyzer()
self.competitor_analyzer = CompetitorAnalyzer()
self.content_angle_generator = ContentAngleGenerator()
@abstractmethod
def build_research_prompt(
self,
topic: str,
industry: str,
target_audience: str,
config: ResearchConfig
) -> str:
"""Build the research prompt for the strategy."""
pass
@abstractmethod
def get_mode(self) -> ResearchMode:
"""Return the research mode this strategy handles."""
pass
class BasicResearchStrategy(ResearchStrategy):
"""Basic research strategy - keyword focused, minimal analysis."""
def get_mode(self) -> ResearchMode:
return ResearchMode.BASIC
def build_research_prompt(
self,
topic: str,
industry: str,
target_audience: str,
config: ResearchConfig
) -> str:
"""Build basic research prompt focused on keywords and quick insights."""
prompt = f"""You are a professional blog content strategist researching for a {industry} blog targeting {target_audience}.
Research Topic: "{topic}"
Provide analysis in this EXACT format:
## CURRENT TRENDS (2024-2025)
- [Trend 1 with specific data and source URL]
- [Trend 2 with specific data and source URL]
- [Trend 3 with specific data and source URL]
## KEY STATISTICS
- [Statistic 1: specific number/percentage with source URL]
- [Statistic 2: specific number/percentage with source URL]
- [Statistic 3: specific number/percentage with source URL]
- [Statistic 4: specific number/percentage with source URL]
- [Statistic 5: specific number/percentage with source URL]
## PRIMARY KEYWORDS
1. "{topic}" (main keyword)
2. [Variation 1]
3. [Variation 2]
## SECONDARY KEYWORDS
[5 related keywords for blog content]
## CONTENT ANGLES (Top 5)
1. [Angle 1: specific unique approach]
2. [Angle 2: specific unique approach]
3. [Angle 3: specific unique approach]
4. [Angle 4: specific unique approach]
5. [Angle 5: specific unique approach]
REQUIREMENTS:
- Cite EVERY claim with authoritative source URLs
- Use 2024-2025 data when available
- Include specific numbers, dates, examples
- Focus on actionable blog insights for {target_audience}"""
return prompt.strip()
class ComprehensiveResearchStrategy(ResearchStrategy):
"""Comprehensive research strategy - full analysis with all components."""
def get_mode(self) -> ResearchMode:
return ResearchMode.COMPREHENSIVE
def build_research_prompt(
self,
topic: str,
industry: str,
target_audience: str,
config: ResearchConfig
) -> str:
"""Build comprehensive research prompt with all analysis components."""
date_filter = f"\nDate Focus: {config.date_range.value.replace('_', ' ')}" if config.date_range else ""
source_filter = f"\nPriority Sources: {', '.join([s.value for s in config.source_types])}" if config.source_types else ""
prompt = f"""You are a senior blog content strategist conducting comprehensive research for a {industry} blog targeting {target_audience}.
Research Topic: "{topic}"{date_filter}{source_filter}
Provide COMPLETE analysis in this EXACT format:
## TRENDS AND INSIGHTS (2024-2025)
[5-7 trends with specific data, numbers, and source URLs]
## KEY STATISTICS
[7-10 statistics with exact numbers, percentages, dates, and source URLs]
## EXPERT OPINIONS
[4-5 expert quotes with full attribution and source URLs]
## RECENT DEVELOPMENTS
[5-7 recent news/developments with dates and source URLs]
## MARKET ANALYSIS
[3-5 market insights with data points and source URLs]
## BEST PRACTICES & CASE STUDIES
[3-5 examples with specific outcomes/metrics and source URLs]
## KEYWORD ANALYSIS
Primary Keywords: [3 main variations]
Secondary Keywords: [7-10 related keywords]
Long-Tail Opportunities: [5-7 specific search phrases]
## COMPETITOR ANALYSIS
Top Competitors: [5 competitors with brief descriptions]
Content Gaps: [5 topics competitors are missing]
Competitive Advantages: [5 unique angles we can own]
## CONTENT ANGLES (Exactly 5)
1. [Unique angle with reasoning and target benefit]
2. [Unique angle with reasoning and target benefit]
3. [Unique angle with reasoning and target benefit]
4. [Unique angle with reasoning and target benefit]
5. [Unique angle with reasoning and target benefit]
VERIFICATION REQUIREMENTS:
- Minimum 2 authoritative sources per major claim
- Prioritize: Industry publications > Research papers > News > Blogs
- 2024-2025 data strongly preferred
- All numbers must include context (timeframe, sample size, methodology)
- Every recommendation must be actionable for {target_audience}"""
return prompt.strip()
class TargetedResearchStrategy(ResearchStrategy):
"""Targeted research strategy - focused on specific aspects."""
def get_mode(self) -> ResearchMode:
return ResearchMode.TARGETED
def build_research_prompt(
self,
topic: str,
industry: str,
target_audience: str,
config: ResearchConfig
) -> str:
"""Build targeted research prompt based on config preferences."""
sections = []
if config.include_trends:
sections.append("""## CURRENT TRENDS
[3-5 trends with data and source URLs]""")
if config.include_statistics:
sections.append("""## KEY STATISTICS
[5-7 statistics with numbers and source URLs]""")
if config.include_expert_quotes:
sections.append("""## EXPERT OPINIONS
[3-4 expert quotes with attribution and source URLs]""")
if config.include_competitors:
sections.append("""## COMPETITOR ANALYSIS
Top Competitors: [3-5]
Content Gaps: [3-5]""")
# Always include keywords and angles
sections.append("""## KEYWORD ANALYSIS
Primary: [2-3 variations]
Secondary: [5-7 keywords]
Long-Tail: [3-5 phrases]""")
sections.append("""## CONTENT ANGLES (3-5)
[Unique blog angles with reasoning]""")
sections_str = "\n\n".join(sections)
prompt = f"""You are a blog content strategist conducting targeted research for a {industry} blog targeting {target_audience}.
Research Topic: "{topic}"
Provide focused analysis in this EXACT format:
{sections_str}
REQUIREMENTS:
- Cite all claims with authoritative source URLs
- Include specific numbers, dates, examples
- Focus on actionable insights for {target_audience}
- Use 2024-2025 data when available"""
return prompt.strip()
def get_strategy_for_mode(mode: ResearchMode) -> ResearchStrategy:
"""Factory function to get the appropriate strategy for a mode."""
strategy_map = {
ResearchMode.BASIC: BasicResearchStrategy,
ResearchMode.COMPREHENSIVE: ComprehensiveResearchStrategy,
ResearchMode.TARGETED: TargetedResearchStrategy,
}
strategy_class = strategy_map.get(mode, BasicResearchStrategy)
return strategy_class()

View File

@@ -2,4 +2,14 @@
Wix integration modular services package.
"""
from services.integrations.wix.seo import build_seo_data
from services.integrations.wix.ricos_converter import markdown_to_html, convert_via_wix_api
from services.integrations.wix.blog_publisher import create_blog_post
__all__ = [
'build_seo_data',
'markdown_to_html',
'convert_via_wix_api',
'create_blog_post',
]

View File

@@ -20,6 +20,40 @@ class WixBlogService:
return h
def create_draft_post(self, access_token: str, payload: Dict[str, Any], extra_headers: Optional[Dict[str, str]] = None) -> Dict[str, Any]:
# Log the exact payload being sent for debugging
import json
logger.warning(f"📤 Sending to Wix Blog API:")
logger.warning(f" Endpoint: {self.base_url}/blog/v3/draft-posts")
logger.warning(f" Payload top-level keys: {list(payload.keys())}")
if 'draftPost' in payload:
dp = payload['draftPost']
logger.warning(f" draftPost keys: {list(dp.keys())}")
if 'richContent' in dp:
rc = dp['richContent']
logger.warning(f" richContent keys: {list(rc.keys()) if isinstance(rc, dict) else 'N/A'}")
if isinstance(rc, dict) and 'nodes' in rc:
nodes = rc['nodes']
logger.warning(f" richContent.nodes count: {len(nodes) if isinstance(nodes, list) else 'N/A'}")
# Inspect first LIST_ITEM node if any
for i, node in enumerate(nodes[:10]):
if isinstance(node, dict) and node.get('type') == 'LIST_ITEM':
logger.warning(f" Found LIST_ITEM at index {i}:")
logger.warning(f" Keys: {list(node.keys())}")
logger.warning(f" Has listItemData: {'listItemData' in node}")
if 'listItemData' in node:
logger.warning(f" listItemData type: {type(node['listItemData'])}, value: {node['listItemData']}")
if 'nodes' in node:
nested = node['nodes']
logger.warning(f" Nested nodes count: {len(nested) if isinstance(nested, list) else 'N/A'}")
for j, n_node in enumerate(nested[:3]):
if isinstance(n_node, dict):
logger.warning(f" Nested node {j}: type={n_node.get('type')}, keys={list(n_node.keys())}")
if n_node.get('type') == 'PARAGRAPH' and 'paragraphData' in n_node:
logger.warning(f" paragraphData type: {type(n_node['paragraphData'])}, value: {n_node['paragraphData']}")
break # Only inspect first LIST_ITEM
logger.warning(f" Full Payload JSON (first 8000 chars):\n{json.dumps(payload, indent=2, ensure_ascii=False)[:8000]}...")
response = requests.post(f"{self.base_url}/blog/v3/draft-posts", headers=self.headers(access_token, extra_headers), json=payload)
response.raise_for_status()
return response.json()

View File

@@ -0,0 +1,716 @@
"""
Blog Post Publisher for Wix
Handles blog post creation, validation, and publishing to Wix.
"""
import json
import uuid
import requests
import jwt
from typing import Dict, Any, Optional, List
from loguru import logger
from services.integrations.wix.blog import WixBlogService
from services.integrations.wix.content import convert_content_to_ricos
from services.integrations.wix.ricos_converter import convert_via_wix_api
from services.integrations.wix.seo import build_seo_data
def validate_ricos_content(ricos_content: Dict[str, Any]) -> Dict[str, Any]:
"""
Validate and normalize Ricos document structure.
Args:
ricos_content: Ricos document dict
Returns:
Validated and normalized Ricos document
"""
# Validate Ricos document structure before using
if not ricos_content or not isinstance(ricos_content, dict):
logger.error("Invalid Ricos content - not a dict")
raise ValueError("Failed to convert content to valid Ricos format")
if 'type' not in ricos_content:
ricos_content['type'] = 'DOCUMENT'
logger.debug("Added missing richContent type 'DOCUMENT'")
if ricos_content.get('type') != 'DOCUMENT':
logger.warning(f"richContent type expected 'DOCUMENT', got {ricos_content.get('type')}, correcting")
ricos_content['type'] = 'DOCUMENT'
if 'id' not in ricos_content or not isinstance(ricos_content.get('id'), str):
ricos_content['id'] = str(uuid.uuid4())
logger.debug("Added missing richContent id")
if 'nodes' not in ricos_content:
logger.warning("Ricos document missing 'nodes' field, adding empty nodes array")
ricos_content['nodes'] = []
logger.debug(f"Ricos document structure: nodes={len(ricos_content.get('nodes', []))}")
# Validate richContent is a proper object with nodes array
# Per Wix API: richContent must be a RichContent object with nodes array
if not isinstance(ricos_content, dict):
raise ValueError(f"richContent must be a dict object, got {type(ricos_content)}")
# Ensure nodes array exists and is valid
if 'nodes' not in ricos_content:
logger.warning("richContent missing 'nodes', adding empty array")
ricos_content['nodes'] = []
if not isinstance(ricos_content['nodes'], list):
raise ValueError(f"richContent.nodes must be a list, got {type(ricos_content['nodes'])}")
# Recursive function to validate and fix nodes at any depth
def validate_node_recursive(node: Dict[str, Any], path: str = "root") -> None:
"""
Recursively validate a node and all its nested children, ensuring:
1. All required data fields exist for each node type
2. All 'nodes' arrays are proper lists
3. No None values in critical fields
"""
if not isinstance(node, dict):
logger.error(f"{path}: Node is not a dict: {type(node)}")
return
# Ensure type and id exist
if 'type' not in node:
logger.error(f"{path}: Missing 'type' field - REQUIRED")
node['type'] = 'PARAGRAPH' # Default fallback
if 'id' not in node:
node['id'] = str(uuid.uuid4())
logger.debug(f"{path}: Added missing 'id'")
node_type = node.get('type')
# CRITICAL: Per Wix API schema, data fields like paragraphData, bulletedListData, etc.
# are OPTIONAL and should be OMITTED entirely when empty, not included as {}
# Only validate fields that have required properties
# Special handling: Remove listItemData if it exists (not in Wix API schema)
if node_type == 'LIST_ITEM' and 'listItemData' in node:
logger.debug(f"{path}: Removing incorrect listItemData field from LIST_ITEM")
del node['listItemData']
# Only validate HEADING nodes - they require headingData with level property
if node_type == 'HEADING':
if 'headingData' not in node or not isinstance(node.get('headingData'), dict):
logger.warning(f"{path} (HEADING): Missing headingData, adding default level 1")
node['headingData'] = {'level': 1}
elif 'level' not in node['headingData']:
logger.warning(f"{path} (HEADING): Missing level in headingData, adding default")
node['headingData']['level'] = 1
# TEXT nodes must have textData
if node_type == 'TEXT':
if 'textData' not in node or not isinstance(node.get('textData'), dict):
logger.error(f"{path} (TEXT): Missing/invalid textData - node will be problematic")
node['textData'] = {'text': '', 'decorations': []}
# LINK and IMAGE nodes must have their data fields
if node_type == 'LINK' and ('linkData' not in node or not isinstance(node.get('linkData'), dict)):
logger.error(f"{path} (LINK): Missing/invalid linkData - node will be problematic")
if node_type == 'IMAGE' and ('imageData' not in node or not isinstance(node.get('imageData'), dict)):
logger.error(f"{path} (IMAGE): Missing/invalid imageData - node will be problematic")
# Remove None values from any data fields that exist (Wix API rejects None)
for data_field in ['headingData', 'paragraphData', 'blockquoteData', 'bulletedListData',
'orderedListData', 'textData', 'linkData', 'imageData']:
if data_field in node and isinstance(node[data_field], dict):
data_value = node[data_field]
keys_to_remove = [k for k, v in data_value.items() if v is None]
if keys_to_remove:
logger.debug(f"{path} ({node_type}): Removing None values from {data_field}: {keys_to_remove}")
for key in keys_to_remove:
del data_value[key]
# Ensure 'nodes' field exists for container nodes
container_types = ['HEADING', 'PARAGRAPH', 'BLOCKQUOTE', 'LIST_ITEM', 'LINK',
'BULLETED_LIST', 'ORDERED_LIST']
if node_type in container_types:
if 'nodes' not in node:
logger.warning(f"{path} ({node_type}): Missing 'nodes' field, adding empty array")
node['nodes'] = []
elif not isinstance(node['nodes'], list):
logger.error(f"{path} ({node_type}): Invalid 'nodes' field (not a list), fixing")
node['nodes'] = []
# Recursively validate all nested nodes
for nested_idx, nested_node in enumerate(node['nodes']):
nested_path = f"{path}.nodes[{nested_idx}]"
validate_node_recursive(nested_node, nested_path)
# Validate all top-level nodes recursively
for idx, node in enumerate(ricos_content['nodes']):
validate_node_recursive(node, f"nodes[{idx}]")
# Ensure documentStyle exists and is a dict (required by Wix API when provided)
if 'metadata' not in ricos_content or not isinstance(ricos_content.get('metadata'), dict):
ricos_content['metadata'] = {'version': 1, 'id': str(uuid.uuid4())}
logger.debug("Added default metadata to richContent")
else:
ricos_content['metadata'].setdefault('version', 1)
ricos_content['metadata'].setdefault('id', str(uuid.uuid4()))
if 'documentStyle' not in ricos_content or not isinstance(ricos_content.get('documentStyle'), dict):
ricos_content['documentStyle'] = {
'paragraph': {
'decorations': [],
'nodeStyle': {},
'lineHeight': '1.5'
}
}
logger.debug("Added default documentStyle to richContent")
logger.debug(f"✅ Validated richContent: {len(ricos_content['nodes'])} nodes, has_metadata={bool(ricos_content.get('metadata'))}, has_documentStyle={bool(ricos_content.get('documentStyle'))}")
return ricos_content
def validate_payload_no_none(obj, path=""):
"""Recursively validate that no None values exist in the payload"""
if obj is None:
raise ValueError(f"Found None value at path: {path}")
if isinstance(obj, dict):
for key, value in obj.items():
validate_payload_no_none(value, f"{path}.{key}" if path else key)
elif isinstance(obj, list):
for idx, item in enumerate(obj):
validate_payload_no_none(item, f"{path}[{idx}]" if path else f"[{idx}]")
def create_blog_post(
blog_service: WixBlogService,
access_token: str,
title: str,
content: str,
member_id: str,
cover_image_url: str = None,
category_ids: List[str] = None,
tag_ids: List[str] = None,
publish: bool = True,
seo_metadata: Dict[str, Any] = None,
import_image_func = None,
lookup_categories_func = None,
lookup_tags_func = None,
base_url: str = 'https://www.wixapis.com'
) -> Dict[str, Any]:
"""
Create and optionally publish a blog post on Wix
Args:
blog_service: WixBlogService instance
access_token: Valid access token
title: Blog post title
content: Blog post content (markdown)
member_id: Required for third-party apps - the member ID of the post author
cover_image_url: Optional cover image URL
category_ids: Optional list of category IDs or names
tag_ids: Optional list of tag IDs or names
publish: Whether to publish immediately or save as draft
seo_metadata: Optional SEO metadata dict
import_image_func: Function to import images (optional)
lookup_categories_func: Function to lookup/create categories (optional)
lookup_tags_func: Function to lookup/create tags (optional)
base_url: Wix API base URL
Returns:
Created blog post information
"""
if not member_id:
raise ValueError("memberId is required for third-party apps creating blog posts")
headers = {
'Authorization': f'Bearer {access_token}',
'Content-Type': 'application/json'
}
# Build valid Ricos rich content
# Ensure content is not empty
if not content or not content.strip():
content = "This is a post from ALwrity."
logger.warning("⚠️ Content was empty, using default text")
# Try Wix API first (more reliable), fall back to custom parser
ricos_content = None
try:
logger.warning("🔄 Attempting to convert markdown to Ricos via Wix API...")
ricos_content = convert_via_wix_api(content, access_token, base_url)
logger.warning(f"✅ Wix API conversion successful. Ricos document has {len(ricos_content.get('nodes', []))} nodes")
except Exception as e:
logger.warning(f"⚠️ Wix Ricos API conversion failed: {e}. Falling back to custom parser...")
# Fall back to custom parser
ricos_content = convert_content_to_ricos(content, None)
logger.warning(f"✅ Custom parser conversion complete. Ricos document has {len(ricos_content.get('nodes', []))} nodes")
# Validate Ricos content
ricos_content = validate_ricos_content(ricos_content)
# Minimal payload per Wix docs: title, memberId, and richContent
# CRITICAL: Only include fields that have valid values (no None, no empty strings for required fields)
blog_data = {
'draftPost': {
'title': str(title).strip() if title else "Untitled",
'memberId': str(member_id).strip(), # Required for third-party apps (validated above)
'richContent': ricos_content, # Must be a valid Ricos document object
},
'publish': bool(publish),
'fieldsets': ['URL'] # Simplified fieldsets
}
# Add excerpt only if content exists and is not empty (avoid None or empty strings)
excerpt = (content or '').strip()[:200] if content else None
if excerpt and len(excerpt) > 0:
blog_data['draftPost']['excerpt'] = str(excerpt)
# Add cover image if provided
if cover_image_url and import_image_func:
try:
media_id = import_image_func(access_token, cover_image_url, f'Cover: {title}')
# Ensure media_id is a string and not None
if media_id and isinstance(media_id, str):
blog_data['draftPost']['media'] = {
'wixMedia': {
'image': {'id': str(media_id).strip()}
},
'displayed': True,
'custom': True
}
else:
logger.warning(f"Invalid media_id type or value: {type(media_id)}, skipping media")
except Exception as e:
logger.warning(f"Failed to import cover image: {e}")
# Handle categories - can be either IDs (list of strings) or names (for lookup)
category_ids_to_use = None
if category_ids:
# Check if these are IDs (UUIDs) or names
if isinstance(category_ids, list) and len(category_ids) > 0:
# Assume IDs if first item looks like UUID (has hyphens and is long)
first_item = str(category_ids[0])
if '-' in first_item and len(first_item) > 30:
category_ids_to_use = category_ids
elif lookup_categories_func:
# These are names, need to lookup/create
extra_headers = {}
if 'wix-site-id' in headers:
extra_headers['wix-site-id'] = headers['wix-site-id']
category_ids_to_use = lookup_categories_func(
access_token, category_ids, extra_headers if extra_headers else None
)
# Handle tags - can be either IDs (list of strings) or names (for lookup)
tag_ids_to_use = None
if tag_ids:
# Check if these are IDs (UUIDs) or names
if isinstance(tag_ids, list) and len(tag_ids) > 0:
# Assume IDs if first item looks like UUID (has hyphens and is long)
first_item = str(tag_ids[0])
if '-' in first_item and len(first_item) > 30:
tag_ids_to_use = tag_ids
elif lookup_tags_func:
# These are names, need to lookup/create
extra_headers = {}
if 'wix-site-id' in headers:
extra_headers['wix-site-id'] = headers['wix-site-id']
tag_ids_to_use = lookup_tags_func(
access_token, tag_ids, extra_headers if extra_headers else None
)
# Add categories if we have IDs (must be non-empty list of strings)
# CRITICAL: Wix API rejects empty arrays or arrays with None/empty strings
if category_ids_to_use and isinstance(category_ids_to_use, list) and len(category_ids_to_use) > 0:
# Filter out None, empty strings, and ensure all are valid UUID strings
valid_category_ids = [str(cid).strip() for cid in category_ids_to_use if cid and str(cid).strip()]
if valid_category_ids:
blog_data['draftPost']['categoryIds'] = valid_category_ids
logger.debug(f"Added {len(valid_category_ids)} category IDs")
else:
logger.warning("All category IDs were invalid, not including categoryIds in payload")
# Add tags if we have IDs (must be non-empty list of strings)
# CRITICAL: Wix API rejects empty arrays or arrays with None/empty strings
if tag_ids_to_use and isinstance(tag_ids_to_use, list) and len(tag_ids_to_use) > 0:
# Filter out None, empty strings, and ensure all are valid UUID strings
valid_tag_ids = [str(tid).strip() for tid in tag_ids_to_use if tid and str(tid).strip()]
if valid_tag_ids:
blog_data['draftPost']['tagIds'] = valid_tag_ids
logger.debug(f"Added {len(valid_tag_ids)} tag IDs")
else:
logger.warning("All tag IDs were invalid, not including tagIds in payload")
# Build SEO data from metadata if provided
# TESTING: Skip SEO data temporarily to confirm richContent fix
test_skip_seo = True
if test_skip_seo:
logger.warning("🧪 TESTING: Skipping SEO data to isolate richContent vs seoData issue")
seo_data = None
elif seo_metadata:
logger.warning(f"📊 Building SEO data from metadata. Keys: {list(seo_metadata.keys())}")
seo_data = build_seo_data(seo_metadata, title)
if seo_data:
# Log detailed SEO structure
logger.warning(f"📋 SEO data built: {len(seo_data.get('tags', []))} tags, {len(seo_data.get('settings', {}).get('keywords', []))} keywords")
# Log each SEO tag for debugging (key ones only to avoid too much output)
if seo_data.get('tags'):
for idx, tag in enumerate(seo_data['tags'][:3]): # First 3 tags only
tag_type = tag.get('type')
if tag_type == 'title':
logger.warning(f" SEO tag {idx+1}: type={tag_type}, children={str(tag.get('children', ''))[:50]}...")
else:
props = tag.get('props', {})
content_preview = str(props.get('content', props.get('href', props.get('name', ''))))[:50]
logger.warning(f" SEO tag {idx+1}: type={tag_type}, props={list(props.keys())}, content={content_preview}...")
if len(seo_data['tags']) > 3:
logger.warning(f" ... and {len(seo_data['tags']) - 3} more SEO tags")
blog_data['draftPost']['seoData'] = seo_data
logger.warning(f"✅ Added seoData to blog post with {len(seo_data.get('tags', []))} tags")
else:
logger.warning("⚠️ SEO data was empty after building - check build_seo_data function")
# Add SEO slug if provided (separate field from seoData)
if seo_metadata and seo_metadata.get('url_slug'):
blog_data['draftPost']['seoSlug'] = str(seo_metadata.get('url_slug')).strip()
logger.warning(f"✅ Added SEO slug: {blog_data['draftPost']['seoSlug']}")
if test_skip_seo:
logger.warning("⚠️ SEO data skipped for testing - will add back once richContent is confirmed working")
elif not seo_metadata:
logger.warning("⚠️ No SEO metadata provided to create_blog_post")
# Log the payload structure for debugging (without sensitive data)
logger.warning(f"📝 Creating blog post with title: '{title}'")
logger.warning(f"📋 Draft post fields: {list(blog_data['draftPost'].keys())}")
# Detailed SEO logging
if 'seoData' in blog_data['draftPost']:
seo_data_debug = blog_data['draftPost']['seoData']
logger.warning(f"📊 SEO data in payload: {len(seo_data_debug.get('tags', []))} tags, {len(seo_data_debug.get('settings', {}).get('keywords', []))} keywords")
# Log sample SEO tags (first 2 only to avoid too much output)
if seo_data_debug.get('tags'):
logger.warning("📋 SEO Tags sample:")
for i, tag in enumerate(seo_data_debug['tags'][:2]): # First 2 tags
logger.warning(f" Tag {i+1}: type={tag.get('type')}, custom={tag.get('custom')}, disabled={tag.get('disabled')}")
if len(seo_data_debug['tags']) > 2:
logger.warning(f" ... and {len(seo_data_debug['tags']) - 2} more tags")
if seo_data_debug.get('settings', {}).get('keywords'):
keywords_list = [k.get('term') for k in seo_data_debug['settings']['keywords'][:3]]
logger.warning(f"🔑 Keywords: {keywords_list}")
# Log FULL seoData structure for debugging
import json
try:
seo_json = json.dumps(seo_data_debug, indent=2, ensure_ascii=False)
logger.warning(f"📄 FULL seoData JSON:\n{seo_json[:2000]}...") # First 2000 chars
except Exception as e:
logger.error(f"Failed to serialize seoData: {e}")
else:
logger.warning("⚠️ No seoData in draft post payload!")
try:
# Add wix-site-id header if we can extract it from token
extra_headers = {}
try:
token_str = str(access_token)
if token_str and token_str.startswith('OauthNG.JWS.'):
jwt_part = token_str[12:]
payload = jwt.decode(jwt_part, options={"verify_signature": False, "verify_aud": False})
data_payload = payload.get('data', {})
if isinstance(data_payload, str):
try:
data_payload = json.loads(data_payload)
except:
pass
instance_data = data_payload.get('instance', {})
meta_site_id = instance_data.get('metaSiteId')
if isinstance(meta_site_id, str) and meta_site_id:
extra_headers['wix-site-id'] = meta_site_id
headers['wix-site-id'] = meta_site_id
except Exception as e:
logger.debug(f"Could not extract site ID from token: {e}")
# Make the API call
logger.warning(f"🚀 Calling Wix API: POST /blog/v3/draft-posts")
logger.warning(f"📦 Payload: title='{blog_data['draftPost'].get('title')}', has_seoData={'seoData' in blog_data['draftPost']}, has_richContent={'richContent' in blog_data['draftPost']}")
# Validate payload structure before sending
draft_post = blog_data.get('draftPost', {})
if not isinstance(draft_post, dict):
raise ValueError("draftPost must be a dict object")
# Validate richContent structure
if 'richContent' in draft_post:
rc = draft_post['richContent']
if not isinstance(rc, dict):
raise ValueError(f"richContent must be a dict, got {type(rc)}")
if 'nodes' not in rc:
raise ValueError("richContent missing 'nodes' field")
if not isinstance(rc['nodes'], list):
raise ValueError(f"richContent.nodes must be a list, got {type(rc['nodes'])}")
logger.debug(f"✅ richContent validation passed: {len(rc.get('nodes', []))} nodes")
# Validate seoData structure if present
if 'seoData' in draft_post:
seo = draft_post['seoData']
if not isinstance(seo, dict):
raise ValueError(f"seoData must be a dict, got {type(seo)}")
if 'tags' in seo and not isinstance(seo['tags'], list):
raise ValueError(f"seoData.tags must be a list, got {type(seo.get('tags'))}")
if 'settings' in seo and not isinstance(seo['settings'], dict):
raise ValueError(f"seoData.settings must be a dict, got {type(seo.get('settings'))}")
logger.debug(f"✅ seoData validation passed: {len(seo.get('tags', []))} tags")
# Final validation: Ensure no None values in any nested objects
# Wix API rejects None values and expects proper types
try:
validate_payload_no_none(blog_data, "blog_data")
logger.debug("✅ Payload validation passed: No None values found")
except ValueError as e:
logger.error(f"❌ Payload validation failed: {e}")
raise
# Log full payload structure for debugging (sanitized)
logger.warning(f"📦 Full payload structure validation:")
logger.warning(f" - draftPost type: {type(draft_post)}")
logger.warning(f" - draftPost keys: {list(draft_post.keys())}")
logger.warning(f" - richContent type: {type(draft_post.get('richContent'))}")
if 'richContent' in draft_post:
rc = draft_post['richContent']
logger.warning(f" - richContent keys: {list(rc.keys()) if isinstance(rc, dict) else 'N/A'}")
logger.warning(f" - richContent.nodes type: {type(rc.get('nodes'))}, count: {len(rc.get('nodes', []))}")
logger.warning(f" - richContent.metadata type: {type(rc.get('metadata'))}")
logger.warning(f" - richContent.documentStyle type: {type(rc.get('documentStyle'))}")
logger.warning(f" - seoData type: {type(draft_post.get('seoData'))}")
if 'seoData' in draft_post:
seo = draft_post['seoData']
logger.warning(f" - seoData keys: {list(seo.keys()) if isinstance(seo, dict) else 'N/A'}")
logger.warning(f" - seoData.tags type: {type(seo.get('tags'))}, count: {len(seo.get('tags', []))}")
logger.warning(f" - seoData.settings type: {type(seo.get('settings'))}")
if 'categoryIds' in draft_post:
logger.warning(f" - categoryIds type: {type(draft_post.get('categoryIds'))}, count: {len(draft_post.get('categoryIds', []))}")
if 'tagIds' in draft_post:
logger.warning(f" - tagIds type: {type(draft_post.get('tagIds'))}, count: {len(draft_post.get('tagIds', []))}")
# Log a sample of the payload JSON to see exact structure (first 2000 chars)
try:
import json
payload_json = json.dumps(blog_data, indent=2, ensure_ascii=False)
logger.warning(f"📄 Payload JSON preview (first 3000 chars):\n{payload_json[:3000]}...")
# Also log a deep structure inspection of richContent.nodes (first few nodes)
if 'richContent' in blog_data['draftPost']:
nodes = blog_data['draftPost']['richContent'].get('nodes', [])
if nodes:
logger.warning(f"🔍 Inspecting first 5 richContent.nodes:")
for i, node in enumerate(nodes[:5]):
logger.warning(f" Node {i+1}: type={node.get('type')}, keys={list(node.keys())}")
# Check for any None values in node
for key, value in node.items():
if value is None:
logger.error(f" ⚠️ Node {i+1}.{key} is None!")
elif isinstance(value, dict):
for k, v in value.items():
if v is None:
logger.error(f" ⚠️ Node {i+1}.{key}.{k} is None!")
# Deep check: if it's a list-type node, inspect list items
if node.get('type') in ['BULLETED_LIST', 'ORDERED_LIST']:
list_items = node.get('nodes', [])
if list_items:
logger.warning(f" List has {len(list_items)} items, checking first LIST_ITEM:")
first_item = list_items[0]
logger.warning(f" LIST_ITEM keys: {list(first_item.keys())}")
# Verify listItemData is NOT present (correct per Wix API spec)
if 'listItemData' in first_item:
logger.error(f" ❌ LIST_ITEM incorrectly has listItemData!")
else:
logger.debug(f" ✅ LIST_ITEM correctly has no listItemData")
# Check nested PARAGRAPH nodes
nested_nodes = first_item.get('nodes', [])
if nested_nodes:
logger.warning(f" LIST_ITEM has {len(nested_nodes)} nested nodes")
for n_idx, n_node in enumerate(nested_nodes[:2]):
logger.warning(f" Nested node {n_idx+1}: type={n_node.get('type')}, keys={list(n_node.keys())}")
except Exception as e:
logger.warning(f"Could not serialize payload for logging: {e}")
# Note: All node validation is done by validate_ricos_content() which runs earlier
# The recursive validation ensures all required data fields are present at any depth
# Final deep validation: Serialize and deserialize to catch any JSON-serialization issues
# This will raise an error if there are any objects that can't be serialized
try:
import json
test_json = json.dumps(blog_data, ensure_ascii=False)
test_parsed = json.loads(test_json)
logger.debug("✅ Payload JSON serialization test passed")
except (TypeError, ValueError) as e:
logger.error(f"❌ Payload JSON serialization failed: {e}")
raise ValueError(f"Payload contains non-serializable data: {e}")
# Final check: Ensure documentStyle and metadata are valid objects (not None, not empty strings)
rc = blog_data['draftPost']['richContent']
if 'documentStyle' in rc:
doc_style = rc['documentStyle']
if doc_style is None or doc_style == "":
logger.warning("⚠️ documentStyle is None or empty string, removing it")
del rc['documentStyle']
elif not isinstance(doc_style, dict):
logger.warning(f"⚠️ documentStyle is not a dict ({type(doc_style)}), removing it")
del rc['documentStyle']
if 'metadata' in rc:
metadata = rc['metadata']
if metadata is None or metadata == "":
logger.warning("⚠️ metadata is None or empty string, removing it")
del rc['metadata']
elif not isinstance(metadata, dict):
logger.warning(f"⚠️ metadata is not a dict ({type(metadata)}), removing it")
del rc['metadata']
# Check for any None values in critical nested structures
def check_none_in_dict(d, path=""):
"""Recursively check for None values that shouldn't be there"""
issues = []
if isinstance(d, dict):
for key, value in d.items():
current_path = f"{path}.{key}" if path else key
if value is None:
# Some fields can legitimately be None, but most shouldn't
if key not in ['decorations', 'nodeStyle', 'props']:
issues.append(current_path)
elif isinstance(value, dict):
issues.extend(check_none_in_dict(value, current_path))
elif isinstance(value, list):
for i, item in enumerate(value):
if item is None:
issues.append(f"{current_path}[{i}]")
elif isinstance(item, dict):
issues.extend(check_none_in_dict(item, f"{current_path}[{i}]"))
return issues
none_issues = check_none_in_dict(blog_data['draftPost']['richContent'])
if none_issues:
logger.error(f"❌ Found None values in richContent at: {none_issues[:10]}") # Limit to first 10
# Remove None values from critical paths
for issue_path in none_issues[:5]: # Fix first 5
parts = issue_path.split('.')
try:
obj = blog_data['draftPost']['richContent']
for part in parts[:-1]:
if '[' in part:
key, idx = part.split('[')
idx = int(idx.rstrip(']'))
obj = obj[key][idx]
else:
obj = obj[part]
final_key = parts[-1]
if '[' in final_key:
key, idx = final_key.split('[')
idx = int(idx.rstrip(']'))
obj[key][idx] = {}
else:
obj[final_key] = {}
logger.warning(f"Fixed None value at {issue_path}")
except:
pass
# Log the final payload structure one more time before sending
logger.warning(f"📤 Final payload ready - draftPost keys: {list(blog_data['draftPost'].keys())}")
logger.warning(f"📤 RichContent nodes count: {len(blog_data['draftPost']['richContent'].get('nodes', []))}")
logger.warning(f"📤 RichContent has metadata: {bool(blog_data['draftPost']['richContent'].get('metadata'))}")
logger.warning(f"📤 RichContent has documentStyle: {bool(blog_data['draftPost']['richContent'].get('documentStyle'))}")
# Try sending WITHOUT SEO data first to isolate the issue
test_without_seo = False # Disabled - listItemData issue fixed
if test_without_seo and 'seoData' in blog_data['draftPost']:
logger.warning("🧪 TESTING WITHOUT SEO DATA to isolate issue...")
# Clone the payload without SEO data
test_payload_no_seo = {
'draftPost': {
'title': blog_data['draftPost']['title'],
'memberId': blog_data['draftPost']['memberId'],
'richContent': blog_data['draftPost']['richContent'],
'excerpt': blog_data['draftPost'].get('excerpt', '')
},
'publish': False,
'fieldsets': ['URL']
}
try:
logger.warning("🧪 Attempting without SEO data...")
test_result = blog_service.create_draft_post(access_token, test_payload_no_seo, extra_headers or None)
logger.warning(f"✅ WITHOUT SEO DATA SUCCEEDED! Post ID: {test_result.get('draftPost', {}).get('id')}")
logger.error("⚠️⚠️⚠️ ISSUE IS WITH SEO DATA STRUCTURE!")
# If this succeeds, don't send the full payload, just return this result
return test_result
except Exception as e:
logger.warning(f"❌ WITHOUT SEO DATA ALSO FAILED: {e}")
logger.warning("⚠️ Issue is NOT with SEO data, continuing with full payload...")
# Try sending with minimal structure first to isolate the issue
# Create a test payload with just required fields
minimal_test = False # Set to True to test with minimal payload
if minimal_test:
logger.warning("🧪 TESTING WITH MINIMAL PAYLOAD (title + memberId + simple richContent)")
test_payload = {
'draftPost': {
'title': blog_data['draftPost']['title'],
'memberId': blog_data['draftPost']['memberId'],
'richContent': {
'nodes': [
{
'id': str(uuid.uuid4()),
'type': 'PARAGRAPH',
'nodes': [
{
'id': str(uuid.uuid4()),
'type': 'TEXT',
'textData': {
'text': 'Test paragraph',
'decorations': []
}
}
],
'paragraphData': {}
}
],
'metadata': {'version': 1, 'id': str(uuid.uuid4())},
'documentStyle': {}
}
},
'publish': False,
'fieldsets': ['URL']
}
logger.warning("🧪 Attempting minimal payload first...")
try:
test_result = blog_service.create_draft_post(access_token, test_payload, extra_headers or None)
logger.warning(f"✅ MINIMAL PAYLOAD SUCCEEDED! Post ID: {test_result.get('draftPost', {}).get('id')}")
logger.warning("⚠️ Issue is with complex content, not basic structure")
except Exception as e:
logger.error(f"❌ MINIMAL PAYLOAD ALSO FAILED: {e}")
logger.error("⚠️ Issue is with basic structure or permissions")
result = blog_service.create_draft_post(access_token, blog_data, extra_headers or None)
# Log response
draft_post = result.get('draftPost', {})
logger.warning(f"✅ Blog post created successfully! Post ID: {draft_post.get('id', 'N/A')}")
# Check if SEO data was preserved in response
if 'seoData' in draft_post:
seo_response = draft_post['seoData']
logger.warning(f"✅ SEO data confirmed in response: {len(seo_response.get('tags', []))} tags, {len(seo_response.get('settings', {}).get('keywords', []))} keywords")
else:
logger.warning("⚠️ No seoData in response - it may have been filtered out by Wix API")
logger.warning(f"📋 Response fields: {list(draft_post.keys())}")
return result
except requests.RequestException as e:
logger.error(f"Failed to create blog post: {e}")
if hasattr(e, 'response') and e.response is not None:
logger.error(f"Response body: {e.response.text}")
raise

View File

@@ -1,58 +1,460 @@
import re
import uuid
from typing import Any, Dict, List
def parse_markdown_inline(text: str) -> List[Dict[str, Any]]:
"""
Parse inline markdown formatting (bold, italic, links) into Ricos text nodes.
Returns a list of text nodes with decorations.
Handles: **bold**, *italic*, [links](url), `code`, and combinations.
"""
if not text:
return [{
'id': str(uuid.uuid4()),
'type': 'TEXT',
'textData': {'text': '', 'decorations': []}
}]
nodes = []
# Process text character by character to handle nested/adjacent formatting
# This is more robust than regex for complex cases
i = 0
current_text = ''
current_decorations = []
while i < len(text):
# Check for bold **text** (must come before single * check)
if i < len(text) - 1 and text[i:i+2] == '**':
# Save any accumulated text
if current_text:
nodes.append({
'id': str(uuid.uuid4()),
'type': 'TEXT',
'textData': {
'text': current_text,
'decorations': current_decorations.copy()
}
})
current_text = ''
# Find closing **
end_bold = text.find('**', i + 2)
if end_bold != -1:
bold_text = text[i + 2:end_bold]
# Recursively parse the bold text for nested formatting
bold_nodes = parse_markdown_inline(bold_text)
# Add BOLD decoration to all text nodes within
for node in bold_nodes:
if node['type'] == 'TEXT':
node_decorations = node['textData'].get('decorations', []).copy()
if 'BOLD' not in node_decorations:
node_decorations.append('BOLD')
node['textData']['decorations'] = node_decorations
nodes.append(node)
i = end_bold + 2
continue
# Check for link [text](url)
elif text[i] == '[':
# Save any accumulated text
if current_text:
nodes.append({
'id': str(uuid.uuid4()),
'type': 'TEXT',
'textData': {
'text': current_text,
'decorations': current_decorations.copy()
}
})
current_text = ''
current_decorations = []
# Find matching ]
link_end = text.find(']', i)
if link_end != -1 and link_end < len(text) - 1 and text[link_end + 1] == '(':
link_text = text[i + 1:link_end]
url_start = link_end + 2
url_end = text.find(')', url_start)
if url_end != -1:
url = text[url_start:url_end]
# Create link node
link_node_id = str(uuid.uuid4())
text_node_id = str(uuid.uuid4())
link_text_nodes = parse_markdown_inline(link_text)
# Wrap link text in LINK node
nodes.append({
'id': link_node_id,
'type': 'LINK',
'nodes': link_text_nodes if link_text_nodes else [{
'id': text_node_id,
'type': 'TEXT',
'textData': {'text': link_text, 'decorations': []}
}],
'linkData': {
'link': {
'url': url,
'target': '_blank'
}
}
})
i = url_end + 1
continue
# Check for code `text`
elif text[i] == '`':
# Save any accumulated text
if current_text:
nodes.append({
'id': str(uuid.uuid4()),
'type': 'TEXT',
'textData': {
'text': current_text,
'decorations': current_decorations.copy()
}
})
current_text = ''
current_decorations = []
# Find closing `
code_end = text.find('`', i + 1)
if code_end != -1:
code_text = text[i + 1:code_end]
nodes.append({
'id': str(uuid.uuid4()),
'type': 'TEXT',
'textData': {
'text': code_text,
'decorations': ['CODE']
}
})
i = code_end + 1
continue
# Check for italic *text* (only if not part of **)
elif text[i] == '*' and (i == 0 or text[i-1] != '*') and (i == len(text) - 1 or text[i+1] != '*'):
# Save any accumulated text
if current_text:
nodes.append({
'id': str(uuid.uuid4()),
'type': 'TEXT',
'textData': {
'text': current_text,
'decorations': current_decorations.copy()
}
})
current_text = ''
current_decorations = []
# Find closing * (but not **)
italic_end = text.find('*', i + 1)
if italic_end != -1:
# Make sure it's not part of **
if italic_end == len(text) - 1 or text[italic_end + 1] != '*':
italic_text = text[i + 1:italic_end]
italic_nodes = parse_markdown_inline(italic_text)
# Add ITALIC decoration
for node in italic_nodes:
if node['type'] == 'TEXT':
node_decorations = node['textData'].get('decorations', []).copy()
if 'ITALIC' not in node_decorations:
node_decorations.append('ITALIC')
node['textData']['decorations'] = node_decorations
nodes.append(node)
i = italic_end + 1
continue
# Regular character
current_text += text[i]
i += 1
# Add any remaining text
if current_text:
nodes.append({
'id': str(uuid.uuid4()),
'type': 'TEXT',
'textData': {
'text': current_text,
'decorations': current_decorations.copy()
}
})
# If no nodes created, return single plain text node
if not nodes:
nodes.append({
'id': str(uuid.uuid4()),
'type': 'TEXT',
'textData': {
'text': text,
'decorations': []
}
})
return nodes
def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str, Any]:
"""
Convert simple markdown-like text into minimal valid Ricos JSON.
Convert markdown content into valid Ricos JSON format.
Supports headings, paragraphs, lists, bold, italic, links, and images.
"""
paragraphs = content.split('\n\n')
if not content:
content = "This is a post from ALwrity."
nodes = []
import uuid
for paragraph in paragraphs:
text = paragraph.strip()
if not text:
lines = content.split('\n')
i = 0
while i < len(lines):
line = lines[i].strip()
if not line:
i += 1
continue
node_id = str(uuid.uuid4())
text_node_id = str(uuid.uuid4())
if text.startswith('#'):
level = len(text) - len(text.lstrip('#'))
heading_text = text.lstrip('# ').strip()
# Check for headings
if line.startswith('#'):
level = len(line) - len(line.lstrip('#'))
heading_text = line.lstrip('# ').strip()
text_nodes = parse_markdown_inline(heading_text)
nodes.append({
'id': node_id,
'type': 'HEADING',
'nodes': [{
'id': text_node_id,
'type': 'TEXT',
'textData': {
'text': heading_text,
'decorations': []
}
}],
'headingData': { 'level': min(level, 6) }
'nodes': text_nodes,
'headingData': {'level': min(level, 6)}
})
else:
nodes.append({
'id': node_id,
i += 1
# Check for blockquotes
elif line.startswith('>'):
quote_text = line.lstrip('> ').strip()
# Continue reading consecutive blockquote lines
quote_lines = [quote_text]
i += 1
while i < len(lines) and lines[i].strip().startswith('>'):
quote_lines.append(lines[i].strip().lstrip('> ').strip())
i += 1
quote_content = ' '.join(quote_lines)
text_nodes = parse_markdown_inline(quote_content)
# CRITICAL: TEXT nodes must be wrapped in PARAGRAPH nodes within BLOCKQUOTE
paragraph_node = {
'id': str(uuid.uuid4()),
'type': 'PARAGRAPH',
'nodes': [{
'id': text_node_id,
'type': 'TEXT',
'textData': {
'text': text,
'decorations': []
}
}],
'nodes': text_nodes,
'paragraphData': {}
})
}
blockquote_node = {
'id': node_id,
'type': 'BLOCKQUOTE',
'nodes': [paragraph_node],
'blockquoteData': {}
}
nodes.append(blockquote_node)
# Check for unordered lists (handle both '- ' and '* ' markers)
elif (line.startswith('- ') or line.startswith('* ') or
(line.startswith('-') and len(line) > 1 and line[1] != '-') or
(line.startswith('*') and len(line) > 1 and line[1] != '*')):
list_items = []
list_marker = '- ' if line.startswith('-') else '* '
# Process list items
while i < len(lines):
current_line = lines[i].strip()
# Check if this is a list item
is_list_item = (current_line.startswith('- ') or current_line.startswith('* ') or
(current_line.startswith('-') and len(current_line) > 1 and current_line[1] != '-') or
(current_line.startswith('*') and len(current_line) > 1 and current_line[1] != '*'))
if not is_list_item:
break
# Extract item text (handle both '- ' and '-item' formats)
if current_line.startswith('- ') or current_line.startswith('* '):
item_text = current_line[2:].strip()
elif current_line.startswith('-'):
item_text = current_line[1:].strip()
elif current_line.startswith('*'):
item_text = current_line[1:].strip()
else:
item_text = current_line
list_items.append(item_text)
i += 1
# Check for nested items (indented with 2+ spaces)
while i < len(lines):
next_line = lines[i]
# Must be indented and be a list marker
if next_line.startswith(' ') and (next_line.strip().startswith('- ') or
next_line.strip().startswith('* ') or
(next_line.strip().startswith('-') and len(next_line.strip()) > 1) or
(next_line.strip().startswith('*') and len(next_line.strip()) > 1)):
nested_text = next_line.strip()
if nested_text.startswith('- ') or nested_text.startswith('* '):
nested_text = nested_text[2:].strip()
elif nested_text.startswith('-'):
nested_text = nested_text[1:].strip()
elif nested_text.startswith('*'):
nested_text = nested_text[1:].strip()
list_items.append(nested_text)
i += 1
else:
break
# Build list items with proper formatting
# CRITICAL: TEXT nodes must be wrapped in PARAGRAPH nodes within LIST_ITEM
# NOTE: LIST_ITEM nodes do NOT have a data field per Wix API schema
# Wix API: omit empty data objects, don't include them as {}
list_node_items = []
for item_text in list_items:
item_node_id = str(uuid.uuid4())
text_nodes = parse_markdown_inline(item_text)
paragraph_node = {
'id': str(uuid.uuid4()),
'type': 'PARAGRAPH',
'nodes': text_nodes,
'paragraphData': {}
}
list_item_node = {
'id': item_node_id,
'type': 'LIST_ITEM',
'nodes': [paragraph_node]
}
list_node_items.append(list_item_node)
bulleted_list_node = {
'id': node_id,
'type': 'BULLETED_LIST',
'nodes': list_node_items,
'bulletedListData': {}
}
nodes.append(bulleted_list_node)
# Check for ordered lists
elif re.match(r'^\d+\.\s+', line):
list_items = []
while i < len(lines) and re.match(r'^\d+\.\s+', lines[i].strip()):
item_text = re.sub(r'^\d+\.\s+', '', lines[i].strip())
list_items.append(item_text)
i += 1
# Check for nested items
while i < len(lines) and lines[i].strip().startswith(' ') and re.match(r'^\s+\d+\.\s+', lines[i].strip()):
nested_text = re.sub(r'^\s+\d+\.\s+', '', lines[i].strip())
list_items.append(nested_text)
i += 1
# CRITICAL: TEXT nodes must be wrapped in PARAGRAPH nodes within LIST_ITEM
# NOTE: LIST_ITEM nodes do NOT have a data field per Wix API schema
# Wix API: omit empty data objects, don't include them as {}
list_node_items = []
for item_text in list_items:
item_node_id = str(uuid.uuid4())
text_nodes = parse_markdown_inline(item_text)
paragraph_node = {
'id': str(uuid.uuid4()),
'type': 'PARAGRAPH',
'nodes': text_nodes,
'paragraphData': {}
}
list_item_node = {
'id': item_node_id,
'type': 'LIST_ITEM',
'nodes': [paragraph_node]
}
list_node_items.append(list_item_node)
ordered_list_node = {
'id': node_id,
'type': 'ORDERED_LIST',
'nodes': list_node_items,
'orderedListData': {}
}
nodes.append(ordered_list_node)
# Check for images
elif line.startswith('!['):
img_match = re.match(r'!\[([^\]]*)\]\(([^)]+)\)', line)
if img_match:
alt_text = img_match.group(1)
img_url = img_match.group(2)
nodes.append({
'id': node_id,
'type': 'IMAGE',
'nodes': [],
'imageData': {
'image': {
'src': {'url': img_url},
'altText': alt_text
},
'containerData': {
'alignment': 'CENTER',
'width': {'size': 'CONTENT'}
}
}
})
i += 1
# Regular paragraph
else:
# Collect consecutive non-empty lines as paragraph content
para_lines = [line]
i += 1
while i < len(lines):
next_line = lines[i].strip()
if not next_line:
break
# Stop if next line is a special markdown element
if (next_line.startswith('#') or
next_line.startswith('- ') or
next_line.startswith('* ') or
next_line.startswith('>') or
next_line.startswith('![') or
re.match(r'^\d+\.\s+', next_line)):
break
para_lines.append(next_line)
i += 1
para_text = ' '.join(para_lines)
text_nodes = parse_markdown_inline(para_text)
# Only add paragraph if there are text nodes
if text_nodes:
paragraph_node = {
'id': node_id,
'type': 'PARAGRAPH',
'nodes': text_nodes,
'paragraphData': {}
}
nodes.append(paragraph_node)
# Ensure at least one node exists
# Wix API: omit empty data objects, don't include them as {}
if not nodes:
fallback_paragraph = {
'id': str(uuid.uuid4()),
'type': 'PARAGRAPH',
'nodes': [{
'id': str(uuid.uuid4()),
'type': 'TEXT',
'textData': {
'text': content[:500] if content else "This is a post from ALwrity.",
'decorations': []
}
}],
'paragraphData': {}
}
nodes.append(fallback_paragraph)
return {
'type': 'DOCUMENT',
'id': str(uuid.uuid4()),
'nodes': nodes,
'metadata': { 'version': 1, 'id': str(uuid.uuid4()) },
'metadata': {'version': 1, 'id': str(uuid.uuid4())},
'documentStyle': {
'paragraph': { 'decorations': [], 'nodeStyle': {}, 'lineHeight': '1.5' }
'paragraph': {'decorations': [], 'nodeStyle': {}, 'lineHeight': '1.5'}
}
}

View File

@@ -7,6 +7,12 @@ class WixMediaService:
self.base_url = base_url
def import_image(self, access_token: str, image_url: str, display_name: str) -> Dict[str, Any]:
"""
Import external image to Wix Media Manager.
Official endpoint: https://www.wixapis.com/site-media/v1/files/import
Reference: https://dev.wix.com/docs/rest/assets/media/media-manager/files/import-file
"""
headers = {
'Authorization': f'Bearer {access_token}',
'Content-Type': 'application/json',
@@ -16,7 +22,9 @@ class WixMediaService:
'mediaType': 'IMAGE',
'displayName': display_name,
}
response = requests.post(f"{self.base_url}/media/v1/files/import", headers=headers, json=payload)
# Correct endpoint per Wix API documentation
endpoint = f"{self.base_url}/site-media/v1/files/import"
response = requests.post(endpoint, headers=headers, json=payload)
response.raise_for_status()
return response.json()

View File

@@ -0,0 +1,277 @@
"""
Ricos Document Converter for Wix
Converts markdown content to Wix Ricos JSON format using either:
1. Wix's official Ricos Documents API (preferred)
2. Custom markdown parser (fallback)
"""
import json
import requests
import jwt
from typing import Dict, Any, Optional
from loguru import logger
def markdown_to_html(markdown_content: str) -> str:
"""
Convert markdown content to HTML.
Uses a simple markdown parser for basic conversion.
Args:
markdown_content: Markdown content to convert
Returns:
HTML string
"""
try:
# Try using markdown library if available
import markdown
html = markdown.markdown(markdown_content, extensions=['fenced_code', 'tables'])
return html
except ImportError:
# Fallback: Simple regex-based conversion for basic markdown
logger.warning("markdown library not available, using basic markdown-to-HTML conversion")
import re
if not markdown_content or not markdown_content.strip():
return "<p>This is a post from ALwrity.</p>"
lines = markdown_content.split('\n')
result = []
in_list = False
list_type = None # 'ul' or 'ol'
in_code_block = False
code_block_content = []
i = 0
while i < len(lines):
line = lines[i].strip()
# Handle code blocks first
if line.startswith('```'):
if not in_code_block:
in_code_block = True
code_block_content = []
i += 1
continue
else:
in_code_block = False
result.append(f'<pre><code>{"\n".join(code_block_content)}</code></pre>')
code_block_content = []
i += 1
continue
if in_code_block:
code_block_content.append(lines[i])
i += 1
continue
# Close any open lists
if in_list and not (line.startswith('- ') or line.startswith('* ') or re.match(r'^\d+\.\s+', line)):
result.append(f'</{list_type}>')
in_list = False
list_type = None
if not line:
i += 1
continue
# Headers
if line.startswith('###'):
result.append(f'<h3>{line[3:].strip()}</h3>')
elif line.startswith('##'):
result.append(f'<h2>{line[2:].strip()}</h2>')
elif line.startswith('#'):
result.append(f'<h1>{line[1:].strip()}</h1>')
# Lists
elif line.startswith('- ') or line.startswith('* '):
if not in_list or list_type != 'ul':
if in_list:
result.append(f'</{list_type}>')
result.append('<ul>')
in_list = True
list_type = 'ul'
# Process inline formatting in list item
item_text = line[2:].strip()
item_text = re.sub(r'\*\*(.*?)\*\*', r'<strong>\1</strong>', item_text)
item_text = re.sub(r'\*(.*?)\*', r'<em>\1</em>', item_text)
result.append(f'<li>{item_text}</li>')
elif re.match(r'^\d+\.\s+', line):
if not in_list or list_type != 'ol':
if in_list:
result.append(f'</{list_type}>')
result.append('<ol>')
in_list = True
list_type = 'ol'
# Process inline formatting in list item
match = re.match(r'^\d+\.\s+(.*)', line)
if match:
item_text = match.group(1)
item_text = re.sub(r'\*\*(.*?)\*\*', r'<strong>\1</strong>', item_text)
item_text = re.sub(r'\*(.*?)\*', r'<em>\1</em>', item_text)
result.append(f'<li>{item_text}</li>')
# Blockquotes
elif line.startswith('>'):
quote_text = line[1:].strip()
quote_text = re.sub(r'\*\*(.*?)\*\*', r'<strong>\1</strong>', quote_text)
quote_text = re.sub(r'\*(.*?)\*', r'<em>\1</em>', quote_text)
result.append(f'<blockquote><p>{quote_text}</p></blockquote>')
# Regular paragraphs
else:
para_text = line
# Process inline formatting
para_text = re.sub(r'\*\*(.*?)\*\*', r'<strong>\1</strong>', para_text)
para_text = re.sub(r'\*(.*?)\*', r'<em>\1</em>', para_text)
para_text = re.sub(r'\[([^\]]+)\]\(([^\)]+)\)', r'<a href="\2">\1</a>', para_text)
para_text = re.sub(r'`([^`]+)`', r'<code>\1</code>', para_text)
result.append(f'<p>{para_text}</p>')
i += 1
# Close any open lists
if in_list:
result.append(f'</{list_type}>')
# Ensure we have at least one paragraph
if not result:
result.append('<p>This is a post from ALwrity.</p>')
html = '\n'.join(result)
logger.debug(f"Converted {len(markdown_content)} chars markdown to {len(html)} chars HTML")
return html
def convert_via_wix_api(markdown_content: str, access_token: str, base_url: str = 'https://www.wixapis.com') -> Dict[str, Any]:
"""
Convert markdown to Ricos using Wix's official Ricos Documents API.
Uses HTML format for better reliability (per Wix documentation, HTML is fully supported).
Reference: https://dev.wix.com/docs/api-reference/assets/rich-content/ricos-documents/convert-to-ricos-document
Args:
markdown_content: Markdown content to convert (will be converted to HTML)
access_token: Wix access token
base_url: Wix API base URL (default: https://www.wixapis.com)
Returns:
Ricos JSON document
"""
# Validate content is not empty
markdown_stripped = markdown_content.strip() if markdown_content else ""
if not markdown_stripped:
logger.error("Markdown content is empty or whitespace-only")
raise ValueError("Content cannot be empty for Wix Ricos API conversion")
logger.debug(f"Converting markdown to HTML: input_length={len(markdown_stripped)} chars")
# Convert markdown to HTML for better reliability with Wix API
# HTML format is more structured and less prone to parsing errors
html_content = markdown_to_html(markdown_stripped)
# Validate HTML content is not empty - CRITICAL for Wix API
html_stripped = html_content.strip() if html_content else ""
if not html_stripped or len(html_stripped) == 0:
logger.error(f"HTML conversion produced empty content! Markdown length: {len(markdown_stripped)}")
logger.error(f"Markdown sample: {markdown_stripped[:500]}...")
logger.error(f"HTML result: '{html_content}' (type: {type(html_content)})")
# Fallback: use a minimal valid HTML if conversion failed
html_content = "<p>Content from ALwrity blog writer.</p>"
logger.warning("Using fallback HTML due to empty conversion result")
else:
html_content = html_stripped
logger.debug(f"✅ Converted markdown to HTML: {len(html_content)} chars, preview: {html_content[:200]}...")
headers = {
'Authorization': f'Bearer {access_token}',
'Content-Type': 'application/json'
}
# Add wix-site-id if available from token
try:
token_str = str(access_token)
if token_str and token_str.startswith('OauthNG.JWS.'):
jwt_part = token_str[12:]
payload = jwt.decode(jwt_part, options={"verify_signature": False, "verify_aud": False})
data_payload = payload.get('data', {})
if isinstance(data_payload, str):
try:
data_payload = json.loads(data_payload)
except:
pass
instance_data = data_payload.get('instance', {})
meta_site_id = instance_data.get('metaSiteId')
if isinstance(meta_site_id, str) and meta_site_id:
headers['wix-site-id'] = meta_site_id
except Exception as e:
logger.debug(f"Could not extract site ID from token: {e}")
# Call Wix Ricos Documents API: Convert to Ricos Document
# Official endpoint: https://www.wixapis.com/ricos/v1/ricos-document/convert/to-ricos
# Reference: https://dev.wix.com/docs/rest/assets/rich-content/ricos-documents/convert-to-ricos-document
endpoint = f"{base_url}/ricos/v1/ricos-document/convert/to-ricos"
# Ensure HTML content is not empty or just whitespace
html_stripped = html_content.strip() if html_content else ""
if not html_stripped or len(html_stripped) == 0:
logger.error(f"HTML content is empty after conversion. Markdown length: {len(markdown_content)}")
logger.error(f"Markdown preview (first 500 chars): {markdown_content[:500] if markdown_content else 'N/A'}")
raise ValueError(f"HTML content cannot be empty. Original markdown had {len(markdown_content)} characters.")
# Payload structure per Wix API: html/markdown/plainText field at root, optional plugins
payload = {
'html': html_stripped, # Direct field, not nested in options
'plugins': [] # Optional: empty array uses default plugins
}
logger.warning(f"📤 Sending to Wix Ricos API: html_length={len(payload['html'])}, plugins_count={len(payload['plugins'])}")
logger.debug(f"HTML preview (first 300 chars): {html_stripped[:300]}...")
try:
# Log the exact payload being sent (for debugging)
logger.warning(f"📤 Wix Ricos API Request:")
logger.warning(f" Endpoint: {endpoint}")
logger.warning(f" Payload keys: {list(payload.keys())}")
logger.warning(f" HTML length: {len(payload.get('html', ''))}")
logger.warning(f" Plugins: {payload.get('plugins', [])}")
logger.debug(f" Full payload (first 500 chars of HTML): {str(payload)[:500]}")
response = requests.post(
endpoint,
headers=headers,
json=payload,
timeout=30
)
response.raise_for_status()
result = response.json()
# Extract the ricos document from response
# Response structure: { "document": { "nodes": [...], "metadata": {...}, "documentStyle": {...} } }
ricos_document = result.get('document')
if not ricos_document:
# Fallback: try other possible response fields
ricos_document = result.get('ricosDocument') or result.get('ricos') or result
if not ricos_document:
logger.error(f"Unexpected response structure from Wix API: {list(result.keys())}")
logger.error(f"Response: {result}")
raise ValueError("Wix API did not return a valid Ricos document")
logger.warning(f"✅ Successfully converted HTML to Ricos via Wix API: {len(ricos_document.get('nodes', []))} nodes")
return ricos_document
except requests.RequestException as e:
logger.error(f"❌ Wix Ricos API conversion failed: {e}")
if hasattr(e, 'response') and e.response is not None:
logger.error(f" Response status: {e.response.status_code}")
logger.error(f" Response headers: {dict(e.response.headers)}")
try:
error_body = e.response.json()
logger.error(f" Response JSON: {error_body}")
except:
logger.error(f" Response text: {e.response.text}")
logger.error(f" Request payload was: {json.dumps(payload, indent=2)[:1000]}...") # First 1000 chars
raise

View File

@@ -0,0 +1,300 @@
"""
SEO Data Builder for Wix Blog Posts
Builds Wix-compatible seoData objects from ALwrity SEO metadata.
"""
from typing import Dict, Any, Optional
from loguru import logger
def build_seo_data(seo_metadata: Dict[str, Any], default_title: str = None) -> Optional[Dict[str, Any]]:
"""
Build Wix seoData object from our SEO metadata format.
Args:
seo_metadata: SEO metadata dict with fields like:
- seo_title: SEO optimized title
- meta_description: Meta description
- focus_keyword: Main keyword
- blog_tags: List of tag strings (for keywords)
- open_graph: Open Graph data dict
- canonical_url: Canonical URL
default_title: Fallback title if seo_title not provided
Returns:
Wix seoData object with settings.keywords and tags array, or None if empty
"""
seo_data = {
'settings': {
'keywords': []
},
'tags': []
}
# Build keywords array
keywords_list = []
# Add main keyword (focus_keyword) if provided
focus_keyword = seo_metadata.get('focus_keyword')
if focus_keyword:
keywords_list.append({
'term': str(focus_keyword),
'isMain': True
})
# Add additional keywords from blog_tags or other sources
blog_tags = seo_metadata.get('blog_tags', [])
if isinstance(blog_tags, list):
for tag in blog_tags:
tag_str = str(tag).strip()
if tag_str and tag_str != focus_keyword: # Don't duplicate main keyword
keywords_list.append({
'term': tag_str,
'isMain': False
})
# Add social hashtags as keywords if available
social_hashtags = seo_metadata.get('social_hashtags', [])
if isinstance(social_hashtags, list):
for hashtag in social_hashtags:
# Remove # if present
hashtag_str = str(hashtag).strip().lstrip('#')
if hashtag_str and hashtag_str != focus_keyword:
keywords_list.append({
'term': hashtag_str,
'isMain': False
})
seo_data['settings']['keywords'] = keywords_list
# Validate keywords list is not empty (or ensure at least one keyword exists)
if not seo_data['settings']['keywords']:
logger.warning("No keywords found in SEO metadata, adding empty keywords array")
# Build tags array (meta tags, Open Graph, etc.)
tags_list = []
# Meta description
meta_description = seo_metadata.get('meta_description')
if meta_description:
tags_list.append({
'type': 'meta',
'props': {
'name': 'description',
'content': str(meta_description)
},
'custom': True,
'disabled': False
})
# SEO title - 'title' type uses 'children' field, not 'props.content'
seo_title = seo_metadata.get('seo_title') or default_title
if seo_title:
tags_list.append({
'type': 'title',
'children': str(seo_title), # Title tags use 'children', not 'props.content'
'custom': True,
'disabled': False
})
# Open Graph tags
open_graph = seo_metadata.get('open_graph', {})
if isinstance(open_graph, dict):
# OG Title
og_title = open_graph.get('title') or seo_title
if og_title:
tags_list.append({
'type': 'meta',
'props': {
'property': 'og:title',
'content': str(og_title)
},
'custom': True,
'disabled': False
})
# OG Description
og_description = open_graph.get('description') or meta_description
if og_description:
tags_list.append({
'type': 'meta',
'props': {
'property': 'og:description',
'content': str(og_description)
},
'custom': True,
'disabled': False
})
# OG Image
og_image = open_graph.get('image')
if og_image:
# Skip base64 images for OG tags (Wix needs URLs)
if isinstance(og_image, str) and (og_image.startswith('http://') or og_image.startswith('https://')):
tags_list.append({
'type': 'meta',
'props': {
'property': 'og:image',
'content': og_image
},
'custom': True,
'disabled': False
})
# OG Type
tags_list.append({
'type': 'meta',
'props': {
'property': 'og:type',
'content': 'article'
},
'custom': True,
'disabled': False
})
# OG URL (canonical or provided URL)
og_url = open_graph.get('url') or seo_metadata.get('canonical_url')
if og_url:
tags_list.append({
'type': 'meta',
'props': {
'property': 'og:url',
'content': str(og_url)
},
'custom': True,
'disabled': False
})
# Twitter Card tags
twitter_card = seo_metadata.get('twitter_card', {})
if isinstance(twitter_card, dict):
twitter_title = twitter_card.get('title') or seo_title
if twitter_title:
tags_list.append({
'type': 'meta',
'props': {
'name': 'twitter:title',
'content': str(twitter_title)
},
'custom': True,
'disabled': False
})
twitter_description = twitter_card.get('description') or meta_description
if twitter_description:
tags_list.append({
'type': 'meta',
'props': {
'name': 'twitter:description',
'content': str(twitter_description)
},
'custom': True,
'disabled': False
})
twitter_image = twitter_card.get('image')
if twitter_image and isinstance(twitter_image, str) and (twitter_image.startswith('http://') or twitter_image.startswith('https://')):
tags_list.append({
'type': 'meta',
'props': {
'name': 'twitter:image',
'content': twitter_image
},
'custom': True,
'disabled': False
})
twitter_card_type = twitter_card.get('card', 'summary_large_image')
tags_list.append({
'type': 'meta',
'props': {
'name': 'twitter:card',
'content': str(twitter_card_type)
},
'custom': True,
'disabled': False
})
# Canonical URL as link tag
canonical_url = seo_metadata.get('canonical_url')
if canonical_url:
tags_list.append({
'type': 'link',
'props': {
'rel': 'canonical',
'href': str(canonical_url)
},
'custom': True,
'disabled': False
})
# Validate all tags have required fields before adding
validated_tags = []
for tag in tags_list:
if not isinstance(tag, dict):
logger.warning(f"Skipping invalid tag (not a dict): {type(tag)}")
continue
# Ensure required fields exist
if 'type' not in tag:
logger.warning("Skipping tag missing 'type' field")
continue
# Ensure 'custom' and 'disabled' fields exist
if 'custom' not in tag:
tag['custom'] = True
if 'disabled' not in tag:
tag['disabled'] = False
# Validate tag structure based on type
tag_type = tag.get('type')
if tag_type == 'title':
if 'children' not in tag or not tag['children']:
logger.warning("Skipping title tag with missing/invalid 'children' field")
continue
elif tag_type == 'meta':
if 'props' not in tag or not isinstance(tag['props'], dict):
logger.warning("Skipping meta tag with missing/invalid 'props' field")
continue
if 'name' not in tag['props'] and 'property' not in tag['props']:
logger.warning("Skipping meta tag with missing 'name' or 'property' in props")
continue
# Ensure 'content' exists and is not empty
if 'content' not in tag['props'] or not str(tag['props'].get('content', '')).strip():
logger.warning(f"Skipping meta tag with missing/empty 'content': {tag.get('props', {})}")
continue
elif tag_type == 'link':
if 'props' not in tag or not isinstance(tag['props'], dict):
logger.warning("Skipping link tag with missing/invalid 'props' field")
continue
# Ensure 'href' exists and is not empty for link tags
if 'href' not in tag['props'] or not str(tag['props'].get('href', '')).strip():
logger.warning(f"Skipping link tag with missing/empty 'href': {tag.get('props', {})}")
continue
validated_tags.append(tag)
seo_data['tags'] = validated_tags
# Final validation: ensure seoData structure is complete
if not isinstance(seo_data['settings'], dict):
logger.error("seoData.settings is not a dict, creating default")
seo_data['settings'] = {'keywords': []}
if not isinstance(seo_data['settings'].get('keywords'), list):
logger.error("seoData.settings.keywords is not a list, creating empty list")
seo_data['settings']['keywords'] = []
if not isinstance(seo_data['tags'], list):
logger.error("seoData.tags is not a list, creating empty list")
seo_data['tags'] = []
# CRITICAL: Per Wix API patterns, omit empty structures instead of including them as {}
# If keywords is empty, omit settings entirely
if not seo_data['settings'].get('keywords'):
logger.debug("No keywords found, omitting settings from seoData")
seo_data.pop('settings', None)
logger.debug(f"Built SEO data: {len(validated_tags)} tags, {len(keywords_list)} keywords")
# Only return seoData if we have at least keywords or tags
if keywords_list or validated_tags:
return seo_data
return None

View File

@@ -9,6 +9,7 @@ import json
from typing import Optional, Dict, Any
from datetime import datetime
from loguru import logger
from fastapi import HTTPException
from ..onboarding.api_key_manager import APIKeyManager
from .gemini_provider import gemini_text_response, gemini_structured_json_response
@@ -129,11 +130,16 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
pricing_service = PricingService(db)
# Estimate tokens from prompt (input tokens)
# Note: We estimate output tokens conservatively (assume response is similar length to prompt)
# This prevents underestimating total token usage
# CRITICAL: Use worst-case scenario (input + max_tokens) for validation to prevent abuse
# This ensures we block requests that would exceed limits even if response is longer than expected
input_tokens = int(len(prompt.split()) * 1.3)
# Conservative estimate: assume output tokens ≈ input tokens * 1.0 (can be up to max_tokens)
estimated_output_tokens = min(input_tokens, max_tokens) if max_tokens else int(input_tokens * 0.8)
# Worst-case estimate: assume maximum possible output tokens (max_tokens if specified)
# This prevents abuse where actual response tokens exceed the estimate
if max_tokens:
estimated_output_tokens = max_tokens # Use maximum allowed output tokens
else:
# If max_tokens not specified, use conservative estimate (input * 1.5)
estimated_output_tokens = int(input_tokens * 1.5)
estimated_total_tokens = input_tokens + estimated_output_tokens
# Check limits using sync method from pricing service (strict enforcement)
@@ -146,7 +152,14 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
if not can_proceed:
logger.warning(f"[llm_text_gen] Subscription limit exceeded for user {user_id}: {message}")
raise RuntimeError(f"Subscription limit exceeded: {message}")
# Raise HTTPException(429) with usage info so frontend can display subscription modal
error_detail = {
'error': message,
'message': message,
'provider': actual_provider_name or provider_enum.value,
'usage_info': usage_info if usage_info else {}
}
raise HTTPException(status_code=429, detail=error_detail)
# Get current usage for limit checking only
current_period = pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
@@ -159,6 +172,9 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
finally:
db.close()
except HTTPException:
# Re-raise HTTPExceptions (e.g., 429 subscription limit) - preserve error details
raise
except RuntimeError:
# Re-raise subscription limit errors
raise
@@ -244,7 +260,8 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
db_track = next(get_db())
try:
# Estimate tokens from prompt and response
tokens_input = estimated_tokens # Already calculated above
# Recalculate input tokens from prompt (consistent with pre-flight estimation)
tokens_input = int(len(prompt.split()) * 1.3)
tokens_output = int(len(str(response_text).split()) * 1.3) # Estimate output tokens
tokens_total = tokens_input + tokens_output
@@ -259,45 +276,186 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
logger.debug(f"[llm_text_gen] Looking for usage summary: user_id={user_id}, period={current_period}")
# Get limits once for safety check (to prevent exceeding limits even if actual usage > estimate)
provider_name = provider_enum.value
limits = pricing.get_user_limits(user_id)
token_limit = 0
if limits and limits.get('limits'):
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
# CRITICAL: Use raw SQL to read current values directly from DB, bypassing SQLAlchemy cache
# This ensures we always get the absolute latest committed values, even across different sessions
from sqlalchemy import text
current_calls_before = 0
current_tokens_before = 0
record_count = 0 # Initialize to ensure it's always defined
# CRITICAL: First check if record exists using COUNT query
try:
check_query = text("SELECT COUNT(*) FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period")
record_count = db_track.execute(check_query, {'user_id': user_id, 'period': current_period}).scalar()
logger.debug(f"[llm_text_gen] 🔍 DEBUG: Record count check - found {record_count} record(s) for user={user_id}, period={current_period}")
except Exception as count_error:
logger.error(f"[llm_text_gen] ❌ COUNT query failed: {count_error}", exc_info=True)
record_count = 0
if record_count and record_count > 0:
# Record exists - read current values with raw SQL
try:
# Validate provider_name to prevent SQL injection (whitelist approach)
valid_providers = ['gemini', 'openai', 'anthropic', 'mistral']
if provider_name not in valid_providers:
raise ValueError(f"Invalid provider_name for SQL query: {provider_name}")
# Read current values directly from database using raw SQL
# CRITICAL: This bypasses SQLAlchemy's session cache and gets absolute latest values
sql_query = text(f"""
SELECT {provider_name}_calls, {provider_name}_tokens
FROM usage_summaries
WHERE user_id = :user_id AND billing_period = :period
LIMIT 1
""")
logger.debug(f"[llm_text_gen] 🔍 Executing raw SQL for EXISTING record: SELECT {provider_name}_calls, {provider_name}_tokens WHERE user_id={user_id}, period={current_period}")
result = db_track.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
if result:
raw_calls = result[0] if result[0] is not None else 0
raw_tokens = result[1] if result[1] is not None else 0
current_calls_before = raw_calls
current_tokens_before = raw_tokens
logger.debug(f"[llm_text_gen] ✅ Raw SQL SUCCESS: Found EXISTING record - calls={current_calls_before}, tokens={current_tokens_before} (provider={provider_name}, column={provider_name}_calls/{provider_name}_tokens)")
logger.debug(f"[llm_text_gen] 🔍 Raw SQL returned row: {result}, extracted calls={raw_calls}, tokens={raw_tokens}")
else:
logger.error(f"[llm_text_gen] ❌ CRITICAL BUG: Record EXISTS (count={record_count}) but SELECT query returned None! Query: {sql_query}")
# Fallback: Use ORM to get values
summary_fallback = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if summary_fallback:
db_track.refresh(summary_fallback)
current_calls_before = getattr(summary_fallback, f"{provider_name}_calls", 0) or 0
current_tokens_before = getattr(summary_fallback, f"{provider_name}_tokens", 0) or 0
logger.warning(f"[llm_text_gen] ⚠️ Using ORM fallback: calls={current_calls_before}, tokens={current_tokens_before}")
except Exception as sql_error:
logger.error(f"[llm_text_gen] ❌ Raw SQL query failed: {sql_error}", exc_info=True)
# Fallback: Use ORM to get values
summary_fallback = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if summary_fallback:
db_track.refresh(summary_fallback)
current_calls_before = getattr(summary_fallback, f"{provider_name}_calls", 0) or 0
current_tokens_before = getattr(summary_fallback, f"{provider_name}_tokens", 0) or 0
else:
logger.debug(f"[llm_text_gen] No record exists yet (will create new) - user={user_id}, period={current_period}")
# Get or create usage summary object (needed for ORM update)
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
logger.info(f"[llm_text_gen] Creating new usage summary for user {user_id}, period {current_period}")
logger.debug(f"[llm_text_gen] Creating NEW usage summary for user {user_id}, period {current_period}")
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush() # Ensure summary is persisted before updating
# New record - values are already 0, no need to set
logger.debug(f"[llm_text_gen] ✅ New summary created - starting from 0")
else:
# CRITICAL: Update the ORM object with values from raw SQL query
# This ensures the ORM object reflects the actual database state before we increment
logger.debug(f"[llm_text_gen] 🔄 Existing summary found - syncing with raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}")
setattr(summary, f"{provider_name}_calls", current_calls_before)
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
setattr(summary, f"{provider_name}_tokens", current_tokens_before)
logger.debug(f"[llm_text_gen] ✅ Synchronized ORM object: {provider_name}_calls={current_calls_before}, {provider_name}_tokens={current_tokens_before}")
# Get "before" state for unified log
provider_name = provider_enum.value
current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0
logger.debug(f"[llm_text_gen] Current {provider_name}_calls from DB (raw SQL): {current_calls_before}")
# Update provider-specific counters (sync operation)
new_calls = current_calls_before + 1
setattr(summary, f"{provider_name}_calls", new_calls)
logger.debug(f"[llm_text_gen] Updated {provider_name}_calls: {current_calls_before} -> {new_calls}")
# Update token usage for LLM providers
# CRITICAL: Use direct SQL UPDATE instead of ORM setattr for dynamic attributes
# SQLAlchemy doesn't detect changes when using setattr() on dynamic attributes
# Using raw SQL UPDATE ensures the change is persisted
from sqlalchemy import text
update_calls_query = text(f"""
UPDATE usage_summaries
SET {provider_name}_calls = :new_calls
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_calls_query, {
'new_calls': new_calls,
'user_id': user_id,
'period': current_period
})
logger.debug(f"[llm_text_gen] Updated {provider_name}_calls via SQL: {current_calls_before} -> {new_calls}")
# Update token usage for LLM providers with safety check
# CRITICAL: Use current_tokens_before from raw SQL query (NOT from ORM object)
# The ORM object may have stale values, but raw SQL always has the latest committed values
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0
new_tokens = current_tokens_before + tokens_total
setattr(summary, f"{provider_name}_tokens", new_tokens)
logger.debug(f"[llm_text_gen] Updated {provider_name}_tokens: {current_tokens_before} -> {new_tokens}")
logger.debug(f"[llm_text_gen] Current {provider_name}_tokens from DB (raw SQL): {current_tokens_before}")
# SAFETY CHECK: Prevent exceeding token limit even if actual usage exceeds estimate
# This prevents abuse where actual response tokens exceed pre-flight validation estimate
projected_new_tokens = current_tokens_before + tokens_total
# If limit is set (> 0) and would be exceeded, cap at limit
if token_limit > 0 and projected_new_tokens > token_limit:
logger.warning(
f"[llm_text_gen] ⚠️ ACTUAL token usage ({tokens_total}) exceeded estimate. "
f"Would exceed limit: {projected_new_tokens} > {token_limit}. "
f"Capping tracked tokens at limit to prevent abuse."
)
# Cap at limit to prevent abuse
new_tokens = token_limit
# Adjust tokens_total for accurate total tracking
tokens_total = token_limit - current_tokens_before
if tokens_total < 0:
tokens_total = 0
else:
new_tokens = projected_new_tokens
# CRITICAL: Use direct SQL UPDATE instead of ORM setattr for dynamic attributes
update_tokens_query = text(f"""
UPDATE usage_summaries
SET {provider_name}_tokens = :new_tokens
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_tokens_query, {
'new_tokens': new_tokens,
'user_id': user_id,
'period': current_period
})
logger.debug(f"[llm_text_gen] Updated {provider_name}_tokens via SQL: {current_tokens_before} -> {new_tokens}")
else:
current_tokens_before = 0
new_tokens = 0
# Update totals
# Update totals using SQL UPDATE
old_total_calls = summary.total_calls or 0
old_total_tokens = summary.total_tokens or 0
summary.total_calls = old_total_calls + 1
summary.total_tokens = old_total_tokens + tokens_total
logger.debug(f"[llm_text_gen] Updated totals: calls {old_total_calls} -> {summary.total_calls}, tokens {old_total_tokens} -> {summary.total_tokens}")
new_total_calls = old_total_calls + 1
new_total_tokens = old_total_tokens + tokens_total
update_totals_query = text("""
UPDATE usage_summaries
SET total_calls = :total_calls, total_tokens = :total_tokens
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_totals_query, {
'total_calls': new_total_calls,
'total_tokens': new_total_tokens,
'user_id': user_id,
'period': current_period
})
logger.debug(f"[llm_text_gen] Updated totals via SQL: calls {old_total_calls} -> {new_total_calls}, tokens {old_total_tokens} -> {new_total_tokens}")
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
@@ -310,12 +468,52 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
current_images_before = getattr(summary, "stability_calls", 0) or 0
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
db_track.commit()
logger.info(f"[llm_text_gen] ✅ Successfully tracked usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens")
# CRITICAL DEBUG: Print diagnostic info BEFORE commit (always visible, flushed immediately)
import sys
debug_msg = f"[DEBUG] BEFORE COMMIT - Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}, Period: {current_period}, New calls will be: {new_calls}, New tokens will be: {new_tokens}"
print(debug_msg, flush=True)
sys.stdout.flush()
logger.debug(f"[llm_text_gen] {debug_msg}")
# CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions
db_track.flush() # Flush to ensure changes are in DB (not just in transaction)
db_track.commit() # Commit transaction to make changes visible to other sessions
logger.debug(f"[llm_text_gen] ✅ Successfully tracked usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens (COMMITTED to DB)")
logger.debug(f"[llm_text_gen] Database state after commit: {provider_name}_calls={new_calls}, {provider_name}_tokens={new_tokens} (should be visible to next session)")
# CRITICAL: Verify commit worked by reading back from DB immediately after commit
try:
verify_query = text(f"SELECT {provider_name}_calls, {provider_name}_tokens FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1")
verify_result = db_track.execute(verify_query, {'user_id': user_id, 'period': current_period}).first()
if verify_result:
verified_calls = verify_result[0] if verify_result[0] is not None else 0
verified_tokens = verify_result[1] if verify_result[1] is not None else 0
logger.debug(f"[llm_text_gen] ✅ VERIFICATION AFTER COMMIT: Read back calls={verified_calls}, tokens={verified_tokens} (expected: calls={new_calls}, tokens={new_tokens})")
if verified_calls != new_calls or verified_tokens != new_tokens:
logger.error(f"[llm_text_gen] ❌ CRITICAL: COMMIT VERIFICATION FAILED! Expected calls={new_calls}, tokens={new_tokens}, but DB has calls={verified_calls}, tokens={verified_tokens}")
# Force another commit attempt
db_track.commit()
verify_result2 = db_track.execute(verify_query, {'user_id': user_id, 'period': current_period}).first()
if verify_result2:
verified_calls2 = verify_result2[0] if verify_result2[0] is not None else 0
verified_tokens2 = verify_result2[1] if verify_result2[1] is not None else 0
logger.debug(f"[llm_text_gen] 🔄 After second commit attempt: calls={verified_calls2}, tokens={verified_tokens2}")
else:
logger.debug(f"[llm_text_gen] ✅ COMMIT VERIFICATION PASSED: Values match expected values")
else:
logger.error(f"[llm_text_gen] ❌ CRITICAL: COMMIT VERIFICATION FAILED! Record not found after commit!")
except Exception as verify_error:
logger.error(f"[llm_text_gen] ❌ Error verifying commit: {verify_error}", exc_info=True)
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")
# Include image stats in the log
# DEBUG: Log the actual values being used
logger.debug(f"[llm_text_gen] 📊 FINAL VALUES FOR LOG: calls_before={current_calls_before}, calls_after={new_calls}, tokens_before={current_tokens_before}, tokens_after={new_tokens}, provider={provider_name}, enum={provider_enum}")
# CRITICAL DEBUG: Print diagnostic info to stdout (always visible)
print(f"[DEBUG] Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}")
print(f"""
[SUBSCRIPTION] LLM Text Generation
├─ User: {user_id}
@@ -407,7 +605,8 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
db_track = next(get_db())
try:
# Estimate tokens from prompt and response
tokens_input = estimated_tokens
# Recalculate input tokens from prompt (consistent with pre-flight estimation)
tokens_input = int(len(prompt.split()) * 1.3)
tokens_output = int(len(str(response_text).split()) * 1.3)
tokens_total = tokens_input + tokens_output
@@ -418,6 +617,49 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Get limits once for safety check (to prevent exceeding limits even if actual usage > estimate)
provider_name = provider_enum.value
limits = pricing.get_user_limits(user_id)
token_limit = 0
if limits and limits.get('limits'):
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
# CRITICAL: Use raw SQL to read current values directly from DB, bypassing SQLAlchemy cache
from sqlalchemy import text
current_calls_before = 0
current_tokens_before = 0
try:
# Validate provider_name to prevent SQL injection
valid_providers = ['gemini', 'openai', 'anthropic', 'mistral']
if provider_name not in valid_providers:
raise ValueError(f"Invalid provider_name for SQL query: {provider_name}")
# Read current values directly from database using raw SQL
sql_query = text(f"""
SELECT {provider_name}_calls, {provider_name}_tokens
FROM usage_summaries
WHERE user_id = :user_id AND billing_period = :period
LIMIT 1
""")
result = db_track.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
if result:
current_calls_before = result[0] if result[0] is not None else 0
current_tokens_before = result[1] if result[1] is not None else 0
logger.debug(f"[llm_text_gen] Raw SQL read current values (fallback): calls={current_calls_before}, tokens={current_tokens_before}")
except Exception as sql_error:
logger.warning(f"[llm_text_gen] Raw SQL query failed (fallback), falling back to ORM: {sql_error}")
# Fallback to ORM query if raw SQL fails
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if summary:
db_track.refresh(summary)
current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0
current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0
# Get or create usage summary object (needed for ORM update)
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
@@ -430,41 +672,68 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
)
db_track.add(summary)
db_track.flush() # Ensure summary is persisted before updating
else:
# CRITICAL: Update the ORM object with values from raw SQL query
# This ensures the ORM object reflects the actual database state before we increment
setattr(summary, f"{provider_name}_calls", current_calls_before)
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
setattr(summary, f"{provider_name}_tokens", current_tokens_before)
logger.debug(f"[llm_text_gen] Synchronized summary object with raw SQL values (fallback): calls={current_calls_before}, tokens={current_tokens_before}")
# Get "before" state for unified log
provider_name = provider_enum.value
current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0
# Get "before" state for unified log (from raw SQL query)
logger.debug(f"[llm_text_gen] Current {provider_name}_calls from DB (fallback, raw SQL): {current_calls_before}")
# Update provider-specific counters (sync operation)
new_calls = current_calls_before + 1
setattr(summary, f"{provider_name}_calls", new_calls)
# Update token usage for LLM providers
# Update token usage for LLM providers with safety check
# Use current_tokens_before from raw SQL query (most reliable)
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0
new_tokens = current_tokens_before + tokens_total
logger.debug(f"[llm_text_gen] Current {provider_name}_tokens from DB (fallback, raw SQL): {current_tokens_before}")
# SAFETY CHECK: Prevent exceeding token limit even if actual usage exceeds estimate
# This prevents abuse where actual response tokens exceed pre-flight validation estimate
projected_new_tokens = current_tokens_before + tokens_total
# If limit is set (> 0) and would be exceeded, cap at limit
if token_limit > 0 and projected_new_tokens > token_limit:
logger.warning(
f"[llm_text_gen] ⚠️ ACTUAL token usage ({tokens_total}) exceeded estimate in fallback provider. "
f"Would exceed limit: {projected_new_tokens} > {token_limit}. "
f"Capping tracked tokens at limit to prevent abuse."
)
# Cap at limit to prevent abuse
new_tokens = token_limit
# Adjust tokens_total for accurate total tracking
tokens_total = token_limit - current_tokens_before
if tokens_total < 0:
tokens_total = 0
else:
new_tokens = projected_new_tokens
setattr(summary, f"{provider_name}_tokens", new_tokens)
else:
current_tokens_before = 0
new_tokens = 0
# Update totals
# Update totals (using potentially capped tokens_total from safety check)
summary.total_calls = (summary.total_calls or 0) + 1
summary.total_tokens = (summary.total_tokens or 0) + tokens_total
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
# Get plan details for unified log (limits already retrieved above)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) if limits else 0
# Get image stats for unified log
current_images_before = getattr(summary, "stability_calls", 0) or 0
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
db_track.commit()
logger.info(f"[llm_text_gen] ✅ Successfully tracked fallback usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens")
# CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions
db_track.flush() # Flush to ensure changes are in DB (not just in transaction)
db_track.commit() # Commit transaction to make changes visible to other sessions
logger.info(f"[llm_text_gen] ✅ Successfully tracked fallback usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens (committed)")
# UNIFIED SUBSCRIPTION LOG for fallback
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")

View File

@@ -0,0 +1,725 @@
"""
Limit Validation Module
Handles subscription limit checking and validation logic.
Extracted from pricing_service.py for better modularity.
"""
from typing import Dict, Any, Optional, List, Tuple, TYPE_CHECKING
from datetime import datetime, timedelta
from sqlalchemy import text
from loguru import logger
from models.subscription_models import (
UserSubscription, UsageSummary, SubscriptionPlan,
APIProvider, SubscriptionTier
)
if TYPE_CHECKING:
from .pricing_service import PricingService
class LimitValidator:
"""Validates subscription limits for API usage."""
def __init__(self, pricing_service: 'PricingService'):
"""
Initialize limit validator with reference to PricingService.
Args:
pricing_service: Instance of PricingService to access helper methods and cache
"""
self.pricing_service = pricing_service
self.db = pricing_service.db
def check_usage_limits(self, user_id: str, provider: APIProvider,
tokens_requested: int = 0, actual_provider_name: Optional[str] = None) -> Tuple[bool, str, Dict[str, Any]]:
"""Check if user can make an API call within their limits.
Args:
user_id: User ID
provider: APIProvider enum (may be MISTRAL for HuggingFace)
tokens_requested: Estimated tokens for the request
actual_provider_name: Optional actual provider name (e.g., "huggingface" when provider is MISTRAL)
Returns:
(can_proceed, error_message, usage_info)
"""
try:
# Use actual_provider_name if provided, otherwise use enum value
# This fixes cases where HuggingFace maps to MISTRAL enum but should show as "huggingface" in errors
display_provider_name = actual_provider_name or provider.value
logger.debug(f"[Subscription Check] Starting limit check for user {user_id}, provider {display_provider_name}, tokens {tokens_requested}")
# Short TTL cache to reduce DB reads under sustained traffic
cache_key = f"{user_id}:{provider.value}"
now = datetime.utcnow()
cached = self.pricing_service._limits_cache.get(cache_key)
if cached and cached.get('expires_at') and cached['expires_at'] > now:
logger.debug(f"[Subscription Check] Using cached result for {user_id}:{provider.value}")
return tuple(cached['result']) # type: ignore
# Get user subscription first to check expiration
subscription = self.db.query(UserSubscription).filter(
UserSubscription.user_id == user_id,
UserSubscription.is_active == True
).first()
if subscription:
logger.debug(f"[Subscription Check] Found subscription for user {user_id}: plan_id={subscription.plan_id}, period_end={subscription.current_period_end}")
else:
logger.debug(f"[Subscription Check] No active subscription found for user {user_id}")
# Check subscription expiration (STRICT: deny if expired)
if subscription:
if subscription.current_period_end < now:
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}: period_end={subscription.current_period_end}, now={now}")
# Subscription expired - check if auto_renew is enabled
if not getattr(subscription, 'auto_renew', False):
# Expired and no auto-renew - deny access
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}, auto_renew=False, denying access")
result = (False, "Subscription expired. Please renew your subscription to continue using the service.", {
'expired': True,
'period_end': subscription.current_period_end.isoformat()
})
self.pricing_service._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
else:
# Try to auto-renew
if not self.pricing_service._ensure_subscription_current(subscription):
# Auto-renew failed - deny access
result = (False, "Subscription expired and auto-renewal failed. Please renew manually.", {
'expired': True,
'auto_renew_failed': True
})
self.pricing_service._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
# Get user limits with error handling (STRICT: fail on errors)
# CRITICAL: Expire SQLAlchemy objects to ensure we get fresh plan data after renewal
try:
# Force expire subscription and plan objects to avoid stale cache
if subscription and subscription.plan_id:
plan_obj = self.db.query(SubscriptionPlan).filter(SubscriptionPlan.id == subscription.plan_id).first()
if plan_obj:
self.db.expire(plan_obj)
logger.debug(f"[Subscription Check] Expired plan object to ensure fresh limits after renewal")
limits = self.pricing_service.get_user_limits(user_id)
if limits:
logger.debug(f"[Subscription Check] Retrieved limits for user {user_id}: plan={limits.get('plan_name')}, tier={limits.get('tier')}")
# Log token limits for debugging
token_limits = limits.get('limits', {})
logger.debug(f"[Subscription Check] Token limits: gemini={token_limits.get('gemini_tokens')}, mistral={token_limits.get('mistral_tokens')}, openai={token_limits.get('openai_tokens')}, anthropic={token_limits.get('anthropic_tokens')}")
else:
logger.debug(f"[Subscription Check] No limits found for user {user_id}, checking free tier")
except Exception as e:
logger.error(f"[Subscription Check] Error getting user limits for {user_id}: {e}", exc_info=True)
# STRICT: Fail closed - deny request if we can't check limits
return False, f"Failed to retrieve subscription limits: {str(e)}", {}
if not limits:
# No subscription found - check for free tier
free_plan = self.db.query(SubscriptionPlan).filter(
SubscriptionPlan.tier == SubscriptionTier.FREE,
SubscriptionPlan.is_active == True
).first()
if free_plan:
logger.info(f"[Subscription Check] Assigning free tier to user {user_id}")
limits = self.pricing_service._plan_to_limits_dict(free_plan)
else:
# No subscription and no free tier - deny access
logger.warning(f"[Subscription Check] No subscription or free tier found for user {user_id}, denying access")
return False, "No subscription plan found. Please subscribe to a plan.", {}
# Get current usage for this billing period with error handling
# CRITICAL: Use fresh queries to avoid SQLAlchemy cache after renewal
try:
current_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Expire all objects to force fresh read from DB (critical after renewal)
self.db.expire_all()
# Use raw SQL query first to bypass ORM cache, fallback to ORM if SQL fails
usage = None
try:
from sqlalchemy import text
sql_query = text("SELECT * FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1")
result = self.db.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
if result:
# Map result to UsageSummary object
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if usage:
self.db.refresh(usage) # Ensure fresh data
except Exception as sql_error:
logger.debug(f"[Subscription Check] Raw SQL query failed, using ORM: {sql_error}")
# Fallback to ORM query
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if usage:
self.db.refresh(usage) # Ensure fresh data
if not usage:
# First usage this period, create summary
try:
usage = UsageSummary(
user_id=user_id,
billing_period=current_period
)
self.db.add(usage)
self.db.commit()
except Exception as create_error:
logger.error(f"Error creating usage summary: {create_error}")
self.db.rollback()
# STRICT: Fail closed on DB error
return False, f"Failed to create usage summary: {str(create_error)}", {}
except Exception as e:
logger.error(f"Error getting usage summary for {user_id}: {e}")
self.db.rollback()
# STRICT: Fail closed on DB error
return False, f"Failed to retrieve usage summary: {str(e)}", {}
# Check call limits with error handling
# NOTE: call_limit = 0 means UNLIMITED (Enterprise plans)
try:
# Use display_provider_name for error messages, but provider.value for DB queries
provider_name = provider.value # For DB field names (e.g., "mistral_calls", "mistral_tokens")
# For LLM text generation providers, check against unified total_calls limit
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
is_llm_provider = provider_name in llm_providers
if is_llm_provider:
# Use unified AI text generation limit (total_calls across all LLM providers)
ai_text_gen_limit = limits['limits'].get('ai_text_generation_calls', 0) or 0
# If unified limit not set, fall back to provider-specific limit for backwards compatibility
if ai_text_gen_limit == 0:
ai_text_gen_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
# Calculate total LLM provider calls (sum of gemini + openai + anthropic + mistral)
current_total_llm_calls = (
(usage.gemini_calls or 0) +
(usage.openai_calls or 0) +
(usage.anthropic_calls or 0) +
(usage.mistral_calls or 0)
)
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if ai_text_gen_limit > 0 and current_total_llm_calls >= ai_text_gen_limit:
logger.error(f"[Subscription Check] AI text generation call limit exceeded for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit} (provider: {display_provider_name})")
result = (False, f"AI text generation call limit reached. Used {current_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls this billing period.", {
'current_calls': current_total_llm_calls,
'limit': ai_text_gen_limit,
'usage_percentage': (current_total_llm_calls / ai_text_gen_limit) * 100 if ai_text_gen_limit > 0 else 0,
'provider': display_provider_name, # Use display name for consistency
'usage_info': {
'provider': display_provider_name, # Use display name for user-facing info
'current_calls': current_total_llm_calls,
'limit': ai_text_gen_limit,
'type': 'ai_text_generation',
'breakdown': {
'gemini': usage.gemini_calls or 0,
'openai': usage.openai_calls or 0,
'anthropic': usage.anthropic_calls or 0,
'mistral': usage.mistral_calls or 0 # DB field name (not display name)
}
}
})
self.pricing_service._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
else:
logger.debug(f"[Subscription Check] AI text generation limit check passed for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit if ai_text_gen_limit > 0 else 'unlimited'} (provider: {display_provider_name})")
else:
# For non-LLM providers, check provider-specific limit
current_calls = getattr(usage, f"{provider_name}_calls", 0) or 0
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if call_limit > 0 and current_calls >= call_limit:
logger.error(f"[Subscription Check] Call limit exceeded for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit}")
result = (False, f"API call limit reached for {display_provider_name}. Used {current_calls} of {call_limit} calls this billing period.", {
'current_calls': current_calls,
'limit': call_limit,
'usage_percentage': 100.0,
'provider': display_provider_name # Use display name for consistency
})
self.pricing_service._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
else:
logger.debug(f"[Subscription Check] Call limit check passed for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit if call_limit > 0 else 'unlimited'}")
except Exception as e:
logger.error(f"Error checking call limits: {e}")
# Continue to next check
# Check token limits for LLM providers with error handling
# NOTE: token_limit = 0 means UNLIMITED (Enterprise plans)
try:
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
current_tokens = getattr(usage, f"{provider_name}_tokens", 0) or 0
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if token_limit > 0 and (current_tokens + tokens_requested) > token_limit:
result = (False, f"Token limit would be exceeded for {display_provider_name}. Current: {current_tokens}, Requested: {tokens_requested}, Limit: {token_limit}", {
'current_tokens': current_tokens,
'requested_tokens': tokens_requested,
'limit': token_limit,
'usage_percentage': ((current_tokens + tokens_requested) / token_limit) * 100,
'provider': display_provider_name, # Use display name in error details
'usage_info': {
'provider': display_provider_name,
'current_tokens': current_tokens,
'requested_tokens': tokens_requested,
'limit': token_limit,
'type': 'tokens'
}
})
self.pricing_service._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
except Exception as e:
logger.error(f"Error checking token limits: {e}")
# Continue to next check
# Check cost limits with error handling
# NOTE: cost_limit = 0 means UNLIMITED (Enterprise plans)
try:
cost_limit = limits['limits'].get('monthly_cost', 0) or 0
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if cost_limit > 0 and usage.total_cost >= cost_limit:
result = (False, f"Monthly cost limit reached. Current cost: ${usage.total_cost:.2f}, Limit: ${cost_limit:.2f}", {
'current_cost': usage.total_cost,
'limit': cost_limit,
'usage_percentage': 100.0
})
self.pricing_service._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
except Exception as e:
logger.error(f"Error checking cost limits: {e}")
# Continue to success case
# Calculate usage percentages for warnings
try:
# Determine which call variables to use based on provider type
if is_llm_provider:
# Use unified LLM call tracking
current_call_count = current_total_llm_calls
call_limit_value = ai_text_gen_limit
else:
# Use provider-specific call tracking
current_call_count = current_calls
call_limit_value = call_limit
call_usage_pct = (current_call_count / max(call_limit_value, 1)) * 100 if call_limit_value > 0 else 0
cost_usage_pct = (usage.total_cost / max(cost_limit, 1)) * 100 if cost_limit > 0 else 0
result = (True, "Within limits", {
'current_calls': current_call_count,
'call_limit': call_limit_value,
'call_usage_percentage': call_usage_pct,
'current_cost': usage.total_cost,
'cost_limit': cost_limit,
'cost_usage_percentage': cost_usage_pct
})
self.pricing_service._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
except Exception as e:
logger.error(f"Error calculating usage percentages: {e}")
# Return basic success
return True, "Within limits", {}
except Exception as e:
logger.error(f"Unexpected error in check_usage_limits for {user_id}: {e}")
# STRICT: Fail closed - deny requests if subscription system fails
return False, f"Subscription check error: {str(e)}", {}
def check_comprehensive_limits(
self,
user_id: str,
operations: List[Dict[str, Any]]
) -> Tuple[bool, Optional[str], Optional[Dict[str, Any]]]:
"""
Comprehensive pre-flight validation that checks ALL limits before making ANY API calls.
This prevents wasteful API calls by validating that ALL subsequent operations will succeed
before making the first external API call.
Args:
user_id: User ID
operations: List of operations to validate, each with:
- 'provider': APIProvider enum
- 'tokens_requested': int (estimated tokens for LLM calls, 0 for non-LLM)
- 'actual_provider_name': Optional[str] (e.g., "huggingface" when provider is MISTRAL)
- 'operation_type': str (e.g., "google_grounding", "llm_call", "image_generation")
Returns:
(can_proceed, error_message, error_details)
If can_proceed is False, error_message explains which limit would be exceeded
"""
try:
logger.info(f"[Pre-flight Check] 🔍 Starting comprehensive validation for user {user_id}")
logger.info(f"[Pre-flight Check] 📋 Validating {len(operations)} operation(s) before making any API calls")
# Get current usage and limits once
current_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
logger.info(f"[Pre-flight Check] 📅 Billing Period: {current_period} (for user {user_id})")
# Explicitly expire any cached objects and refresh from DB to ensure fresh data
self.db.expire_all()
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
# CRITICAL: Explicitly refresh from database to get latest values (clears SQLAlchemy cache)
if usage:
self.db.refresh(usage)
# Log what we actually read from database
if usage:
logger.info(f"[Pre-flight Check] 📊 Usage Summary from DB (Period: {current_period}):")
logger.info(f" ├─ Gemini: {usage.gemini_tokens or 0} tokens / {usage.gemini_calls or 0} calls")
logger.info(f" ├─ Mistral/HF: {usage.mistral_tokens or 0} tokens / {usage.mistral_calls or 0} calls")
logger.info(f" ├─ Total Tokens: {usage.total_tokens or 0}")
logger.info(f" └─ Usage Status: {usage.usage_status.value if usage.usage_status else 'N/A'}")
else:
logger.info(f"[Pre-flight Check] 📊 No usage summary found for period {current_period} (will create new)")
if not usage:
# First usage this period, create summary
try:
usage = UsageSummary(
user_id=user_id,
billing_period=current_period
)
self.db.add(usage)
self.db.commit()
except Exception as create_error:
logger.error(f"Error creating usage summary: {create_error}")
self.db.rollback()
return False, f"Failed to create usage summary: {str(create_error)}", {}
# Get user limits
limits_dict = self.pricing_service.get_user_limits(user_id)
if not limits_dict:
# No subscription found - check for free tier
free_plan = self.db.query(SubscriptionPlan).filter(
SubscriptionPlan.tier == SubscriptionTier.FREE,
SubscriptionPlan.is_active == True
).first()
if free_plan:
limits_dict = self.pricing_service._plan_to_limits_dict(free_plan)
else:
return False, "No subscription plan found. Please subscribe to a plan.", {}
limits = limits_dict.get('limits', {})
# Track cumulative usage across all operations
total_llm_calls = (
(usage.gemini_calls or 0) +
(usage.openai_calls or 0) +
(usage.anthropic_calls or 0) +
(usage.mistral_calls or 0)
)
total_llm_tokens = {}
total_images = usage.stability_calls or 0
# Log current usage summary
logger.info(f"[Pre-flight Check] 📊 Current Usage Summary:")
logger.info(f" └─ Total LLM Calls: {total_llm_calls}")
logger.info(f" └─ Gemini Tokens: {usage.gemini_tokens or 0}, Mistral/HF Tokens: {usage.mistral_tokens or 0}")
logger.info(f" └─ Image Calls: {total_images}")
# Validate each operation
for op_idx, operation in enumerate(operations):
provider = operation.get('provider')
provider_name = provider.value if hasattr(provider, 'value') else str(provider)
tokens_requested = operation.get('tokens_requested', 0)
actual_provider_name = operation.get('actual_provider_name')
operation_type = operation.get('operation_type', 'unknown')
display_provider_name = actual_provider_name or provider_name
logger.error(f"[Pre-flight Check] ✅ Operation {op_idx + 1}/{len(operations)}: {operation_type}")
logger.error(f" ├─ Provider: {display_provider_name} (enum: {provider_name})")
logger.error(f" ├─ Operation Index: {op_idx}")
logger.error(f" └─ Estimated Tokens Requested: {tokens_requested}")
# Check if this is an LLM provider
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
is_llm_provider = provider_name in llm_providers
# Check unified AI text generation limit for LLM providers
if is_llm_provider:
ai_text_gen_limit = limits.get('ai_text_generation_calls', 0) or 0
if ai_text_gen_limit == 0:
# Fallback to provider-specific limit
ai_text_gen_limit = limits.get(f"{provider_name}_calls", 0) or 0
# Count this operation as an LLM call
projected_total_llm_calls = total_llm_calls + 1
if ai_text_gen_limit > 0 and projected_total_llm_calls > ai_text_gen_limit:
error_info = {
'current_calls': total_llm_calls,
'limit': ai_text_gen_limit,
'provider': display_provider_name,
'operation_type': operation_type,
'operation_index': op_idx
}
return False, f"AI text generation call limit would be exceeded. Would use {projected_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls.", {
'error_type': 'call_limit',
'usage_info': error_info
}
# Check token limits for this provider
# CRITICAL: Always query fresh from DB for each operation to avoid SQLAlchemy cache issues
# This ensures we get the latest values after subscription renewal, even for cumulative tracking
provider_tokens_key = f"{provider_name}_tokens"
# Try to get fresh value from DB with comprehensive error handling
base_current_tokens = 0
query_succeeded = False
try:
# Validate column name is safe (only allow known provider token columns)
valid_token_columns = ['gemini_tokens', 'openai_tokens', 'anthropic_tokens', 'mistral_tokens']
if provider_tokens_key not in valid_token_columns:
logger.error(f" └─ Invalid provider tokens key: {provider_tokens_key}")
query_succeeded = True # Treat as success with 0 value
else:
# Method 1: Try raw SQL query to completely bypass ORM cache
try:
logger.debug(f" └─ Attempting raw SQL query for {provider_tokens_key}")
sql_query = text(f"""
SELECT {provider_tokens_key}
FROM usage_summaries
WHERE user_id = :user_id
AND billing_period = :period
LIMIT 1
""")
logger.debug(f" └─ SQL: SELECT {provider_tokens_key} FROM usage_summaries WHERE user_id={user_id} AND billing_period={current_period}")
result = self.db.execute(sql_query, {
'user_id': user_id,
'period': current_period
}).first()
if result:
base_current_tokens = result[0] if result[0] is not None else 0
logger.error(f"[Pre-flight Check] ✅ Raw SQL query returned result: {result[0]} -> {base_current_tokens}")
else:
base_current_tokens = 0
logger.error(f"[Pre-flight Check] ⚠️ Raw SQL query returned None (no rows found)")
query_succeeded = True
logger.error(f"[Pre-flight Check] ✅ Raw SQL query succeeded for {provider_tokens_key}: {base_current_tokens}")
except Exception as sql_error:
logger.error(f" └─ Raw SQL query failed for {provider_tokens_key}: {type(sql_error).__name__}: {sql_error}", exc_info=True)
query_succeeded = False # Will try ORM fallback
# Method 2: Fallback to fresh ORM query if raw SQL fails
if not query_succeeded:
try:
# Expire all cached objects and do fresh query
self.db.expire_all()
fresh_usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if fresh_usage:
# Explicitly refresh to get latest from DB
self.db.refresh(fresh_usage)
base_current_tokens = getattr(fresh_usage, provider_tokens_key, 0) or 0
else:
base_current_tokens = 0
query_succeeded = True
logger.info(f"[Pre-flight Check] ✅ ORM fallback query succeeded for {provider_tokens_key}: {base_current_tokens}")
except Exception as orm_error:
logger.error(f" └─ ORM query also failed: {orm_error}", exc_info=True)
query_succeeded = False
except Exception as e:
logger.error(f" └─ Unexpected error getting tokens from DB for {provider_tokens_key}: {e}", exc_info=True)
base_current_tokens = 0 # Fail safe - assume 0 if we can't query
if not query_succeeded:
logger.warning(f" └─ Both query methods failed, using 0 as fallback")
# CRITICAL LOG: Always log what we got from DB - this helps debug renewal issues
# Use ERROR level to ensure it shows even if INFO is filtered
logger.error(f"[Pre-flight Check] 🔍 Fresh DB Query for {display_provider_name}:")
logger.error(f" ├─ Column: {provider_tokens_key}")
logger.error(f" ├─ Billing Period: {current_period}")
logger.error(f" ├─ User ID: {user_id}")
logger.error(f" ├─ Method: {'Raw SQL' if query_succeeded and base_current_tokens >= 0 else 'ORM' if query_succeeded else 'Failed - using 0'}")
logger.error(f" └─ Value from DB: {base_current_tokens}")
# Add any projected tokens from previous operations in this validation run
# Note: total_llm_tokens tracks ONLY projected tokens from this run, not base DB value
projected_from_previous = total_llm_tokens.get(provider_tokens_key, 0)
# Current tokens = base from DB + projected from previous operations in this run
current_provider_tokens = base_current_tokens + projected_from_previous
# Use ERROR level to ensure visibility
logger.error(f"[Pre-flight Check] 📊 Token Calculation for {display_provider_name}:")
logger.error(f" ├─ Base from DB (fresh query): {base_current_tokens}")
logger.error(f" ├─ Projected from previous ops in this run: {projected_from_previous}")
logger.error(f" └─ Total current tokens (base + projected): {current_provider_tokens}")
# Also check the initial usage object to see if it's being used incorrectly
if usage and hasattr(usage, provider_tokens_key):
initial_usage_value = getattr(usage, provider_tokens_key, 0) or 0
logger.error(f" ⚠️ Initial usage object value: {initial_usage_value} (this should NOT be used for fresh query)")
token_limit = limits.get(provider_tokens_key, 0) or 0
if token_limit > 0 and tokens_requested > 0:
projected_tokens = current_provider_tokens + tokens_requested
logger.info(f" └─ Token Check: {current_provider_tokens} (current) + {tokens_requested} (requested) = {projected_tokens} (total) / {token_limit} (limit)")
if projected_tokens > token_limit:
usage_percentage = (projected_tokens / token_limit) * 100 if token_limit > 0 else 0
error_info = {
'current_tokens': current_provider_tokens,
'base_tokens_from_db': base_current_tokens,
'projected_from_previous_ops': projected_from_previous,
'requested_tokens': tokens_requested,
'limit': token_limit,
'provider': display_provider_name,
'operation_type': operation_type,
'operation_index': op_idx
}
# Make error message clearer: show actual DB usage vs projected
if projected_from_previous > 0:
error_msg = (
f"Token limit exceeded for {display_provider_name} "
f"({operation_type}). "
f"Base usage: {base_current_tokens}/{token_limit}, "
f"After previous operations in this workflow: {current_provider_tokens}/{token_limit}, "
f"This operation would add: {tokens_requested}, "
f"Total would be: {projected_tokens} (exceeds by {projected_tokens - token_limit} tokens)"
)
else:
error_msg = (
f"Token limit exceeded for {display_provider_name} "
f"({operation_type}). "
f"Current: {current_provider_tokens}/{token_limit}, "
f"Requested: {tokens_requested}, "
f"Would exceed by: {projected_tokens - token_limit} tokens "
f"({usage_percentage:.1f}% of limit)"
)
logger.error(f"[Pre-flight Check] ❌ BLOCKED: {error_msg}")
return False, error_msg, {
'error_type': 'token_limit',
'usage_info': error_info
}
else:
logger.info(f" └─ ✅ Token limit check passed: {projected_tokens} <= {token_limit}")
# Update cumulative counts for next operation
total_llm_calls = projected_total_llm_calls
# Update cumulative projected tokens from this validation run
# This represents only projected tokens from previous operations in this run
# Base DB value is always queried fresh, so we only track the projection delta
old_projected = total_llm_tokens.get(provider_tokens_key, 0)
if tokens_requested > 0:
# Add this operation's tokens to cumulative projected tokens
total_llm_tokens[provider_tokens_key] = projected_from_previous + tokens_requested
logger.error(f"[Pre-flight Check] 📝 Updated cumulative projected tokens for {display_provider_name}:")
logger.error(f" ├─ Previous projected: {projected_from_previous}")
logger.error(f" ├─ This operation requested: {tokens_requested}")
logger.error(f" ├─ New cumulative projected: {total_llm_tokens[provider_tokens_key]}")
logger.error(f" └─ Old value in dict was: {old_projected}")
else:
# No tokens requested, keep existing projected tokens (or 0 if first operation)
total_llm_tokens[provider_tokens_key] = projected_from_previous
logger.error(f"[Pre-flight Check] 📝 No tokens requested, keeping projected at: {projected_from_previous}")
# Check image generation limits
elif provider == APIProvider.STABILITY:
image_limit = limits.get('stability_calls', 0) or 0
projected_images = total_images + 1
if image_limit > 0 and projected_images > image_limit:
error_info = {
'current_images': total_images,
'limit': image_limit,
'provider': 'stability',
'operation_type': operation_type,
'operation_index': op_idx
}
return False, f"Image generation limit would be exceeded. Would use {projected_images} of {image_limit} images this billing period.", {
'error_type': 'image_limit',
'usage_info': error_info
}
total_images = projected_images
# Check other provider-specific limits
else:
provider_calls_key = f"{provider_name}_calls"
current_provider_calls = getattr(usage, provider_calls_key, 0) or 0
call_limit = limits.get(provider_calls_key, 0) or 0
if call_limit > 0:
projected_calls = current_provider_calls + 1
if projected_calls > call_limit:
error_info = {
'current_calls': current_provider_calls,
'limit': call_limit,
'provider': display_provider_name,
'operation_type': operation_type,
'operation_index': op_idx
}
return False, f"API call limit would be exceeded for {display_provider_name}. Would use {projected_calls} of {call_limit} calls this billing period.", {
'error_type': 'call_limit',
'usage_info': error_info
}
# All checks passed
logger.info(f"[Pre-flight Check] ✅ All {len(operations)} operation(s) validated successfully")
logger.info(f"[Pre-flight Check] ✅ User {user_id} is cleared to proceed with API calls")
return True, None, None
except Exception as e:
error_type = type(e).__name__
error_message = str(e)
logger.error(f"[Pre-flight Check] ❌ Error during comprehensive limit check: {error_type}: {error_message}", exc_info=True)
logger.error(f"[Pre-flight Check] ❌ User: {user_id}, Operations count: {len(operations) if operations else 0}")
return False, f"Failed to validate limits: {error_type}: {error_message}", {}

View File

@@ -44,15 +44,17 @@ def validate_research_operations(
llm_provider_name = "gemini"
# Estimate tokens for each operation in research workflow
# Google Grounding call: ~2000 tokens (input + output)
# Google Grounding call: ~1200 tokens (input: ~500 tokens, output: ~700 tokens for research results)
# Keyword analyzer: ~1000 tokens (input: 3000 chars research, output: structured JSON)
# Competitor analyzer: ~1000 tokens (input: 3000 chars research, output: structured JSON)
# Content angle generator: ~1000 tokens (input: 3000 chars research, output: list of angles)
# Note: These are conservative estimates. Actual usage may be lower, but we use these for pre-flight validation
# to prevent wasteful API calls if the workflow would exceed limits.
operations_to_validate = [
{
'provider': APIProvider.GEMINI, # Google Grounding uses Gemini
'tokens_requested': 2000,
'tokens_requested': 1200, # Reduced from 2000 to more realistic estimate
'actual_provider_name': 'gemini',
'operation_type': 'google_grounding'
},
@@ -126,6 +128,120 @@ def validate_research_operations(
)
def validate_exa_research_operations(
pricing_service: PricingService,
user_id: str,
gpt_provider: str = "google"
) -> None:
"""
Validate all operations for an Exa research workflow before making ANY API calls.
This prevents wasteful external API calls (like Exa search) if subsequent
LLM calls would fail due to token or call limits.
Args:
pricing_service: PricingService instance
user_id: User ID for subscription checking
gpt_provider: GPT provider from env var (defaults to "google")
Returns:
None
If validation fails, raises HTTPException with 429 status
"""
try:
# Determine actual provider for LLM calls based on GPT_PROVIDER env var
gpt_provider_lower = gpt_provider.lower()
if gpt_provider_lower == "huggingface":
llm_provider_enum = APIProvider.MISTRAL # Maps to HuggingFace
llm_provider_name = "huggingface"
else:
llm_provider_enum = APIProvider.GEMINI
llm_provider_name = "gemini"
# Estimate tokens for each operation in Exa research workflow
# Exa Search call: 1 Exa API call (not token-based)
# Keyword analyzer: ~1000 tokens (input: research results, output: structured JSON)
# Competitor analyzer: ~1000 tokens (input: research results, output: structured JSON)
# Content angle generator: ~1000 tokens (input: research results, output: list of angles)
# Note: These are conservative estimates for pre-flight validation
operations_to_validate = [
{
'provider': APIProvider.EXA, # Exa API call
'tokens_requested': 0,
'actual_provider_name': 'exa',
'operation_type': 'exa_neural_search'
},
{
'provider': llm_provider_enum,
'tokens_requested': 1000,
'actual_provider_name': llm_provider_name,
'operation_type': 'keyword_analysis'
},
{
'provider': llm_provider_enum,
'tokens_requested': 1000,
'actual_provider_name': llm_provider_name,
'operation_type': 'competitor_analysis'
},
{
'provider': llm_provider_enum,
'tokens_requested': 1000,
'actual_provider_name': llm_provider_name,
'operation_type': 'content_angle_generation'
}
]
logger.info(f"[Pre-flight Validator] 🚀 Starting Exa Research Workflow Validation")
logger.info(f" ├─ User: {user_id}")
logger.info(f" ├─ LLM Provider: {llm_provider_name} (GPT_PROVIDER={gpt_provider})")
logger.info(f" └─ Operations to validate: {len(operations_to_validate)}")
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate
)
if not can_proceed:
usage_info = error_details.get('usage_info', {}) if error_details else {}
provider = usage_info.get('provider', llm_provider_name) if usage_info else llm_provider_name
operation_type = usage_info.get('operation_type', 'unknown')
logger.error(f"[Pre-flight Validator] ❌ EXA RESEARCH WORKFLOW BLOCKED")
logger.error(f" ├─ User: {user_id}")
logger.error(f" ├─ Blocked at: {operation_type}")
logger.error(f" ├─ Provider: {provider}")
logger.error(f" └─ Reason: {message}")
# Raise HTTPException immediately - frontend gets immediate response, no API calls made
raise HTTPException(
status_code=429,
detail={
'error': message,
'message': message,
'provider': provider,
'usage_info': usage_info if usage_info else error_details
}
)
logger.info(f"[Pre-flight Validator] ✅ EXA RESEARCH WORKFLOW APPROVED")
logger.info(f" ├─ User: {user_id}")
logger.info(f" └─ All {len(operations_to_validate)} operations validated - proceeding with API calls")
# Validation passed - no return needed (function raises HTTPException if validation fails)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Pre-flight Validator] Error validating Exa research operations: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
'error': f"Failed to validate operations: {str(e)}",
'message': f"Failed to validate operations: {str(e)}"
}
)
def validate_image_generation_operations(
pricing_service: PricingService,
user_id: str

View File

@@ -258,6 +258,12 @@ class PricingService:
"model_name": "stable-diffusion",
"cost_per_image": 0.04, # $0.04 per image
"description": "Stability AI Image Generation"
},
{
"provider": APIProvider.EXA,
"model_name": "exa-search",
"cost_per_request": 0.005, # $0.005 per search (1-25 results)
"description": "Exa Neural Search API"
}
]
@@ -296,6 +302,7 @@ class PricingService:
"metaphor_calls_limit": 10,
"firecrawl_calls_limit": 10,
"stability_calls_limit": 5,
"exa_calls_limit": 100,
"gemini_tokens_limit": 100000,
"monthly_cost_limit": 0.0,
"features": ["basic_content_generation", "limited_research"],
@@ -316,10 +323,11 @@ class PricingService:
"metaphor_calls_limit": 100,
"firecrawl_calls_limit": 100,
"stability_calls_limit": 5,
"gemini_tokens_limit": 2000,
"openai_tokens_limit": 2000,
"anthropic_tokens_limit": 2000,
"mistral_tokens_limit": 2000,
"exa_calls_limit": 500,
"gemini_tokens_limit": 20000, # Increased from 5000 for better stability
"openai_tokens_limit": 20000, # Increased from 5000 for better stability
"anthropic_tokens_limit": 20000, # Increased from 5000 for better stability
"mistral_tokens_limit": 20000, # Increased from 5000 for better stability
"monthly_cost_limit": 50.0,
"features": ["full_content_generation", "advanced_research", "basic_analytics"],
"description": "Great for individuals and small teams"
@@ -338,6 +346,7 @@ class PricingService:
"metaphor_calls_limit": 500,
"firecrawl_calls_limit": 500,
"stability_calls_limit": 200,
"exa_calls_limit": 2000,
"gemini_tokens_limit": 5000000,
"openai_tokens_limit": 2500000,
"anthropic_tokens_limit": 1000000,
@@ -360,6 +369,7 @@ class PricingService:
"metaphor_calls_limit": 0,
"firecrawl_calls_limit": 0,
"stability_calls_limit": 0,
"exa_calls_limit": 0, # Unlimited
"gemini_tokens_limit": 0,
"openai_tokens_limit": 0,
"anthropic_tokens_limit": 0,
@@ -423,11 +433,14 @@ class PricingService:
def get_user_limits(self, user_id: str) -> Optional[Dict[str, Any]]:
"""Get usage limits for a user based on their subscription."""
# CRITICAL: Expire all objects first to ensure fresh data after renewal
self.db.expire_all()
subscription = self.db.query(UserSubscription).filter(
UserSubscription.user_id == user_id,
UserSubscription.is_active == True
).first()
if not subscription:
# Return free tier limits
free_plan = self.db.query(SubscriptionPlan).filter(
@@ -439,7 +452,23 @@ class PricingService:
# Ensure current period before returning limits
self._ensure_subscription_current(subscription)
return self._plan_to_limits_dict(subscription.plan)
# CRITICAL: Refresh subscription to get latest plan_id, then refresh plan relationship
self.db.refresh(subscription)
# Re-query plan directly to ensure fresh data (bypass relationship cache)
plan = self.db.query(SubscriptionPlan).filter(
SubscriptionPlan.id == subscription.plan_id
).first()
if not plan:
logger.error(f"Plan not found for subscription plan_id={subscription.plan_id}")
return None
# Refresh plan to ensure fresh limits
self.db.refresh(plan)
return self._plan_to_limits_dict(plan)
def _ensure_ai_text_gen_column_detection(self) -> None:
"""Detect at runtime whether ai_text_generation_calls_limit column exists and cache the result."""
@@ -508,290 +537,20 @@ class PricingService:
tokens_requested: int = 0, actual_provider_name: Optional[str] = None) -> Tuple[bool, str, Dict[str, Any]]:
"""Check if user can make an API call within their limits.
Delegates to LimitValidator for actual validation logic.
Args:
user_id: User ID
provider: APIProvider enum (may be MISTRAL for HuggingFace)
tokens_requested: Estimated tokens for the request
actual_provider_name: Optional actual provider name (e.g., "huggingface" when provider is MISTRAL)
"""
try:
# Use actual_provider_name if provided, otherwise use enum value
# This fixes cases where HuggingFace maps to MISTRAL enum but should show as "huggingface" in errors
display_provider_name = actual_provider_name or provider.value
logger.debug(f"[Subscription Check] Starting limit check for user {user_id}, provider {display_provider_name}, tokens {tokens_requested}")
# Short TTL cache to reduce DB reads under sustained traffic
cache_key = f"{user_id}:{provider.value}"
now = datetime.utcnow()
cached = self._limits_cache.get(cache_key)
if cached and cached.get('expires_at') and cached['expires_at'] > now:
logger.debug(f"[Subscription Check] Using cached result for {user_id}:{provider.value}")
return tuple(cached['result']) # type: ignore
# Get user subscription first to check expiration
subscription = self.db.query(UserSubscription).filter(
UserSubscription.user_id == user_id,
UserSubscription.is_active == True
).first()
if subscription:
logger.debug(f"[Subscription Check] Found subscription for user {user_id}: plan_id={subscription.plan_id}, period_end={subscription.current_period_end}")
else:
logger.debug(f"[Subscription Check] No active subscription found for user {user_id}")
# Check subscription expiration (STRICT: deny if expired)
if subscription:
if subscription.current_period_end < now:
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}: period_end={subscription.current_period_end}, now={now}")
# Subscription expired - check if auto_renew is enabled
if not getattr(subscription, 'auto_renew', False):
# Expired and no auto-renew - deny access
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}, auto_renew=False, denying access")
result = (False, "Subscription expired. Please renew your subscription to continue using the service.", {
'expired': True,
'period_end': subscription.current_period_end.isoformat()
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
else:
# Try to auto-renew
if not self._ensure_subscription_current(subscription):
# Auto-renew failed - deny access
result = (False, "Subscription expired and auto-renewal failed. Please renew manually.", {
'expired': True,
'auto_renew_failed': True
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
# Get user limits with error handling (STRICT: fail on errors)
try:
limits = self.get_user_limits(user_id)
if limits:
logger.debug(f"[Subscription Check] Retrieved limits for user {user_id}: plan={limits.get('plan_name')}, tier={limits.get('tier')}")
else:
logger.debug(f"[Subscription Check] No limits found for user {user_id}, checking free tier")
except Exception as e:
logger.error(f"[Subscription Check] Error getting user limits for {user_id}: {e}", exc_info=True)
# STRICT: Fail closed - deny request if we can't check limits
return False, f"Failed to retrieve subscription limits: {str(e)}", {}
if not limits:
# No subscription found - check for free tier
free_plan = self.db.query(SubscriptionPlan).filter(
SubscriptionPlan.tier == SubscriptionTier.FREE,
SubscriptionPlan.is_active == True
).first()
if free_plan:
logger.info(f"[Subscription Check] Assigning free tier to user {user_id}")
limits = self._plan_to_limits_dict(free_plan)
else:
# No subscription and no free tier - deny access
logger.warning(f"[Subscription Check] No subscription or free tier found for user {user_id}, denying access")
return False, "No subscription plan found. Please subscribe to a plan.", {}
# Get current usage for this billing period with error handling
try:
current_period = self.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not usage:
# First usage this period, create summary
try:
usage = UsageSummary(
user_id=user_id,
billing_period=current_period
)
self.db.add(usage)
self.db.commit()
except Exception as create_error:
logger.error(f"Error creating usage summary: {create_error}")
self.db.rollback()
# STRICT: Fail closed on DB error
return False, f"Failed to create usage summary: {str(create_error)}", {}
except Exception as e:
logger.error(f"Error getting usage summary for {user_id}: {e}")
self.db.rollback()
# STRICT: Fail closed on DB error
return False, f"Failed to retrieve usage summary: {str(e)}", {}
# Check call limits with error handling
# NOTE: call_limit = 0 means UNLIMITED (Enterprise plans)
try:
# Use display_provider_name for error messages, but provider.value for DB queries
provider_name = provider.value # For DB field names (e.g., "mistral_calls", "mistral_tokens")
# For LLM text generation providers, check against unified total_calls limit
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
is_llm_provider = provider_name in llm_providers
if is_llm_provider:
# Use unified AI text generation limit (total_calls across all LLM providers)
ai_text_gen_limit = limits['limits'].get('ai_text_generation_calls', 0) or 0
# If unified limit not set, fall back to provider-specific limit for backwards compatibility
if ai_text_gen_limit == 0:
ai_text_gen_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
# Calculate total LLM provider calls (sum of gemini + openai + anthropic + mistral)
current_total_llm_calls = (
(usage.gemini_calls or 0) +
(usage.openai_calls or 0) +
(usage.anthropic_calls or 0) +
(usage.mistral_calls or 0)
)
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if ai_text_gen_limit > 0 and current_total_llm_calls >= ai_text_gen_limit:
logger.error(f"[Subscription Check] AI text generation call limit exceeded for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit} (provider: {display_provider_name})")
result = (False, f"AI text generation call limit reached. Used {current_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls this billing period.", {
'current_calls': current_total_llm_calls,
'limit': ai_text_gen_limit,
'usage_percentage': (current_total_llm_calls / ai_text_gen_limit) * 100 if ai_text_gen_limit > 0 else 0,
'provider': display_provider_name, # Use display name for consistency
'usage_info': {
'provider': display_provider_name, # Use display name for user-facing info
'current_calls': current_total_llm_calls,
'limit': ai_text_gen_limit,
'type': 'ai_text_generation',
'breakdown': {
'gemini': usage.gemini_calls or 0,
'openai': usage.openai_calls or 0,
'anthropic': usage.anthropic_calls or 0,
'mistral': usage.mistral_calls or 0 # DB field name (not display name)
}
}
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
else:
logger.debug(f"[Subscription Check] AI text generation limit check passed for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit if ai_text_gen_limit > 0 else 'unlimited'} (provider: {display_provider_name})")
else:
# For non-LLM providers, check provider-specific limit
current_calls = getattr(usage, f"{provider_name}_calls", 0) or 0
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if call_limit > 0 and current_calls >= call_limit:
logger.error(f"[Subscription Check] Call limit exceeded for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit}")
result = (False, f"API call limit reached for {display_provider_name}. Used {current_calls} of {call_limit} calls this billing period.", {
'current_calls': current_calls,
'limit': call_limit,
'usage_percentage': 100.0,
'provider': display_provider_name # Use display name for consistency
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
else:
logger.debug(f"[Subscription Check] Call limit check passed for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit if call_limit > 0 else 'unlimited'}")
except Exception as e:
logger.error(f"Error checking call limits: {e}")
# Continue to next check
# Check token limits for LLM providers with error handling
# NOTE: token_limit = 0 means UNLIMITED (Enterprise plans)
try:
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
current_tokens = getattr(usage, f"{provider_name}_tokens", 0) or 0
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if token_limit > 0 and (current_tokens + tokens_requested) > token_limit:
result = (False, f"Token limit would be exceeded for {display_provider_name}. Current: {current_tokens}, Requested: {tokens_requested}, Limit: {token_limit}", {
'current_tokens': current_tokens,
'requested_tokens': tokens_requested,
'limit': token_limit,
'usage_percentage': ((current_tokens + tokens_requested) / token_limit) * 100,
'provider': display_provider_name, # Use display name in error details
'usage_info': {
'provider': display_provider_name,
'current_tokens': current_tokens,
'requested_tokens': tokens_requested,
'limit': token_limit,
'type': 'tokens'
}
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
except Exception as e:
logger.error(f"Error checking token limits: {e}")
# Continue to next check
# Check cost limits with error handling
# NOTE: cost_limit = 0 means UNLIMITED (Enterprise plans)
try:
cost_limit = limits['limits'].get('monthly_cost', 0) or 0
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if cost_limit > 0 and usage.total_cost >= cost_limit:
result = (False, f"Monthly cost limit reached. Current cost: ${usage.total_cost:.2f}, Limit: ${cost_limit:.2f}", {
'current_cost': usage.total_cost,
'limit': cost_limit,
'usage_percentage': 100.0
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
except Exception as e:
logger.error(f"Error checking cost limits: {e}")
# Continue to success case
# Calculate usage percentages for warnings
try:
# Determine which call variables to use based on provider type
if is_llm_provider:
# Use unified LLM call tracking
current_call_count = current_total_llm_calls
call_limit_value = ai_text_gen_limit
else:
# Use provider-specific call tracking
current_call_count = current_calls
call_limit_value = call_limit
call_usage_pct = (current_call_count / max(call_limit_value, 1)) * 100 if call_limit_value > 0 else 0
cost_usage_pct = (usage.total_cost / max(cost_limit, 1)) * 100 if cost_limit > 0 else 0
result = (True, "Within limits", {
'current_calls': current_call_count,
'call_limit': call_limit_value,
'call_usage_percentage': call_usage_pct,
'current_cost': usage.total_cost,
'cost_limit': cost_limit,
'cost_usage_percentage': cost_usage_pct
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
except Exception as e:
logger.error(f"Error calculating usage percentages: {e}")
# Return basic success
return True, "Within limits", {}
except Exception as e:
logger.error(f"Unexpected error in check_usage_limits for {user_id}: {e}")
# STRICT: Fail closed - deny requests if subscription system fails
return False, f"Subscription check error: {str(e)}", {}
Returns:
(can_proceed, error_message, usage_info)
"""
from .limit_validation import LimitValidator
validator = LimitValidator(self)
return validator.check_usage_limits(user_id, provider, tokens_requested, actual_provider_name)
def estimate_tokens(self, text: str, provider: APIProvider) -> int:
"""Estimate token count for text based on provider."""
@@ -827,6 +586,16 @@ class PricingService:
if not pricing:
return None
# Return pricing info as dict
return {
'provider': pricing.provider.value,
'model_name': pricing.model_name,
'cost_per_input_token': pricing.cost_per_input_token,
'cost_per_output_token': pricing.cost_per_output_token,
'cost_per_request': pricing.cost_per_request,
'description': pricing.description
}
def check_comprehensive_limits(
self,
user_id: str,
@@ -835,6 +604,7 @@ class PricingService:
"""
Comprehensive pre-flight validation that checks ALL limits before making ANY API calls.
Delegates to LimitValidator for actual validation logic.
This prevents wasteful API calls by validating that ALL subsequent operations will succeed
before making the first external API call.
@@ -850,202 +620,9 @@ class PricingService:
(can_proceed, error_message, error_details)
If can_proceed is False, error_message explains which limit would be exceeded
"""
try:
logger.info(f"[Pre-flight Check] 🔍 Starting comprehensive validation for user {user_id}")
logger.info(f"[Pre-flight Check] 📋 Validating {len(operations)} operation(s) before making any API calls")
# Get current usage and limits once
current_period = self.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not usage:
# First usage this period, create summary
try:
usage = UsageSummary(
user_id=user_id,
billing_period=current_period
)
self.db.add(usage)
self.db.commit()
except Exception as create_error:
logger.error(f"Error creating usage summary: {create_error}")
self.db.rollback()
return False, f"Failed to create usage summary: {str(create_error)}", {}
# Get user limits
limits_dict = self.get_user_limits(user_id)
if not limits_dict:
# No subscription found - check for free tier
free_plan = self.db.query(SubscriptionPlan).filter(
SubscriptionPlan.tier == SubscriptionTier.FREE,
SubscriptionPlan.is_active == True
).first()
if free_plan:
limits_dict = self._plan_to_limits_dict(free_plan)
else:
return False, "No subscription plan found. Please subscribe to a plan.", {}
limits = limits_dict.get('limits', {})
# Track cumulative usage across all operations
total_llm_calls = (
(usage.gemini_calls or 0) +
(usage.openai_calls or 0) +
(usage.anthropic_calls or 0) +
(usage.mistral_calls or 0)
)
total_llm_tokens = {}
total_images = usage.stability_calls or 0
# Log current usage summary
logger.info(f"[Pre-flight Check] 📊 Current Usage Summary:")
logger.info(f" └─ Total LLM Calls: {total_llm_calls}")
logger.info(f" └─ Gemini Tokens: {usage.gemini_tokens or 0}, Mistral/HF Tokens: {usage.mistral_tokens or 0}")
logger.info(f" └─ Image Calls: {total_images}")
# Validate each operation
for op_idx, operation in enumerate(operations):
provider = operation.get('provider')
provider_name = provider.value if hasattr(provider, 'value') else str(provider)
tokens_requested = operation.get('tokens_requested', 0)
actual_provider_name = operation.get('actual_provider_name')
operation_type = operation.get('operation_type', 'unknown')
display_provider_name = actual_provider_name or provider_name
logger.info(f"[Pre-flight Check] ✅ Operation {op_idx + 1}/{len(operations)}: {operation_type}")
logger.info(f" ├─ Provider: {display_provider_name} (enum: {provider_name})")
logger.info(f" └─ Estimated Tokens: {tokens_requested}")
# Check if this is an LLM provider
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
is_llm_provider = provider_name in llm_providers
# Check unified AI text generation limit for LLM providers
if is_llm_provider:
ai_text_gen_limit = limits.get('ai_text_generation_calls', 0) or 0
if ai_text_gen_limit == 0:
# Fallback to provider-specific limit
ai_text_gen_limit = limits.get(f"{provider_name}_calls", 0) or 0
# Count this operation as an LLM call
projected_total_llm_calls = total_llm_calls + 1
if ai_text_gen_limit > 0 and projected_total_llm_calls > ai_text_gen_limit:
error_info = {
'current_calls': total_llm_calls,
'limit': ai_text_gen_limit,
'provider': display_provider_name,
'operation_type': operation_type,
'operation_index': op_idx
}
return False, f"AI text generation call limit would be exceeded. Would use {projected_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls.", {
'error_type': 'call_limit',
'usage_info': error_info
}
# Check token limits for this provider
# Use cumulative projected tokens from previous operations, or current from DB if first operation
provider_tokens_key = f"{provider_name}_tokens"
if provider_tokens_key in total_llm_tokens:
# Use cumulative projected tokens from previous operations
current_provider_tokens = total_llm_tokens[provider_tokens_key]
logger.info(f" └─ Using cumulative projected tokens: {current_provider_tokens}")
else:
# First operation for this provider - get current from database
current_provider_tokens = getattr(usage, provider_tokens_key, 0) or 0
total_llm_tokens[provider_tokens_key] = current_provider_tokens
logger.info(f" └─ Current tokens from DB: {current_provider_tokens}")
token_limit = limits.get(provider_tokens_key, 0) or 0
if token_limit > 0 and tokens_requested > 0:
projected_tokens = current_provider_tokens + tokens_requested
logger.info(f" └─ Token Check: {current_provider_tokens} (current) + {tokens_requested} (requested) = {projected_tokens} (total) / {token_limit} (limit)")
if projected_tokens > token_limit:
usage_percentage = (projected_tokens / token_limit) * 100 if token_limit > 0 else 0
error_info = {
'current_tokens': current_provider_tokens,
'requested_tokens': tokens_requested,
'limit': token_limit,
'provider': display_provider_name,
'operation_type': operation_type,
'operation_index': op_idx
}
error_msg = (
f"Token limit exceeded for {display_provider_name} "
f"({operation_type}). "
f"Current: {current_provider_tokens}/{token_limit}, "
f"Requested: {tokens_requested}, "
f"Would exceed by: {projected_tokens - token_limit} tokens "
f"({usage_percentage:.1f}% of limit)"
)
logger.error(f"[Pre-flight Check] ❌ BLOCKED: {error_msg}")
return False, error_msg, {
'error_type': 'token_limit',
'usage_info': error_info
}
else:
logger.info(f" └─ ✅ Token limit check passed: {projected_tokens} <= {token_limit}")
# Update cumulative counts for next operation
total_llm_calls = projected_total_llm_calls
total_llm_tokens[provider_tokens_key] += tokens_requested
logger.info(f" └─ Updated cumulative tokens for {display_provider_name}: {total_llm_tokens[provider_tokens_key]}")
# Check image generation limits
elif provider == APIProvider.STABILITY:
image_limit = limits.get('stability_calls', 0) or 0
projected_images = total_images + 1
if image_limit > 0 and projected_images > image_limit:
error_info = {
'current_images': total_images,
'limit': image_limit,
'provider': 'stability',
'operation_type': operation_type,
'operation_index': op_idx
}
return False, f"Image generation limit would be exceeded. Would use {projected_images} of {image_limit} images this billing period.", {
'error_type': 'image_limit',
'usage_info': error_info
}
total_images = projected_images
# Check other provider-specific limits
else:
provider_calls_key = f"{provider_name}_calls"
current_provider_calls = getattr(usage, provider_calls_key, 0) or 0
call_limit = limits.get(provider_calls_key, 0) or 0
if call_limit > 0:
projected_calls = current_provider_calls + 1
if projected_calls > call_limit:
error_info = {
'current_calls': current_provider_calls,
'limit': call_limit,
'provider': display_provider_name,
'operation_type': operation_type,
'operation_index': op_idx
}
return False, f"API call limit would be exceeded for {display_provider_name}. Would use {projected_calls} of {call_limit} calls this billing period.", {
'error_type': 'call_limit',
'usage_info': error_info
}
# All checks passed
logger.info(f"[Pre-flight Check] ✅ All {len(operations)} operation(s) validated successfully")
logger.info(f"[Pre-flight Check] ✅ User {user_id} is cleared to proceed with API calls")
return True, None, None
except Exception as e:
logger.error(f"[Pre-flight Check] Error during comprehensive limit check: {e}", exc_info=True)
return False, f"Failed to validate limits: {str(e)}", {}
from .limit_validation import LimitValidator
validator = LimitValidator(self)
return validator.check_comprehensive_limits(user_id, operations)
def get_pricing_for_provider_model(self, provider: APIProvider, model_name: str) -> Optional[Dict[str, Any]]:
"""Get pricing configuration for a specific provider and model."""

View File

@@ -0,0 +1,39 @@
from typing import Set
from sqlalchemy.orm import Session
_checked_subscription_plan_columns: bool = False
def ensure_subscription_plan_columns(db: Session) -> None:
"""Ensure required columns exist on subscription_plans for runtime safety.
This is a defensive guard for environments where migrations have not yet
been applied. If columns are missing (e.g., exa_calls_limit), we add them
with a safe default so ORM queries do not fail.
"""
global _checked_subscription_plan_columns
if _checked_subscription_plan_columns:
return
try:
# Discover existing columns
result = db.execute("PRAGMA table_info(subscription_plans)")
cols: Set[str] = {row[1] for row in result}
# Columns we may reference in models but might be missing in older DBs
required_columns = {
"exa_calls_limit": "INTEGER DEFAULT 0",
}
for col_name, ddl in required_columns.items():
if col_name not in cols:
db.execute(f"ALTER TABLE subscription_plans ADD COLUMN {col_name} {ddl}")
db.commit()
except Exception:
# Do not block app if pragma/alter fails; let normal errors surface
db.rollback()
finally:
_checked_subscription_plan_columns = True

View File

@@ -7,6 +7,7 @@ Handles authentication, permission checking, and blog publishing to Wix websites
import os
import json
import requests
import uuid
from typing import Dict, Any, Optional, List
from loguru import logger
from datetime import datetime, timedelta
@@ -19,6 +20,9 @@ from services.integrations.wix.media import WixMediaService
from services.integrations.wix.utils import extract_meta_from_token, normalize_token_string, extract_member_id_from_access_token as utils_extract_member
from services.integrations.wix.content import convert_content_to_ricos as ricos_builder
from services.integrations.wix.auth import WixAuthService
from services.integrations.wix.seo import build_seo_data
from services.integrations.wix.ricos_converter import markdown_to_html, convert_via_wix_api
from services.integrations.wix.blog_publisher import create_blog_post as publish_blog_post
class WixService:
"""Service for interacting with Wix APIs"""
@@ -237,13 +241,35 @@ class WixService:
logger.error(f"Failed to import image to Wix: {e}")
raise
def convert_content_to_ricos(self, content: str, images: List[str] = None) -> Dict[str, Any]:
def convert_content_to_ricos(self, content: str, images: List[str] = None,
use_wix_api: bool = False, access_token: str = None) -> Dict[str, Any]:
"""
Convert markdown content to Ricos JSON format.
Args:
content: Markdown content to convert
images: Optional list of image URLs
use_wix_api: If True, use Wix's official Ricos Documents API (requires access_token)
access_token: Wix access token (required if use_wix_api=True)
Returns:
Ricos JSON document
"""
if use_wix_api and access_token:
try:
return convert_via_wix_api(content, access_token, self.base_url)
except Exception as e:
logger.warning(f"Failed to convert via Wix API, falling back to custom parser: {e}")
# Fall back to custom parser
# Use custom parser (current implementation)
return ricos_builder(content, images)
def create_blog_post(self, access_token: str, title: str, content: str,
cover_image_url: str = None, category_ids: List[str] = None,
tag_ids: List[str] = None, publish: bool = True,
member_id: str = None) -> Dict[str, Any]:
member_id: str = None, seo_metadata: Dict[str, Any] = None) -> Dict[str, Any]:
"""
Create and optionally publish a blog post on Wix
@@ -256,101 +282,33 @@ class WixService:
tag_ids: Optional list of tag IDs
publish: Whether to publish immediately or save as draft
member_id: Required for third-party apps - the member ID of the post author
seo_metadata: Optional SEO metadata dict with fields like:
- seo_title: SEO optimized title
- meta_description: Meta description
- focus_keyword: Main keyword
- blog_tags: List of tag strings (for keywords)
- open_graph: Open Graph data
- canonical_url: Canonical URL
Returns:
Created blog post information
"""
if not member_id:
raise ValueError("memberId is required for third-party apps creating blog posts")
headers = {
'Authorization': f'Bearer {access_token}',
'Content-Type': 'application/json'
}
# Build valid Ricos rich content (minimum: one paragraph with text)
ricos_content = self.convert_content_to_ricos(content or "This is a post from ALwrity.", None)
# Minimal payload per Wix docs: title, memberId, and richContent
blog_data = {
'draftPost': {
'title': title,
'memberId': member_id, # Required for third-party apps
'richContent': ricos_content,
'excerpt': (content or '').strip()[:200]
},
'publish': publish,
'fieldsets': ['URL'] # Simplified fieldsets
}
# Add cover image if provided
if cover_image_url:
try:
media_id = self.import_image_to_wix(access_token, cover_image_url, f'Cover: {title}')
blog_data['draftPost']['media'] = {
'wixMedia': {
'image': {'id': media_id}
},
'displayed': True,
'custom': True
}
except Exception as e:
logger.warning(f"Failed to import cover image: {e}")
# Add categories if provided
if category_ids:
blog_data['draftPost']['categoryIds'] = category_ids
# Add tags if provided
if tag_ids:
blog_data['draftPost']['tagIds'] = tag_ids
try:
# Check what permissions we have in the token
logger.info("DEBUG: Checking token permissions...")
try:
import jwt
# Extract token string manually since _normalize_access_token doesn't exist
token_str = str(access_token)
if token_str and token_str.startswith('OauthNG.JWS.'):
jwt_part = token_str[12:]
payload = jwt.decode(jwt_part, options={"verify_signature": False, "verify_aud": False})
logger.info(f"DEBUG: Full token payload: {payload}")
# Check for permissions in various possible locations
data_payload = payload.get('data', {})
if isinstance(data_payload, str):
try:
data_payload = json.loads(data_payload)
except:
pass
instance_data = data_payload.get('instance', {})
permissions = instance_data.get('permissions', '')
scopes = instance_data.get('scopes', [])
meta_site_id = instance_data.get('metaSiteId')
if isinstance(meta_site_id, str) and meta_site_id:
headers['wix-site-id'] = meta_site_id
logger.info(f"DEBUG: Added wix-site-id header: {meta_site_id}")
logger.info(f"DEBUG: Token permissions: {permissions}")
logger.info(f"DEBUG: Token scopes: {scopes}")
else:
logger.info("DEBUG: Could not decode token for permission check")
except Exception as perm_e:
logger.warning(f"DEBUG: Failed to check permissions: {perm_e}")
logger.info(f"DEBUG: Sending simplified blog data: {json.dumps(blog_data, indent=2)}")
extra_headers = {}
if 'wix-site-id' in headers:
extra_headers['wix-site-id'] = headers['wix-site-id']
result = self.blog_service.create_draft_post(access_token, blog_data, extra_headers or None)
logger.info(f"DEBUG: Create draft result: {result}")
return result
except requests.RequestException as e:
logger.error(f"Failed to create blog post: {e}")
if hasattr(e, 'response') and e.response is not None:
logger.error(f"Response body: {e.response.text}")
raise
return publish_blog_post(
blog_service=self.blog_service,
access_token=access_token,
title=title,
content=content,
member_id=member_id,
cover_image_url=cover_image_url,
category_ids=category_ids,
tag_ids=tag_ids,
publish=publish,
seo_metadata=seo_metadata,
import_image_func=self.import_image_to_wix,
lookup_categories_func=self.lookup_or_create_categories,
lookup_tags_func=self.lookup_or_create_tags,
base_url=self.base_url
)
def get_blog_categories(self, access_token: str) -> List[Dict[str, Any]]:
"""
@@ -383,6 +341,138 @@ class WixService:
except requests.RequestException as e:
logger.error(f"Failed to get blog tags: {e}")
raise
def lookup_or_create_categories(self, access_token: str, category_names: List[str],
extra_headers: Optional[Dict[str, str]] = None) -> List[str]:
"""
Lookup existing categories by name or create new ones, return their IDs.
Args:
access_token: Valid access token
category_names: List of category name strings
extra_headers: Optional extra headers (e.g., wix-site-id)
Returns:
List of category UUIDs
"""
if not category_names:
return []
try:
# Get existing categories
existing_categories = self.blog_service.list_categories(access_token, extra_headers)
# Create name -> ID mapping (case-insensitive)
category_map = {}
for cat in existing_categories:
cat_label = cat.get('label', '').strip()
cat_id = cat.get('id')
if cat_label and cat_id:
category_map[cat_label.lower()] = cat_id
category_ids = []
for category_name in category_names:
category_name_clean = str(category_name).strip()
if not category_name_clean:
continue
# Lookup existing category (case-insensitive)
category_id = category_map.get(category_name_clean.lower())
if not category_id:
# Create new category
try:
logger.info(f"Creating new category: {category_name_clean}")
result = self.blog_service.create_category(
access_token,
label=category_name_clean,
extra_headers=extra_headers
)
new_category = result.get('category', {})
category_id = new_category.get('id')
if category_id:
category_ids.append(category_id)
# Update map to avoid duplicate creates
category_map[category_name_clean.lower()] = category_id
logger.info(f"Created category '{category_name_clean}' with ID: {category_id}")
except Exception as create_error:
logger.warning(f"Failed to create category '{category_name_clean}': {create_error}")
# Continue with other categories
else:
category_ids.append(category_id)
logger.info(f"Found existing category '{category_name_clean}' with ID: {category_id}")
return category_ids
except requests.RequestException as e:
logger.error(f"Failed to lookup/create categories: {e}")
return []
def lookup_or_create_tags(self, access_token: str, tag_names: List[str],
extra_headers: Optional[Dict[str, str]] = None) -> List[str]:
"""
Lookup existing tags by name or create new ones, return their IDs.
Args:
access_token: Valid access token
tag_names: List of tag name strings
extra_headers: Optional extra headers (e.g., wix-site-id)
Returns:
List of tag UUIDs
"""
if not tag_names:
return []
try:
# Get existing tags
existing_tags = self.blog_service.list_tags(access_token, extra_headers)
# Create name -> ID mapping (case-insensitive)
tag_map = {}
for tag in existing_tags:
tag_label = tag.get('label', '').strip()
tag_id = tag.get('id')
if tag_label and tag_id:
tag_map[tag_label.lower()] = tag_id
tag_ids = []
for tag_name in tag_names:
tag_name_clean = str(tag_name).strip()
if not tag_name_clean:
continue
# Lookup existing tag (case-insensitive)
tag_id = tag_map.get(tag_name_clean.lower())
if not tag_id:
# Create new tag
try:
logger.info(f"Creating new tag: {tag_name_clean}")
result = self.blog_service.create_tag(
access_token,
label=tag_name_clean,
extra_headers=extra_headers
)
new_tag = result.get('tag', {})
tag_id = new_tag.get('id')
if tag_id:
tag_ids.append(tag_id)
# Update map to avoid duplicate creates
tag_map[tag_name_clean.lower()] = tag_id
logger.info(f"Created tag '{tag_name_clean}' with ID: {tag_id}")
except Exception as create_error:
logger.warning(f"Failed to create tag '{tag_name_clean}': {create_error}")
# Continue with other tags
else:
tag_ids.append(tag_id)
logger.info(f"Found existing tag '{tag_name_clean}' with ID: {tag_id}")
return tag_ids
except requests.RequestException as e:
logger.error(f"Failed to lookup/create tags: {e}")
return []
def publish_draft_post(self, access_token: str, draft_post_id: str) -> Dict[str, Any]:
"""