Research component integration, Copilotkit implementation, SEO copilotkit implementation, Wix SEO metadata complete, Wix SEO metadata review

This commit is contained in:
ajaysi
2025-11-03 16:01:44 +05:30
parent de4328175d
commit e69107b07c
94 changed files with 9748 additions and 1565 deletions

View File

@@ -185,10 +185,20 @@ async def get_research_status(task_id: str) -> Dict[str, Any]:
# Outline Endpoints # Outline Endpoints
@router.post("/outline/start") @router.post("/outline/start")
async def start_outline_generation(request: BlogOutlineRequest) -> Dict[str, Any]: async def start_outline_generation(
request: BlogOutlineRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""Start an outline generation operation and return a task ID for polling.""" """Start an outline generation operation and return a task ID for polling."""
try: try:
task_id = task_manager.start_outline_task(request) # Extract Clerk user ID (required)
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id'))
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found in authentication token")
task_id = task_manager.start_outline_task(request, user_id)
return {"task_id": task_id, "status": "started"} return {"task_id": task_id, "status": "started"}
except Exception as e: except Exception as e:
logger.error(f"Failed to start outline generation: {e}") logger.error(f"Failed to start outline generation: {e}")
@@ -272,12 +282,22 @@ async def generate_section(request: BlogSectionRequest) -> BlogSectionResponse:
@router.post("/content/start") @router.post("/content/start")
async def start_content_generation(request: Dict[str, Any]) -> Dict[str, Any]: async def start_content_generation(
request: Dict[str, Any],
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""Start full content generation and return a task id for polling. """Start full content generation and return a task id for polling.
Accepts a payload compatible with MediumBlogGenerateRequest to minimize duplication. Accepts a payload compatible with MediumBlogGenerateRequest to minimize duplication.
""" """
try: try:
# Extract Clerk user ID (required)
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id'))
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found in authentication token")
# Map dict to MediumBlogGenerateRequest for reuse # Map dict to MediumBlogGenerateRequest for reuse
from models.blog_models import MediumBlogGenerateRequest, MediumSectionOutline, PersonaInfo from models.blog_models import MediumBlogGenerateRequest, MediumSectionOutline, PersonaInfo
sections = [MediumSectionOutline(**s) for s in request.get("sections", [])] sections = [MediumSectionOutline(**s) for s in request.get("sections", [])]
@@ -293,7 +313,7 @@ async def start_content_generation(request: Dict[str, Any]) -> Dict[str, Any]:
globalTargetWords=request.get("globalTargetWords", 1000), globalTargetWords=request.get("globalTargetWords", 1000),
researchKeywords=request.get("researchKeywords") or request.get("keywords"), researchKeywords=request.get("researchKeywords") or request.get("keywords"),
) )
task_id = task_manager.start_content_generation_task(req) task_id = task_manager.start_content_generation_task(req, user_id)
return {"task_id": task_id, "status": "started"} return {"task_id": task_id, "status": "started"}
except Exception as e: except Exception as e:
logger.error(f"Failed to start content generation: {e}") logger.error(f"Failed to start content generation: {e}")
@@ -307,6 +327,51 @@ async def content_generation_status(task_id: str) -> Dict[str, Any]:
status = await task_manager.get_task_status(task_id) status = await task_manager.get_task_status(task_id)
if status is None: if status is None:
raise HTTPException(status_code=404, detail="Task not found") raise HTTPException(status_code=404, detail="Task not found")
# If task failed with subscription error, return HTTP error so frontend interceptor can catch it
if status.get('status') == 'failed' and status.get('error_status') in [429, 402]:
error_data = status.get('error_data', {}) or {}
error_status = status.get('error_status', 429)
if not isinstance(error_data, dict):
logger.warning(f"Content generation task {task_id} error_data not dict: {error_data}")
error_data = {'error': str(error_data)}
# Determine provider and usage info
stored_error_message = status.get('error', error_data.get('error'))
provider = error_data.get('provider', 'unknown')
usage_info = error_data.get('usage_info')
if not usage_info:
usage_info = {
'provider': provider,
'message': stored_error_message,
'error_type': error_data.get('error_type', 'unknown')
}
# Include any known fields from error_data
for key in ['current_tokens', 'requested_tokens', 'limit', 'current_calls']:
if key in error_data:
usage_info[key] = error_data[key]
# Build error message for detail
error_msg = error_data.get('message', stored_error_message or 'Subscription limit exceeded')
# Log the subscription error with all context
logger.warning(f"Content generation task {task_id} failed with subscription error {error_status}: {error_msg}")
logger.warning(f" Provider: {provider}, Usage Info: {usage_info}")
# Use JSONResponse to ensure detail is returned as-is, not wrapped in an array
from fastapi.responses import JSONResponse
return JSONResponse(
status_code=error_status,
content={
'error': error_data.get('error', stored_error_message or 'Subscription limit exceeded'),
'message': error_msg,
'provider': provider,
'usage_info': usage_info
}
)
return status return status
except HTTPException: except HTTPException:
raise raise
@@ -499,14 +564,24 @@ async def get_outline_cache_entries(limit: int = 20):
# --------------------------- # ---------------------------
@router.post("/generate/medium/start") @router.post("/generate/medium/start")
async def start_medium_generation(request: MediumBlogGenerateRequest): async def start_medium_generation(
request: MediumBlogGenerateRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Start medium-length blog generation (≤1000 words) and return a task id.""" """Start medium-length blog generation (≤1000 words) and return a task id."""
try: try:
# Extract Clerk user ID (required)
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id'))
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found in authentication token")
# Simple server-side guard # Simple server-side guard
if (request.globalTargetWords or 1000) > 1000: if (request.globalTargetWords or 1000) > 1000:
raise HTTPException(status_code=400, detail="Global target words exceed 1000; use per-section generation") raise HTTPException(status_code=400, detail="Global target words exceed 1000; use per-section generation")
task_id = task_manager.start_medium_generation_task(request) task_id = task_manager.start_medium_generation_task(request, user_id)
return {"task_id": task_id, "status": "started"} return {"task_id": task_id, "status": "started"}
except HTTPException: except HTTPException:
raise raise
@@ -522,6 +597,51 @@ async def medium_generation_status(task_id: str):
status = await task_manager.get_task_status(task_id) status = await task_manager.get_task_status(task_id)
if status is None: if status is None:
raise HTTPException(status_code=404, detail="Task not found") raise HTTPException(status_code=404, detail="Task not found")
# If task failed with subscription error, return HTTP error so frontend interceptor can catch it
if status.get('status') == 'failed' and status.get('error_status') in [429, 402]:
error_data = status.get('error_data', {}) or {}
error_status = status.get('error_status', 429)
if not isinstance(error_data, dict):
logger.warning(f"Medium generation task {task_id} error_data not dict: {error_data}")
error_data = {'error': str(error_data)}
# Determine provider and usage info
stored_error_message = status.get('error', error_data.get('error'))
provider = error_data.get('provider', 'unknown')
usage_info = error_data.get('usage_info')
if not usage_info:
usage_info = {
'provider': provider,
'message': stored_error_message,
'error_type': error_data.get('error_type', 'unknown')
}
# Include any known fields from error_data
for key in ['current_tokens', 'requested_tokens', 'limit', 'current_calls']:
if key in error_data:
usage_info[key] = error_data[key]
# Build error message for detail
error_msg = error_data.get('message', stored_error_message or 'Subscription limit exceeded')
# Log the subscription error with all context
logger.warning(f"Medium generation task {task_id} failed with subscription error {error_status}: {error_msg}")
logger.warning(f" Provider: {provider}, Usage Info: {usage_info}")
# Use JSONResponse to ensure detail is returned as-is, not wrapped in an array
from fastapi.responses import JSONResponse
return JSONResponse(
status_code=error_status,
content={
'error': error_data.get('error', stored_error_message or 'Subscription limit exceeded'),
'message': error_msg,
'provider': provider,
'usage_info': usage_info
}
)
return status return status
except HTTPException: except HTTPException:
raise raise

View File

@@ -5,7 +5,7 @@ Provides API endpoint for analyzing blog content SEO with parallel processing
and CopilotKit integration for real-time progress updates. and CopilotKit integration for real-time progress updates.
""" """
from fastapi import APIRouter, HTTPException, BackgroundTasks from fastapi import APIRouter, HTTPException, BackgroundTasks, Depends
from pydantic import BaseModel from pydantic import BaseModel
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
from loguru import logger from loguru import logger
@@ -13,6 +13,7 @@ from datetime import datetime
from services.blog_writer.seo.blog_content_seo_analyzer import BlogContentSEOAnalyzer from services.blog_writer.seo.blog_content_seo_analyzer import BlogContentSEOAnalyzer
from services.blog_writer.core.blog_writer_service import BlogWriterService from services.blog_writer.core.blog_writer_service import BlogWriterService
from middleware.auth_middleware import get_current_user
router = APIRouter(prefix="/api/blog-writer/seo", tags=["Blog SEO Analysis"]) router = APIRouter(prefix="/api/blog-writer/seo", tags=["Blog SEO Analysis"])
@@ -56,7 +57,10 @@ blog_writer_service = BlogWriterService()
@router.post("/analyze", response_model=SEOAnalysisResponse) @router.post("/analyze", response_model=SEOAnalysisResponse)
async def analyze_blog_seo(request: SEOAnalysisRequest): async def analyze_blog_seo(
request: SEOAnalysisRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
):
""" """
Analyze blog content for SEO optimization Analyze blog content for SEO optimization
@@ -69,6 +73,7 @@ async def analyze_blog_seo(request: SEOAnalysisRequest):
Args: Args:
request: SEOAnalysisRequest containing blog content and research data request: SEOAnalysisRequest containing blog content and research data
current_user: Authenticated user from middleware
Returns: Returns:
SEOAnalysisResponse with comprehensive analysis results SEOAnalysisResponse with comprehensive analysis results
@@ -76,6 +81,14 @@ async def analyze_blog_seo(request: SEOAnalysisRequest):
try: try:
logger.info(f"Starting SEO analysis for blog content") logger.info(f"Starting SEO analysis for blog content")
# Extract Clerk user ID (required)
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
# Validate request # Validate request
if not request.blog_content or not request.blog_content.strip(): if not request.blog_content or not request.blog_content.strip():
raise HTTPException(status_code=400, detail="Blog content is required") raise HTTPException(status_code=400, detail="Blog content is required")
@@ -91,7 +104,8 @@ async def analyze_blog_seo(request: SEOAnalysisRequest):
analysis_results = await seo_analyzer.analyze_blog_content( analysis_results = await seo_analyzer.analyze_blog_content(
blog_content=request.blog_content, blog_content=request.blog_content,
research_data=request.research_data, research_data=request.research_data,
blog_title=request.blog_title blog_title=request.blog_title,
user_id=user_id
) )
# Check for errors # Check for errors
@@ -131,7 +145,10 @@ async def analyze_blog_seo(request: SEOAnalysisRequest):
@router.post("/analyze-with-progress") @router.post("/analyze-with-progress")
async def analyze_blog_seo_with_progress(request: SEOAnalysisRequest): async def analyze_blog_seo_with_progress(
request: SEOAnalysisRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
):
""" """
Analyze blog content for SEO with real-time progress updates Analyze blog content for SEO with real-time progress updates
@@ -140,6 +157,7 @@ async def analyze_blog_seo_with_progress(request: SEOAnalysisRequest):
Args: Args:
request: SEOAnalysisRequest containing blog content and research data request: SEOAnalysisRequest containing blog content and research data
current_user: Authenticated user from middleware
Returns: Returns:
Generator yielding progress updates and final results Generator yielding progress updates and final results
@@ -147,6 +165,14 @@ async def analyze_blog_seo_with_progress(request: SEOAnalysisRequest):
try: try:
logger.info(f"Starting SEO analysis with progress for blog content") logger.info(f"Starting SEO analysis with progress for blog content")
# Extract Clerk user ID (required)
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
# Validate request # Validate request
if not request.blog_content or not request.blog_content.strip(): if not request.blog_content or not request.blog_content.strip():
raise HTTPException(status_code=400, detail="Blog content is required") raise HTTPException(status_code=400, detail="Blog content is required")
@@ -209,7 +235,9 @@ async def analyze_blog_seo_with_progress(request: SEOAnalysisRequest):
# Perform actual analysis # Perform actual analysis
analysis_results = await seo_analyzer.analyze_blog_content( analysis_results = await seo_analyzer.analyze_blog_content(
blog_content=request.blog_content, blog_content=request.blog_content,
research_data=request.research_data research_data=request.research_data,
blog_title=request.blog_title,
user_id=user_id
) )
# Final result # Final result

View File

@@ -88,8 +88,12 @@ class TaskManager:
response["error"] = task["error"] response["error"] = task["error"]
if "error_status" in task: if "error_status" in task:
response["error_status"] = task["error_status"] response["error_status"] = task["error_status"]
logger.info(f"[TaskManager] get_task_status for {task_id}: Including error_status={task['error_status']} in response")
if "error_data" in task: if "error_data" in task:
response["error_data"] = task["error_data"] response["error_data"] = task["error_data"]
logger.info(f"[TaskManager] get_task_status for {task_id}: Including error_data with keys: {list(task['error_data'].keys()) if isinstance(task['error_data'], dict) else 'not-dict'}")
else:
logger.warning(f"[TaskManager] get_task_status for {task_id}: Task failed but no error_data found. Task keys: {list(task.keys())}")
return response return response
@@ -127,29 +131,33 @@ class TaskManager:
asyncio.create_task(self._run_research_task(task_id, request, user_id)) asyncio.create_task(self._run_research_task(task_id, request, user_id))
return task_id return task_id
def start_outline_task(self, request: BlogOutlineRequest) -> str: def start_outline_task(self, request: BlogOutlineRequest, user_id: str) -> str:
"""Start an outline generation operation and return a task ID.""" """Start an outline generation operation and return a task ID."""
task_id = self.create_task("outline") task_id = self.create_task("outline")
# Start the outline generation operation in the background # Start the outline generation operation in the background
asyncio.create_task(self._run_outline_generation_task(task_id, request)) asyncio.create_task(self._run_outline_generation_task(task_id, request, user_id))
return task_id return task_id
def start_medium_generation_task(self, request: MediumBlogGenerateRequest) -> str: def start_medium_generation_task(self, request: MediumBlogGenerateRequest, user_id: str) -> str:
"""Start a medium (≤1000 words) full-blog generation task.""" """Start a medium (≤1000 words) full-blog generation task."""
task_id = self.create_task("medium_generation") task_id = self.create_task("medium_generation")
asyncio.create_task(self._run_medium_generation_task(task_id, request)) asyncio.create_task(self._run_medium_generation_task(task_id, request, user_id))
return task_id return task_id
def start_content_generation_task(self, request: MediumBlogGenerateRequest) -> str: def start_content_generation_task(self, request: MediumBlogGenerateRequest, user_id: str) -> str:
"""Start content generation (full blog via sections) with provider parity. """Start content generation (full blog via sections) with provider parity.
Internally reuses medium generator pipeline for now but tracked under Internally reuses medium generator pipeline for now but tracked under
distinct task_type 'content_generation' and same polling contract. distinct task_type 'content_generation' and same polling contract.
Args:
request: Content generation request
user_id: User ID (required for subscription checks and usage tracking)
""" """
task_id = self.create_task("content_generation") task_id = self.create_task("content_generation")
asyncio.create_task(self._run_medium_generation_task(task_id, request)) asyncio.create_task(self._run_medium_generation_task(task_id, request, user_id))
return task_id return task_id
async def _run_research_task(self, task_id: str, request: BlogResearchRequest, user_id: str): async def _run_research_task(self, task_id: str, request: BlogResearchRequest, user_id: str):
@@ -205,7 +213,7 @@ class TaskManager:
self.task_storage[task_id]["status"] = "failed" self.task_storage[task_id]["status"] = "failed"
self.task_storage[task_id]["error"] = "Research completed with unknown status" self.task_storage[task_id]["error"] = "Research completed with unknown status"
async def _run_outline_generation_task(self, task_id: str, request: BlogOutlineRequest): async def _run_outline_generation_task(self, task_id: str, request: BlogOutlineRequest, user_id: str):
"""Background task to run outline generation and update status with progress messages.""" """Background task to run outline generation and update status with progress messages."""
try: try:
# Update status to running # Update status to running
@@ -215,21 +223,31 @@ class TaskManager:
# Send initial progress message # Send initial progress message
await self.update_progress(task_id, "🧩 Starting outline generation...") await self.update_progress(task_id, "🧩 Starting outline generation...")
# Run the actual outline generation with progress updates # Run the actual outline generation with progress updates (pass user_id for subscription checks)
result = await self.service.generate_outline_with_progress(request, task_id) result = await self.service.generate_outline_with_progress(request, task_id, user_id)
# Update status to completed # Update status to completed
await self.update_progress(task_id, f"✅ Outline generated successfully! Created {len(result.outline)} sections with {len(result.title_options)} title options.") await self.update_progress(task_id, f"✅ Outline generated successfully! Created {len(result.outline)} sections with {len(result.title_options)} title options.")
self.task_storage[task_id]["status"] = "completed" self.task_storage[task_id]["status"] = "completed"
self.task_storage[task_id]["result"] = result.dict() self.task_storage[task_id]["result"] = result.dict()
except HTTPException as http_error:
# Handle HTTPException (e.g., 429 subscription limit) - preserve error details for frontend
error_detail = http_error.detail
error_message = error_detail.get('message', str(error_detail)) if isinstance(error_detail, dict) else str(error_detail)
await self.update_progress(task_id, f"{error_message}")
self.task_storage[task_id]["status"] = "failed"
self.task_storage[task_id]["error"] = error_message
# Store HTTP error details for frontend modal
self.task_storage[task_id]["error_status"] = http_error.status_code
self.task_storage[task_id]["error_data"] = error_detail if isinstance(error_detail, dict) else {"error": str(error_detail)}
except Exception as e: except Exception as e:
await self.update_progress(task_id, f"❌ Outline generation failed: {str(e)}") await self.update_progress(task_id, f"❌ Outline generation failed: {str(e)}")
# Update status to failed # Update status to failed
self.task_storage[task_id]["status"] = "failed" self.task_storage[task_id]["status"] = "failed"
self.task_storage[task_id]["error"] = str(e) self.task_storage[task_id]["error"] = str(e)
async def _run_medium_generation_task(self, task_id: str, request: MediumBlogGenerateRequest): async def _run_medium_generation_task(self, task_id: str, request: MediumBlogGenerateRequest, user_id: str):
"""Background task to generate a medium blog using a single structured JSON call.""" """Background task to generate a medium blog using a single structured JSON call."""
try: try:
self.task_storage[task_id]["status"] = "running" self.task_storage[task_id]["status"] = "running"
@@ -245,6 +263,7 @@ class TaskManager:
result: MediumBlogGenerateResult = await self.service.generate_medium_blog_with_progress( result: MediumBlogGenerateResult = await self.service.generate_medium_blog_with_progress(
request, request,
task_id, task_id,
user_id
) )
if not result or not getattr(result, "sections", None): if not result or not getattr(result, "sections", None):
@@ -263,10 +282,38 @@ class TaskManager:
self.task_storage[task_id]["result"] = result.dict() self.task_storage[task_id]["result"] = result.dict()
await self.update_progress(task_id, f"✅ Generated {len(result.sections)} sections successfully.") await self.update_progress(task_id, f"✅ Generated {len(result.sections)} sections successfully.")
except Exception as e: except HTTPException as http_error:
await self.update_progress(task_id, f"❌ Medium generation failed: {str(e)}") # Handle HTTPException (e.g., 429 subscription limit) - preserve error details for frontend
logger.info(f"[TaskManager] Caught HTTPException in medium generation task {task_id}: status={http_error.status_code}, detail={http_error.detail}")
error_detail = http_error.detail
error_message = error_detail.get('message', str(error_detail)) if isinstance(error_detail, dict) else str(error_detail)
await self.update_progress(task_id, f"{error_message}")
self.task_storage[task_id]["status"] = "failed" self.task_storage[task_id]["status"] = "failed"
self.task_storage[task_id]["error"] = str(e) self.task_storage[task_id]["error"] = error_message
# Store HTTP error details for frontend modal
self.task_storage[task_id]["error_status"] = http_error.status_code
self.task_storage[task_id]["error_data"] = error_detail if isinstance(error_detail, dict) else {"error": str(error_detail)}
logger.info(f"[TaskManager] Stored error_status={http_error.status_code} and error_data keys: {list(error_detail.keys()) if isinstance(error_detail, dict) else 'not-dict'}")
except Exception as e:
# Check if this is an HTTPException that got wrapped (can happen in async tasks)
# HTTPException has status_code and detail attributes
logger.info(f"[TaskManager] Caught Exception in medium generation task {task_id}: type={type(e).__name__}, has_status_code={hasattr(e, 'status_code')}, has_detail={hasattr(e, 'detail')}")
if hasattr(e, 'status_code') and hasattr(e, 'detail'):
# This is an HTTPException that was caught as generic Exception
logger.info(f"[TaskManager] Detected HTTPException in Exception handler: status={e.status_code}, detail={e.detail}")
error_detail = e.detail
error_message = error_detail.get('message', str(error_detail)) if isinstance(error_detail, dict) else str(error_detail)
await self.update_progress(task_id, f"{error_message}")
self.task_storage[task_id]["status"] = "failed"
self.task_storage[task_id]["error"] = error_message
# Store HTTP error details for frontend modal
self.task_storage[task_id]["error_status"] = e.status_code
self.task_storage[task_id]["error_data"] = error_detail if isinstance(error_detail, dict) else {"error": str(error_detail)}
logger.info(f"[TaskManager] Stored error_status={e.status_code} and error_data keys: {list(error_detail.keys()) if isinstance(error_detail, dict) else 'not-dict'}")
else:
await self.update_progress(task_id, f"❌ Medium generation failed: {str(e)}")
self.task_storage[task_id]["status"] = "failed"
self.task_storage[task_id]["error"] = str(e)
# Global task manager instance # Global task manager instance

View File

@@ -12,6 +12,7 @@ from functools import lru_cache
from services.database import get_db from services.database import get_db
from services.subscription import UsageTrackingService, PricingService from services.subscription import UsageTrackingService, PricingService
from services.subscription.schema_utils import ensure_subscription_plan_columns
from middleware.auth_middleware import get_current_user from middleware.auth_middleware import get_current_user
from models.subscription_models import ( from models.subscription_models import (
APIProvider, SubscriptionPlan, UserSubscription, UsageSummary, APIProvider, SubscriptionPlan, UserSubscription, UsageSummary,
@@ -79,6 +80,8 @@ async def get_subscription_plans(
"""Get all available subscription plans.""" """Get all available subscription plans."""
try: try:
# Ensure required columns exist (handles environments without migrations applied yet)
ensure_subscription_plan_columns(db)
plans = db.query(SubscriptionPlan).filter( plans = db.query(SubscriptionPlan).filter(
SubscriptionPlan.is_active == True SubscriptionPlan.is_active == True
).order_by(SubscriptionPlan.price_monthly).all() ).order_by(SubscriptionPlan.price_monthly).all()
@@ -137,6 +140,7 @@ async def get_user_subscription(
raise HTTPException(status_code=403, detail="Access denied") raise HTTPException(status_code=403, detail="Access denied")
try: try:
ensure_subscription_plan_columns(db)
subscription = db.query(UserSubscription).filter( subscription = db.query(UserSubscription).filter(
UserSubscription.user_id == user_id, UserSubscription.user_id == user_id,
UserSubscription.is_active == True UserSubscription.is_active == True
@@ -234,6 +238,7 @@ async def get_subscription_status(
raise HTTPException(status_code=403, detail="Access denied") raise HTTPException(status_code=403, detail="Access denied")
try: try:
ensure_subscription_plan_columns(db)
subscription = db.query(UserSubscription).filter( subscription = db.query(UserSubscription).filter(
UserSubscription.user_id == user_id, UserSubscription.user_id == user_id,
UserSubscription.is_active == True UserSubscription.is_active == True
@@ -346,6 +351,7 @@ async def subscribe_to_plan(
raise HTTPException(status_code=403, detail="Access denied") raise HTTPException(status_code=403, detail="Access denied")
try: try:
ensure_subscription_plan_columns(db)
plan_id = subscription_data.get('plan_id') plan_id = subscription_data.get('plan_id')
billing_cycle = subscription_data.get('billing_cycle', 'monthly') billing_cycle = subscription_data.get('billing_cycle', 'monthly')
@@ -427,11 +433,16 @@ async def subscribe_to_plan(
logger.info(f" 📊 No usage summary found for period {current_period} (will be created on reset)") logger.info(f" 📊 No usage summary found for period {current_period} (will be created on reset)")
# Clear subscription limits cache to force refresh on next check # Clear subscription limits cache to force refresh on next check
# IMPORTANT: Do this BEFORE resetting usage to ensure cache is cleared first
try: try:
from services.subscription import PricingService from services.subscription import PricingService
# Clear cache for this specific user (class-level cache shared across all instances) # Clear cache for this specific user (class-level cache shared across all instances)
cleared_count = PricingService.clear_user_cache(user_id) cleared_count = PricingService.clear_user_cache(user_id)
logger.info(f" 🗑️ Cleared {cleared_count} subscription cache entries for user {user_id}") logger.info(f" 🗑️ Cleared {cleared_count} subscription cache entries for user {user_id}")
# Also expire all SQLAlchemy objects to force fresh reads
db.expire_all()
logger.info(f" 🔄 Expired all SQLAlchemy objects to force fresh reads")
except Exception as cache_err: except Exception as cache_err:
logger.error(f" ❌ Failed to clear cache after subscribe: {cache_err}") logger.error(f" ❌ Failed to clear cache after subscribe: {cache_err}")
@@ -441,12 +452,22 @@ async def subscribe_to_plan(
usage_service = UsageTrackingService(db) usage_service = UsageTrackingService(db)
reset_result = await usage_service.reset_current_billing_period(user_id) reset_result = await usage_service.reset_current_billing_period(user_id)
# Re-query usage summary from DB after reset to get fresh data # Force commit to ensure reset is persisted
db.commit()
# Expire all SQLAlchemy objects to force fresh reads
db.expire_all()
# Re-query usage summary from DB after reset to get fresh data (fresh query)
usage_after = db.query(UsageSummary).filter( usage_after = db.query(UsageSummary).filter(
UsageSummary.user_id == user_id, UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period UsageSummary.billing_period == current_period
).first() ).first()
# Refresh the usage object if found to ensure we have latest data
if usage_after:
db.refresh(usage_after)
if reset_result.get('reset'): if reset_result.get('reset'):
logger.info(f" ✅ Usage counters RESET successfully") logger.info(f" ✅ Usage counters RESET successfully")
if usage_after: if usage_after:
@@ -635,6 +656,7 @@ async def get_dashboard_data(
"""Get comprehensive dashboard data for usage monitoring.""" """Get comprehensive dashboard data for usage monitoring."""
try: try:
ensure_subscription_plan_columns(db)
# Serve from short TTL cache to avoid hammering DB on bursts # Serve from short TTL cache to avoid hammering DB on bursts
import time import time
now = time.time() now = time.time()

View File

@@ -535,15 +535,33 @@ async def test_publish_real(payload: Dict[str, Any]) -> Dict[str, Any]:
if not member_id: if not member_id:
raise HTTPException(status_code=400, detail="Unable to resolve member_id from token") raise HTTPException(status_code=400, detail="Unable to resolve member_id from token")
# Extract SEO metadata if provided
seo_metadata = payload.get("seo_metadata")
# Extract category/tag IDs or names
# Can be either:
# - IDs: List of UUID strings
# - Names: List of name strings (will be looked up/created)
category_ids = payload.get("category_ids") or payload.get("category_names")
tag_ids = payload.get("tag_ids") or payload.get("tag_names")
# If SEO metadata has categories/tags but they weren't explicitly provided, use them
if seo_metadata:
if not category_ids and seo_metadata.get("blog_categories"):
category_ids = seo_metadata.get("blog_categories")
if not tag_ids and seo_metadata.get("blog_tags"):
tag_ids = seo_metadata.get("blog_tags")
result = wix_service.create_blog_post( result = wix_service.create_blog_post(
access_token=access_token, access_token=access_token,
title=payload.get("title") or "Untitled", title=payload.get("title") or "Untitled",
content=payload.get("content") or "", content=payload.get("content") or "",
cover_image_url=payload.get("cover_image_url"), cover_image_url=payload.get("cover_image_url"),
category_ids=payload.get("category_ids") or None, category_ids=category_ids,
tag_ids=payload.get("tag_ids") or None, tag_ids=tag_ids,
publish=bool(payload.get("publish", True)), publish=bool(payload.get("publish", True)),
member_id=member_id, member_id=member_id,
seo_metadata=seo_metadata,
) )
return { return {

View File

@@ -11,6 +11,7 @@ import asyncio
from datetime import datetime from datetime import datetime
from services.subscription import monitoring_middleware from services.subscription import monitoring_middleware
# Import modular utilities # Import modular utilities
from alwrity_utils import HealthChecker, RateLimiter, FrontendServing, RouterManager, OnboardingManager from alwrity_utils import HealthChecker, RateLimiter, FrontendServing, RouterManager, OnboardingManager

View File

@@ -0,0 +1,17 @@
-- Add EXA to subscription plans
ALTER TABLE subscription_plans
ADD COLUMN exa_calls_limit INT DEFAULT 0;
-- Add EXA to usage summaries
ALTER TABLE usage_summaries
ADD COLUMN exa_calls INT DEFAULT 0;
ALTER TABLE usage_summaries
ADD COLUMN exa_cost FLOAT DEFAULT 0.0;
-- Update default limits for existing plans
UPDATE subscription_plans SET exa_calls_limit = 100 WHERE tier = 'free';
UPDATE subscription_plans SET exa_calls_limit = 500 WHERE tier = 'basic';
UPDATE subscription_plans SET exa_calls_limit = 2000 WHERE tier = 'pro';
UPDATE subscription_plans SET exa_calls_limit = 0 WHERE tier = 'enterprise';

View File

@@ -1,5 +1,6 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any from typing import List, Optional, Dict, Any
from enum import Enum
class PersonaInfo(BaseModel): class PersonaInfo(BaseModel):
@@ -50,6 +51,51 @@ class GroundingMetadata(BaseModel):
web_search_queries: List[str] = [] web_search_queries: List[str] = []
class ResearchMode(str, Enum):
"""Research modes for different depth levels."""
BASIC = "basic"
COMPREHENSIVE = "comprehensive"
TARGETED = "targeted"
class SourceType(str, Enum):
"""Types of sources to include in research."""
WEB = "web"
ACADEMIC = "academic"
NEWS = "news"
INDUSTRY = "industry"
EXPERT = "expert"
class DateRange(str, Enum):
"""Date range filters for research."""
LAST_WEEK = "last_week"
LAST_MONTH = "last_month"
LAST_3_MONTHS = "last_3_months"
LAST_6_MONTHS = "last_6_months"
LAST_YEAR = "last_year"
ALL_TIME = "all_time"
class ResearchProvider(str, Enum):
"""Research provider options."""
GOOGLE = "google" # Gemini native grounding
EXA = "exa" # Exa neural search
class ResearchConfig(BaseModel):
"""Configuration for research execution."""
mode: ResearchMode = ResearchMode.BASIC
provider: ResearchProvider = ResearchProvider.GOOGLE
date_range: Optional[DateRange] = None
source_types: List[SourceType] = []
max_sources: int = 10
include_statistics: bool = True
include_expert_quotes: bool = True
include_competitors: bool = True
include_trends: bool = True
class BlogResearchRequest(BaseModel): class BlogResearchRequest(BaseModel):
keywords: List[str] keywords: List[str]
topic: Optional[str] = None topic: Optional[str] = None
@@ -58,6 +104,8 @@ class BlogResearchRequest(BaseModel):
tone: Optional[str] = None tone: Optional[str] = None
word_count_target: Optional[int] = 1500 word_count_target: Optional[int] = 1500
persona: Optional[PersonaInfo] = None persona: Optional[PersonaInfo] = None
research_mode: Optional[ResearchMode] = ResearchMode.BASIC
config: Optional[ResearchConfig] = None
class BlogResearchResponse(BaseModel): class BlogResearchResponse(BaseModel):

View File

@@ -34,6 +34,7 @@ class APIProvider(enum.Enum):
METAPHOR = "metaphor" METAPHOR = "metaphor"
FIRECRAWL = "firecrawl" FIRECRAWL = "firecrawl"
STABILITY = "stability" STABILITY = "stability"
EXA = "exa"
class BillingCycle(enum.Enum): class BillingCycle(enum.Enum):
MONTHLY = "monthly" MONTHLY = "monthly"
@@ -66,6 +67,7 @@ class SubscriptionPlan(Base):
metaphor_calls_limit = Column(Integer, default=0) metaphor_calls_limit = Column(Integer, default=0)
firecrawl_calls_limit = Column(Integer, default=0) firecrawl_calls_limit = Column(Integer, default=0)
stability_calls_limit = Column(Integer, default=0) # Image generation stability_calls_limit = Column(Integer, default=0) # Image generation
exa_calls_limit = Column(Integer, default=0) # Exa neural search
# Token Limits (for LLM providers) # Token Limits (for LLM providers)
gemini_tokens_limit = Column(Integer, default=0) gemini_tokens_limit = Column(Integer, default=0)
@@ -182,6 +184,7 @@ class UsageSummary(Base):
metaphor_calls = Column(Integer, default=0) metaphor_calls = Column(Integer, default=0)
firecrawl_calls = Column(Integer, default=0) firecrawl_calls = Column(Integer, default=0)
stability_calls = Column(Integer, default=0) stability_calls = Column(Integer, default=0)
exa_calls = Column(Integer, default=0)
# Token Usage # Token Usage
gemini_tokens = Column(Integer, default=0) gemini_tokens = Column(Integer, default=0)
@@ -199,6 +202,7 @@ class UsageSummary(Base):
metaphor_cost = Column(Float, default=0.0) metaphor_cost = Column(Float, default=0.0)
firecrawl_cost = Column(Float, default=0.0) firecrawl_cost = Column(Float, default=0.0)
stability_cost = Column(Float, default=0.0) stability_cost = Column(Float, default=0.0)
exa_cost = Column(Float, default=0.0)
# Totals # Totals
total_calls = Column(Integer, default=0) total_calls = Column(Integer, default=0)

View File

@@ -38,6 +38,7 @@ aiohttp>=3.9.0
# Data processing # Data processing
pandas>=2.0.0 pandas>=2.0.0
numpy>=1.24.0 numpy>=1.24.0
markdown>=3.5.0
# SEO Analysis dependencies # SEO Analysis dependencies
advertools>=0.14.0 advertools>=0.14.0

View File

@@ -3,7 +3,7 @@ Script to update Basic plan subscription limits for testing rate limits and rene
Updates: Updates:
- LLM Calls (all providers): 10 calls (was 500-1000) - LLM Calls (all providers): 10 calls (was 500-1000)
- LLM Tokens (all providers): 2000 tokens (was 200k-1M) - LLM Tokens (all providers): 5000 tokens (increased from 2000 to support research workflow)
- Images: 5 images (was 50) - Images: 5 images (was 50)
This script updates the SubscriptionPlan table, which automatically applies to all users This script updates the SubscriptionPlan table, which automatically applies to all users
@@ -69,11 +69,11 @@ def update_basic_plan_limits():
basic_plan.anthropic_calls_limit = 200 basic_plan.anthropic_calls_limit = 200
basic_plan.mistral_calls_limit = 500 basic_plan.mistral_calls_limit = 500
# Update all LLM provider token limits to 2000 # Update all LLM provider token limits to 20000 (increased from 5000 for better stability)
basic_plan.gemini_tokens_limit = 2000 basic_plan.gemini_tokens_limit = 20000
basic_plan.openai_tokens_limit = 2000 basic_plan.openai_tokens_limit = 20000
basic_plan.anthropic_tokens_limit = 2000 basic_plan.anthropic_tokens_limit = 20000
basic_plan.mistral_tokens_limit = 2000 basic_plan.mistral_tokens_limit = 20000
# Update image generation limit to 5 # Update image generation limit to 5
basic_plan.stability_calls_limit = 5 basic_plan.stability_calls_limit = 5
@@ -83,7 +83,7 @@ def update_basic_plan_limits():
logger.info("\n📝 New Basic plan limits:") logger.info("\n📝 New Basic plan limits:")
logger.info(f" LLM Calls (all providers): 10") logger.info(f" LLM Calls (all providers): 10")
logger.info(f" LLM Tokens (all providers): 2000") logger.info(f" LLM Tokens (all providers): 20000 (increased from 5000)")
logger.info(f" Images: 5") logger.info(f" Images: 5")
# Count and get affected users # Count and get affected users
@@ -118,7 +118,7 @@ def update_basic_plan_limits():
# New limits - use unified AI text generation limit if available # New limits - use unified AI text generation limit if available
new_call_limit = getattr(basic_plan, 'ai_text_generation_calls_limit', None) or basic_plan.gemini_calls_limit new_call_limit = getattr(basic_plan, 'ai_text_generation_calls_limit', None) or basic_plan.gemini_calls_limit
new_token_limit = basic_plan.gemini_tokens_limit # 2000 new_token_limit = basic_plan.gemini_tokens_limit # 5000 (increased from 2000)
new_image_limit = basic_plan.stability_calls_limit # 5 new_image_limit = basic_plan.stability_calls_limit # 5
for sub in user_subscriptions: for sub in user_subscriptions:
@@ -253,7 +253,7 @@ if __name__ == "__main__":
logger.info("="*60) logger.info("="*60)
logger.info("This will update Basic plan limits for testing rate limits:") logger.info("This will update Basic plan limits for testing rate limits:")
logger.info(" - LLM Calls: 10 (all providers)") logger.info(" - LLM Calls: 10 (all providers)")
logger.info(" - LLM Tokens: 2000 (all providers)") logger.info(" - LLM Tokens: 20000 (all providers, increased from 5000)")
logger.info(" - Images: 5") logger.info(" - Images: 5")
logger.info("="*60) logger.info("="*60)

View File

@@ -8,6 +8,7 @@ import time
import json import json
from typing import Dict, Any, List from typing import Dict, Any, List
from loguru import logger from loguru import logger
from fastapi import HTTPException
from models.blog_models import ( from models.blog_models import (
MediumBlogGenerateRequest, MediumBlogGenerateRequest,
@@ -25,8 +26,20 @@ class MediumBlogGenerator:
def __init__(self): def __init__(self):
self.cache = persistent_content_cache self.cache = persistent_content_cache
async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str) -> MediumBlogGenerateResult: async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str, user_id: str) -> MediumBlogGenerateResult:
"""Use Gemini structured JSON to generate a medium-length blog in one call.""" """Use Gemini structured JSON to generate a medium-length blog in one call.
Args:
req: Medium blog generation request
task_id: Task ID for progress updates
user_id: User ID (required for subscription checks and usage tracking)
Raises:
ValueError: If user_id is not provided
"""
if not user_id:
raise ValueError("user_id is required for medium blog generation (subscription checks and usage tracking)")
import time import time
start = time.time() start = time.time()
@@ -156,7 +169,7 @@ class MediumBlogGenerator:
- Use language that resonates with {audience} - Use language that resonates with {audience}
- Maintain consistent voice that reflects this persona's expertise - Maintain consistent voice that reflects this persona's expertise
""" """
prompt = ( prompt = (
f"Write blog content for the following sections. Each section should be {req.globalTargetWords or 1000} words total, distributed across all sections.\n\n" f"Write blog content for the following sections. Each section should be {req.globalTargetWords or 1000} words total, distributed across all sections.\n\n"
f"Blog Title: {req.title}\n\n" f"Blog Title: {req.title}\n\n"
@@ -176,11 +189,20 @@ class MediumBlogGenerator:
f"Sections to write:\n{json.dumps(payload, ensure_ascii=False, indent=2)}" f"Sections to write:\n{json.dumps(payload, ensure_ascii=False, indent=2)}"
) )
ai_resp = llm_text_gen( try:
prompt=prompt, ai_resp = llm_text_gen(
json_struct=schema, prompt=prompt,
system_prompt=system, json_struct=schema,
) system_prompt=system,
user_id=user_id
)
except HTTPException:
# Re-raise HTTPExceptions (e.g., 429 subscription limit) to preserve error details
raise
except Exception as llm_error:
# Wrap other errors
logger.error(f"AI generation failed: {llm_error}")
raise Exception(f"AI generation failed: {str(llm_error)}")
# Check for errors in AI response # Check for errors in AI response
if not ai_resp or ai_resp.get("error"): if not ai_resp or ai_resp.get("error"):

View File

@@ -105,13 +105,20 @@ class BlogWriterService:
return await self.research_service.research_with_progress(request, task_id, user_id) return await self.research_service.research_with_progress(request, task_id, user_id)
# Outline Methods # Outline Methods
async def generate_outline(self, request: BlogOutlineRequest) -> BlogOutlineResponse: async def generate_outline(self, request: BlogOutlineRequest, user_id: str) -> BlogOutlineResponse:
"""Generate AI-powered outline from research data.""" """Generate AI-powered outline from research data.
return await self.outline_service.generate_outline(request)
Args:
request: Outline generation request with research data
user_id: User ID (required for subscription checks and usage tracking)
"""
if not user_id:
raise ValueError("user_id is required for outline generation (subscription checks and usage tracking)")
return await self.outline_service.generate_outline(request, user_id)
async def generate_outline_with_progress(self, request: BlogOutlineRequest, task_id: str) -> BlogOutlineResponse: async def generate_outline_with_progress(self, request: BlogOutlineRequest, task_id: str, user_id: str) -> BlogOutlineResponse:
"""Generate outline with real-time progress updates.""" """Generate outline with real-time progress updates."""
return await self.outline_service.generate_outline_with_progress(request, task_id) return await self.outline_service.generate_outline_with_progress(request, task_id, user_id)
async def refine_outline(self, request: BlogOutlineRefineRequest) -> BlogOutlineResponse: async def refine_outline(self, request: BlogOutlineRefineRequest) -> BlogOutlineResponse:
"""Refine outline with HITL operations.""" """Refine outline with HITL operations."""
@@ -334,9 +341,17 @@ class BlogWriterService:
# TODO: Move to content module # TODO: Move to content module
return BlogPublishResponse(success=True, platform=request.platform, url="https://example.com/post") return BlogPublishResponse(success=True, platform=request.platform, url="https://example.com/post")
async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str) -> MediumBlogGenerateResult: async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str, user_id: str) -> MediumBlogGenerateResult:
"""Use Gemini structured JSON to generate a medium-length blog in one call.""" """Use Gemini structured JSON to generate a medium-length blog in one call.
return await self.medium_blog_generator.generate_medium_blog_with_progress(req, task_id)
Args:
req: Medium blog generation request
task_id: Task ID for progress updates
user_id: User ID (required for subscription checks and usage tracking)
"""
if not user_id:
raise ValueError("user_id is required for medium blog generation (subscription checks and usage tracking)")
return await self.medium_blog_generator.generate_medium_blog_with_progress(req, task_id, user_id)
async def analyze_flow_basic(self, request: Dict[str, Any]) -> Dict[str, Any]: async def analyze_flow_basic(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""Analyze flow metrics for entire blog using single AI call (cost-effective).""" """Analyze flow metrics for entire blog using single AI call (cost-effective)."""

View File

@@ -42,10 +42,20 @@ class OutlineGenerator:
self.response_processor = ResponseProcessor() self.response_processor = ResponseProcessor()
self.parallel_processor = ParallelProcessor(self.source_mapper, self.grounding_engine) self.parallel_processor = ParallelProcessor(self.source_mapper, self.grounding_engine)
async def generate(self, request: BlogOutlineRequest) -> BlogOutlineResponse: async def generate(self, request: BlogOutlineRequest, user_id: str) -> BlogOutlineResponse:
""" """
Generate AI-powered outline using research results Generate AI-powered outline using research results.
Args:
request: Outline generation request with research data
user_id: User ID (required for subscription checks and usage tracking)
Raises:
ValueError: If user_id is not provided
""" """
if not user_id:
raise ValueError("user_id is required for outline generation (subscription checks and usage tracking)")
# Extract research insights # Extract research insights
research = request.research research = request.research
primary_keywords = research.keyword_analysis.get('primary', []) primary_keywords = research.keyword_analysis.get('primary', [])
@@ -68,15 +78,15 @@ class OutlineGenerator:
# Define schema with proper property ordering (critical for Gemini API) # Define schema with proper property ordering (critical for Gemini API)
outline_schema = self.prompt_builder.get_outline_schema() outline_schema = self.prompt_builder.get_outline_schema()
# Generate outline using structured JSON response with retry logic # Generate outline using structured JSON response with retry logic (user_id required)
outline_data = await self.response_processor.generate_with_retry(outline_prompt, outline_schema) outline_data = await self.response_processor.generate_with_retry(outline_prompt, outline_schema, user_id)
# Convert to BlogOutlineSection objects # Convert to BlogOutlineSection objects
outline_sections = self.response_processor.convert_to_sections(outline_data, sources) outline_sections = self.response_processor.convert_to_sections(outline_data, sources)
# Run parallel processing for speed optimization # Run parallel processing for speed optimization (user_id required)
mapped_sections, grounding_insights = await self.parallel_processor.run_parallel_processing_async( mapped_sections, grounding_insights = await self.parallel_processor.run_parallel_processing_async(
outline_sections, research outline_sections, research, user_id
) )
# Enhance sections with grounding insights # Enhance sections with grounding insights
@@ -85,9 +95,9 @@ class OutlineGenerator:
mapped_sections, research.grounding_metadata, grounding_insights mapped_sections, research.grounding_metadata, grounding_insights
) )
# Optimize outline for better flow, SEO, and engagement # Optimize outline for better flow, SEO, and engagement (user_id required)
logger.info("Optimizing outline for better flow and engagement...") logger.info("Optimizing outline for better flow and engagement...")
optimized_sections = await self.outline_optimizer.optimize(grounding_enhanced_sections, "comprehensive optimization") optimized_sections = await self.outline_optimizer.optimize(grounding_enhanced_sections, "comprehensive optimization", user_id)
# Rebalance word counts for optimal distribution # Rebalance word counts for optimal distribution
target_words = request.word_count or 1500 target_words = request.word_count or 1500
@@ -118,10 +128,21 @@ class OutlineGenerator:
research_coverage=research_coverage research_coverage=research_coverage
) )
async def generate_with_progress(self, request: BlogOutlineRequest, task_id: str) -> BlogOutlineResponse: async def generate_with_progress(self, request: BlogOutlineRequest, task_id: str, user_id: str) -> BlogOutlineResponse:
""" """
Outline generation method with progress updates for real-time feedback. Outline generation method with progress updates for real-time feedback.
Args:
request: Outline generation request with research data
task_id: Task ID for progress updates
user_id: User ID (required for subscription checks and usage tracking)
Raises:
ValueError: If user_id is not provided
""" """
if not user_id:
raise ValueError("user_id is required for outline generation (subscription checks and usage tracking)")
from api.blog_writer.task_manager import task_manager from api.blog_writer.task_manager import task_manager
# Extract research insights # Extract research insights
@@ -150,17 +171,17 @@ class OutlineGenerator:
await task_manager.update_progress(task_id, "🔄 Making AI request to generate structured outline...") await task_manager.update_progress(task_id, "🔄 Making AI request to generate structured outline...")
# Generate outline using structured JSON response with retry logic # Generate outline using structured JSON response with retry logic (user_id required for subscription checks)
outline_data = await self.response_processor.generate_with_retry(outline_prompt, outline_schema, task_id) outline_data = await self.response_processor.generate_with_retry(outline_prompt, outline_schema, user_id, task_id)
await task_manager.update_progress(task_id, "📝 Processing outline structure and validating sections...") await task_manager.update_progress(task_id, "📝 Processing outline structure and validating sections...")
# Convert to BlogOutlineSection objects # Convert to BlogOutlineSection objects
outline_sections = self.response_processor.convert_to_sections(outline_data, sources) outline_sections = self.response_processor.convert_to_sections(outline_data, sources)
# Run parallel processing for speed optimization # Run parallel processing for speed optimization (user_id required for subscription checks)
mapped_sections, grounding_insights = await self.parallel_processor.run_parallel_processing( mapped_sections, grounding_insights = await self.parallel_processor.run_parallel_processing(
outline_sections, research, task_id outline_sections, research, user_id, task_id
) )
# Enhance sections with grounding insights (depends on both previous tasks) # Enhance sections with grounding insights (depends on both previous tasks)
@@ -169,9 +190,9 @@ class OutlineGenerator:
mapped_sections, research.grounding_metadata, grounding_insights mapped_sections, research.grounding_metadata, grounding_insights
) )
# Optimize outline for better flow, SEO, and engagement # Optimize outline for better flow, SEO, and engagement (user_id required for subscription checks)
await task_manager.update_progress(task_id, "🎯 Optimizing outline for better flow and engagement...") await task_manager.update_progress(task_id, "🎯 Optimizing outline for better flow and engagement...")
optimized_sections = await self.outline_optimizer.optimize(grounding_enhanced_sections, "comprehensive optimization") optimized_sections = await self.outline_optimizer.optimize(grounding_enhanced_sections, "comprehensive optimization", user_id)
# Rebalance word counts for optimal distribution # Rebalance word counts for optimal distribution
await task_manager.update_progress(task_id, "⚖️ Rebalancing word count distribution...") await task_manager.update_progress(task_id, "⚖️ Rebalancing word count distribution...")

View File

@@ -13,8 +13,23 @@ from models.blog_models import BlogOutlineSection
class OutlineOptimizer: class OutlineOptimizer:
"""Optimizes outlines for better flow, SEO, and engagement.""" """Optimizes outlines for better flow, SEO, and engagement."""
async def optimize(self, outline: List[BlogOutlineSection], focus: str = "general optimization") -> List[BlogOutlineSection]: async def optimize(self, outline: List[BlogOutlineSection], focus: str, user_id: str) -> List[BlogOutlineSection]:
"""Optimize entire outline for better flow, SEO, and engagement.""" """Optimize entire outline for better flow, SEO, and engagement.
Args:
outline: List of outline sections to optimize
focus: Optimization focus (e.g., "general optimization")
user_id: User ID (required for subscription checks and usage tracking)
Returns:
List of optimized outline sections
Raises:
ValueError: If user_id is not provided
"""
if not user_id:
raise ValueError("user_id is required for outline optimization (subscription checks and usage tracking)")
outline_text = "\n".join([f"{i+1}. {s.heading}" for i, s in enumerate(outline)]) outline_text = "\n".join([f"{i+1}. {s.heading}" for i, s in enumerate(outline)])
optimization_prompt = f"""Optimize this blog outline for better flow, engagement, and SEO: optimization_prompt = f"""Optimize this blog outline for better flow, engagement, and SEO:
@@ -67,7 +82,8 @@ Return JSON format:
optimized_data = llm_text_gen( optimized_data = llm_text_gen(
prompt=optimization_prompt, prompt=optimization_prompt,
json_struct=optimization_schema, json_struct=optimization_schema,
system_prompt=None system_prompt=None,
user_id=user_id
) )
# Handle the new schema format with "outline" wrapper # Handle the new schema format with "outline" wrapper

View File

@@ -29,11 +29,21 @@ class OutlineService:
self.outline_optimizer = OutlineOptimizer() self.outline_optimizer = OutlineOptimizer()
self.section_enhancer = SectionEnhancer() self.section_enhancer = SectionEnhancer()
async def generate_outline(self, request: BlogOutlineRequest) -> BlogOutlineResponse: async def generate_outline(self, request: BlogOutlineRequest, user_id: str) -> BlogOutlineResponse:
""" """
Stage 2: Content Planning with AI-generated outline using research results Stage 2: Content Planning with AI-generated outline using research results.
Uses Gemini with research data to create comprehensive, SEO-optimized outline Uses Gemini with research data to create comprehensive, SEO-optimized outline.
Args:
request: Outline generation request with research data
user_id: User ID (required for subscription checks and usage tracking)
Raises:
ValueError: If user_id is not provided
""" """
if not user_id:
raise ValueError("user_id is required for outline generation (subscription checks and usage tracking)")
# Extract cache parameters - use original user keywords for consistent caching # Extract cache parameters - use original user keywords for consistent caching
keywords = request.research.original_keywords or request.research.keyword_analysis.get('primary', []) keywords = request.research.original_keywords or request.research.keyword_analysis.get('primary', [])
industry = getattr(request.persona, 'industry', 'general') if request.persona else 'general' industry = getattr(request.persona, 'industry', 'general') if request.persona else 'general'
@@ -56,9 +66,9 @@ class OutlineService:
logger.info(f"Using cached outline for keywords: {keywords}") logger.info(f"Using cached outline for keywords: {keywords}")
return BlogOutlineResponse(**cached_result) return BlogOutlineResponse(**cached_result)
# Generate new outline if not cached # Generate new outline if not cached (user_id required)
logger.info(f"Generating new outline for keywords: {keywords}") logger.info(f"Generating new outline for keywords: {keywords}")
result = await self.outline_generator.generate(request) result = await self.outline_generator.generate(request, user_id)
# Cache the result # Cache the result
persistent_outline_cache.cache_outline( persistent_outline_cache.cache_outline(
@@ -73,7 +83,7 @@ class OutlineService:
return result return result
async def generate_outline_with_progress(self, request: BlogOutlineRequest, task_id: str) -> BlogOutlineResponse: async def generate_outline_with_progress(self, request: BlogOutlineRequest, task_id: str, user_id: str) -> BlogOutlineResponse:
""" """
Outline generation method with progress updates for real-time feedback. Outline generation method with progress updates for real-time feedback.
""" """
@@ -104,7 +114,7 @@ class OutlineService:
# Generate new outline if not cached # Generate new outline if not cached
logger.info(f"Generating new outline for keywords: {keywords} (with progress updates)") logger.info(f"Generating new outline for keywords: {keywords} (with progress updates)")
result = await self.outline_generator.generate_with_progress(request, task_id) result = await self.outline_generator.generate_with_progress(request, task_id, user_id)
# Cache the result # Cache the result
persistent_outline_cache.cache_outline( persistent_outline_cache.cache_outline(

View File

@@ -17,18 +17,25 @@ class ParallelProcessor:
self.source_mapper = source_mapper self.source_mapper = source_mapper
self.grounding_engine = grounding_engine self.grounding_engine = grounding_engine
async def run_parallel_processing(self, outline_sections, research, task_id: str = None) -> Tuple[Any, Any]: async def run_parallel_processing(self, outline_sections, research, user_id: str, task_id: str = None) -> Tuple[Any, Any]:
""" """
Run source mapping and grounding insights extraction in parallel. Run source mapping and grounding insights extraction in parallel.
Args: Args:
outline_sections: List of outline sections to process outline_sections: List of outline sections to process
research: Research data object research: Research data object
user_id: User ID (required for subscription checks and usage tracking)
task_id: Optional task ID for progress updates task_id: Optional task ID for progress updates
Returns: Returns:
Tuple of (mapped_sections, grounding_insights) Tuple of (mapped_sections, grounding_insights)
Raises:
ValueError: If user_id is not provided
""" """
if not user_id:
raise ValueError("user_id is required for parallel processing (subscription checks and usage tracking)")
if task_id: if task_id:
from api.blog_writer.task_manager import task_manager from api.blog_writer.task_manager import task_manager
await task_manager.update_progress(task_id, "⚡ Running parallel processing for maximum speed...") await task_manager.update_progress(task_id, "⚡ Running parallel processing for maximum speed...")
@@ -37,7 +44,7 @@ class ParallelProcessor:
# Run these tasks in parallel to save time # Run these tasks in parallel to save time
source_mapping_task = asyncio.create_task( source_mapping_task = asyncio.create_task(
self._run_source_mapping(outline_sections, research, task_id) self._run_source_mapping(outline_sections, research, task_id, user_id)
) )
grounding_insights_task = asyncio.create_task( grounding_insights_task = asyncio.create_task(
@@ -52,22 +59,29 @@ class ParallelProcessor:
return mapped_sections, grounding_insights return mapped_sections, grounding_insights
async def run_parallel_processing_async(self, outline_sections, research) -> Tuple[Any, Any]: async def run_parallel_processing_async(self, outline_sections, research, user_id: str) -> Tuple[Any, Any]:
""" """
Run parallel processing without progress updates (for non-progress methods). Run parallel processing without progress updates (for non-progress methods).
Args: Args:
outline_sections: List of outline sections to process outline_sections: List of outline sections to process
research: Research data object research: Research data object
user_id: User ID (required for subscription checks and usage tracking)
Returns: Returns:
Tuple of (mapped_sections, grounding_insights) Tuple of (mapped_sections, grounding_insights)
Raises:
ValueError: If user_id is not provided
""" """
if not user_id:
raise ValueError("user_id is required for parallel processing (subscription checks and usage tracking)")
logger.info("Running parallel processing for maximum speed...") logger.info("Running parallel processing for maximum speed...")
# Run these tasks in parallel to save time # Run these tasks in parallel to save time
source_mapping_task = asyncio.create_task( source_mapping_task = asyncio.create_task(
self._run_source_mapping_async(outline_sections, research) self._run_source_mapping_async(outline_sections, research, user_id)
) )
grounding_insights_task = asyncio.create_task( grounding_insights_task = asyncio.create_task(
@@ -82,12 +96,12 @@ class ParallelProcessor:
return mapped_sections, grounding_insights return mapped_sections, grounding_insights
async def _run_source_mapping(self, outline_sections, research, task_id): async def _run_source_mapping(self, outline_sections, research, task_id, user_id: str):
"""Run source mapping in parallel.""" """Run source mapping in parallel."""
if task_id: if task_id:
from api.blog_writer.task_manager import task_manager from api.blog_writer.task_manager import task_manager
await task_manager.update_progress(task_id, "🔗 Applying intelligent source-to-section mapping...") await task_manager.update_progress(task_id, "🔗 Applying intelligent source-to-section mapping...")
return self.source_mapper.map_sources_to_sections(outline_sections, research) return self.source_mapper.map_sources_to_sections(outline_sections, research, user_id)
async def _run_grounding_insights_extraction(self, research, task_id): async def _run_grounding_insights_extraction(self, research, task_id):
"""Run grounding insights extraction in parallel.""" """Run grounding insights extraction in parallel."""
@@ -96,10 +110,10 @@ class ParallelProcessor:
await task_manager.update_progress(task_id, "🧠 Extracting grounding metadata insights...") await task_manager.update_progress(task_id, "🧠 Extracting grounding metadata insights...")
return self.grounding_engine.extract_contextual_insights(research.grounding_metadata) return self.grounding_engine.extract_contextual_insights(research.grounding_metadata)
async def _run_source_mapping_async(self, outline_sections, research): async def _run_source_mapping_async(self, outline_sections, research, user_id: str):
"""Run source mapping in parallel (async version without progress updates).""" """Run source mapping in parallel (async version without progress updates)."""
logger.info("Applying intelligent source-to-section mapping...") logger.info("Applying intelligent source-to-section mapping...")
return self.source_mapper.map_sources_to_sections(outline_sections, research) return self.source_mapper.map_sources_to_sections(outline_sections, research, user_id)
async def _run_grounding_insights_extraction_async(self, research): async def _run_grounding_insights_extraction_async(self, research):
"""Run grounding insights extraction in parallel (async version without progress updates).""" """Run grounding insights extraction in parallel (async version without progress updates)."""

View File

@@ -18,8 +18,21 @@ class ResponseProcessor:
"""Initialize the response processor.""" """Initialize the response processor."""
pass pass
async def generate_with_retry(self, prompt: str, schema: Dict[str, Any], task_id: str = None) -> Dict[str, Any]: async def generate_with_retry(self, prompt: str, schema: Dict[str, Any], user_id: str, task_id: str = None) -> Dict[str, Any]:
"""Generate outline with retry logic for API failures.""" """Generate outline with retry logic for API failures.
Args:
prompt: The prompt for outline generation
schema: JSON schema for structured response
user_id: User ID (required for subscription checks and usage tracking)
task_id: Optional task ID for progress updates
Raises:
ValueError: If user_id is not provided
"""
if not user_id:
raise ValueError("user_id is required for outline generation (subscription checks and usage tracking)")
from services.llm_providers.main_text_generation import llm_text_gen from services.llm_providers.main_text_generation import llm_text_gen
from api.blog_writer.task_manager import task_manager from api.blog_writer.task_manager import task_manager
@@ -34,7 +47,8 @@ class ResponseProcessor:
outline_data = llm_text_gen( outline_data = llm_text_gen(
prompt=prompt, prompt=prompt,
json_struct=schema, json_struct=schema,
system_prompt=None system_prompt=None,
user_id=user_id
) )
# Log response for debugging # Log response for debugging

View File

@@ -12,8 +12,23 @@ from models.blog_models import BlogOutlineSection
class SectionEnhancer: class SectionEnhancer:
"""Enhances individual outline sections using AI.""" """Enhances individual outline sections using AI."""
async def enhance(self, section: BlogOutlineSection, focus: str = "general improvement") -> BlogOutlineSection: async def enhance(self, section: BlogOutlineSection, focus: str, user_id: str) -> BlogOutlineSection:
"""Enhance a section using AI with research context.""" """Enhance a section using AI with research context.
Args:
section: Outline section to enhance
focus: Enhancement focus (e.g., "general improvement")
user_id: User ID (required for subscription checks and usage tracking)
Returns:
Enhanced outline section
Raises:
ValueError: If user_id is not provided
"""
if not user_id:
raise ValueError("user_id is required for section enhancement (subscription checks and usage tracking)")
enhancement_prompt = f""" enhancement_prompt = f"""
Enhance the following blog section to make it more engaging, comprehensive, and valuable: Enhance the following blog section to make it more engaging, comprehensive, and valuable:
@@ -61,7 +76,8 @@ class SectionEnhancer:
enhanced_data = llm_text_gen( enhanced_data = llm_text_gen(
prompt=enhancement_prompt, prompt=enhancement_prompt,
json_struct=enhancement_schema, json_struct=enhancement_schema,
system_prompt=None system_prompt=None,
user_id=user_id
) )
if isinstance(enhanced_data, dict) and 'error' not in enhanced_data: if isinstance(enhanced_data, dict) and 'error' not in enhanced_data:

View File

@@ -52,7 +52,8 @@ class SourceToSectionMapper:
def map_sources_to_sections( def map_sources_to_sections(
self, self,
sections: List[BlogOutlineSection], sections: List[BlogOutlineSection],
research_data: BlogResearchResponse research_data: BlogResearchResponse,
user_id: str
) -> List[BlogOutlineSection]: ) -> List[BlogOutlineSection]:
""" """
Map research sources to outline sections using intelligent algorithms. Map research sources to outline sections using intelligent algorithms.
@@ -60,10 +61,17 @@ class SourceToSectionMapper:
Args: Args:
sections: List of outline sections to map sources to sections: List of outline sections to map sources to
research_data: Research data containing sources and metadata research_data: Research data containing sources and metadata
user_id: User ID (required for subscription checks and usage tracking)
Returns: Returns:
List of outline sections with intelligently mapped sources List of outline sections with intelligently mapped sources
Raises:
ValueError: If user_id is not provided
""" """
if not user_id:
raise ValueError("user_id is required for source mapping (subscription checks and usage tracking)")
if not sections or not research_data.sources: if not sections or not research_data.sources:
logger.warning("No sections or sources to map") logger.warning("No sections or sources to map")
return sections return sections
@@ -73,8 +81,8 @@ class SourceToSectionMapper:
# Step 1: Algorithmic mapping # Step 1: Algorithmic mapping
mapping_results = self._algorithmic_source_mapping(sections, research_data) mapping_results = self._algorithmic_source_mapping(sections, research_data)
# Step 2: AI validation and improvement (single prompt) # Step 2: AI validation and improvement (single prompt, user_id required for subscription checks)
validated_mapping = self._ai_validate_mapping(mapping_results, research_data) validated_mapping = self._ai_validate_mapping(mapping_results, research_data, user_id)
# Step 3: Apply validated mapping to sections # Step 3: Apply validated mapping to sections
mapped_sections = self._apply_mapping_to_sections(sections, validated_mapping) mapped_sections = self._apply_mapping_to_sections(sections, validated_mapping)
@@ -261,7 +269,8 @@ class SourceToSectionMapper:
def _ai_validate_mapping( def _ai_validate_mapping(
self, self,
mapping_results: Dict[str, List[Tuple[ResearchSource, float]]], mapping_results: Dict[str, List[Tuple[ResearchSource, float]]],
research_data: BlogResearchResponse research_data: BlogResearchResponse,
user_id: str
) -> Dict[str, List[Tuple[ResearchSource, float]]]: ) -> Dict[str, List[Tuple[ResearchSource, float]]]:
""" """
Use AI to validate and improve the algorithmic mapping results. Use AI to validate and improve the algorithmic mapping results.
@@ -269,18 +278,25 @@ class SourceToSectionMapper:
Args: Args:
mapping_results: Algorithmic mapping results mapping_results: Algorithmic mapping results
research_data: Research data for context research_data: Research data for context
user_id: User ID (required for subscription checks and usage tracking)
Returns: Returns:
AI-validated and improved mapping results AI-validated and improved mapping results
Raises:
ValueError: If user_id is not provided
""" """
if not user_id:
raise ValueError("user_id is required for AI validation (subscription checks and usage tracking)")
try: try:
logger.info("Starting AI validation of source-to-section mapping...") logger.info("Starting AI validation of source-to-section mapping...")
# Build AI validation prompt # Build AI validation prompt
validation_prompt = self._build_validation_prompt(mapping_results, research_data) validation_prompt = self._build_validation_prompt(mapping_results, research_data)
# Get AI validation response # Get AI validation response (user_id required for subscription checks)
validation_response = self._get_ai_validation_response(validation_prompt) validation_response = self._get_ai_validation_response(validation_prompt, user_id)
# Parse and apply AI validation results # Parse and apply AI validation results
validated_mapping = self._parse_validation_response(validation_response, mapping_results, research_data) validated_mapping = self._parse_validation_response(validation_response, mapping_results, research_data)
@@ -548,23 +564,31 @@ Analyze the mapping and provide your recommendations.
return prompt return prompt
def _get_ai_validation_response(self, prompt: str) -> str: def _get_ai_validation_response(self, prompt: str, user_id: str) -> str:
""" """
Get AI validation response using LLM provider. Get AI validation response using LLM provider.
Args: Args:
prompt: Validation prompt prompt: Validation prompt
user_id: User ID (required for subscription checks and usage tracking)
Returns: Returns:
AI validation response AI validation response
Raises:
ValueError: If user_id is not provided
""" """
if not user_id:
raise ValueError("user_id is required for AI validation response (subscription checks and usage tracking)")
try: try:
from services.llm_providers.main_text_generation import llm_text_gen from services.llm_providers.main_text_generation import llm_text_gen
response = llm_text_gen( response = llm_text_gen(
prompt=prompt, prompt=prompt,
json_struct=None, json_struct=None,
system_prompt=None system_prompt=None,
user_id=user_id
) )
return response return response

View File

@@ -13,11 +13,17 @@ from .keyword_analyzer import KeywordAnalyzer
from .competitor_analyzer import CompetitorAnalyzer from .competitor_analyzer import CompetitorAnalyzer
from .content_angle_generator import ContentAngleGenerator from .content_angle_generator import ContentAngleGenerator
from .data_filter import ResearchDataFilter from .data_filter import ResearchDataFilter
from .base_provider import ResearchProvider as BaseResearchProvider
from .google_provider import GoogleResearchProvider
from .exa_provider import ExaResearchProvider
__all__ = [ __all__ = [
'ResearchService', 'ResearchService',
'KeywordAnalyzer', 'KeywordAnalyzer',
'CompetitorAnalyzer', 'CompetitorAnalyzer',
'ContentAngleGenerator', 'ContentAngleGenerator',
'ResearchDataFilter' 'ResearchDataFilter',
'BaseResearchProvider',
'GoogleResearchProvider',
'ExaResearchProvider',
] ]

View File

@@ -0,0 +1,37 @@
"""
Base Research Provider Interface
Abstract base class for research provider implementations.
Ensures consistency across different research providers (Google, Exa, etc.)
"""
from abc import ABC, abstractmethod
from typing import Dict, Any
class ResearchProvider(ABC):
"""Abstract base class for research providers."""
@abstractmethod
async def search(
self,
prompt: str,
topic: str,
industry: str,
target_audience: str,
config: Any, # ResearchConfig
user_id: str
) -> Dict[str, Any]:
"""Execute research and return raw results."""
pass
@abstractmethod
def get_provider_enum(self):
"""Return APIProvider enum for subscription tracking."""
pass
@abstractmethod
def estimate_tokens(self) -> int:
"""Estimate token usage for pre-flight validation."""
pass

View File

@@ -0,0 +1,188 @@
"""
Exa Research Provider
Neural search implementation using Exa API for high-quality, citation-rich research.
"""
from exa_py import Exa
import os
from loguru import logger
from models.subscription_models import APIProvider
from .base_provider import ResearchProvider as BaseProvider
class ExaResearchProvider(BaseProvider):
"""Exa neural search provider."""
def __init__(self):
self.api_key = os.getenv("EXA_API_KEY")
if not self.api_key:
raise RuntimeError("EXA_API_KEY not configured")
self.exa = Exa(self.api_key)
logger.info("✅ Exa Research Provider initialized")
async def search(self, prompt, topic, industry, target_audience, config, user_id):
"""Execute Exa neural search and return standardized results."""
# Build Exa query
query = f"{topic} {industry} {target_audience}"
# Map source types to Exa categories
category = self._map_source_type_to_category(config.source_types)
logger.info(f"[Exa Research] Executing search: {query}")
# Execute Exa search
results = self.exa.search_and_contents(
query,
type="auto",
category=category,
num_results=min(config.max_sources, 25),
contents={
'text': {'max_characters': 1000},
'summary': {'query': f"Key insights about {topic}"},
'highlights': {
'num_sentences': 2,
'highlights_per_url': 3
}
}
)
# Transform to standardized format
sources = self._transform_sources(results.results)
content = self._aggregate_content(results.results)
search_type = getattr(results, 'resolvedSearchType', 'neural') if hasattr(results, 'resolvedSearchType') else 'neural'
# Get cost if available
cost = 0.005 # Default Exa cost for 1-25 results
if hasattr(results, 'costDollars'):
if hasattr(results.costDollars, 'total'):
cost = results.costDollars.total
logger.info(f"[Exa Research] Search completed: {len(sources)} sources, type: {search_type}")
return {
'sources': sources,
'content': content,
'search_type': search_type,
'provider': 'exa',
'search_queries': [query],
'cost': {'total': cost}
}
def get_provider_enum(self):
"""Return EXA provider enum for subscription tracking."""
return APIProvider.EXA
def estimate_tokens(self) -> int:
"""Estimate token usage for Exa (not token-based)."""
return 0 # Exa is per-search, not token-based
def _map_source_type_to_category(self, source_types):
"""Map SourceType enum to Exa category parameter."""
if not source_types:
return None
category_map = {
'research paper': 'research paper',
'news': 'news',
'web': 'personal site',
'industry': 'company',
'expert': 'linkedin profile'
}
for st in source_types:
if st.value in category_map:
return category_map[st.value]
return None
def _transform_sources(self, results):
"""Transform Exa results to ResearchSource format."""
sources = []
for idx, result in enumerate(results):
source_type = self._determine_source_type(result.url if hasattr(result, 'url') else '')
sources.append({
'title': result.title if hasattr(result, 'title') else '',
'url': result.url if hasattr(result, 'url') else '',
'excerpt': self._get_excerpt(result),
'credibility_score': 0.85, # Exa results are high quality
'published_at': result.publishedDate if hasattr(result, 'publishedDate') else None,
'index': idx,
'source_type': source_type,
'content': result.text if hasattr(result, 'text') else '',
'highlights': result.highlights if hasattr(result, 'highlights') else [],
'summary': result.summary if hasattr(result, 'summary') else ''
})
return sources
def _get_excerpt(self, result):
"""Extract excerpt from Exa result."""
if hasattr(result, 'text') and result.text:
return result.text[:500]
elif hasattr(result, 'summary') and result.summary:
return result.summary
return ''
def _determine_source_type(self, url):
"""Determine source type from URL."""
if not url:
return 'web'
url_lower = url.lower()
if 'arxiv.org' in url_lower or 'research' in url_lower:
return 'academic'
elif any(news in url_lower for news in ['cnn.com', 'bbc.com', 'reuters.com', 'theguardian.com']):
return 'news'
elif 'linkedin.com' in url_lower:
return 'expert'
else:
return 'web'
def _aggregate_content(self, results):
"""Aggregate content from Exa results for LLM analysis."""
content_parts = []
for idx, result in enumerate(results):
if hasattr(result, 'summary') and result.summary:
content_parts.append(f"Source {idx + 1}: {result.summary}")
elif hasattr(result, 'text') and result.text:
content_parts.append(f"Source {idx + 1}: {result.text[:1000]}")
return "\n\n".join(content_parts)
def track_exa_usage(self, user_id: str, cost: float):
"""Track Exa API usage after successful call."""
from services.database import get_db
from services.subscription import PricingService
from sqlalchemy import text
db = next(get_db())
try:
pricing_service = PricingService(db)
current_period = pricing_service.get_current_billing_period(user_id)
# Update exa_calls and exa_cost via SQL UPDATE
update_query = text("""
UPDATE usage_summaries
SET exa_calls = COALESCE(exa_calls, 0) + 1,
exa_cost = COALESCE(exa_cost, 0) + :cost,
total_calls = total_calls + 1,
total_cost = total_cost + :cost
WHERE user_id = :user_id AND billing_period = :period
""")
db.execute(update_query, {
'cost': cost,
'user_id': user_id,
'period': current_period
})
db.commit()
logger.info(f"[Exa] Tracked usage: user={user_id}, cost=${cost}")
except Exception as e:
logger.error(f"[Exa] Failed to track usage: {e}")
db.rollback()
finally:
db.close()

View File

@@ -0,0 +1,40 @@
"""
Google Research Provider
Wrapper for Gemini native Google Search grounding to match base provider interface.
"""
from services.llm_providers.gemini_grounded_provider import GeminiGroundedProvider
from models.subscription_models import APIProvider
from .base_provider import ResearchProvider as BaseProvider
from loguru import logger
class GoogleResearchProvider(BaseProvider):
"""Google research provider using Gemini native grounding."""
def __init__(self):
self.gemini = GeminiGroundedProvider()
async def search(self, prompt, topic, industry, target_audience, config, user_id):
"""Call Gemini grounding with pre-flight validation."""
logger.info(f"[Google Research] Executing search for topic: {topic}")
result = await self.gemini.generate_grounded_content(
prompt=prompt,
content_type="research",
max_tokens=2000,
user_id=user_id,
validate_subsequent_operations=True
)
return result
def get_provider_enum(self):
"""Return GEMINI provider enum for subscription tracking."""
return APIProvider.GEMINI
def estimate_tokens(self) -> int:
"""Estimate token usage for Google grounding."""
return 1200 # Conservative estimate

View File

@@ -16,6 +16,9 @@ from models.blog_models import (
GroundingChunk, GroundingChunk,
GroundingSupport, GroundingSupport,
Citation, Citation,
ResearchConfig,
ResearchMode,
ResearchProvider,
) )
from services.blog_writer.logger_config import blog_writer_logger, log_function_call from services.blog_writer.logger_config import blog_writer_logger, log_function_call
from fastapi import HTTPException from fastapi import HTTPException
@@ -24,6 +27,7 @@ from .keyword_analyzer import KeywordAnalyzer
from .competitor_analyzer import CompetitorAnalyzer from .competitor_analyzer import CompetitorAnalyzer
from .content_angle_generator import ContentAngleGenerator from .content_angle_generator import ContentAngleGenerator
from .data_filter import ResearchDataFilter from .data_filter import ResearchDataFilter
from .research_strategies import get_strategy_for_mode
class ResearchService: class ResearchService:
@@ -44,7 +48,6 @@ class ResearchService:
Includes intelligent caching for exact keyword matches. Includes intelligent caching for exact keyword matches.
""" """
try: try:
from services.llm_providers.gemini_grounded_provider import GeminiGroundedProvider
from services.cache.research_cache import research_cache from services.cache.research_cache import research_cache
topic = request.topic or ", ".join(request.keywords) topic = request.topic or ", ".join(request.keywords)
@@ -79,62 +82,104 @@ class ResearchService:
# Cache miss - proceed with API call # Cache miss - proceed with API call
logger.info(f"Cache miss - making API call for keywords: {request.keywords}") logger.info(f"Cache miss - making API call for keywords: {request.keywords}")
blog_writer_logger.log_operation_start("gemini_api_call", api_name="gemini_grounded", operation="research") blog_writer_logger.log_operation_start("research_api_call", api_name="research", operation="research")
gemini = GeminiGroundedProvider()
# Single comprehensive research prompt - Gemini handles Google Search automatically # Determine research mode and get appropriate strategy
research_prompt = f""" research_mode = request.research_mode or ResearchMode.BASIC
Research the topic "{topic}" in the {industry} industry for {target_audience} audience. Provide a comprehensive analysis including: config = request.config or ResearchConfig(mode=research_mode, provider=ResearchProvider.GOOGLE)
strategy = get_strategy_for_mode(research_mode)
1. Current trends and insights (2024-2025)
2. Key statistics and data points with sources
3. Industry expert opinions and quotes
4. Recent developments and news
5. Market analysis and forecasts
6. Best practices and case studies
7. Keyword analysis: primary, secondary, and long-tail opportunities
8. Competitor analysis: top players and content gaps
9. Content angle suggestions: 5 compelling angles for blog posts
Focus on factual, up-to-date information from credible sources.
Include specific data points, percentages, and recent developments.
Structure your response with clear sections for each analysis area.
"""
# Single Gemini call with native Google Search grounding - no fallbacks logger.info(f"Research: mode={research_mode.value}, provider={config.provider.value}")
# Validation is handled inside generate_grounded_content when validate_subsequent_operations=True
import time
api_start_time = time.time()
gemini_result = await gemini.generate_grounded_content(
prompt=research_prompt,
content_type="research",
max_tokens=2000,
user_id=user_id,
validate_subsequent_operations=True # Validates Google Grounding + 3 LLM calls
)
api_duration_ms = (time.time() - api_start_time) * 1000
# Log API call performance # Build research prompt based on strategy
blog_writer_logger.log_api_call( research_prompt = strategy.build_research_prompt(topic, industry, target_audience, config)
"gemini_grounded",
"generate_grounded_content",
api_duration_ms,
token_usage=gemini_result.get("token_usage", {}),
content_length=len(gemini_result.get("content", ""))
)
# Extract sources from grounding metadata # Route to appropriate provider
sources = self._extract_sources_from_grounding(gemini_result) if config.provider == ResearchProvider.EXA:
# Exa research workflow
from .exa_provider import ExaResearchProvider
from services.subscription.preflight_validator import validate_exa_research_operations
from services.database import get_db
from services.subscription import PricingService
import os
import time
# Pre-flight validation
db_val = next(get_db())
try:
pricing_service = PricingService(db_val)
gpt_provider = os.getenv("GPT_PROVIDER", "google")
validate_exa_research_operations(pricing_service, user_id, gpt_provider)
finally:
db_val.close()
# Execute Exa search
api_start_time = time.time()
try:
exa_provider = ExaResearchProvider()
raw_result = await exa_provider.search(
research_prompt, topic, industry, target_audience, config, user_id
)
api_duration_ms = (time.time() - api_start_time) * 1000
# Track usage
cost = raw_result.get('cost', {}).get('total', 0.005) if isinstance(raw_result.get('cost'), dict) else 0.005
exa_provider.track_exa_usage(user_id, cost)
# Log API call performance
blog_writer_logger.log_api_call(
"exa_search",
"search_and_contents",
api_duration_ms,
token_usage={},
content_length=len(raw_result.get('content', ''))
)
# Extract content for downstream analysis
content = raw_result.get('content', '')
sources = raw_result.get('sources', [])
search_widget = "" # Exa doesn't provide search widgets
search_queries = raw_result.get('search_queries', [])
grounding_metadata = None # Exa doesn't provide grounding metadata
except RuntimeError as e:
if "EXA_API_KEY not configured" in str(e):
logger.warning("Exa not configured, falling back to Google")
config.provider = ResearchProvider.GOOGLE
# Continue to Google flow below
raw_result = None
else:
raise
if config.provider != ResearchProvider.EXA:
# Google research (existing flow) or fallback from Exa
from .google_provider import GoogleResearchProvider
import time
api_start_time = time.time()
google_provider = GoogleResearchProvider()
gemini_result = await google_provider.search(
research_prompt, topic, industry, target_audience, config, user_id
)
api_duration_ms = (time.time() - api_start_time) * 1000
# Log API call performance
blog_writer_logger.log_api_call(
"gemini_grounded",
"generate_grounded_content",
api_duration_ms,
token_usage=gemini_result.get("token_usage", {}),
content_length=len(gemini_result.get("content", ""))
)
# Extract sources and content
sources = self._extract_sources_from_grounding(gemini_result)
content = gemini_result.get("content", "")
search_widget = gemini_result.get("search_widget", "") or ""
search_queries = gemini_result.get("search_queries", []) or []
grounding_metadata = self._extract_grounding_metadata(gemini_result)
# Extract grounding metadata for detailed UI display # Continue with common analysis (same for both providers)
grounding_metadata = self._extract_grounding_metadata(gemini_result)
# Extract search widget and queries for UI display
search_widget = gemini_result.get("search_widget", "") or ""
search_queries = gemini_result.get("search_queries", []) or []
# Parse the comprehensive response for different analysis components
content = gemini_result.get("content", "")
keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords, user_id=user_id) keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords, user_id=user_id)
competitor_analysis = self.competitor_analyzer.analyze(content, user_id=user_id) competitor_analysis = self.competitor_analyzer.analyze(content, user_id=user_id)
suggested_angles = self.content_angle_generator.generate(content, topic, industry, user_id=user_id) suggested_angles = self.content_angle_generator.generate(content, topic, industry, user_id=user_id)
@@ -261,7 +306,6 @@ class ResearchService:
Research method with progress updates for real-time feedback. Research method with progress updates for real-time feedback.
""" """
try: try:
from services.llm_providers.gemini_grounded_provider import GeminiGroundedProvider
from services.cache.research_cache import research_cache from services.cache.research_cache import research_cache
from services.cache.persistent_research_cache import persistent_research_cache from services.cache.persistent_research_cache import persistent_research_cache
from api.blog_writer.task_manager import task_manager from api.blog_writer.task_manager import task_manager
@@ -293,66 +337,100 @@ class ResearchService:
logger.info(f"Returning cached research result for keywords: {request.keywords}") logger.info(f"Returning cached research result for keywords: {request.keywords}")
return BlogResearchResponse(**cached_result) return BlogResearchResponse(**cached_result)
# User ID validation (validation logic is now in Google Grounding provider) # User ID validation
if not user_id: if not user_id:
await task_manager.update_progress(task_id, "❌ Error: User ID is required for research operation") await task_manager.update_progress(task_id, "❌ Error: User ID is required for research operation")
raise ValueError("user_id is required for research operation. Please provide Clerk user ID.") raise ValueError("user_id is required for research operation. Please provide Clerk user ID.")
# Cache miss - proceed with API call # Determine research mode and get appropriate strategy
await task_manager.update_progress(task_id, "🌐 Cache miss - connecting to Google Search grounding...") research_mode = request.research_mode or ResearchMode.BASIC
logger.info(f"Cache miss - making API call for keywords: {request.keywords}") config = request.config or ResearchConfig(mode=research_mode, provider=ResearchProvider.GOOGLE)
gemini = GeminiGroundedProvider() strategy = get_strategy_for_mode(research_mode)
# Single comprehensive research prompt - Gemini handles Google Search automatically
research_prompt = f"""
Research the topic "{topic}" in the {industry} industry for {target_audience} audience. Provide a comprehensive analysis including:
1. Current trends and insights (2024-2025)
2. Key statistics and data points with sources
3. Industry expert opinions and quotes
4. Recent developments and news
5. Market analysis and forecasts
6. Best practices and case studies
7. Keyword analysis: primary, secondary, and long-tail opportunities
8. Competitor analysis: top players and content gaps
9. Content angle suggestions: 5 compelling angles for blog posts
Focus on factual, up-to-date information from credible sources.
Include specific data points, percentages, and recent developments.
Structure your response with clear sections for each analysis area.
"""
await task_manager.update_progress(task_id, "🤖 Making AI request to Gemini with Google Search grounding...") logger.info(f"Research: mode={research_mode.value}, provider={config.provider.value}")
# Single Gemini call with native Google Search grounding - no fallbacks
# Validation is handled inside generate_grounded_content when validate_subsequent_operations=True
try:
gemini_result = await gemini.generate_grounded_content(
prompt=research_prompt,
content_type="research",
max_tokens=2000,
user_id=user_id,
validate_subsequent_operations=True # Validates Google Grounding + 3 LLM calls
)
except HTTPException as http_error:
# Re-raise HTTPException so it can be properly handled by task manager
logger.error(f"Subscription limit exceeded for research: {http_error.detail}")
await task_manager.update_progress(task_id, f"❌ Subscription limit exceeded: {http_error.detail.get('message', str(http_error.detail)) if isinstance(http_error.detail, dict) else str(http_error.detail)}")
raise # Re-raise HTTPException to preserve status code and error details
await task_manager.update_progress(task_id, "📊 Processing research results and extracting insights...") # Build research prompt based on strategy
# Extract sources from grounding metadata research_prompt = strategy.build_research_prompt(topic, industry, target_audience, config)
sources = self._extract_sources_from_grounding(gemini_result)
# Extract grounding metadata for detailed UI display # Route to appropriate provider
grounding_metadata = self._extract_grounding_metadata(gemini_result) if config.provider == ResearchProvider.EXA:
# Exa research workflow
# Extract search widget and queries for UI display from .exa_provider import ExaResearchProvider
search_widget = gemini_result.get("search_widget", "") or "" from services.subscription.preflight_validator import validate_exa_research_operations
search_queries = gemini_result.get("search_queries", []) or [] from services.database import get_db
from services.subscription import PricingService
import os
await task_manager.update_progress(task_id, "🌐 Connecting to Exa neural search...")
# Pre-flight validation
db_val = next(get_db())
try:
pricing_service = PricingService(db_val)
gpt_provider = os.getenv("GPT_PROVIDER", "google")
validate_exa_research_operations(pricing_service, user_id, gpt_provider)
except HTTPException as http_error:
logger.error(f"Subscription limit exceeded for Exa research: {http_error.detail}")
await task_manager.update_progress(task_id, f"❌ Subscription limit exceeded: {http_error.detail.get('message', str(http_error.detail)) if isinstance(http_error.detail, dict) else str(http_error.detail)}")
raise
finally:
db_val.close()
# Execute Exa search
await task_manager.update_progress(task_id, "🤖 Executing Exa neural search...")
try:
exa_provider = ExaResearchProvider()
raw_result = await exa_provider.search(
research_prompt, topic, industry, target_audience, config, user_id
)
# Track usage
cost = raw_result.get('cost', {}).get('total', 0.005) if isinstance(raw_result.get('cost'), dict) else 0.005
exa_provider.track_exa_usage(user_id, cost)
# Extract content for downstream analysis
content = raw_result.get('content', '')
sources = raw_result.get('sources', [])
search_widget = "" # Exa doesn't provide search widgets
search_queries = raw_result.get('search_queries', [])
grounding_metadata = None # Exa doesn't provide grounding metadata
except RuntimeError as e:
if "EXA_API_KEY not configured" in str(e):
logger.warning("Exa not configured, falling back to Google")
await task_manager.update_progress(task_id, "⚠️ Exa not configured, falling back to Google Search")
config.provider = ResearchProvider.GOOGLE
# Continue to Google flow below
else:
raise
if config.provider != ResearchProvider.EXA:
# Google research (existing flow)
from .google_provider import GoogleResearchProvider
await task_manager.update_progress(task_id, "🌐 Connecting to Google Search grounding...")
google_provider = GoogleResearchProvider()
await task_manager.update_progress(task_id, "🤖 Making AI request to Gemini with Google Search grounding...")
try:
gemini_result = await google_provider.search(
research_prompt, topic, industry, target_audience, config, user_id
)
except HTTPException as http_error:
logger.error(f"Subscription limit exceeded for Google research: {http_error.detail}")
await task_manager.update_progress(task_id, f"❌ Subscription limit exceeded: {http_error.detail.get('message', str(http_error.detail)) if isinstance(http_error.detail, dict) else str(http_error.detail)}")
raise
await task_manager.update_progress(task_id, "📊 Processing research results and extracting insights...")
# Extract sources and content
sources = self._extract_sources_from_grounding(gemini_result)
content = gemini_result.get("content", "")
search_widget = gemini_result.get("search_widget", "") or ""
search_queries = gemini_result.get("search_queries", []) or []
grounding_metadata = self._extract_grounding_metadata(gemini_result)
# Continue with common analysis (same for both providers)
await task_manager.update_progress(task_id, "🔍 Analyzing keywords and content angles...") await task_manager.update_progress(task_id, "🔍 Analyzing keywords and content angles...")
# Parse the comprehensive response for different analysis components
content = gemini_result.get("content", "")
keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords, user_id=user_id) keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords, user_id=user_id)
competitor_analysis = self.competitor_analyzer.analyze(content, user_id=user_id) competitor_analysis = self.competitor_analyzer.analyze(content, user_id=user_id)
suggested_angles = self.content_angle_generator.generate(content, topic, industry, user_id=user_id) suggested_angles = self.content_angle_generator.generate(content, topic, industry, user_id=user_id)

View File

@@ -0,0 +1,234 @@
"""
Research Strategy Pattern Implementation
Different strategies for executing research based on depth and focus.
"""
from abc import ABC, abstractmethod
from typing import Dict, Any
from loguru import logger
from models.blog_models import BlogResearchRequest, ResearchMode, ResearchConfig
from .keyword_analyzer import KeywordAnalyzer
from .competitor_analyzer import CompetitorAnalyzer
from .content_angle_generator import ContentAngleGenerator
class ResearchStrategy(ABC):
"""Base class for research strategies."""
def __init__(self):
self.keyword_analyzer = KeywordAnalyzer()
self.competitor_analyzer = CompetitorAnalyzer()
self.content_angle_generator = ContentAngleGenerator()
@abstractmethod
def build_research_prompt(
self,
topic: str,
industry: str,
target_audience: str,
config: ResearchConfig
) -> str:
"""Build the research prompt for the strategy."""
pass
@abstractmethod
def get_mode(self) -> ResearchMode:
"""Return the research mode this strategy handles."""
pass
class BasicResearchStrategy(ResearchStrategy):
"""Basic research strategy - keyword focused, minimal analysis."""
def get_mode(self) -> ResearchMode:
return ResearchMode.BASIC
def build_research_prompt(
self,
topic: str,
industry: str,
target_audience: str,
config: ResearchConfig
) -> str:
"""Build basic research prompt focused on keywords and quick insights."""
prompt = f"""You are a professional blog content strategist researching for a {industry} blog targeting {target_audience}.
Research Topic: "{topic}"
Provide analysis in this EXACT format:
## CURRENT TRENDS (2024-2025)
- [Trend 1 with specific data and source URL]
- [Trend 2 with specific data and source URL]
- [Trend 3 with specific data and source URL]
## KEY STATISTICS
- [Statistic 1: specific number/percentage with source URL]
- [Statistic 2: specific number/percentage with source URL]
- [Statistic 3: specific number/percentage with source URL]
- [Statistic 4: specific number/percentage with source URL]
- [Statistic 5: specific number/percentage with source URL]
## PRIMARY KEYWORDS
1. "{topic}" (main keyword)
2. [Variation 1]
3. [Variation 2]
## SECONDARY KEYWORDS
[5 related keywords for blog content]
## CONTENT ANGLES (Top 5)
1. [Angle 1: specific unique approach]
2. [Angle 2: specific unique approach]
3. [Angle 3: specific unique approach]
4. [Angle 4: specific unique approach]
5. [Angle 5: specific unique approach]
REQUIREMENTS:
- Cite EVERY claim with authoritative source URLs
- Use 2024-2025 data when available
- Include specific numbers, dates, examples
- Focus on actionable blog insights for {target_audience}"""
return prompt.strip()
class ComprehensiveResearchStrategy(ResearchStrategy):
"""Comprehensive research strategy - full analysis with all components."""
def get_mode(self) -> ResearchMode:
return ResearchMode.COMPREHENSIVE
def build_research_prompt(
self,
topic: str,
industry: str,
target_audience: str,
config: ResearchConfig
) -> str:
"""Build comprehensive research prompt with all analysis components."""
date_filter = f"\nDate Focus: {config.date_range.value.replace('_', ' ')}" if config.date_range else ""
source_filter = f"\nPriority Sources: {', '.join([s.value for s in config.source_types])}" if config.source_types else ""
prompt = f"""You are a senior blog content strategist conducting comprehensive research for a {industry} blog targeting {target_audience}.
Research Topic: "{topic}"{date_filter}{source_filter}
Provide COMPLETE analysis in this EXACT format:
## TRENDS AND INSIGHTS (2024-2025)
[5-7 trends with specific data, numbers, and source URLs]
## KEY STATISTICS
[7-10 statistics with exact numbers, percentages, dates, and source URLs]
## EXPERT OPINIONS
[4-5 expert quotes with full attribution and source URLs]
## RECENT DEVELOPMENTS
[5-7 recent news/developments with dates and source URLs]
## MARKET ANALYSIS
[3-5 market insights with data points and source URLs]
## BEST PRACTICES & CASE STUDIES
[3-5 examples with specific outcomes/metrics and source URLs]
## KEYWORD ANALYSIS
Primary Keywords: [3 main variations]
Secondary Keywords: [7-10 related keywords]
Long-Tail Opportunities: [5-7 specific search phrases]
## COMPETITOR ANALYSIS
Top Competitors: [5 competitors with brief descriptions]
Content Gaps: [5 topics competitors are missing]
Competitive Advantages: [5 unique angles we can own]
## CONTENT ANGLES (Exactly 5)
1. [Unique angle with reasoning and target benefit]
2. [Unique angle with reasoning and target benefit]
3. [Unique angle with reasoning and target benefit]
4. [Unique angle with reasoning and target benefit]
5. [Unique angle with reasoning and target benefit]
VERIFICATION REQUIREMENTS:
- Minimum 2 authoritative sources per major claim
- Prioritize: Industry publications > Research papers > News > Blogs
- 2024-2025 data strongly preferred
- All numbers must include context (timeframe, sample size, methodology)
- Every recommendation must be actionable for {target_audience}"""
return prompt.strip()
class TargetedResearchStrategy(ResearchStrategy):
"""Targeted research strategy - focused on specific aspects."""
def get_mode(self) -> ResearchMode:
return ResearchMode.TARGETED
def build_research_prompt(
self,
topic: str,
industry: str,
target_audience: str,
config: ResearchConfig
) -> str:
"""Build targeted research prompt based on config preferences."""
sections = []
if config.include_trends:
sections.append("""## CURRENT TRENDS
[3-5 trends with data and source URLs]""")
if config.include_statistics:
sections.append("""## KEY STATISTICS
[5-7 statistics with numbers and source URLs]""")
if config.include_expert_quotes:
sections.append("""## EXPERT OPINIONS
[3-4 expert quotes with attribution and source URLs]""")
if config.include_competitors:
sections.append("""## COMPETITOR ANALYSIS
Top Competitors: [3-5]
Content Gaps: [3-5]""")
# Always include keywords and angles
sections.append("""## KEYWORD ANALYSIS
Primary: [2-3 variations]
Secondary: [5-7 keywords]
Long-Tail: [3-5 phrases]""")
sections.append("""## CONTENT ANGLES (3-5)
[Unique blog angles with reasoning]""")
sections_str = "\n\n".join(sections)
prompt = f"""You are a blog content strategist conducting targeted research for a {industry} blog targeting {target_audience}.
Research Topic: "{topic}"
Provide focused analysis in this EXACT format:
{sections_str}
REQUIREMENTS:
- Cite all claims with authoritative source URLs
- Include specific numbers, dates, examples
- Focus on actionable insights for {target_audience}
- Use 2024-2025 data when available"""
return prompt.strip()
def get_strategy_for_mode(mode: ResearchMode) -> ResearchStrategy:
"""Factory function to get the appropriate strategy for a mode."""
strategy_map = {
ResearchMode.BASIC: BasicResearchStrategy,
ResearchMode.COMPREHENSIVE: ComprehensiveResearchStrategy,
ResearchMode.TARGETED: TargetedResearchStrategy,
}
strategy_class = strategy_map.get(mode, BasicResearchStrategy)
return strategy_class()

View File

@@ -2,4 +2,14 @@
Wix integration modular services package. Wix integration modular services package.
""" """
from services.integrations.wix.seo import build_seo_data
from services.integrations.wix.ricos_converter import markdown_to_html, convert_via_wix_api
from services.integrations.wix.blog_publisher import create_blog_post
__all__ = [
'build_seo_data',
'markdown_to_html',
'convert_via_wix_api',
'create_blog_post',
]

View File

@@ -20,6 +20,40 @@ class WixBlogService:
return h return h
def create_draft_post(self, access_token: str, payload: Dict[str, Any], extra_headers: Optional[Dict[str, str]] = None) -> Dict[str, Any]: def create_draft_post(self, access_token: str, payload: Dict[str, Any], extra_headers: Optional[Dict[str, str]] = None) -> Dict[str, Any]:
# Log the exact payload being sent for debugging
import json
logger.warning(f"📤 Sending to Wix Blog API:")
logger.warning(f" Endpoint: {self.base_url}/blog/v3/draft-posts")
logger.warning(f" Payload top-level keys: {list(payload.keys())}")
if 'draftPost' in payload:
dp = payload['draftPost']
logger.warning(f" draftPost keys: {list(dp.keys())}")
if 'richContent' in dp:
rc = dp['richContent']
logger.warning(f" richContent keys: {list(rc.keys()) if isinstance(rc, dict) else 'N/A'}")
if isinstance(rc, dict) and 'nodes' in rc:
nodes = rc['nodes']
logger.warning(f" richContent.nodes count: {len(nodes) if isinstance(nodes, list) else 'N/A'}")
# Inspect first LIST_ITEM node if any
for i, node in enumerate(nodes[:10]):
if isinstance(node, dict) and node.get('type') == 'LIST_ITEM':
logger.warning(f" Found LIST_ITEM at index {i}:")
logger.warning(f" Keys: {list(node.keys())}")
logger.warning(f" Has listItemData: {'listItemData' in node}")
if 'listItemData' in node:
logger.warning(f" listItemData type: {type(node['listItemData'])}, value: {node['listItemData']}")
if 'nodes' in node:
nested = node['nodes']
logger.warning(f" Nested nodes count: {len(nested) if isinstance(nested, list) else 'N/A'}")
for j, n_node in enumerate(nested[:3]):
if isinstance(n_node, dict):
logger.warning(f" Nested node {j}: type={n_node.get('type')}, keys={list(n_node.keys())}")
if n_node.get('type') == 'PARAGRAPH' and 'paragraphData' in n_node:
logger.warning(f" paragraphData type: {type(n_node['paragraphData'])}, value: {n_node['paragraphData']}")
break # Only inspect first LIST_ITEM
logger.warning(f" Full Payload JSON (first 8000 chars):\n{json.dumps(payload, indent=2, ensure_ascii=False)[:8000]}...")
response = requests.post(f"{self.base_url}/blog/v3/draft-posts", headers=self.headers(access_token, extra_headers), json=payload) response = requests.post(f"{self.base_url}/blog/v3/draft-posts", headers=self.headers(access_token, extra_headers), json=payload)
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()

View File

@@ -0,0 +1,716 @@
"""
Blog Post Publisher for Wix
Handles blog post creation, validation, and publishing to Wix.
"""
import json
import uuid
import requests
import jwt
from typing import Dict, Any, Optional, List
from loguru import logger
from services.integrations.wix.blog import WixBlogService
from services.integrations.wix.content import convert_content_to_ricos
from services.integrations.wix.ricos_converter import convert_via_wix_api
from services.integrations.wix.seo import build_seo_data
def validate_ricos_content(ricos_content: Dict[str, Any]) -> Dict[str, Any]:
"""
Validate and normalize Ricos document structure.
Args:
ricos_content: Ricos document dict
Returns:
Validated and normalized Ricos document
"""
# Validate Ricos document structure before using
if not ricos_content or not isinstance(ricos_content, dict):
logger.error("Invalid Ricos content - not a dict")
raise ValueError("Failed to convert content to valid Ricos format")
if 'type' not in ricos_content:
ricos_content['type'] = 'DOCUMENT'
logger.debug("Added missing richContent type 'DOCUMENT'")
if ricos_content.get('type') != 'DOCUMENT':
logger.warning(f"richContent type expected 'DOCUMENT', got {ricos_content.get('type')}, correcting")
ricos_content['type'] = 'DOCUMENT'
if 'id' not in ricos_content or not isinstance(ricos_content.get('id'), str):
ricos_content['id'] = str(uuid.uuid4())
logger.debug("Added missing richContent id")
if 'nodes' not in ricos_content:
logger.warning("Ricos document missing 'nodes' field, adding empty nodes array")
ricos_content['nodes'] = []
logger.debug(f"Ricos document structure: nodes={len(ricos_content.get('nodes', []))}")
# Validate richContent is a proper object with nodes array
# Per Wix API: richContent must be a RichContent object with nodes array
if not isinstance(ricos_content, dict):
raise ValueError(f"richContent must be a dict object, got {type(ricos_content)}")
# Ensure nodes array exists and is valid
if 'nodes' not in ricos_content:
logger.warning("richContent missing 'nodes', adding empty array")
ricos_content['nodes'] = []
if not isinstance(ricos_content['nodes'], list):
raise ValueError(f"richContent.nodes must be a list, got {type(ricos_content['nodes'])}")
# Recursive function to validate and fix nodes at any depth
def validate_node_recursive(node: Dict[str, Any], path: str = "root") -> None:
"""
Recursively validate a node and all its nested children, ensuring:
1. All required data fields exist for each node type
2. All 'nodes' arrays are proper lists
3. No None values in critical fields
"""
if not isinstance(node, dict):
logger.error(f"{path}: Node is not a dict: {type(node)}")
return
# Ensure type and id exist
if 'type' not in node:
logger.error(f"{path}: Missing 'type' field - REQUIRED")
node['type'] = 'PARAGRAPH' # Default fallback
if 'id' not in node:
node['id'] = str(uuid.uuid4())
logger.debug(f"{path}: Added missing 'id'")
node_type = node.get('type')
# CRITICAL: Per Wix API schema, data fields like paragraphData, bulletedListData, etc.
# are OPTIONAL and should be OMITTED entirely when empty, not included as {}
# Only validate fields that have required properties
# Special handling: Remove listItemData if it exists (not in Wix API schema)
if node_type == 'LIST_ITEM' and 'listItemData' in node:
logger.debug(f"{path}: Removing incorrect listItemData field from LIST_ITEM")
del node['listItemData']
# Only validate HEADING nodes - they require headingData with level property
if node_type == 'HEADING':
if 'headingData' not in node or not isinstance(node.get('headingData'), dict):
logger.warning(f"{path} (HEADING): Missing headingData, adding default level 1")
node['headingData'] = {'level': 1}
elif 'level' not in node['headingData']:
logger.warning(f"{path} (HEADING): Missing level in headingData, adding default")
node['headingData']['level'] = 1
# TEXT nodes must have textData
if node_type == 'TEXT':
if 'textData' not in node or not isinstance(node.get('textData'), dict):
logger.error(f"{path} (TEXT): Missing/invalid textData - node will be problematic")
node['textData'] = {'text': '', 'decorations': []}
# LINK and IMAGE nodes must have their data fields
if node_type == 'LINK' and ('linkData' not in node or not isinstance(node.get('linkData'), dict)):
logger.error(f"{path} (LINK): Missing/invalid linkData - node will be problematic")
if node_type == 'IMAGE' and ('imageData' not in node or not isinstance(node.get('imageData'), dict)):
logger.error(f"{path} (IMAGE): Missing/invalid imageData - node will be problematic")
# Remove None values from any data fields that exist (Wix API rejects None)
for data_field in ['headingData', 'paragraphData', 'blockquoteData', 'bulletedListData',
'orderedListData', 'textData', 'linkData', 'imageData']:
if data_field in node and isinstance(node[data_field], dict):
data_value = node[data_field]
keys_to_remove = [k for k, v in data_value.items() if v is None]
if keys_to_remove:
logger.debug(f"{path} ({node_type}): Removing None values from {data_field}: {keys_to_remove}")
for key in keys_to_remove:
del data_value[key]
# Ensure 'nodes' field exists for container nodes
container_types = ['HEADING', 'PARAGRAPH', 'BLOCKQUOTE', 'LIST_ITEM', 'LINK',
'BULLETED_LIST', 'ORDERED_LIST']
if node_type in container_types:
if 'nodes' not in node:
logger.warning(f"{path} ({node_type}): Missing 'nodes' field, adding empty array")
node['nodes'] = []
elif not isinstance(node['nodes'], list):
logger.error(f"{path} ({node_type}): Invalid 'nodes' field (not a list), fixing")
node['nodes'] = []
# Recursively validate all nested nodes
for nested_idx, nested_node in enumerate(node['nodes']):
nested_path = f"{path}.nodes[{nested_idx}]"
validate_node_recursive(nested_node, nested_path)
# Validate all top-level nodes recursively
for idx, node in enumerate(ricos_content['nodes']):
validate_node_recursive(node, f"nodes[{idx}]")
# Ensure documentStyle exists and is a dict (required by Wix API when provided)
if 'metadata' not in ricos_content or not isinstance(ricos_content.get('metadata'), dict):
ricos_content['metadata'] = {'version': 1, 'id': str(uuid.uuid4())}
logger.debug("Added default metadata to richContent")
else:
ricos_content['metadata'].setdefault('version', 1)
ricos_content['metadata'].setdefault('id', str(uuid.uuid4()))
if 'documentStyle' not in ricos_content or not isinstance(ricos_content.get('documentStyle'), dict):
ricos_content['documentStyle'] = {
'paragraph': {
'decorations': [],
'nodeStyle': {},
'lineHeight': '1.5'
}
}
logger.debug("Added default documentStyle to richContent")
logger.debug(f"✅ Validated richContent: {len(ricos_content['nodes'])} nodes, has_metadata={bool(ricos_content.get('metadata'))}, has_documentStyle={bool(ricos_content.get('documentStyle'))}")
return ricos_content
def validate_payload_no_none(obj, path=""):
"""Recursively validate that no None values exist in the payload"""
if obj is None:
raise ValueError(f"Found None value at path: {path}")
if isinstance(obj, dict):
for key, value in obj.items():
validate_payload_no_none(value, f"{path}.{key}" if path else key)
elif isinstance(obj, list):
for idx, item in enumerate(obj):
validate_payload_no_none(item, f"{path}[{idx}]" if path else f"[{idx}]")
def create_blog_post(
blog_service: WixBlogService,
access_token: str,
title: str,
content: str,
member_id: str,
cover_image_url: str = None,
category_ids: List[str] = None,
tag_ids: List[str] = None,
publish: bool = True,
seo_metadata: Dict[str, Any] = None,
import_image_func = None,
lookup_categories_func = None,
lookup_tags_func = None,
base_url: str = 'https://www.wixapis.com'
) -> Dict[str, Any]:
"""
Create and optionally publish a blog post on Wix
Args:
blog_service: WixBlogService instance
access_token: Valid access token
title: Blog post title
content: Blog post content (markdown)
member_id: Required for third-party apps - the member ID of the post author
cover_image_url: Optional cover image URL
category_ids: Optional list of category IDs or names
tag_ids: Optional list of tag IDs or names
publish: Whether to publish immediately or save as draft
seo_metadata: Optional SEO metadata dict
import_image_func: Function to import images (optional)
lookup_categories_func: Function to lookup/create categories (optional)
lookup_tags_func: Function to lookup/create tags (optional)
base_url: Wix API base URL
Returns:
Created blog post information
"""
if not member_id:
raise ValueError("memberId is required for third-party apps creating blog posts")
headers = {
'Authorization': f'Bearer {access_token}',
'Content-Type': 'application/json'
}
# Build valid Ricos rich content
# Ensure content is not empty
if not content or not content.strip():
content = "This is a post from ALwrity."
logger.warning("⚠️ Content was empty, using default text")
# Try Wix API first (more reliable), fall back to custom parser
ricos_content = None
try:
logger.warning("🔄 Attempting to convert markdown to Ricos via Wix API...")
ricos_content = convert_via_wix_api(content, access_token, base_url)
logger.warning(f"✅ Wix API conversion successful. Ricos document has {len(ricos_content.get('nodes', []))} nodes")
except Exception as e:
logger.warning(f"⚠️ Wix Ricos API conversion failed: {e}. Falling back to custom parser...")
# Fall back to custom parser
ricos_content = convert_content_to_ricos(content, None)
logger.warning(f"✅ Custom parser conversion complete. Ricos document has {len(ricos_content.get('nodes', []))} nodes")
# Validate Ricos content
ricos_content = validate_ricos_content(ricos_content)
# Minimal payload per Wix docs: title, memberId, and richContent
# CRITICAL: Only include fields that have valid values (no None, no empty strings for required fields)
blog_data = {
'draftPost': {
'title': str(title).strip() if title else "Untitled",
'memberId': str(member_id).strip(), # Required for third-party apps (validated above)
'richContent': ricos_content, # Must be a valid Ricos document object
},
'publish': bool(publish),
'fieldsets': ['URL'] # Simplified fieldsets
}
# Add excerpt only if content exists and is not empty (avoid None or empty strings)
excerpt = (content or '').strip()[:200] if content else None
if excerpt and len(excerpt) > 0:
blog_data['draftPost']['excerpt'] = str(excerpt)
# Add cover image if provided
if cover_image_url and import_image_func:
try:
media_id = import_image_func(access_token, cover_image_url, f'Cover: {title}')
# Ensure media_id is a string and not None
if media_id and isinstance(media_id, str):
blog_data['draftPost']['media'] = {
'wixMedia': {
'image': {'id': str(media_id).strip()}
},
'displayed': True,
'custom': True
}
else:
logger.warning(f"Invalid media_id type or value: {type(media_id)}, skipping media")
except Exception as e:
logger.warning(f"Failed to import cover image: {e}")
# Handle categories - can be either IDs (list of strings) or names (for lookup)
category_ids_to_use = None
if category_ids:
# Check if these are IDs (UUIDs) or names
if isinstance(category_ids, list) and len(category_ids) > 0:
# Assume IDs if first item looks like UUID (has hyphens and is long)
first_item = str(category_ids[0])
if '-' in first_item and len(first_item) > 30:
category_ids_to_use = category_ids
elif lookup_categories_func:
# These are names, need to lookup/create
extra_headers = {}
if 'wix-site-id' in headers:
extra_headers['wix-site-id'] = headers['wix-site-id']
category_ids_to_use = lookup_categories_func(
access_token, category_ids, extra_headers if extra_headers else None
)
# Handle tags - can be either IDs (list of strings) or names (for lookup)
tag_ids_to_use = None
if tag_ids:
# Check if these are IDs (UUIDs) or names
if isinstance(tag_ids, list) and len(tag_ids) > 0:
# Assume IDs if first item looks like UUID (has hyphens and is long)
first_item = str(tag_ids[0])
if '-' in first_item and len(first_item) > 30:
tag_ids_to_use = tag_ids
elif lookup_tags_func:
# These are names, need to lookup/create
extra_headers = {}
if 'wix-site-id' in headers:
extra_headers['wix-site-id'] = headers['wix-site-id']
tag_ids_to_use = lookup_tags_func(
access_token, tag_ids, extra_headers if extra_headers else None
)
# Add categories if we have IDs (must be non-empty list of strings)
# CRITICAL: Wix API rejects empty arrays or arrays with None/empty strings
if category_ids_to_use and isinstance(category_ids_to_use, list) and len(category_ids_to_use) > 0:
# Filter out None, empty strings, and ensure all are valid UUID strings
valid_category_ids = [str(cid).strip() for cid in category_ids_to_use if cid and str(cid).strip()]
if valid_category_ids:
blog_data['draftPost']['categoryIds'] = valid_category_ids
logger.debug(f"Added {len(valid_category_ids)} category IDs")
else:
logger.warning("All category IDs were invalid, not including categoryIds in payload")
# Add tags if we have IDs (must be non-empty list of strings)
# CRITICAL: Wix API rejects empty arrays or arrays with None/empty strings
if tag_ids_to_use and isinstance(tag_ids_to_use, list) and len(tag_ids_to_use) > 0:
# Filter out None, empty strings, and ensure all are valid UUID strings
valid_tag_ids = [str(tid).strip() for tid in tag_ids_to_use if tid and str(tid).strip()]
if valid_tag_ids:
blog_data['draftPost']['tagIds'] = valid_tag_ids
logger.debug(f"Added {len(valid_tag_ids)} tag IDs")
else:
logger.warning("All tag IDs were invalid, not including tagIds in payload")
# Build SEO data from metadata if provided
# TESTING: Skip SEO data temporarily to confirm richContent fix
test_skip_seo = True
if test_skip_seo:
logger.warning("🧪 TESTING: Skipping SEO data to isolate richContent vs seoData issue")
seo_data = None
elif seo_metadata:
logger.warning(f"📊 Building SEO data from metadata. Keys: {list(seo_metadata.keys())}")
seo_data = build_seo_data(seo_metadata, title)
if seo_data:
# Log detailed SEO structure
logger.warning(f"📋 SEO data built: {len(seo_data.get('tags', []))} tags, {len(seo_data.get('settings', {}).get('keywords', []))} keywords")
# Log each SEO tag for debugging (key ones only to avoid too much output)
if seo_data.get('tags'):
for idx, tag in enumerate(seo_data['tags'][:3]): # First 3 tags only
tag_type = tag.get('type')
if tag_type == 'title':
logger.warning(f" SEO tag {idx+1}: type={tag_type}, children={str(tag.get('children', ''))[:50]}...")
else:
props = tag.get('props', {})
content_preview = str(props.get('content', props.get('href', props.get('name', ''))))[:50]
logger.warning(f" SEO tag {idx+1}: type={tag_type}, props={list(props.keys())}, content={content_preview}...")
if len(seo_data['tags']) > 3:
logger.warning(f" ... and {len(seo_data['tags']) - 3} more SEO tags")
blog_data['draftPost']['seoData'] = seo_data
logger.warning(f"✅ Added seoData to blog post with {len(seo_data.get('tags', []))} tags")
else:
logger.warning("⚠️ SEO data was empty after building - check build_seo_data function")
# Add SEO slug if provided (separate field from seoData)
if seo_metadata and seo_metadata.get('url_slug'):
blog_data['draftPost']['seoSlug'] = str(seo_metadata.get('url_slug')).strip()
logger.warning(f"✅ Added SEO slug: {blog_data['draftPost']['seoSlug']}")
if test_skip_seo:
logger.warning("⚠️ SEO data skipped for testing - will add back once richContent is confirmed working")
elif not seo_metadata:
logger.warning("⚠️ No SEO metadata provided to create_blog_post")
# Log the payload structure for debugging (without sensitive data)
logger.warning(f"📝 Creating blog post with title: '{title}'")
logger.warning(f"📋 Draft post fields: {list(blog_data['draftPost'].keys())}")
# Detailed SEO logging
if 'seoData' in blog_data['draftPost']:
seo_data_debug = blog_data['draftPost']['seoData']
logger.warning(f"📊 SEO data in payload: {len(seo_data_debug.get('tags', []))} tags, {len(seo_data_debug.get('settings', {}).get('keywords', []))} keywords")
# Log sample SEO tags (first 2 only to avoid too much output)
if seo_data_debug.get('tags'):
logger.warning("📋 SEO Tags sample:")
for i, tag in enumerate(seo_data_debug['tags'][:2]): # First 2 tags
logger.warning(f" Tag {i+1}: type={tag.get('type')}, custom={tag.get('custom')}, disabled={tag.get('disabled')}")
if len(seo_data_debug['tags']) > 2:
logger.warning(f" ... and {len(seo_data_debug['tags']) - 2} more tags")
if seo_data_debug.get('settings', {}).get('keywords'):
keywords_list = [k.get('term') for k in seo_data_debug['settings']['keywords'][:3]]
logger.warning(f"🔑 Keywords: {keywords_list}")
# Log FULL seoData structure for debugging
import json
try:
seo_json = json.dumps(seo_data_debug, indent=2, ensure_ascii=False)
logger.warning(f"📄 FULL seoData JSON:\n{seo_json[:2000]}...") # First 2000 chars
except Exception as e:
logger.error(f"Failed to serialize seoData: {e}")
else:
logger.warning("⚠️ No seoData in draft post payload!")
try:
# Add wix-site-id header if we can extract it from token
extra_headers = {}
try:
token_str = str(access_token)
if token_str and token_str.startswith('OauthNG.JWS.'):
jwt_part = token_str[12:]
payload = jwt.decode(jwt_part, options={"verify_signature": False, "verify_aud": False})
data_payload = payload.get('data', {})
if isinstance(data_payload, str):
try:
data_payload = json.loads(data_payload)
except:
pass
instance_data = data_payload.get('instance', {})
meta_site_id = instance_data.get('metaSiteId')
if isinstance(meta_site_id, str) and meta_site_id:
extra_headers['wix-site-id'] = meta_site_id
headers['wix-site-id'] = meta_site_id
except Exception as e:
logger.debug(f"Could not extract site ID from token: {e}")
# Make the API call
logger.warning(f"🚀 Calling Wix API: POST /blog/v3/draft-posts")
logger.warning(f"📦 Payload: title='{blog_data['draftPost'].get('title')}', has_seoData={'seoData' in blog_data['draftPost']}, has_richContent={'richContent' in blog_data['draftPost']}")
# Validate payload structure before sending
draft_post = blog_data.get('draftPost', {})
if not isinstance(draft_post, dict):
raise ValueError("draftPost must be a dict object")
# Validate richContent structure
if 'richContent' in draft_post:
rc = draft_post['richContent']
if not isinstance(rc, dict):
raise ValueError(f"richContent must be a dict, got {type(rc)}")
if 'nodes' not in rc:
raise ValueError("richContent missing 'nodes' field")
if not isinstance(rc['nodes'], list):
raise ValueError(f"richContent.nodes must be a list, got {type(rc['nodes'])}")
logger.debug(f"✅ richContent validation passed: {len(rc.get('nodes', []))} nodes")
# Validate seoData structure if present
if 'seoData' in draft_post:
seo = draft_post['seoData']
if not isinstance(seo, dict):
raise ValueError(f"seoData must be a dict, got {type(seo)}")
if 'tags' in seo and not isinstance(seo['tags'], list):
raise ValueError(f"seoData.tags must be a list, got {type(seo.get('tags'))}")
if 'settings' in seo and not isinstance(seo['settings'], dict):
raise ValueError(f"seoData.settings must be a dict, got {type(seo.get('settings'))}")
logger.debug(f"✅ seoData validation passed: {len(seo.get('tags', []))} tags")
# Final validation: Ensure no None values in any nested objects
# Wix API rejects None values and expects proper types
try:
validate_payload_no_none(blog_data, "blog_data")
logger.debug("✅ Payload validation passed: No None values found")
except ValueError as e:
logger.error(f"❌ Payload validation failed: {e}")
raise
# Log full payload structure for debugging (sanitized)
logger.warning(f"📦 Full payload structure validation:")
logger.warning(f" - draftPost type: {type(draft_post)}")
logger.warning(f" - draftPost keys: {list(draft_post.keys())}")
logger.warning(f" - richContent type: {type(draft_post.get('richContent'))}")
if 'richContent' in draft_post:
rc = draft_post['richContent']
logger.warning(f" - richContent keys: {list(rc.keys()) if isinstance(rc, dict) else 'N/A'}")
logger.warning(f" - richContent.nodes type: {type(rc.get('nodes'))}, count: {len(rc.get('nodes', []))}")
logger.warning(f" - richContent.metadata type: {type(rc.get('metadata'))}")
logger.warning(f" - richContent.documentStyle type: {type(rc.get('documentStyle'))}")
logger.warning(f" - seoData type: {type(draft_post.get('seoData'))}")
if 'seoData' in draft_post:
seo = draft_post['seoData']
logger.warning(f" - seoData keys: {list(seo.keys()) if isinstance(seo, dict) else 'N/A'}")
logger.warning(f" - seoData.tags type: {type(seo.get('tags'))}, count: {len(seo.get('tags', []))}")
logger.warning(f" - seoData.settings type: {type(seo.get('settings'))}")
if 'categoryIds' in draft_post:
logger.warning(f" - categoryIds type: {type(draft_post.get('categoryIds'))}, count: {len(draft_post.get('categoryIds', []))}")
if 'tagIds' in draft_post:
logger.warning(f" - tagIds type: {type(draft_post.get('tagIds'))}, count: {len(draft_post.get('tagIds', []))}")
# Log a sample of the payload JSON to see exact structure (first 2000 chars)
try:
import json
payload_json = json.dumps(blog_data, indent=2, ensure_ascii=False)
logger.warning(f"📄 Payload JSON preview (first 3000 chars):\n{payload_json[:3000]}...")
# Also log a deep structure inspection of richContent.nodes (first few nodes)
if 'richContent' in blog_data['draftPost']:
nodes = blog_data['draftPost']['richContent'].get('nodes', [])
if nodes:
logger.warning(f"🔍 Inspecting first 5 richContent.nodes:")
for i, node in enumerate(nodes[:5]):
logger.warning(f" Node {i+1}: type={node.get('type')}, keys={list(node.keys())}")
# Check for any None values in node
for key, value in node.items():
if value is None:
logger.error(f" ⚠️ Node {i+1}.{key} is None!")
elif isinstance(value, dict):
for k, v in value.items():
if v is None:
logger.error(f" ⚠️ Node {i+1}.{key}.{k} is None!")
# Deep check: if it's a list-type node, inspect list items
if node.get('type') in ['BULLETED_LIST', 'ORDERED_LIST']:
list_items = node.get('nodes', [])
if list_items:
logger.warning(f" List has {len(list_items)} items, checking first LIST_ITEM:")
first_item = list_items[0]
logger.warning(f" LIST_ITEM keys: {list(first_item.keys())}")
# Verify listItemData is NOT present (correct per Wix API spec)
if 'listItemData' in first_item:
logger.error(f" ❌ LIST_ITEM incorrectly has listItemData!")
else:
logger.debug(f" ✅ LIST_ITEM correctly has no listItemData")
# Check nested PARAGRAPH nodes
nested_nodes = first_item.get('nodes', [])
if nested_nodes:
logger.warning(f" LIST_ITEM has {len(nested_nodes)} nested nodes")
for n_idx, n_node in enumerate(nested_nodes[:2]):
logger.warning(f" Nested node {n_idx+1}: type={n_node.get('type')}, keys={list(n_node.keys())}")
except Exception as e:
logger.warning(f"Could not serialize payload for logging: {e}")
# Note: All node validation is done by validate_ricos_content() which runs earlier
# The recursive validation ensures all required data fields are present at any depth
# Final deep validation: Serialize and deserialize to catch any JSON-serialization issues
# This will raise an error if there are any objects that can't be serialized
try:
import json
test_json = json.dumps(blog_data, ensure_ascii=False)
test_parsed = json.loads(test_json)
logger.debug("✅ Payload JSON serialization test passed")
except (TypeError, ValueError) as e:
logger.error(f"❌ Payload JSON serialization failed: {e}")
raise ValueError(f"Payload contains non-serializable data: {e}")
# Final check: Ensure documentStyle and metadata are valid objects (not None, not empty strings)
rc = blog_data['draftPost']['richContent']
if 'documentStyle' in rc:
doc_style = rc['documentStyle']
if doc_style is None or doc_style == "":
logger.warning("⚠️ documentStyle is None or empty string, removing it")
del rc['documentStyle']
elif not isinstance(doc_style, dict):
logger.warning(f"⚠️ documentStyle is not a dict ({type(doc_style)}), removing it")
del rc['documentStyle']
if 'metadata' in rc:
metadata = rc['metadata']
if metadata is None or metadata == "":
logger.warning("⚠️ metadata is None or empty string, removing it")
del rc['metadata']
elif not isinstance(metadata, dict):
logger.warning(f"⚠️ metadata is not a dict ({type(metadata)}), removing it")
del rc['metadata']
# Check for any None values in critical nested structures
def check_none_in_dict(d, path=""):
"""Recursively check for None values that shouldn't be there"""
issues = []
if isinstance(d, dict):
for key, value in d.items():
current_path = f"{path}.{key}" if path else key
if value is None:
# Some fields can legitimately be None, but most shouldn't
if key not in ['decorations', 'nodeStyle', 'props']:
issues.append(current_path)
elif isinstance(value, dict):
issues.extend(check_none_in_dict(value, current_path))
elif isinstance(value, list):
for i, item in enumerate(value):
if item is None:
issues.append(f"{current_path}[{i}]")
elif isinstance(item, dict):
issues.extend(check_none_in_dict(item, f"{current_path}[{i}]"))
return issues
none_issues = check_none_in_dict(blog_data['draftPost']['richContent'])
if none_issues:
logger.error(f"❌ Found None values in richContent at: {none_issues[:10]}") # Limit to first 10
# Remove None values from critical paths
for issue_path in none_issues[:5]: # Fix first 5
parts = issue_path.split('.')
try:
obj = blog_data['draftPost']['richContent']
for part in parts[:-1]:
if '[' in part:
key, idx = part.split('[')
idx = int(idx.rstrip(']'))
obj = obj[key][idx]
else:
obj = obj[part]
final_key = parts[-1]
if '[' in final_key:
key, idx = final_key.split('[')
idx = int(idx.rstrip(']'))
obj[key][idx] = {}
else:
obj[final_key] = {}
logger.warning(f"Fixed None value at {issue_path}")
except:
pass
# Log the final payload structure one more time before sending
logger.warning(f"📤 Final payload ready - draftPost keys: {list(blog_data['draftPost'].keys())}")
logger.warning(f"📤 RichContent nodes count: {len(blog_data['draftPost']['richContent'].get('nodes', []))}")
logger.warning(f"📤 RichContent has metadata: {bool(blog_data['draftPost']['richContent'].get('metadata'))}")
logger.warning(f"📤 RichContent has documentStyle: {bool(blog_data['draftPost']['richContent'].get('documentStyle'))}")
# Try sending WITHOUT SEO data first to isolate the issue
test_without_seo = False # Disabled - listItemData issue fixed
if test_without_seo and 'seoData' in blog_data['draftPost']:
logger.warning("🧪 TESTING WITHOUT SEO DATA to isolate issue...")
# Clone the payload without SEO data
test_payload_no_seo = {
'draftPost': {
'title': blog_data['draftPost']['title'],
'memberId': blog_data['draftPost']['memberId'],
'richContent': blog_data['draftPost']['richContent'],
'excerpt': blog_data['draftPost'].get('excerpt', '')
},
'publish': False,
'fieldsets': ['URL']
}
try:
logger.warning("🧪 Attempting without SEO data...")
test_result = blog_service.create_draft_post(access_token, test_payload_no_seo, extra_headers or None)
logger.warning(f"✅ WITHOUT SEO DATA SUCCEEDED! Post ID: {test_result.get('draftPost', {}).get('id')}")
logger.error("⚠️⚠️⚠️ ISSUE IS WITH SEO DATA STRUCTURE!")
# If this succeeds, don't send the full payload, just return this result
return test_result
except Exception as e:
logger.warning(f"❌ WITHOUT SEO DATA ALSO FAILED: {e}")
logger.warning("⚠️ Issue is NOT with SEO data, continuing with full payload...")
# Try sending with minimal structure first to isolate the issue
# Create a test payload with just required fields
minimal_test = False # Set to True to test with minimal payload
if minimal_test:
logger.warning("🧪 TESTING WITH MINIMAL PAYLOAD (title + memberId + simple richContent)")
test_payload = {
'draftPost': {
'title': blog_data['draftPost']['title'],
'memberId': blog_data['draftPost']['memberId'],
'richContent': {
'nodes': [
{
'id': str(uuid.uuid4()),
'type': 'PARAGRAPH',
'nodes': [
{
'id': str(uuid.uuid4()),
'type': 'TEXT',
'textData': {
'text': 'Test paragraph',
'decorations': []
}
}
],
'paragraphData': {}
}
],
'metadata': {'version': 1, 'id': str(uuid.uuid4())},
'documentStyle': {}
}
},
'publish': False,
'fieldsets': ['URL']
}
logger.warning("🧪 Attempting minimal payload first...")
try:
test_result = blog_service.create_draft_post(access_token, test_payload, extra_headers or None)
logger.warning(f"✅ MINIMAL PAYLOAD SUCCEEDED! Post ID: {test_result.get('draftPost', {}).get('id')}")
logger.warning("⚠️ Issue is with complex content, not basic structure")
except Exception as e:
logger.error(f"❌ MINIMAL PAYLOAD ALSO FAILED: {e}")
logger.error("⚠️ Issue is with basic structure or permissions")
result = blog_service.create_draft_post(access_token, blog_data, extra_headers or None)
# Log response
draft_post = result.get('draftPost', {})
logger.warning(f"✅ Blog post created successfully! Post ID: {draft_post.get('id', 'N/A')}")
# Check if SEO data was preserved in response
if 'seoData' in draft_post:
seo_response = draft_post['seoData']
logger.warning(f"✅ SEO data confirmed in response: {len(seo_response.get('tags', []))} tags, {len(seo_response.get('settings', {}).get('keywords', []))} keywords")
else:
logger.warning("⚠️ No seoData in response - it may have been filtered out by Wix API")
logger.warning(f"📋 Response fields: {list(draft_post.keys())}")
return result
except requests.RequestException as e:
logger.error(f"Failed to create blog post: {e}")
if hasattr(e, 'response') and e.response is not None:
logger.error(f"Response body: {e.response.text}")
raise

View File

@@ -1,58 +1,460 @@
import re
import uuid
from typing import Any, Dict, List from typing import Any, Dict, List
def parse_markdown_inline(text: str) -> List[Dict[str, Any]]:
"""
Parse inline markdown formatting (bold, italic, links) into Ricos text nodes.
Returns a list of text nodes with decorations.
Handles: **bold**, *italic*, [links](url), `code`, and combinations.
"""
if not text:
return [{
'id': str(uuid.uuid4()),
'type': 'TEXT',
'textData': {'text': '', 'decorations': []}
}]
nodes = []
# Process text character by character to handle nested/adjacent formatting
# This is more robust than regex for complex cases
i = 0
current_text = ''
current_decorations = []
while i < len(text):
# Check for bold **text** (must come before single * check)
if i < len(text) - 1 and text[i:i+2] == '**':
# Save any accumulated text
if current_text:
nodes.append({
'id': str(uuid.uuid4()),
'type': 'TEXT',
'textData': {
'text': current_text,
'decorations': current_decorations.copy()
}
})
current_text = ''
# Find closing **
end_bold = text.find('**', i + 2)
if end_bold != -1:
bold_text = text[i + 2:end_bold]
# Recursively parse the bold text for nested formatting
bold_nodes = parse_markdown_inline(bold_text)
# Add BOLD decoration to all text nodes within
for node in bold_nodes:
if node['type'] == 'TEXT':
node_decorations = node['textData'].get('decorations', []).copy()
if 'BOLD' not in node_decorations:
node_decorations.append('BOLD')
node['textData']['decorations'] = node_decorations
nodes.append(node)
i = end_bold + 2
continue
# Check for link [text](url)
elif text[i] == '[':
# Save any accumulated text
if current_text:
nodes.append({
'id': str(uuid.uuid4()),
'type': 'TEXT',
'textData': {
'text': current_text,
'decorations': current_decorations.copy()
}
})
current_text = ''
current_decorations = []
# Find matching ]
link_end = text.find(']', i)
if link_end != -1 and link_end < len(text) - 1 and text[link_end + 1] == '(':
link_text = text[i + 1:link_end]
url_start = link_end + 2
url_end = text.find(')', url_start)
if url_end != -1:
url = text[url_start:url_end]
# Create link node
link_node_id = str(uuid.uuid4())
text_node_id = str(uuid.uuid4())
link_text_nodes = parse_markdown_inline(link_text)
# Wrap link text in LINK node
nodes.append({
'id': link_node_id,
'type': 'LINK',
'nodes': link_text_nodes if link_text_nodes else [{
'id': text_node_id,
'type': 'TEXT',
'textData': {'text': link_text, 'decorations': []}
}],
'linkData': {
'link': {
'url': url,
'target': '_blank'
}
}
})
i = url_end + 1
continue
# Check for code `text`
elif text[i] == '`':
# Save any accumulated text
if current_text:
nodes.append({
'id': str(uuid.uuid4()),
'type': 'TEXT',
'textData': {
'text': current_text,
'decorations': current_decorations.copy()
}
})
current_text = ''
current_decorations = []
# Find closing `
code_end = text.find('`', i + 1)
if code_end != -1:
code_text = text[i + 1:code_end]
nodes.append({
'id': str(uuid.uuid4()),
'type': 'TEXT',
'textData': {
'text': code_text,
'decorations': ['CODE']
}
})
i = code_end + 1
continue
# Check for italic *text* (only if not part of **)
elif text[i] == '*' and (i == 0 or text[i-1] != '*') and (i == len(text) - 1 or text[i+1] != '*'):
# Save any accumulated text
if current_text:
nodes.append({
'id': str(uuid.uuid4()),
'type': 'TEXT',
'textData': {
'text': current_text,
'decorations': current_decorations.copy()
}
})
current_text = ''
current_decorations = []
# Find closing * (but not **)
italic_end = text.find('*', i + 1)
if italic_end != -1:
# Make sure it's not part of **
if italic_end == len(text) - 1 or text[italic_end + 1] != '*':
italic_text = text[i + 1:italic_end]
italic_nodes = parse_markdown_inline(italic_text)
# Add ITALIC decoration
for node in italic_nodes:
if node['type'] == 'TEXT':
node_decorations = node['textData'].get('decorations', []).copy()
if 'ITALIC' not in node_decorations:
node_decorations.append('ITALIC')
node['textData']['decorations'] = node_decorations
nodes.append(node)
i = italic_end + 1
continue
# Regular character
current_text += text[i]
i += 1
# Add any remaining text
if current_text:
nodes.append({
'id': str(uuid.uuid4()),
'type': 'TEXT',
'textData': {
'text': current_text,
'decorations': current_decorations.copy()
}
})
# If no nodes created, return single plain text node
if not nodes:
nodes.append({
'id': str(uuid.uuid4()),
'type': 'TEXT',
'textData': {
'text': text,
'decorations': []
}
})
return nodes
def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str, Any]: def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str, Any]:
""" """
Convert simple markdown-like text into minimal valid Ricos JSON. Convert markdown content into valid Ricos JSON format.
Supports headings, paragraphs, lists, bold, italic, links, and images.
""" """
paragraphs = content.split('\n\n') if not content:
content = "This is a post from ALwrity."
nodes = [] nodes = []
lines = content.split('\n')
import uuid
i = 0
for paragraph in paragraphs: while i < len(lines):
text = paragraph.strip() line = lines[i].strip()
if not text:
if not line:
i += 1
continue continue
node_id = str(uuid.uuid4()) node_id = str(uuid.uuid4())
text_node_id = str(uuid.uuid4())
# Check for headings
if text.startswith('#'): if line.startswith('#'):
level = len(text) - len(text.lstrip('#')) level = len(line) - len(line.lstrip('#'))
heading_text = text.lstrip('# ').strip() heading_text = line.lstrip('# ').strip()
text_nodes = parse_markdown_inline(heading_text)
nodes.append({ nodes.append({
'id': node_id, 'id': node_id,
'type': 'HEADING', 'type': 'HEADING',
'nodes': [{ 'nodes': text_nodes,
'id': text_node_id, 'headingData': {'level': min(level, 6)}
'type': 'TEXT',
'textData': {
'text': heading_text,
'decorations': []
}
}],
'headingData': { 'level': min(level, 6) }
}) })
else: i += 1
nodes.append({
'id': node_id, # Check for blockquotes
elif line.startswith('>'):
quote_text = line.lstrip('> ').strip()
# Continue reading consecutive blockquote lines
quote_lines = [quote_text]
i += 1
while i < len(lines) and lines[i].strip().startswith('>'):
quote_lines.append(lines[i].strip().lstrip('> ').strip())
i += 1
quote_content = ' '.join(quote_lines)
text_nodes = parse_markdown_inline(quote_content)
# CRITICAL: TEXT nodes must be wrapped in PARAGRAPH nodes within BLOCKQUOTE
paragraph_node = {
'id': str(uuid.uuid4()),
'type': 'PARAGRAPH', 'type': 'PARAGRAPH',
'nodes': [{ 'nodes': text_nodes,
'id': text_node_id,
'type': 'TEXT',
'textData': {
'text': text,
'decorations': []
}
}],
'paragraphData': {} 'paragraphData': {}
}) }
blockquote_node = {
'id': node_id,
'type': 'BLOCKQUOTE',
'nodes': [paragraph_node],
'blockquoteData': {}
}
nodes.append(blockquote_node)
# Check for unordered lists (handle both '- ' and '* ' markers)
elif (line.startswith('- ') or line.startswith('* ') or
(line.startswith('-') and len(line) > 1 and line[1] != '-') or
(line.startswith('*') and len(line) > 1 and line[1] != '*')):
list_items = []
list_marker = '- ' if line.startswith('-') else '* '
# Process list items
while i < len(lines):
current_line = lines[i].strip()
# Check if this is a list item
is_list_item = (current_line.startswith('- ') or current_line.startswith('* ') or
(current_line.startswith('-') and len(current_line) > 1 and current_line[1] != '-') or
(current_line.startswith('*') and len(current_line) > 1 and current_line[1] != '*'))
if not is_list_item:
break
# Extract item text (handle both '- ' and '-item' formats)
if current_line.startswith('- ') or current_line.startswith('* '):
item_text = current_line[2:].strip()
elif current_line.startswith('-'):
item_text = current_line[1:].strip()
elif current_line.startswith('*'):
item_text = current_line[1:].strip()
else:
item_text = current_line
list_items.append(item_text)
i += 1
# Check for nested items (indented with 2+ spaces)
while i < len(lines):
next_line = lines[i]
# Must be indented and be a list marker
if next_line.startswith(' ') and (next_line.strip().startswith('- ') or
next_line.strip().startswith('* ') or
(next_line.strip().startswith('-') and len(next_line.strip()) > 1) or
(next_line.strip().startswith('*') and len(next_line.strip()) > 1)):
nested_text = next_line.strip()
if nested_text.startswith('- ') or nested_text.startswith('* '):
nested_text = nested_text[2:].strip()
elif nested_text.startswith('-'):
nested_text = nested_text[1:].strip()
elif nested_text.startswith('*'):
nested_text = nested_text[1:].strip()
list_items.append(nested_text)
i += 1
else:
break
# Build list items with proper formatting
# CRITICAL: TEXT nodes must be wrapped in PARAGRAPH nodes within LIST_ITEM
# NOTE: LIST_ITEM nodes do NOT have a data field per Wix API schema
# Wix API: omit empty data objects, don't include them as {}
list_node_items = []
for item_text in list_items:
item_node_id = str(uuid.uuid4())
text_nodes = parse_markdown_inline(item_text)
paragraph_node = {
'id': str(uuid.uuid4()),
'type': 'PARAGRAPH',
'nodes': text_nodes,
'paragraphData': {}
}
list_item_node = {
'id': item_node_id,
'type': 'LIST_ITEM',
'nodes': [paragraph_node]
}
list_node_items.append(list_item_node)
bulleted_list_node = {
'id': node_id,
'type': 'BULLETED_LIST',
'nodes': list_node_items,
'bulletedListData': {}
}
nodes.append(bulleted_list_node)
# Check for ordered lists
elif re.match(r'^\d+\.\s+', line):
list_items = []
while i < len(lines) and re.match(r'^\d+\.\s+', lines[i].strip()):
item_text = re.sub(r'^\d+\.\s+', '', lines[i].strip())
list_items.append(item_text)
i += 1
# Check for nested items
while i < len(lines) and lines[i].strip().startswith(' ') and re.match(r'^\s+\d+\.\s+', lines[i].strip()):
nested_text = re.sub(r'^\s+\d+\.\s+', '', lines[i].strip())
list_items.append(nested_text)
i += 1
# CRITICAL: TEXT nodes must be wrapped in PARAGRAPH nodes within LIST_ITEM
# NOTE: LIST_ITEM nodes do NOT have a data field per Wix API schema
# Wix API: omit empty data objects, don't include them as {}
list_node_items = []
for item_text in list_items:
item_node_id = str(uuid.uuid4())
text_nodes = parse_markdown_inline(item_text)
paragraph_node = {
'id': str(uuid.uuid4()),
'type': 'PARAGRAPH',
'nodes': text_nodes,
'paragraphData': {}
}
list_item_node = {
'id': item_node_id,
'type': 'LIST_ITEM',
'nodes': [paragraph_node]
}
list_node_items.append(list_item_node)
ordered_list_node = {
'id': node_id,
'type': 'ORDERED_LIST',
'nodes': list_node_items,
'orderedListData': {}
}
nodes.append(ordered_list_node)
# Check for images
elif line.startswith('!['):
img_match = re.match(r'!\[([^\]]*)\]\(([^)]+)\)', line)
if img_match:
alt_text = img_match.group(1)
img_url = img_match.group(2)
nodes.append({
'id': node_id,
'type': 'IMAGE',
'nodes': [],
'imageData': {
'image': {
'src': {'url': img_url},
'altText': alt_text
},
'containerData': {
'alignment': 'CENTER',
'width': {'size': 'CONTENT'}
}
}
})
i += 1
# Regular paragraph
else:
# Collect consecutive non-empty lines as paragraph content
para_lines = [line]
i += 1
while i < len(lines):
next_line = lines[i].strip()
if not next_line:
break
# Stop if next line is a special markdown element
if (next_line.startswith('#') or
next_line.startswith('- ') or
next_line.startswith('* ') or
next_line.startswith('>') or
next_line.startswith('![') or
re.match(r'^\d+\.\s+', next_line)):
break
para_lines.append(next_line)
i += 1
para_text = ' '.join(para_lines)
text_nodes = parse_markdown_inline(para_text)
# Only add paragraph if there are text nodes
if text_nodes:
paragraph_node = {
'id': node_id,
'type': 'PARAGRAPH',
'nodes': text_nodes,
'paragraphData': {}
}
nodes.append(paragraph_node)
# Ensure at least one node exists
# Wix API: omit empty data objects, don't include them as {}
if not nodes:
fallback_paragraph = {
'id': str(uuid.uuid4()),
'type': 'PARAGRAPH',
'nodes': [{
'id': str(uuid.uuid4()),
'type': 'TEXT',
'textData': {
'text': content[:500] if content else "This is a post from ALwrity.",
'decorations': []
}
}],
'paragraphData': {}
}
nodes.append(fallback_paragraph)
return { return {
'type': 'DOCUMENT',
'id': str(uuid.uuid4()),
'nodes': nodes, 'nodes': nodes,
'metadata': { 'version': 1, 'id': str(uuid.uuid4()) }, 'metadata': {'version': 1, 'id': str(uuid.uuid4())},
'documentStyle': { 'documentStyle': {
'paragraph': { 'decorations': [], 'nodeStyle': {}, 'lineHeight': '1.5' } 'paragraph': {'decorations': [], 'nodeStyle': {}, 'lineHeight': '1.5'}
} }
} }

View File

@@ -7,6 +7,12 @@ class WixMediaService:
self.base_url = base_url self.base_url = base_url
def import_image(self, access_token: str, image_url: str, display_name: str) -> Dict[str, Any]: def import_image(self, access_token: str, image_url: str, display_name: str) -> Dict[str, Any]:
"""
Import external image to Wix Media Manager.
Official endpoint: https://www.wixapis.com/site-media/v1/files/import
Reference: https://dev.wix.com/docs/rest/assets/media/media-manager/files/import-file
"""
headers = { headers = {
'Authorization': f'Bearer {access_token}', 'Authorization': f'Bearer {access_token}',
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@@ -16,7 +22,9 @@ class WixMediaService:
'mediaType': 'IMAGE', 'mediaType': 'IMAGE',
'displayName': display_name, 'displayName': display_name,
} }
response = requests.post(f"{self.base_url}/media/v1/files/import", headers=headers, json=payload) # Correct endpoint per Wix API documentation
endpoint = f"{self.base_url}/site-media/v1/files/import"
response = requests.post(endpoint, headers=headers, json=payload)
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()

View File

@@ -0,0 +1,277 @@
"""
Ricos Document Converter for Wix
Converts markdown content to Wix Ricos JSON format using either:
1. Wix's official Ricos Documents API (preferred)
2. Custom markdown parser (fallback)
"""
import json
import requests
import jwt
from typing import Dict, Any, Optional
from loguru import logger
def markdown_to_html(markdown_content: str) -> str:
"""
Convert markdown content to HTML.
Uses a simple markdown parser for basic conversion.
Args:
markdown_content: Markdown content to convert
Returns:
HTML string
"""
try:
# Try using markdown library if available
import markdown
html = markdown.markdown(markdown_content, extensions=['fenced_code', 'tables'])
return html
except ImportError:
# Fallback: Simple regex-based conversion for basic markdown
logger.warning("markdown library not available, using basic markdown-to-HTML conversion")
import re
if not markdown_content or not markdown_content.strip():
return "<p>This is a post from ALwrity.</p>"
lines = markdown_content.split('\n')
result = []
in_list = False
list_type = None # 'ul' or 'ol'
in_code_block = False
code_block_content = []
i = 0
while i < len(lines):
line = lines[i].strip()
# Handle code blocks first
if line.startswith('```'):
if not in_code_block:
in_code_block = True
code_block_content = []
i += 1
continue
else:
in_code_block = False
result.append(f'<pre><code>{"\n".join(code_block_content)}</code></pre>')
code_block_content = []
i += 1
continue
if in_code_block:
code_block_content.append(lines[i])
i += 1
continue
# Close any open lists
if in_list and not (line.startswith('- ') or line.startswith('* ') or re.match(r'^\d+\.\s+', line)):
result.append(f'</{list_type}>')
in_list = False
list_type = None
if not line:
i += 1
continue
# Headers
if line.startswith('###'):
result.append(f'<h3>{line[3:].strip()}</h3>')
elif line.startswith('##'):
result.append(f'<h2>{line[2:].strip()}</h2>')
elif line.startswith('#'):
result.append(f'<h1>{line[1:].strip()}</h1>')
# Lists
elif line.startswith('- ') or line.startswith('* '):
if not in_list or list_type != 'ul':
if in_list:
result.append(f'</{list_type}>')
result.append('<ul>')
in_list = True
list_type = 'ul'
# Process inline formatting in list item
item_text = line[2:].strip()
item_text = re.sub(r'\*\*(.*?)\*\*', r'<strong>\1</strong>', item_text)
item_text = re.sub(r'\*(.*?)\*', r'<em>\1</em>', item_text)
result.append(f'<li>{item_text}</li>')
elif re.match(r'^\d+\.\s+', line):
if not in_list or list_type != 'ol':
if in_list:
result.append(f'</{list_type}>')
result.append('<ol>')
in_list = True
list_type = 'ol'
# Process inline formatting in list item
match = re.match(r'^\d+\.\s+(.*)', line)
if match:
item_text = match.group(1)
item_text = re.sub(r'\*\*(.*?)\*\*', r'<strong>\1</strong>', item_text)
item_text = re.sub(r'\*(.*?)\*', r'<em>\1</em>', item_text)
result.append(f'<li>{item_text}</li>')
# Blockquotes
elif line.startswith('>'):
quote_text = line[1:].strip()
quote_text = re.sub(r'\*\*(.*?)\*\*', r'<strong>\1</strong>', quote_text)
quote_text = re.sub(r'\*(.*?)\*', r'<em>\1</em>', quote_text)
result.append(f'<blockquote><p>{quote_text}</p></blockquote>')
# Regular paragraphs
else:
para_text = line
# Process inline formatting
para_text = re.sub(r'\*\*(.*?)\*\*', r'<strong>\1</strong>', para_text)
para_text = re.sub(r'\*(.*?)\*', r'<em>\1</em>', para_text)
para_text = re.sub(r'\[([^\]]+)\]\(([^\)]+)\)', r'<a href="\2">\1</a>', para_text)
para_text = re.sub(r'`([^`]+)`', r'<code>\1</code>', para_text)
result.append(f'<p>{para_text}</p>')
i += 1
# Close any open lists
if in_list:
result.append(f'</{list_type}>')
# Ensure we have at least one paragraph
if not result:
result.append('<p>This is a post from ALwrity.</p>')
html = '\n'.join(result)
logger.debug(f"Converted {len(markdown_content)} chars markdown to {len(html)} chars HTML")
return html
def convert_via_wix_api(markdown_content: str, access_token: str, base_url: str = 'https://www.wixapis.com') -> Dict[str, Any]:
"""
Convert markdown to Ricos using Wix's official Ricos Documents API.
Uses HTML format for better reliability (per Wix documentation, HTML is fully supported).
Reference: https://dev.wix.com/docs/api-reference/assets/rich-content/ricos-documents/convert-to-ricos-document
Args:
markdown_content: Markdown content to convert (will be converted to HTML)
access_token: Wix access token
base_url: Wix API base URL (default: https://www.wixapis.com)
Returns:
Ricos JSON document
"""
# Validate content is not empty
markdown_stripped = markdown_content.strip() if markdown_content else ""
if not markdown_stripped:
logger.error("Markdown content is empty or whitespace-only")
raise ValueError("Content cannot be empty for Wix Ricos API conversion")
logger.debug(f"Converting markdown to HTML: input_length={len(markdown_stripped)} chars")
# Convert markdown to HTML for better reliability with Wix API
# HTML format is more structured and less prone to parsing errors
html_content = markdown_to_html(markdown_stripped)
# Validate HTML content is not empty - CRITICAL for Wix API
html_stripped = html_content.strip() if html_content else ""
if not html_stripped or len(html_stripped) == 0:
logger.error(f"HTML conversion produced empty content! Markdown length: {len(markdown_stripped)}")
logger.error(f"Markdown sample: {markdown_stripped[:500]}...")
logger.error(f"HTML result: '{html_content}' (type: {type(html_content)})")
# Fallback: use a minimal valid HTML if conversion failed
html_content = "<p>Content from ALwrity blog writer.</p>"
logger.warning("Using fallback HTML due to empty conversion result")
else:
html_content = html_stripped
logger.debug(f"✅ Converted markdown to HTML: {len(html_content)} chars, preview: {html_content[:200]}...")
headers = {
'Authorization': f'Bearer {access_token}',
'Content-Type': 'application/json'
}
# Add wix-site-id if available from token
try:
token_str = str(access_token)
if token_str and token_str.startswith('OauthNG.JWS.'):
jwt_part = token_str[12:]
payload = jwt.decode(jwt_part, options={"verify_signature": False, "verify_aud": False})
data_payload = payload.get('data', {})
if isinstance(data_payload, str):
try:
data_payload = json.loads(data_payload)
except:
pass
instance_data = data_payload.get('instance', {})
meta_site_id = instance_data.get('metaSiteId')
if isinstance(meta_site_id, str) and meta_site_id:
headers['wix-site-id'] = meta_site_id
except Exception as e:
logger.debug(f"Could not extract site ID from token: {e}")
# Call Wix Ricos Documents API: Convert to Ricos Document
# Official endpoint: https://www.wixapis.com/ricos/v1/ricos-document/convert/to-ricos
# Reference: https://dev.wix.com/docs/rest/assets/rich-content/ricos-documents/convert-to-ricos-document
endpoint = f"{base_url}/ricos/v1/ricos-document/convert/to-ricos"
# Ensure HTML content is not empty or just whitespace
html_stripped = html_content.strip() if html_content else ""
if not html_stripped or len(html_stripped) == 0:
logger.error(f"HTML content is empty after conversion. Markdown length: {len(markdown_content)}")
logger.error(f"Markdown preview (first 500 chars): {markdown_content[:500] if markdown_content else 'N/A'}")
raise ValueError(f"HTML content cannot be empty. Original markdown had {len(markdown_content)} characters.")
# Payload structure per Wix API: html/markdown/plainText field at root, optional plugins
payload = {
'html': html_stripped, # Direct field, not nested in options
'plugins': [] # Optional: empty array uses default plugins
}
logger.warning(f"📤 Sending to Wix Ricos API: html_length={len(payload['html'])}, plugins_count={len(payload['plugins'])}")
logger.debug(f"HTML preview (first 300 chars): {html_stripped[:300]}...")
try:
# Log the exact payload being sent (for debugging)
logger.warning(f"📤 Wix Ricos API Request:")
logger.warning(f" Endpoint: {endpoint}")
logger.warning(f" Payload keys: {list(payload.keys())}")
logger.warning(f" HTML length: {len(payload.get('html', ''))}")
logger.warning(f" Plugins: {payload.get('plugins', [])}")
logger.debug(f" Full payload (first 500 chars of HTML): {str(payload)[:500]}")
response = requests.post(
endpoint,
headers=headers,
json=payload,
timeout=30
)
response.raise_for_status()
result = response.json()
# Extract the ricos document from response
# Response structure: { "document": { "nodes": [...], "metadata": {...}, "documentStyle": {...} } }
ricos_document = result.get('document')
if not ricos_document:
# Fallback: try other possible response fields
ricos_document = result.get('ricosDocument') or result.get('ricos') or result
if not ricos_document:
logger.error(f"Unexpected response structure from Wix API: {list(result.keys())}")
logger.error(f"Response: {result}")
raise ValueError("Wix API did not return a valid Ricos document")
logger.warning(f"✅ Successfully converted HTML to Ricos via Wix API: {len(ricos_document.get('nodes', []))} nodes")
return ricos_document
except requests.RequestException as e:
logger.error(f"❌ Wix Ricos API conversion failed: {e}")
if hasattr(e, 'response') and e.response is not None:
logger.error(f" Response status: {e.response.status_code}")
logger.error(f" Response headers: {dict(e.response.headers)}")
try:
error_body = e.response.json()
logger.error(f" Response JSON: {error_body}")
except:
logger.error(f" Response text: {e.response.text}")
logger.error(f" Request payload was: {json.dumps(payload, indent=2)[:1000]}...") # First 1000 chars
raise

View File

@@ -0,0 +1,300 @@
"""
SEO Data Builder for Wix Blog Posts
Builds Wix-compatible seoData objects from ALwrity SEO metadata.
"""
from typing import Dict, Any, Optional
from loguru import logger
def build_seo_data(seo_metadata: Dict[str, Any], default_title: str = None) -> Optional[Dict[str, Any]]:
"""
Build Wix seoData object from our SEO metadata format.
Args:
seo_metadata: SEO metadata dict with fields like:
- seo_title: SEO optimized title
- meta_description: Meta description
- focus_keyword: Main keyword
- blog_tags: List of tag strings (for keywords)
- open_graph: Open Graph data dict
- canonical_url: Canonical URL
default_title: Fallback title if seo_title not provided
Returns:
Wix seoData object with settings.keywords and tags array, or None if empty
"""
seo_data = {
'settings': {
'keywords': []
},
'tags': []
}
# Build keywords array
keywords_list = []
# Add main keyword (focus_keyword) if provided
focus_keyword = seo_metadata.get('focus_keyword')
if focus_keyword:
keywords_list.append({
'term': str(focus_keyword),
'isMain': True
})
# Add additional keywords from blog_tags or other sources
blog_tags = seo_metadata.get('blog_tags', [])
if isinstance(blog_tags, list):
for tag in blog_tags:
tag_str = str(tag).strip()
if tag_str and tag_str != focus_keyword: # Don't duplicate main keyword
keywords_list.append({
'term': tag_str,
'isMain': False
})
# Add social hashtags as keywords if available
social_hashtags = seo_metadata.get('social_hashtags', [])
if isinstance(social_hashtags, list):
for hashtag in social_hashtags:
# Remove # if present
hashtag_str = str(hashtag).strip().lstrip('#')
if hashtag_str and hashtag_str != focus_keyword:
keywords_list.append({
'term': hashtag_str,
'isMain': False
})
seo_data['settings']['keywords'] = keywords_list
# Validate keywords list is not empty (or ensure at least one keyword exists)
if not seo_data['settings']['keywords']:
logger.warning("No keywords found in SEO metadata, adding empty keywords array")
# Build tags array (meta tags, Open Graph, etc.)
tags_list = []
# Meta description
meta_description = seo_metadata.get('meta_description')
if meta_description:
tags_list.append({
'type': 'meta',
'props': {
'name': 'description',
'content': str(meta_description)
},
'custom': True,
'disabled': False
})
# SEO title - 'title' type uses 'children' field, not 'props.content'
seo_title = seo_metadata.get('seo_title') or default_title
if seo_title:
tags_list.append({
'type': 'title',
'children': str(seo_title), # Title tags use 'children', not 'props.content'
'custom': True,
'disabled': False
})
# Open Graph tags
open_graph = seo_metadata.get('open_graph', {})
if isinstance(open_graph, dict):
# OG Title
og_title = open_graph.get('title') or seo_title
if og_title:
tags_list.append({
'type': 'meta',
'props': {
'property': 'og:title',
'content': str(og_title)
},
'custom': True,
'disabled': False
})
# OG Description
og_description = open_graph.get('description') or meta_description
if og_description:
tags_list.append({
'type': 'meta',
'props': {
'property': 'og:description',
'content': str(og_description)
},
'custom': True,
'disabled': False
})
# OG Image
og_image = open_graph.get('image')
if og_image:
# Skip base64 images for OG tags (Wix needs URLs)
if isinstance(og_image, str) and (og_image.startswith('http://') or og_image.startswith('https://')):
tags_list.append({
'type': 'meta',
'props': {
'property': 'og:image',
'content': og_image
},
'custom': True,
'disabled': False
})
# OG Type
tags_list.append({
'type': 'meta',
'props': {
'property': 'og:type',
'content': 'article'
},
'custom': True,
'disabled': False
})
# OG URL (canonical or provided URL)
og_url = open_graph.get('url') or seo_metadata.get('canonical_url')
if og_url:
tags_list.append({
'type': 'meta',
'props': {
'property': 'og:url',
'content': str(og_url)
},
'custom': True,
'disabled': False
})
# Twitter Card tags
twitter_card = seo_metadata.get('twitter_card', {})
if isinstance(twitter_card, dict):
twitter_title = twitter_card.get('title') or seo_title
if twitter_title:
tags_list.append({
'type': 'meta',
'props': {
'name': 'twitter:title',
'content': str(twitter_title)
},
'custom': True,
'disabled': False
})
twitter_description = twitter_card.get('description') or meta_description
if twitter_description:
tags_list.append({
'type': 'meta',
'props': {
'name': 'twitter:description',
'content': str(twitter_description)
},
'custom': True,
'disabled': False
})
twitter_image = twitter_card.get('image')
if twitter_image and isinstance(twitter_image, str) and (twitter_image.startswith('http://') or twitter_image.startswith('https://')):
tags_list.append({
'type': 'meta',
'props': {
'name': 'twitter:image',
'content': twitter_image
},
'custom': True,
'disabled': False
})
twitter_card_type = twitter_card.get('card', 'summary_large_image')
tags_list.append({
'type': 'meta',
'props': {
'name': 'twitter:card',
'content': str(twitter_card_type)
},
'custom': True,
'disabled': False
})
# Canonical URL as link tag
canonical_url = seo_metadata.get('canonical_url')
if canonical_url:
tags_list.append({
'type': 'link',
'props': {
'rel': 'canonical',
'href': str(canonical_url)
},
'custom': True,
'disabled': False
})
# Validate all tags have required fields before adding
validated_tags = []
for tag in tags_list:
if not isinstance(tag, dict):
logger.warning(f"Skipping invalid tag (not a dict): {type(tag)}")
continue
# Ensure required fields exist
if 'type' not in tag:
logger.warning("Skipping tag missing 'type' field")
continue
# Ensure 'custom' and 'disabled' fields exist
if 'custom' not in tag:
tag['custom'] = True
if 'disabled' not in tag:
tag['disabled'] = False
# Validate tag structure based on type
tag_type = tag.get('type')
if tag_type == 'title':
if 'children' not in tag or not tag['children']:
logger.warning("Skipping title tag with missing/invalid 'children' field")
continue
elif tag_type == 'meta':
if 'props' not in tag or not isinstance(tag['props'], dict):
logger.warning("Skipping meta tag with missing/invalid 'props' field")
continue
if 'name' not in tag['props'] and 'property' not in tag['props']:
logger.warning("Skipping meta tag with missing 'name' or 'property' in props")
continue
# Ensure 'content' exists and is not empty
if 'content' not in tag['props'] or not str(tag['props'].get('content', '')).strip():
logger.warning(f"Skipping meta tag with missing/empty 'content': {tag.get('props', {})}")
continue
elif tag_type == 'link':
if 'props' not in tag or not isinstance(tag['props'], dict):
logger.warning("Skipping link tag with missing/invalid 'props' field")
continue
# Ensure 'href' exists and is not empty for link tags
if 'href' not in tag['props'] or not str(tag['props'].get('href', '')).strip():
logger.warning(f"Skipping link tag with missing/empty 'href': {tag.get('props', {})}")
continue
validated_tags.append(tag)
seo_data['tags'] = validated_tags
# Final validation: ensure seoData structure is complete
if not isinstance(seo_data['settings'], dict):
logger.error("seoData.settings is not a dict, creating default")
seo_data['settings'] = {'keywords': []}
if not isinstance(seo_data['settings'].get('keywords'), list):
logger.error("seoData.settings.keywords is not a list, creating empty list")
seo_data['settings']['keywords'] = []
if not isinstance(seo_data['tags'], list):
logger.error("seoData.tags is not a list, creating empty list")
seo_data['tags'] = []
# CRITICAL: Per Wix API patterns, omit empty structures instead of including them as {}
# If keywords is empty, omit settings entirely
if not seo_data['settings'].get('keywords'):
logger.debug("No keywords found, omitting settings from seoData")
seo_data.pop('settings', None)
logger.debug(f"Built SEO data: {len(validated_tags)} tags, {len(keywords_list)} keywords")
# Only return seoData if we have at least keywords or tags
if keywords_list or validated_tags:
return seo_data
return None

View File

@@ -9,6 +9,7 @@ import json
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
from datetime import datetime from datetime import datetime
from loguru import logger from loguru import logger
from fastapi import HTTPException
from ..onboarding.api_key_manager import APIKeyManager from ..onboarding.api_key_manager import APIKeyManager
from .gemini_provider import gemini_text_response, gemini_structured_json_response from .gemini_provider import gemini_text_response, gemini_structured_json_response
@@ -129,11 +130,16 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
pricing_service = PricingService(db) pricing_service = PricingService(db)
# Estimate tokens from prompt (input tokens) # Estimate tokens from prompt (input tokens)
# Note: We estimate output tokens conservatively (assume response is similar length to prompt) # CRITICAL: Use worst-case scenario (input + max_tokens) for validation to prevent abuse
# This prevents underestimating total token usage # This ensures we block requests that would exceed limits even if response is longer than expected
input_tokens = int(len(prompt.split()) * 1.3) input_tokens = int(len(prompt.split()) * 1.3)
# Conservative estimate: assume output tokens ≈ input tokens * 1.0 (can be up to max_tokens) # Worst-case estimate: assume maximum possible output tokens (max_tokens if specified)
estimated_output_tokens = min(input_tokens, max_tokens) if max_tokens else int(input_tokens * 0.8) # This prevents abuse where actual response tokens exceed the estimate
if max_tokens:
estimated_output_tokens = max_tokens # Use maximum allowed output tokens
else:
# If max_tokens not specified, use conservative estimate (input * 1.5)
estimated_output_tokens = int(input_tokens * 1.5)
estimated_total_tokens = input_tokens + estimated_output_tokens estimated_total_tokens = input_tokens + estimated_output_tokens
# Check limits using sync method from pricing service (strict enforcement) # Check limits using sync method from pricing service (strict enforcement)
@@ -146,7 +152,14 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
if not can_proceed: if not can_proceed:
logger.warning(f"[llm_text_gen] Subscription limit exceeded for user {user_id}: {message}") logger.warning(f"[llm_text_gen] Subscription limit exceeded for user {user_id}: {message}")
raise RuntimeError(f"Subscription limit exceeded: {message}") # Raise HTTPException(429) with usage info so frontend can display subscription modal
error_detail = {
'error': message,
'message': message,
'provider': actual_provider_name or provider_enum.value,
'usage_info': usage_info if usage_info else {}
}
raise HTTPException(status_code=429, detail=error_detail)
# Get current usage for limit checking only # Get current usage for limit checking only
current_period = pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m") current_period = pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
@@ -159,6 +172,9 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
finally: finally:
db.close() db.close()
except HTTPException:
# Re-raise HTTPExceptions (e.g., 429 subscription limit) - preserve error details
raise
except RuntimeError: except RuntimeError:
# Re-raise subscription limit errors # Re-raise subscription limit errors
raise raise
@@ -244,7 +260,8 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
db_track = next(get_db()) db_track = next(get_db())
try: try:
# Estimate tokens from prompt and response # Estimate tokens from prompt and response
tokens_input = estimated_tokens # Already calculated above # Recalculate input tokens from prompt (consistent with pre-flight estimation)
tokens_input = int(len(prompt.split()) * 1.3)
tokens_output = int(len(str(response_text).split()) * 1.3) # Estimate output tokens tokens_output = int(len(str(response_text).split()) * 1.3) # Estimate output tokens
tokens_total = tokens_input + tokens_output tokens_total = tokens_input + tokens_output
@@ -259,45 +276,186 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
logger.debug(f"[llm_text_gen] Looking for usage summary: user_id={user_id}, period={current_period}") logger.debug(f"[llm_text_gen] Looking for usage summary: user_id={user_id}, period={current_period}")
# Get limits once for safety check (to prevent exceeding limits even if actual usage > estimate)
provider_name = provider_enum.value
limits = pricing.get_user_limits(user_id)
token_limit = 0
if limits and limits.get('limits'):
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
# CRITICAL: Use raw SQL to read current values directly from DB, bypassing SQLAlchemy cache
# This ensures we always get the absolute latest committed values, even across different sessions
from sqlalchemy import text
current_calls_before = 0
current_tokens_before = 0
record_count = 0 # Initialize to ensure it's always defined
# CRITICAL: First check if record exists using COUNT query
try:
check_query = text("SELECT COUNT(*) FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period")
record_count = db_track.execute(check_query, {'user_id': user_id, 'period': current_period}).scalar()
logger.debug(f"[llm_text_gen] 🔍 DEBUG: Record count check - found {record_count} record(s) for user={user_id}, period={current_period}")
except Exception as count_error:
logger.error(f"[llm_text_gen] ❌ COUNT query failed: {count_error}", exc_info=True)
record_count = 0
if record_count and record_count > 0:
# Record exists - read current values with raw SQL
try:
# Validate provider_name to prevent SQL injection (whitelist approach)
valid_providers = ['gemini', 'openai', 'anthropic', 'mistral']
if provider_name not in valid_providers:
raise ValueError(f"Invalid provider_name for SQL query: {provider_name}")
# Read current values directly from database using raw SQL
# CRITICAL: This bypasses SQLAlchemy's session cache and gets absolute latest values
sql_query = text(f"""
SELECT {provider_name}_calls, {provider_name}_tokens
FROM usage_summaries
WHERE user_id = :user_id AND billing_period = :period
LIMIT 1
""")
logger.debug(f"[llm_text_gen] 🔍 Executing raw SQL for EXISTING record: SELECT {provider_name}_calls, {provider_name}_tokens WHERE user_id={user_id}, period={current_period}")
result = db_track.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
if result:
raw_calls = result[0] if result[0] is not None else 0
raw_tokens = result[1] if result[1] is not None else 0
current_calls_before = raw_calls
current_tokens_before = raw_tokens
logger.debug(f"[llm_text_gen] ✅ Raw SQL SUCCESS: Found EXISTING record - calls={current_calls_before}, tokens={current_tokens_before} (provider={provider_name}, column={provider_name}_calls/{provider_name}_tokens)")
logger.debug(f"[llm_text_gen] 🔍 Raw SQL returned row: {result}, extracted calls={raw_calls}, tokens={raw_tokens}")
else:
logger.error(f"[llm_text_gen] ❌ CRITICAL BUG: Record EXISTS (count={record_count}) but SELECT query returned None! Query: {sql_query}")
# Fallback: Use ORM to get values
summary_fallback = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if summary_fallback:
db_track.refresh(summary_fallback)
current_calls_before = getattr(summary_fallback, f"{provider_name}_calls", 0) or 0
current_tokens_before = getattr(summary_fallback, f"{provider_name}_tokens", 0) or 0
logger.warning(f"[llm_text_gen] ⚠️ Using ORM fallback: calls={current_calls_before}, tokens={current_tokens_before}")
except Exception as sql_error:
logger.error(f"[llm_text_gen] ❌ Raw SQL query failed: {sql_error}", exc_info=True)
# Fallback: Use ORM to get values
summary_fallback = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if summary_fallback:
db_track.refresh(summary_fallback)
current_calls_before = getattr(summary_fallback, f"{provider_name}_calls", 0) or 0
current_tokens_before = getattr(summary_fallback, f"{provider_name}_tokens", 0) or 0
else:
logger.debug(f"[llm_text_gen] No record exists yet (will create new) - user={user_id}, period={current_period}")
# Get or create usage summary object (needed for ORM update)
summary = db_track.query(UsageSummary).filter( summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id, UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period UsageSummary.billing_period == current_period
).first() ).first()
if not summary: if not summary:
logger.info(f"[llm_text_gen] Creating new usage summary for user {user_id}, period {current_period}") logger.debug(f"[llm_text_gen] Creating NEW usage summary for user {user_id}, period {current_period}")
summary = UsageSummary( summary = UsageSummary(
user_id=user_id, user_id=user_id,
billing_period=current_period billing_period=current_period
) )
db_track.add(summary) db_track.add(summary)
db_track.flush() # Ensure summary is persisted before updating db_track.flush() # Ensure summary is persisted before updating
# New record - values are already 0, no need to set
logger.debug(f"[llm_text_gen] ✅ New summary created - starting from 0")
else:
# CRITICAL: Update the ORM object with values from raw SQL query
# This ensures the ORM object reflects the actual database state before we increment
logger.debug(f"[llm_text_gen] 🔄 Existing summary found - syncing with raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}")
setattr(summary, f"{provider_name}_calls", current_calls_before)
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
setattr(summary, f"{provider_name}_tokens", current_tokens_before)
logger.debug(f"[llm_text_gen] ✅ Synchronized ORM object: {provider_name}_calls={current_calls_before}, {provider_name}_tokens={current_tokens_before}")
# Get "before" state for unified log logger.debug(f"[llm_text_gen] Current {provider_name}_calls from DB (raw SQL): {current_calls_before}")
provider_name = provider_enum.value
current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0
# Update provider-specific counters (sync operation) # Update provider-specific counters (sync operation)
new_calls = current_calls_before + 1 new_calls = current_calls_before + 1
setattr(summary, f"{provider_name}_calls", new_calls)
logger.debug(f"[llm_text_gen] Updated {provider_name}_calls: {current_calls_before} -> {new_calls}")
# Update token usage for LLM providers # CRITICAL: Use direct SQL UPDATE instead of ORM setattr for dynamic attributes
# SQLAlchemy doesn't detect changes when using setattr() on dynamic attributes
# Using raw SQL UPDATE ensures the change is persisted
from sqlalchemy import text
update_calls_query = text(f"""
UPDATE usage_summaries
SET {provider_name}_calls = :new_calls
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_calls_query, {
'new_calls': new_calls,
'user_id': user_id,
'period': current_period
})
logger.debug(f"[llm_text_gen] Updated {provider_name}_calls via SQL: {current_calls_before} -> {new_calls}")
# Update token usage for LLM providers with safety check
# CRITICAL: Use current_tokens_before from raw SQL query (NOT from ORM object)
# The ORM object may have stale values, but raw SQL always has the latest committed values
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]: if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0 logger.debug(f"[llm_text_gen] Current {provider_name}_tokens from DB (raw SQL): {current_tokens_before}")
new_tokens = current_tokens_before + tokens_total
setattr(summary, f"{provider_name}_tokens", new_tokens) # SAFETY CHECK: Prevent exceeding token limit even if actual usage exceeds estimate
logger.debug(f"[llm_text_gen] Updated {provider_name}_tokens: {current_tokens_before} -> {new_tokens}") # This prevents abuse where actual response tokens exceed pre-flight validation estimate
projected_new_tokens = current_tokens_before + tokens_total
# If limit is set (> 0) and would be exceeded, cap at limit
if token_limit > 0 and projected_new_tokens > token_limit:
logger.warning(
f"[llm_text_gen] ⚠️ ACTUAL token usage ({tokens_total}) exceeded estimate. "
f"Would exceed limit: {projected_new_tokens} > {token_limit}. "
f"Capping tracked tokens at limit to prevent abuse."
)
# Cap at limit to prevent abuse
new_tokens = token_limit
# Adjust tokens_total for accurate total tracking
tokens_total = token_limit - current_tokens_before
if tokens_total < 0:
tokens_total = 0
else:
new_tokens = projected_new_tokens
# CRITICAL: Use direct SQL UPDATE instead of ORM setattr for dynamic attributes
update_tokens_query = text(f"""
UPDATE usage_summaries
SET {provider_name}_tokens = :new_tokens
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_tokens_query, {
'new_tokens': new_tokens,
'user_id': user_id,
'period': current_period
})
logger.debug(f"[llm_text_gen] Updated {provider_name}_tokens via SQL: {current_tokens_before} -> {new_tokens}")
else: else:
current_tokens_before = 0 current_tokens_before = 0
new_tokens = 0 new_tokens = 0
# Update totals # Update totals using SQL UPDATE
old_total_calls = summary.total_calls or 0 old_total_calls = summary.total_calls or 0
old_total_tokens = summary.total_tokens or 0 old_total_tokens = summary.total_tokens or 0
summary.total_calls = old_total_calls + 1 new_total_calls = old_total_calls + 1
summary.total_tokens = old_total_tokens + tokens_total new_total_tokens = old_total_tokens + tokens_total
logger.debug(f"[llm_text_gen] Updated totals: calls {old_total_calls} -> {summary.total_calls}, tokens {old_total_tokens} -> {summary.total_tokens}")
update_totals_query = text("""
UPDATE usage_summaries
SET total_calls = :total_calls, total_tokens = :total_tokens
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_totals_query, {
'total_calls': new_total_calls,
'total_tokens': new_total_tokens,
'user_id': user_id,
'period': current_period
})
logger.debug(f"[llm_text_gen] Updated totals via SQL: calls {old_total_calls} -> {new_total_calls}, tokens {old_total_tokens} -> {new_total_tokens}")
# Get plan details for unified log # Get plan details for unified log
limits = pricing.get_user_limits(user_id) limits = pricing.get_user_limits(user_id)
@@ -310,12 +468,52 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
current_images_before = getattr(summary, "stability_calls", 0) or 0 current_images_before = getattr(summary, "stability_calls", 0) or 0
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0 image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
db_track.commit() # CRITICAL DEBUG: Print diagnostic info BEFORE commit (always visible, flushed immediately)
logger.info(f"[llm_text_gen] ✅ Successfully tracked usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens") import sys
debug_msg = f"[DEBUG] BEFORE COMMIT - Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}, Period: {current_period}, New calls will be: {new_calls}, New tokens will be: {new_tokens}"
print(debug_msg, flush=True)
sys.stdout.flush()
logger.debug(f"[llm_text_gen] {debug_msg}")
# CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions
db_track.flush() # Flush to ensure changes are in DB (not just in transaction)
db_track.commit() # Commit transaction to make changes visible to other sessions
logger.debug(f"[llm_text_gen] ✅ Successfully tracked usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens (COMMITTED to DB)")
logger.debug(f"[llm_text_gen] Database state after commit: {provider_name}_calls={new_calls}, {provider_name}_tokens={new_tokens} (should be visible to next session)")
# CRITICAL: Verify commit worked by reading back from DB immediately after commit
try:
verify_query = text(f"SELECT {provider_name}_calls, {provider_name}_tokens FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1")
verify_result = db_track.execute(verify_query, {'user_id': user_id, 'period': current_period}).first()
if verify_result:
verified_calls = verify_result[0] if verify_result[0] is not None else 0
verified_tokens = verify_result[1] if verify_result[1] is not None else 0
logger.debug(f"[llm_text_gen] ✅ VERIFICATION AFTER COMMIT: Read back calls={verified_calls}, tokens={verified_tokens} (expected: calls={new_calls}, tokens={new_tokens})")
if verified_calls != new_calls or verified_tokens != new_tokens:
logger.error(f"[llm_text_gen] ❌ CRITICAL: COMMIT VERIFICATION FAILED! Expected calls={new_calls}, tokens={new_tokens}, but DB has calls={verified_calls}, tokens={verified_tokens}")
# Force another commit attempt
db_track.commit()
verify_result2 = db_track.execute(verify_query, {'user_id': user_id, 'period': current_period}).first()
if verify_result2:
verified_calls2 = verify_result2[0] if verify_result2[0] is not None else 0
verified_tokens2 = verify_result2[1] if verify_result2[1] is not None else 0
logger.debug(f"[llm_text_gen] 🔄 After second commit attempt: calls={verified_calls2}, tokens={verified_tokens2}")
else:
logger.debug(f"[llm_text_gen] ✅ COMMIT VERIFICATION PASSED: Values match expected values")
else:
logger.error(f"[llm_text_gen] ❌ CRITICAL: COMMIT VERIFICATION FAILED! Record not found after commit!")
except Exception as verify_error:
logger.error(f"[llm_text_gen] ❌ Error verifying commit: {verify_error}", exc_info=True)
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message # UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral") # Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")
# Include image stats in the log # Include image stats in the log
# DEBUG: Log the actual values being used
logger.debug(f"[llm_text_gen] 📊 FINAL VALUES FOR LOG: calls_before={current_calls_before}, calls_after={new_calls}, tokens_before={current_tokens_before}, tokens_after={new_tokens}, provider={provider_name}, enum={provider_enum}")
# CRITICAL DEBUG: Print diagnostic info to stdout (always visible)
print(f"[DEBUG] Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}")
print(f""" print(f"""
[SUBSCRIPTION] LLM Text Generation [SUBSCRIPTION] LLM Text Generation
├─ User: {user_id} ├─ User: {user_id}
@@ -407,7 +605,8 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
db_track = next(get_db()) db_track = next(get_db())
try: try:
# Estimate tokens from prompt and response # Estimate tokens from prompt and response
tokens_input = estimated_tokens # Recalculate input tokens from prompt (consistent with pre-flight estimation)
tokens_input = int(len(prompt.split()) * 1.3)
tokens_output = int(len(str(response_text).split()) * 1.3) tokens_output = int(len(str(response_text).split()) * 1.3)
tokens_total = tokens_input + tokens_output tokens_total = tokens_input + tokens_output
@@ -418,6 +617,49 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
pricing = PricingService(db_track) pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m") current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Get limits once for safety check (to prevent exceeding limits even if actual usage > estimate)
provider_name = provider_enum.value
limits = pricing.get_user_limits(user_id)
token_limit = 0
if limits and limits.get('limits'):
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
# CRITICAL: Use raw SQL to read current values directly from DB, bypassing SQLAlchemy cache
from sqlalchemy import text
current_calls_before = 0
current_tokens_before = 0
try:
# Validate provider_name to prevent SQL injection
valid_providers = ['gemini', 'openai', 'anthropic', 'mistral']
if provider_name not in valid_providers:
raise ValueError(f"Invalid provider_name for SQL query: {provider_name}")
# Read current values directly from database using raw SQL
sql_query = text(f"""
SELECT {provider_name}_calls, {provider_name}_tokens
FROM usage_summaries
WHERE user_id = :user_id AND billing_period = :period
LIMIT 1
""")
result = db_track.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
if result:
current_calls_before = result[0] if result[0] is not None else 0
current_tokens_before = result[1] if result[1] is not None else 0
logger.debug(f"[llm_text_gen] Raw SQL read current values (fallback): calls={current_calls_before}, tokens={current_tokens_before}")
except Exception as sql_error:
logger.warning(f"[llm_text_gen] Raw SQL query failed (fallback), falling back to ORM: {sql_error}")
# Fallback to ORM query if raw SQL fails
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if summary:
db_track.refresh(summary)
current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0
current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0
# Get or create usage summary object (needed for ORM update)
summary = db_track.query(UsageSummary).filter( summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id, UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period UsageSummary.billing_period == current_period
@@ -430,41 +672,68 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
) )
db_track.add(summary) db_track.add(summary)
db_track.flush() # Ensure summary is persisted before updating db_track.flush() # Ensure summary is persisted before updating
else:
# CRITICAL: Update the ORM object with values from raw SQL query
# This ensures the ORM object reflects the actual database state before we increment
setattr(summary, f"{provider_name}_calls", current_calls_before)
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
setattr(summary, f"{provider_name}_tokens", current_tokens_before)
logger.debug(f"[llm_text_gen] Synchronized summary object with raw SQL values (fallback): calls={current_calls_before}, tokens={current_tokens_before}")
# Get "before" state for unified log # Get "before" state for unified log (from raw SQL query)
provider_name = provider_enum.value logger.debug(f"[llm_text_gen] Current {provider_name}_calls from DB (fallback, raw SQL): {current_calls_before}")
current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0
# Update provider-specific counters (sync operation) # Update provider-specific counters (sync operation)
new_calls = current_calls_before + 1 new_calls = current_calls_before + 1
setattr(summary, f"{provider_name}_calls", new_calls) setattr(summary, f"{provider_name}_calls", new_calls)
# Update token usage for LLM providers # Update token usage for LLM providers with safety check
# Use current_tokens_before from raw SQL query (most reliable)
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]: if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0 logger.debug(f"[llm_text_gen] Current {provider_name}_tokens from DB (fallback, raw SQL): {current_tokens_before}")
new_tokens = current_tokens_before + tokens_total
# SAFETY CHECK: Prevent exceeding token limit even if actual usage exceeds estimate
# This prevents abuse where actual response tokens exceed pre-flight validation estimate
projected_new_tokens = current_tokens_before + tokens_total
# If limit is set (> 0) and would be exceeded, cap at limit
if token_limit > 0 and projected_new_tokens > token_limit:
logger.warning(
f"[llm_text_gen] ⚠️ ACTUAL token usage ({tokens_total}) exceeded estimate in fallback provider. "
f"Would exceed limit: {projected_new_tokens} > {token_limit}. "
f"Capping tracked tokens at limit to prevent abuse."
)
# Cap at limit to prevent abuse
new_tokens = token_limit
# Adjust tokens_total for accurate total tracking
tokens_total = token_limit - current_tokens_before
if tokens_total < 0:
tokens_total = 0
else:
new_tokens = projected_new_tokens
setattr(summary, f"{provider_name}_tokens", new_tokens) setattr(summary, f"{provider_name}_tokens", new_tokens)
else: else:
current_tokens_before = 0 current_tokens_before = 0
new_tokens = 0 new_tokens = 0
# Update totals # Update totals (using potentially capped tokens_total from safety check)
summary.total_calls = (summary.total_calls or 0) + 1 summary.total_calls = (summary.total_calls or 0) + 1
summary.total_tokens = (summary.total_tokens or 0) + tokens_total summary.total_tokens = (summary.total_tokens or 0) + tokens_total
# Get plan details for unified log # Get plan details for unified log (limits already retrieved above)
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown' plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown' tier = limits.get('tier', 'unknown') if limits else 'unknown'
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0 call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) if limits else 0
# Get image stats for unified log # Get image stats for unified log
current_images_before = getattr(summary, "stability_calls", 0) or 0 current_images_before = getattr(summary, "stability_calls", 0) or 0
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0 image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
db_track.commit() # CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions
logger.info(f"[llm_text_gen] ✅ Successfully tracked fallback usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens") db_track.flush() # Flush to ensure changes are in DB (not just in transaction)
db_track.commit() # Commit transaction to make changes visible to other sessions
logger.info(f"[llm_text_gen] ✅ Successfully tracked fallback usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens (committed)")
# UNIFIED SUBSCRIPTION LOG for fallback # UNIFIED SUBSCRIPTION LOG for fallback
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral") # Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")

View File

@@ -0,0 +1,725 @@
"""
Limit Validation Module
Handles subscription limit checking and validation logic.
Extracted from pricing_service.py for better modularity.
"""
from typing import Dict, Any, Optional, List, Tuple, TYPE_CHECKING
from datetime import datetime, timedelta
from sqlalchemy import text
from loguru import logger
from models.subscription_models import (
UserSubscription, UsageSummary, SubscriptionPlan,
APIProvider, SubscriptionTier
)
if TYPE_CHECKING:
from .pricing_service import PricingService
class LimitValidator:
"""Validates subscription limits for API usage."""
def __init__(self, pricing_service: 'PricingService'):
"""
Initialize limit validator with reference to PricingService.
Args:
pricing_service: Instance of PricingService to access helper methods and cache
"""
self.pricing_service = pricing_service
self.db = pricing_service.db
def check_usage_limits(self, user_id: str, provider: APIProvider,
tokens_requested: int = 0, actual_provider_name: Optional[str] = None) -> Tuple[bool, str, Dict[str, Any]]:
"""Check if user can make an API call within their limits.
Args:
user_id: User ID
provider: APIProvider enum (may be MISTRAL for HuggingFace)
tokens_requested: Estimated tokens for the request
actual_provider_name: Optional actual provider name (e.g., "huggingface" when provider is MISTRAL)
Returns:
(can_proceed, error_message, usage_info)
"""
try:
# Use actual_provider_name if provided, otherwise use enum value
# This fixes cases where HuggingFace maps to MISTRAL enum but should show as "huggingface" in errors
display_provider_name = actual_provider_name or provider.value
logger.debug(f"[Subscription Check] Starting limit check for user {user_id}, provider {display_provider_name}, tokens {tokens_requested}")
# Short TTL cache to reduce DB reads under sustained traffic
cache_key = f"{user_id}:{provider.value}"
now = datetime.utcnow()
cached = self.pricing_service._limits_cache.get(cache_key)
if cached and cached.get('expires_at') and cached['expires_at'] > now:
logger.debug(f"[Subscription Check] Using cached result for {user_id}:{provider.value}")
return tuple(cached['result']) # type: ignore
# Get user subscription first to check expiration
subscription = self.db.query(UserSubscription).filter(
UserSubscription.user_id == user_id,
UserSubscription.is_active == True
).first()
if subscription:
logger.debug(f"[Subscription Check] Found subscription for user {user_id}: plan_id={subscription.plan_id}, period_end={subscription.current_period_end}")
else:
logger.debug(f"[Subscription Check] No active subscription found for user {user_id}")
# Check subscription expiration (STRICT: deny if expired)
if subscription:
if subscription.current_period_end < now:
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}: period_end={subscription.current_period_end}, now={now}")
# Subscription expired - check if auto_renew is enabled
if not getattr(subscription, 'auto_renew', False):
# Expired and no auto-renew - deny access
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}, auto_renew=False, denying access")
result = (False, "Subscription expired. Please renew your subscription to continue using the service.", {
'expired': True,
'period_end': subscription.current_period_end.isoformat()
})
self.pricing_service._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
else:
# Try to auto-renew
if not self.pricing_service._ensure_subscription_current(subscription):
# Auto-renew failed - deny access
result = (False, "Subscription expired and auto-renewal failed. Please renew manually.", {
'expired': True,
'auto_renew_failed': True
})
self.pricing_service._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
# Get user limits with error handling (STRICT: fail on errors)
# CRITICAL: Expire SQLAlchemy objects to ensure we get fresh plan data after renewal
try:
# Force expire subscription and plan objects to avoid stale cache
if subscription and subscription.plan_id:
plan_obj = self.db.query(SubscriptionPlan).filter(SubscriptionPlan.id == subscription.plan_id).first()
if plan_obj:
self.db.expire(plan_obj)
logger.debug(f"[Subscription Check] Expired plan object to ensure fresh limits after renewal")
limits = self.pricing_service.get_user_limits(user_id)
if limits:
logger.debug(f"[Subscription Check] Retrieved limits for user {user_id}: plan={limits.get('plan_name')}, tier={limits.get('tier')}")
# Log token limits for debugging
token_limits = limits.get('limits', {})
logger.debug(f"[Subscription Check] Token limits: gemini={token_limits.get('gemini_tokens')}, mistral={token_limits.get('mistral_tokens')}, openai={token_limits.get('openai_tokens')}, anthropic={token_limits.get('anthropic_tokens')}")
else:
logger.debug(f"[Subscription Check] No limits found for user {user_id}, checking free tier")
except Exception as e:
logger.error(f"[Subscription Check] Error getting user limits for {user_id}: {e}", exc_info=True)
# STRICT: Fail closed - deny request if we can't check limits
return False, f"Failed to retrieve subscription limits: {str(e)}", {}
if not limits:
# No subscription found - check for free tier
free_plan = self.db.query(SubscriptionPlan).filter(
SubscriptionPlan.tier == SubscriptionTier.FREE,
SubscriptionPlan.is_active == True
).first()
if free_plan:
logger.info(f"[Subscription Check] Assigning free tier to user {user_id}")
limits = self.pricing_service._plan_to_limits_dict(free_plan)
else:
# No subscription and no free tier - deny access
logger.warning(f"[Subscription Check] No subscription or free tier found for user {user_id}, denying access")
return False, "No subscription plan found. Please subscribe to a plan.", {}
# Get current usage for this billing period with error handling
# CRITICAL: Use fresh queries to avoid SQLAlchemy cache after renewal
try:
current_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Expire all objects to force fresh read from DB (critical after renewal)
self.db.expire_all()
# Use raw SQL query first to bypass ORM cache, fallback to ORM if SQL fails
usage = None
try:
from sqlalchemy import text
sql_query = text("SELECT * FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1")
result = self.db.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
if result:
# Map result to UsageSummary object
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if usage:
self.db.refresh(usage) # Ensure fresh data
except Exception as sql_error:
logger.debug(f"[Subscription Check] Raw SQL query failed, using ORM: {sql_error}")
# Fallback to ORM query
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if usage:
self.db.refresh(usage) # Ensure fresh data
if not usage:
# First usage this period, create summary
try:
usage = UsageSummary(
user_id=user_id,
billing_period=current_period
)
self.db.add(usage)
self.db.commit()
except Exception as create_error:
logger.error(f"Error creating usage summary: {create_error}")
self.db.rollback()
# STRICT: Fail closed on DB error
return False, f"Failed to create usage summary: {str(create_error)}", {}
except Exception as e:
logger.error(f"Error getting usage summary for {user_id}: {e}")
self.db.rollback()
# STRICT: Fail closed on DB error
return False, f"Failed to retrieve usage summary: {str(e)}", {}
# Check call limits with error handling
# NOTE: call_limit = 0 means UNLIMITED (Enterprise plans)
try:
# Use display_provider_name for error messages, but provider.value for DB queries
provider_name = provider.value # For DB field names (e.g., "mistral_calls", "mistral_tokens")
# For LLM text generation providers, check against unified total_calls limit
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
is_llm_provider = provider_name in llm_providers
if is_llm_provider:
# Use unified AI text generation limit (total_calls across all LLM providers)
ai_text_gen_limit = limits['limits'].get('ai_text_generation_calls', 0) or 0
# If unified limit not set, fall back to provider-specific limit for backwards compatibility
if ai_text_gen_limit == 0:
ai_text_gen_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
# Calculate total LLM provider calls (sum of gemini + openai + anthropic + mistral)
current_total_llm_calls = (
(usage.gemini_calls or 0) +
(usage.openai_calls or 0) +
(usage.anthropic_calls or 0) +
(usage.mistral_calls or 0)
)
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if ai_text_gen_limit > 0 and current_total_llm_calls >= ai_text_gen_limit:
logger.error(f"[Subscription Check] AI text generation call limit exceeded for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit} (provider: {display_provider_name})")
result = (False, f"AI text generation call limit reached. Used {current_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls this billing period.", {
'current_calls': current_total_llm_calls,
'limit': ai_text_gen_limit,
'usage_percentage': (current_total_llm_calls / ai_text_gen_limit) * 100 if ai_text_gen_limit > 0 else 0,
'provider': display_provider_name, # Use display name for consistency
'usage_info': {
'provider': display_provider_name, # Use display name for user-facing info
'current_calls': current_total_llm_calls,
'limit': ai_text_gen_limit,
'type': 'ai_text_generation',
'breakdown': {
'gemini': usage.gemini_calls or 0,
'openai': usage.openai_calls or 0,
'anthropic': usage.anthropic_calls or 0,
'mistral': usage.mistral_calls or 0 # DB field name (not display name)
}
}
})
self.pricing_service._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
else:
logger.debug(f"[Subscription Check] AI text generation limit check passed for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit if ai_text_gen_limit > 0 else 'unlimited'} (provider: {display_provider_name})")
else:
# For non-LLM providers, check provider-specific limit
current_calls = getattr(usage, f"{provider_name}_calls", 0) or 0
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if call_limit > 0 and current_calls >= call_limit:
logger.error(f"[Subscription Check] Call limit exceeded for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit}")
result = (False, f"API call limit reached for {display_provider_name}. Used {current_calls} of {call_limit} calls this billing period.", {
'current_calls': current_calls,
'limit': call_limit,
'usage_percentage': 100.0,
'provider': display_provider_name # Use display name for consistency
})
self.pricing_service._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
else:
logger.debug(f"[Subscription Check] Call limit check passed for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit if call_limit > 0 else 'unlimited'}")
except Exception as e:
logger.error(f"Error checking call limits: {e}")
# Continue to next check
# Check token limits for LLM providers with error handling
# NOTE: token_limit = 0 means UNLIMITED (Enterprise plans)
try:
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
current_tokens = getattr(usage, f"{provider_name}_tokens", 0) or 0
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if token_limit > 0 and (current_tokens + tokens_requested) > token_limit:
result = (False, f"Token limit would be exceeded for {display_provider_name}. Current: {current_tokens}, Requested: {tokens_requested}, Limit: {token_limit}", {
'current_tokens': current_tokens,
'requested_tokens': tokens_requested,
'limit': token_limit,
'usage_percentage': ((current_tokens + tokens_requested) / token_limit) * 100,
'provider': display_provider_name, # Use display name in error details
'usage_info': {
'provider': display_provider_name,
'current_tokens': current_tokens,
'requested_tokens': tokens_requested,
'limit': token_limit,
'type': 'tokens'
}
})
self.pricing_service._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
except Exception as e:
logger.error(f"Error checking token limits: {e}")
# Continue to next check
# Check cost limits with error handling
# NOTE: cost_limit = 0 means UNLIMITED (Enterprise plans)
try:
cost_limit = limits['limits'].get('monthly_cost', 0) or 0
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if cost_limit > 0 and usage.total_cost >= cost_limit:
result = (False, f"Monthly cost limit reached. Current cost: ${usage.total_cost:.2f}, Limit: ${cost_limit:.2f}", {
'current_cost': usage.total_cost,
'limit': cost_limit,
'usage_percentage': 100.0
})
self.pricing_service._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
except Exception as e:
logger.error(f"Error checking cost limits: {e}")
# Continue to success case
# Calculate usage percentages for warnings
try:
# Determine which call variables to use based on provider type
if is_llm_provider:
# Use unified LLM call tracking
current_call_count = current_total_llm_calls
call_limit_value = ai_text_gen_limit
else:
# Use provider-specific call tracking
current_call_count = current_calls
call_limit_value = call_limit
call_usage_pct = (current_call_count / max(call_limit_value, 1)) * 100 if call_limit_value > 0 else 0
cost_usage_pct = (usage.total_cost / max(cost_limit, 1)) * 100 if cost_limit > 0 else 0
result = (True, "Within limits", {
'current_calls': current_call_count,
'call_limit': call_limit_value,
'call_usage_percentage': call_usage_pct,
'current_cost': usage.total_cost,
'cost_limit': cost_limit,
'cost_usage_percentage': cost_usage_pct
})
self.pricing_service._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
except Exception as e:
logger.error(f"Error calculating usage percentages: {e}")
# Return basic success
return True, "Within limits", {}
except Exception as e:
logger.error(f"Unexpected error in check_usage_limits for {user_id}: {e}")
# STRICT: Fail closed - deny requests if subscription system fails
return False, f"Subscription check error: {str(e)}", {}
def check_comprehensive_limits(
self,
user_id: str,
operations: List[Dict[str, Any]]
) -> Tuple[bool, Optional[str], Optional[Dict[str, Any]]]:
"""
Comprehensive pre-flight validation that checks ALL limits before making ANY API calls.
This prevents wasteful API calls by validating that ALL subsequent operations will succeed
before making the first external API call.
Args:
user_id: User ID
operations: List of operations to validate, each with:
- 'provider': APIProvider enum
- 'tokens_requested': int (estimated tokens for LLM calls, 0 for non-LLM)
- 'actual_provider_name': Optional[str] (e.g., "huggingface" when provider is MISTRAL)
- 'operation_type': str (e.g., "google_grounding", "llm_call", "image_generation")
Returns:
(can_proceed, error_message, error_details)
If can_proceed is False, error_message explains which limit would be exceeded
"""
try:
logger.info(f"[Pre-flight Check] 🔍 Starting comprehensive validation for user {user_id}")
logger.info(f"[Pre-flight Check] 📋 Validating {len(operations)} operation(s) before making any API calls")
# Get current usage and limits once
current_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
logger.info(f"[Pre-flight Check] 📅 Billing Period: {current_period} (for user {user_id})")
# Explicitly expire any cached objects and refresh from DB to ensure fresh data
self.db.expire_all()
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
# CRITICAL: Explicitly refresh from database to get latest values (clears SQLAlchemy cache)
if usage:
self.db.refresh(usage)
# Log what we actually read from database
if usage:
logger.info(f"[Pre-flight Check] 📊 Usage Summary from DB (Period: {current_period}):")
logger.info(f" ├─ Gemini: {usage.gemini_tokens or 0} tokens / {usage.gemini_calls or 0} calls")
logger.info(f" ├─ Mistral/HF: {usage.mistral_tokens or 0} tokens / {usage.mistral_calls or 0} calls")
logger.info(f" ├─ Total Tokens: {usage.total_tokens or 0}")
logger.info(f" └─ Usage Status: {usage.usage_status.value if usage.usage_status else 'N/A'}")
else:
logger.info(f"[Pre-flight Check] 📊 No usage summary found for period {current_period} (will create new)")
if not usage:
# First usage this period, create summary
try:
usage = UsageSummary(
user_id=user_id,
billing_period=current_period
)
self.db.add(usage)
self.db.commit()
except Exception as create_error:
logger.error(f"Error creating usage summary: {create_error}")
self.db.rollback()
return False, f"Failed to create usage summary: {str(create_error)}", {}
# Get user limits
limits_dict = self.pricing_service.get_user_limits(user_id)
if not limits_dict:
# No subscription found - check for free tier
free_plan = self.db.query(SubscriptionPlan).filter(
SubscriptionPlan.tier == SubscriptionTier.FREE,
SubscriptionPlan.is_active == True
).first()
if free_plan:
limits_dict = self.pricing_service._plan_to_limits_dict(free_plan)
else:
return False, "No subscription plan found. Please subscribe to a plan.", {}
limits = limits_dict.get('limits', {})
# Track cumulative usage across all operations
total_llm_calls = (
(usage.gemini_calls or 0) +
(usage.openai_calls or 0) +
(usage.anthropic_calls or 0) +
(usage.mistral_calls or 0)
)
total_llm_tokens = {}
total_images = usage.stability_calls or 0
# Log current usage summary
logger.info(f"[Pre-flight Check] 📊 Current Usage Summary:")
logger.info(f" └─ Total LLM Calls: {total_llm_calls}")
logger.info(f" └─ Gemini Tokens: {usage.gemini_tokens or 0}, Mistral/HF Tokens: {usage.mistral_tokens or 0}")
logger.info(f" └─ Image Calls: {total_images}")
# Validate each operation
for op_idx, operation in enumerate(operations):
provider = operation.get('provider')
provider_name = provider.value if hasattr(provider, 'value') else str(provider)
tokens_requested = operation.get('tokens_requested', 0)
actual_provider_name = operation.get('actual_provider_name')
operation_type = operation.get('operation_type', 'unknown')
display_provider_name = actual_provider_name or provider_name
logger.error(f"[Pre-flight Check] ✅ Operation {op_idx + 1}/{len(operations)}: {operation_type}")
logger.error(f" ├─ Provider: {display_provider_name} (enum: {provider_name})")
logger.error(f" ├─ Operation Index: {op_idx}")
logger.error(f" └─ Estimated Tokens Requested: {tokens_requested}")
# Check if this is an LLM provider
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
is_llm_provider = provider_name in llm_providers
# Check unified AI text generation limit for LLM providers
if is_llm_provider:
ai_text_gen_limit = limits.get('ai_text_generation_calls', 0) or 0
if ai_text_gen_limit == 0:
# Fallback to provider-specific limit
ai_text_gen_limit = limits.get(f"{provider_name}_calls", 0) or 0
# Count this operation as an LLM call
projected_total_llm_calls = total_llm_calls + 1
if ai_text_gen_limit > 0 and projected_total_llm_calls > ai_text_gen_limit:
error_info = {
'current_calls': total_llm_calls,
'limit': ai_text_gen_limit,
'provider': display_provider_name,
'operation_type': operation_type,
'operation_index': op_idx
}
return False, f"AI text generation call limit would be exceeded. Would use {projected_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls.", {
'error_type': 'call_limit',
'usage_info': error_info
}
# Check token limits for this provider
# CRITICAL: Always query fresh from DB for each operation to avoid SQLAlchemy cache issues
# This ensures we get the latest values after subscription renewal, even for cumulative tracking
provider_tokens_key = f"{provider_name}_tokens"
# Try to get fresh value from DB with comprehensive error handling
base_current_tokens = 0
query_succeeded = False
try:
# Validate column name is safe (only allow known provider token columns)
valid_token_columns = ['gemini_tokens', 'openai_tokens', 'anthropic_tokens', 'mistral_tokens']
if provider_tokens_key not in valid_token_columns:
logger.error(f" └─ Invalid provider tokens key: {provider_tokens_key}")
query_succeeded = True # Treat as success with 0 value
else:
# Method 1: Try raw SQL query to completely bypass ORM cache
try:
logger.debug(f" └─ Attempting raw SQL query for {provider_tokens_key}")
sql_query = text(f"""
SELECT {provider_tokens_key}
FROM usage_summaries
WHERE user_id = :user_id
AND billing_period = :period
LIMIT 1
""")
logger.debug(f" └─ SQL: SELECT {provider_tokens_key} FROM usage_summaries WHERE user_id={user_id} AND billing_period={current_period}")
result = self.db.execute(sql_query, {
'user_id': user_id,
'period': current_period
}).first()
if result:
base_current_tokens = result[0] if result[0] is not None else 0
logger.error(f"[Pre-flight Check] ✅ Raw SQL query returned result: {result[0]} -> {base_current_tokens}")
else:
base_current_tokens = 0
logger.error(f"[Pre-flight Check] ⚠️ Raw SQL query returned None (no rows found)")
query_succeeded = True
logger.error(f"[Pre-flight Check] ✅ Raw SQL query succeeded for {provider_tokens_key}: {base_current_tokens}")
except Exception as sql_error:
logger.error(f" └─ Raw SQL query failed for {provider_tokens_key}: {type(sql_error).__name__}: {sql_error}", exc_info=True)
query_succeeded = False # Will try ORM fallback
# Method 2: Fallback to fresh ORM query if raw SQL fails
if not query_succeeded:
try:
# Expire all cached objects and do fresh query
self.db.expire_all()
fresh_usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if fresh_usage:
# Explicitly refresh to get latest from DB
self.db.refresh(fresh_usage)
base_current_tokens = getattr(fresh_usage, provider_tokens_key, 0) or 0
else:
base_current_tokens = 0
query_succeeded = True
logger.info(f"[Pre-flight Check] ✅ ORM fallback query succeeded for {provider_tokens_key}: {base_current_tokens}")
except Exception as orm_error:
logger.error(f" └─ ORM query also failed: {orm_error}", exc_info=True)
query_succeeded = False
except Exception as e:
logger.error(f" └─ Unexpected error getting tokens from DB for {provider_tokens_key}: {e}", exc_info=True)
base_current_tokens = 0 # Fail safe - assume 0 if we can't query
if not query_succeeded:
logger.warning(f" └─ Both query methods failed, using 0 as fallback")
# CRITICAL LOG: Always log what we got from DB - this helps debug renewal issues
# Use ERROR level to ensure it shows even if INFO is filtered
logger.error(f"[Pre-flight Check] 🔍 Fresh DB Query for {display_provider_name}:")
logger.error(f" ├─ Column: {provider_tokens_key}")
logger.error(f" ├─ Billing Period: {current_period}")
logger.error(f" ├─ User ID: {user_id}")
logger.error(f" ├─ Method: {'Raw SQL' if query_succeeded and base_current_tokens >= 0 else 'ORM' if query_succeeded else 'Failed - using 0'}")
logger.error(f" └─ Value from DB: {base_current_tokens}")
# Add any projected tokens from previous operations in this validation run
# Note: total_llm_tokens tracks ONLY projected tokens from this run, not base DB value
projected_from_previous = total_llm_tokens.get(provider_tokens_key, 0)
# Current tokens = base from DB + projected from previous operations in this run
current_provider_tokens = base_current_tokens + projected_from_previous
# Use ERROR level to ensure visibility
logger.error(f"[Pre-flight Check] 📊 Token Calculation for {display_provider_name}:")
logger.error(f" ├─ Base from DB (fresh query): {base_current_tokens}")
logger.error(f" ├─ Projected from previous ops in this run: {projected_from_previous}")
logger.error(f" └─ Total current tokens (base + projected): {current_provider_tokens}")
# Also check the initial usage object to see if it's being used incorrectly
if usage and hasattr(usage, provider_tokens_key):
initial_usage_value = getattr(usage, provider_tokens_key, 0) or 0
logger.error(f" ⚠️ Initial usage object value: {initial_usage_value} (this should NOT be used for fresh query)")
token_limit = limits.get(provider_tokens_key, 0) or 0
if token_limit > 0 and tokens_requested > 0:
projected_tokens = current_provider_tokens + tokens_requested
logger.info(f" └─ Token Check: {current_provider_tokens} (current) + {tokens_requested} (requested) = {projected_tokens} (total) / {token_limit} (limit)")
if projected_tokens > token_limit:
usage_percentage = (projected_tokens / token_limit) * 100 if token_limit > 0 else 0
error_info = {
'current_tokens': current_provider_tokens,
'base_tokens_from_db': base_current_tokens,
'projected_from_previous_ops': projected_from_previous,
'requested_tokens': tokens_requested,
'limit': token_limit,
'provider': display_provider_name,
'operation_type': operation_type,
'operation_index': op_idx
}
# Make error message clearer: show actual DB usage vs projected
if projected_from_previous > 0:
error_msg = (
f"Token limit exceeded for {display_provider_name} "
f"({operation_type}). "
f"Base usage: {base_current_tokens}/{token_limit}, "
f"After previous operations in this workflow: {current_provider_tokens}/{token_limit}, "
f"This operation would add: {tokens_requested}, "
f"Total would be: {projected_tokens} (exceeds by {projected_tokens - token_limit} tokens)"
)
else:
error_msg = (
f"Token limit exceeded for {display_provider_name} "
f"({operation_type}). "
f"Current: {current_provider_tokens}/{token_limit}, "
f"Requested: {tokens_requested}, "
f"Would exceed by: {projected_tokens - token_limit} tokens "
f"({usage_percentage:.1f}% of limit)"
)
logger.error(f"[Pre-flight Check] ❌ BLOCKED: {error_msg}")
return False, error_msg, {
'error_type': 'token_limit',
'usage_info': error_info
}
else:
logger.info(f" └─ ✅ Token limit check passed: {projected_tokens} <= {token_limit}")
# Update cumulative counts for next operation
total_llm_calls = projected_total_llm_calls
# Update cumulative projected tokens from this validation run
# This represents only projected tokens from previous operations in this run
# Base DB value is always queried fresh, so we only track the projection delta
old_projected = total_llm_tokens.get(provider_tokens_key, 0)
if tokens_requested > 0:
# Add this operation's tokens to cumulative projected tokens
total_llm_tokens[provider_tokens_key] = projected_from_previous + tokens_requested
logger.error(f"[Pre-flight Check] 📝 Updated cumulative projected tokens for {display_provider_name}:")
logger.error(f" ├─ Previous projected: {projected_from_previous}")
logger.error(f" ├─ This operation requested: {tokens_requested}")
logger.error(f" ├─ New cumulative projected: {total_llm_tokens[provider_tokens_key]}")
logger.error(f" └─ Old value in dict was: {old_projected}")
else:
# No tokens requested, keep existing projected tokens (or 0 if first operation)
total_llm_tokens[provider_tokens_key] = projected_from_previous
logger.error(f"[Pre-flight Check] 📝 No tokens requested, keeping projected at: {projected_from_previous}")
# Check image generation limits
elif provider == APIProvider.STABILITY:
image_limit = limits.get('stability_calls', 0) or 0
projected_images = total_images + 1
if image_limit > 0 and projected_images > image_limit:
error_info = {
'current_images': total_images,
'limit': image_limit,
'provider': 'stability',
'operation_type': operation_type,
'operation_index': op_idx
}
return False, f"Image generation limit would be exceeded. Would use {projected_images} of {image_limit} images this billing period.", {
'error_type': 'image_limit',
'usage_info': error_info
}
total_images = projected_images
# Check other provider-specific limits
else:
provider_calls_key = f"{provider_name}_calls"
current_provider_calls = getattr(usage, provider_calls_key, 0) or 0
call_limit = limits.get(provider_calls_key, 0) or 0
if call_limit > 0:
projected_calls = current_provider_calls + 1
if projected_calls > call_limit:
error_info = {
'current_calls': current_provider_calls,
'limit': call_limit,
'provider': display_provider_name,
'operation_type': operation_type,
'operation_index': op_idx
}
return False, f"API call limit would be exceeded for {display_provider_name}. Would use {projected_calls} of {call_limit} calls this billing period.", {
'error_type': 'call_limit',
'usage_info': error_info
}
# All checks passed
logger.info(f"[Pre-flight Check] ✅ All {len(operations)} operation(s) validated successfully")
logger.info(f"[Pre-flight Check] ✅ User {user_id} is cleared to proceed with API calls")
return True, None, None
except Exception as e:
error_type = type(e).__name__
error_message = str(e)
logger.error(f"[Pre-flight Check] ❌ Error during comprehensive limit check: {error_type}: {error_message}", exc_info=True)
logger.error(f"[Pre-flight Check] ❌ User: {user_id}, Operations count: {len(operations) if operations else 0}")
return False, f"Failed to validate limits: {error_type}: {error_message}", {}

View File

@@ -44,15 +44,17 @@ def validate_research_operations(
llm_provider_name = "gemini" llm_provider_name = "gemini"
# Estimate tokens for each operation in research workflow # Estimate tokens for each operation in research workflow
# Google Grounding call: ~2000 tokens (input + output) # Google Grounding call: ~1200 tokens (input: ~500 tokens, output: ~700 tokens for research results)
# Keyword analyzer: ~1000 tokens (input: 3000 chars research, output: structured JSON) # Keyword analyzer: ~1000 tokens (input: 3000 chars research, output: structured JSON)
# Competitor analyzer: ~1000 tokens (input: 3000 chars research, output: structured JSON) # Competitor analyzer: ~1000 tokens (input: 3000 chars research, output: structured JSON)
# Content angle generator: ~1000 tokens (input: 3000 chars research, output: list of angles) # Content angle generator: ~1000 tokens (input: 3000 chars research, output: list of angles)
# Note: These are conservative estimates. Actual usage may be lower, but we use these for pre-flight validation
# to prevent wasteful API calls if the workflow would exceed limits.
operations_to_validate = [ operations_to_validate = [
{ {
'provider': APIProvider.GEMINI, # Google Grounding uses Gemini 'provider': APIProvider.GEMINI, # Google Grounding uses Gemini
'tokens_requested': 2000, 'tokens_requested': 1200, # Reduced from 2000 to more realistic estimate
'actual_provider_name': 'gemini', 'actual_provider_name': 'gemini',
'operation_type': 'google_grounding' 'operation_type': 'google_grounding'
}, },
@@ -126,6 +128,120 @@ def validate_research_operations(
) )
def validate_exa_research_operations(
pricing_service: PricingService,
user_id: str,
gpt_provider: str = "google"
) -> None:
"""
Validate all operations for an Exa research workflow before making ANY API calls.
This prevents wasteful external API calls (like Exa search) if subsequent
LLM calls would fail due to token or call limits.
Args:
pricing_service: PricingService instance
user_id: User ID for subscription checking
gpt_provider: GPT provider from env var (defaults to "google")
Returns:
None
If validation fails, raises HTTPException with 429 status
"""
try:
# Determine actual provider for LLM calls based on GPT_PROVIDER env var
gpt_provider_lower = gpt_provider.lower()
if gpt_provider_lower == "huggingface":
llm_provider_enum = APIProvider.MISTRAL # Maps to HuggingFace
llm_provider_name = "huggingface"
else:
llm_provider_enum = APIProvider.GEMINI
llm_provider_name = "gemini"
# Estimate tokens for each operation in Exa research workflow
# Exa Search call: 1 Exa API call (not token-based)
# Keyword analyzer: ~1000 tokens (input: research results, output: structured JSON)
# Competitor analyzer: ~1000 tokens (input: research results, output: structured JSON)
# Content angle generator: ~1000 tokens (input: research results, output: list of angles)
# Note: These are conservative estimates for pre-flight validation
operations_to_validate = [
{
'provider': APIProvider.EXA, # Exa API call
'tokens_requested': 0,
'actual_provider_name': 'exa',
'operation_type': 'exa_neural_search'
},
{
'provider': llm_provider_enum,
'tokens_requested': 1000,
'actual_provider_name': llm_provider_name,
'operation_type': 'keyword_analysis'
},
{
'provider': llm_provider_enum,
'tokens_requested': 1000,
'actual_provider_name': llm_provider_name,
'operation_type': 'competitor_analysis'
},
{
'provider': llm_provider_enum,
'tokens_requested': 1000,
'actual_provider_name': llm_provider_name,
'operation_type': 'content_angle_generation'
}
]
logger.info(f"[Pre-flight Validator] 🚀 Starting Exa Research Workflow Validation")
logger.info(f" ├─ User: {user_id}")
logger.info(f" ├─ LLM Provider: {llm_provider_name} (GPT_PROVIDER={gpt_provider})")
logger.info(f" └─ Operations to validate: {len(operations_to_validate)}")
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate
)
if not can_proceed:
usage_info = error_details.get('usage_info', {}) if error_details else {}
provider = usage_info.get('provider', llm_provider_name) if usage_info else llm_provider_name
operation_type = usage_info.get('operation_type', 'unknown')
logger.error(f"[Pre-flight Validator] ❌ EXA RESEARCH WORKFLOW BLOCKED")
logger.error(f" ├─ User: {user_id}")
logger.error(f" ├─ Blocked at: {operation_type}")
logger.error(f" ├─ Provider: {provider}")
logger.error(f" └─ Reason: {message}")
# Raise HTTPException immediately - frontend gets immediate response, no API calls made
raise HTTPException(
status_code=429,
detail={
'error': message,
'message': message,
'provider': provider,
'usage_info': usage_info if usage_info else error_details
}
)
logger.info(f"[Pre-flight Validator] ✅ EXA RESEARCH WORKFLOW APPROVED")
logger.info(f" ├─ User: {user_id}")
logger.info(f" └─ All {len(operations_to_validate)} operations validated - proceeding with API calls")
# Validation passed - no return needed (function raises HTTPException if validation fails)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Pre-flight Validator] Error validating Exa research operations: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
'error': f"Failed to validate operations: {str(e)}",
'message': f"Failed to validate operations: {str(e)}"
}
)
def validate_image_generation_operations( def validate_image_generation_operations(
pricing_service: PricingService, pricing_service: PricingService,
user_id: str user_id: str

View File

@@ -258,6 +258,12 @@ class PricingService:
"model_name": "stable-diffusion", "model_name": "stable-diffusion",
"cost_per_image": 0.04, # $0.04 per image "cost_per_image": 0.04, # $0.04 per image
"description": "Stability AI Image Generation" "description": "Stability AI Image Generation"
},
{
"provider": APIProvider.EXA,
"model_name": "exa-search",
"cost_per_request": 0.005, # $0.005 per search (1-25 results)
"description": "Exa Neural Search API"
} }
] ]
@@ -296,6 +302,7 @@ class PricingService:
"metaphor_calls_limit": 10, "metaphor_calls_limit": 10,
"firecrawl_calls_limit": 10, "firecrawl_calls_limit": 10,
"stability_calls_limit": 5, "stability_calls_limit": 5,
"exa_calls_limit": 100,
"gemini_tokens_limit": 100000, "gemini_tokens_limit": 100000,
"monthly_cost_limit": 0.0, "monthly_cost_limit": 0.0,
"features": ["basic_content_generation", "limited_research"], "features": ["basic_content_generation", "limited_research"],
@@ -316,10 +323,11 @@ class PricingService:
"metaphor_calls_limit": 100, "metaphor_calls_limit": 100,
"firecrawl_calls_limit": 100, "firecrawl_calls_limit": 100,
"stability_calls_limit": 5, "stability_calls_limit": 5,
"gemini_tokens_limit": 2000, "exa_calls_limit": 500,
"openai_tokens_limit": 2000, "gemini_tokens_limit": 20000, # Increased from 5000 for better stability
"anthropic_tokens_limit": 2000, "openai_tokens_limit": 20000, # Increased from 5000 for better stability
"mistral_tokens_limit": 2000, "anthropic_tokens_limit": 20000, # Increased from 5000 for better stability
"mistral_tokens_limit": 20000, # Increased from 5000 for better stability
"monthly_cost_limit": 50.0, "monthly_cost_limit": 50.0,
"features": ["full_content_generation", "advanced_research", "basic_analytics"], "features": ["full_content_generation", "advanced_research", "basic_analytics"],
"description": "Great for individuals and small teams" "description": "Great for individuals and small teams"
@@ -338,6 +346,7 @@ class PricingService:
"metaphor_calls_limit": 500, "metaphor_calls_limit": 500,
"firecrawl_calls_limit": 500, "firecrawl_calls_limit": 500,
"stability_calls_limit": 200, "stability_calls_limit": 200,
"exa_calls_limit": 2000,
"gemini_tokens_limit": 5000000, "gemini_tokens_limit": 5000000,
"openai_tokens_limit": 2500000, "openai_tokens_limit": 2500000,
"anthropic_tokens_limit": 1000000, "anthropic_tokens_limit": 1000000,
@@ -360,6 +369,7 @@ class PricingService:
"metaphor_calls_limit": 0, "metaphor_calls_limit": 0,
"firecrawl_calls_limit": 0, "firecrawl_calls_limit": 0,
"stability_calls_limit": 0, "stability_calls_limit": 0,
"exa_calls_limit": 0, # Unlimited
"gemini_tokens_limit": 0, "gemini_tokens_limit": 0,
"openai_tokens_limit": 0, "openai_tokens_limit": 0,
"anthropic_tokens_limit": 0, "anthropic_tokens_limit": 0,
@@ -423,11 +433,14 @@ class PricingService:
def get_user_limits(self, user_id: str) -> Optional[Dict[str, Any]]: def get_user_limits(self, user_id: str) -> Optional[Dict[str, Any]]:
"""Get usage limits for a user based on their subscription.""" """Get usage limits for a user based on their subscription."""
# CRITICAL: Expire all objects first to ensure fresh data after renewal
self.db.expire_all()
subscription = self.db.query(UserSubscription).filter( subscription = self.db.query(UserSubscription).filter(
UserSubscription.user_id == user_id, UserSubscription.user_id == user_id,
UserSubscription.is_active == True UserSubscription.is_active == True
).first() ).first()
if not subscription: if not subscription:
# Return free tier limits # Return free tier limits
free_plan = self.db.query(SubscriptionPlan).filter( free_plan = self.db.query(SubscriptionPlan).filter(
@@ -439,7 +452,23 @@ class PricingService:
# Ensure current period before returning limits # Ensure current period before returning limits
self._ensure_subscription_current(subscription) self._ensure_subscription_current(subscription)
return self._plan_to_limits_dict(subscription.plan)
# CRITICAL: Refresh subscription to get latest plan_id, then refresh plan relationship
self.db.refresh(subscription)
# Re-query plan directly to ensure fresh data (bypass relationship cache)
plan = self.db.query(SubscriptionPlan).filter(
SubscriptionPlan.id == subscription.plan_id
).first()
if not plan:
logger.error(f"Plan not found for subscription plan_id={subscription.plan_id}")
return None
# Refresh plan to ensure fresh limits
self.db.refresh(plan)
return self._plan_to_limits_dict(plan)
def _ensure_ai_text_gen_column_detection(self) -> None: def _ensure_ai_text_gen_column_detection(self) -> None:
"""Detect at runtime whether ai_text_generation_calls_limit column exists and cache the result.""" """Detect at runtime whether ai_text_generation_calls_limit column exists and cache the result."""
@@ -508,290 +537,20 @@ class PricingService:
tokens_requested: int = 0, actual_provider_name: Optional[str] = None) -> Tuple[bool, str, Dict[str, Any]]: tokens_requested: int = 0, actual_provider_name: Optional[str] = None) -> Tuple[bool, str, Dict[str, Any]]:
"""Check if user can make an API call within their limits. """Check if user can make an API call within their limits.
Delegates to LimitValidator for actual validation logic.
Args: Args:
user_id: User ID user_id: User ID
provider: APIProvider enum (may be MISTRAL for HuggingFace) provider: APIProvider enum (may be MISTRAL for HuggingFace)
tokens_requested: Estimated tokens for the request tokens_requested: Estimated tokens for the request
actual_provider_name: Optional actual provider name (e.g., "huggingface" when provider is MISTRAL) actual_provider_name: Optional actual provider name (e.g., "huggingface" when provider is MISTRAL)
"""
try:
# Use actual_provider_name if provided, otherwise use enum value
# This fixes cases where HuggingFace maps to MISTRAL enum but should show as "huggingface" in errors
display_provider_name = actual_provider_name or provider.value
logger.debug(f"[Subscription Check] Starting limit check for user {user_id}, provider {display_provider_name}, tokens {tokens_requested}")
# Short TTL cache to reduce DB reads under sustained traffic
cache_key = f"{user_id}:{provider.value}"
now = datetime.utcnow()
cached = self._limits_cache.get(cache_key)
if cached and cached.get('expires_at') and cached['expires_at'] > now:
logger.debug(f"[Subscription Check] Using cached result for {user_id}:{provider.value}")
return tuple(cached['result']) # type: ignore
# Get user subscription first to check expiration
subscription = self.db.query(UserSubscription).filter(
UserSubscription.user_id == user_id,
UserSubscription.is_active == True
).first()
if subscription:
logger.debug(f"[Subscription Check] Found subscription for user {user_id}: plan_id={subscription.plan_id}, period_end={subscription.current_period_end}")
else:
logger.debug(f"[Subscription Check] No active subscription found for user {user_id}")
# Check subscription expiration (STRICT: deny if expired)
if subscription:
if subscription.current_period_end < now:
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}: period_end={subscription.current_period_end}, now={now}")
# Subscription expired - check if auto_renew is enabled
if not getattr(subscription, 'auto_renew', False):
# Expired and no auto-renew - deny access
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}, auto_renew=False, denying access")
result = (False, "Subscription expired. Please renew your subscription to continue using the service.", {
'expired': True,
'period_end': subscription.current_period_end.isoformat()
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
else:
# Try to auto-renew
if not self._ensure_subscription_current(subscription):
# Auto-renew failed - deny access
result = (False, "Subscription expired and auto-renewal failed. Please renew manually.", {
'expired': True,
'auto_renew_failed': True
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
# Get user limits with error handling (STRICT: fail on errors)
try:
limits = self.get_user_limits(user_id)
if limits:
logger.debug(f"[Subscription Check] Retrieved limits for user {user_id}: plan={limits.get('plan_name')}, tier={limits.get('tier')}")
else:
logger.debug(f"[Subscription Check] No limits found for user {user_id}, checking free tier")
except Exception as e:
logger.error(f"[Subscription Check] Error getting user limits for {user_id}: {e}", exc_info=True)
# STRICT: Fail closed - deny request if we can't check limits
return False, f"Failed to retrieve subscription limits: {str(e)}", {}
if not limits:
# No subscription found - check for free tier
free_plan = self.db.query(SubscriptionPlan).filter(
SubscriptionPlan.tier == SubscriptionTier.FREE,
SubscriptionPlan.is_active == True
).first()
if free_plan:
logger.info(f"[Subscription Check] Assigning free tier to user {user_id}")
limits = self._plan_to_limits_dict(free_plan)
else:
# No subscription and no free tier - deny access
logger.warning(f"[Subscription Check] No subscription or free tier found for user {user_id}, denying access")
return False, "No subscription plan found. Please subscribe to a plan.", {}
# Get current usage for this billing period with error handling
try:
current_period = self.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not usage:
# First usage this period, create summary
try:
usage = UsageSummary(
user_id=user_id,
billing_period=current_period
)
self.db.add(usage)
self.db.commit()
except Exception as create_error:
logger.error(f"Error creating usage summary: {create_error}")
self.db.rollback()
# STRICT: Fail closed on DB error
return False, f"Failed to create usage summary: {str(create_error)}", {}
except Exception as e:
logger.error(f"Error getting usage summary for {user_id}: {e}")
self.db.rollback()
# STRICT: Fail closed on DB error
return False, f"Failed to retrieve usage summary: {str(e)}", {}
# Check call limits with error handling
# NOTE: call_limit = 0 means UNLIMITED (Enterprise plans)
try:
# Use display_provider_name for error messages, but provider.value for DB queries
provider_name = provider.value # For DB field names (e.g., "mistral_calls", "mistral_tokens")
# For LLM text generation providers, check against unified total_calls limit
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
is_llm_provider = provider_name in llm_providers
if is_llm_provider:
# Use unified AI text generation limit (total_calls across all LLM providers)
ai_text_gen_limit = limits['limits'].get('ai_text_generation_calls', 0) or 0
# If unified limit not set, fall back to provider-specific limit for backwards compatibility
if ai_text_gen_limit == 0:
ai_text_gen_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
# Calculate total LLM provider calls (sum of gemini + openai + anthropic + mistral)
current_total_llm_calls = (
(usage.gemini_calls or 0) +
(usage.openai_calls or 0) +
(usage.anthropic_calls or 0) +
(usage.mistral_calls or 0)
)
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if ai_text_gen_limit > 0 and current_total_llm_calls >= ai_text_gen_limit:
logger.error(f"[Subscription Check] AI text generation call limit exceeded for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit} (provider: {display_provider_name})")
result = (False, f"AI text generation call limit reached. Used {current_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls this billing period.", {
'current_calls': current_total_llm_calls,
'limit': ai_text_gen_limit,
'usage_percentage': (current_total_llm_calls / ai_text_gen_limit) * 100 if ai_text_gen_limit > 0 else 0,
'provider': display_provider_name, # Use display name for consistency
'usage_info': {
'provider': display_provider_name, # Use display name for user-facing info
'current_calls': current_total_llm_calls,
'limit': ai_text_gen_limit,
'type': 'ai_text_generation',
'breakdown': {
'gemini': usage.gemini_calls or 0,
'openai': usage.openai_calls or 0,
'anthropic': usage.anthropic_calls or 0,
'mistral': usage.mistral_calls or 0 # DB field name (not display name)
}
}
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
else:
logger.debug(f"[Subscription Check] AI text generation limit check passed for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit if ai_text_gen_limit > 0 else 'unlimited'} (provider: {display_provider_name})")
else:
# For non-LLM providers, check provider-specific limit
current_calls = getattr(usage, f"{provider_name}_calls", 0) or 0
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if call_limit > 0 and current_calls >= call_limit:
logger.error(f"[Subscription Check] Call limit exceeded for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit}")
result = (False, f"API call limit reached for {display_provider_name}. Used {current_calls} of {call_limit} calls this billing period.", {
'current_calls': current_calls,
'limit': call_limit,
'usage_percentage': 100.0,
'provider': display_provider_name # Use display name for consistency
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
else:
logger.debug(f"[Subscription Check] Call limit check passed for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit if call_limit > 0 else 'unlimited'}")
except Exception as e:
logger.error(f"Error checking call limits: {e}")
# Continue to next check
# Check token limits for LLM providers with error handling
# NOTE: token_limit = 0 means UNLIMITED (Enterprise plans)
try:
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
current_tokens = getattr(usage, f"{provider_name}_tokens", 0) or 0
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if token_limit > 0 and (current_tokens + tokens_requested) > token_limit:
result = (False, f"Token limit would be exceeded for {display_provider_name}. Current: {current_tokens}, Requested: {tokens_requested}, Limit: {token_limit}", {
'current_tokens': current_tokens,
'requested_tokens': tokens_requested,
'limit': token_limit,
'usage_percentage': ((current_tokens + tokens_requested) / token_limit) * 100,
'provider': display_provider_name, # Use display name in error details
'usage_info': {
'provider': display_provider_name,
'current_tokens': current_tokens,
'requested_tokens': tokens_requested,
'limit': token_limit,
'type': 'tokens'
}
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
except Exception as e:
logger.error(f"Error checking token limits: {e}")
# Continue to next check
# Check cost limits with error handling
# NOTE: cost_limit = 0 means UNLIMITED (Enterprise plans)
try:
cost_limit = limits['limits'].get('monthly_cost', 0) or 0
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if cost_limit > 0 and usage.total_cost >= cost_limit:
result = (False, f"Monthly cost limit reached. Current cost: ${usage.total_cost:.2f}, Limit: ${cost_limit:.2f}", {
'current_cost': usage.total_cost,
'limit': cost_limit,
'usage_percentage': 100.0
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
except Exception as e:
logger.error(f"Error checking cost limits: {e}")
# Continue to success case
# Calculate usage percentages for warnings
try:
# Determine which call variables to use based on provider type
if is_llm_provider:
# Use unified LLM call tracking
current_call_count = current_total_llm_calls
call_limit_value = ai_text_gen_limit
else:
# Use provider-specific call tracking
current_call_count = current_calls
call_limit_value = call_limit
call_usage_pct = (current_call_count / max(call_limit_value, 1)) * 100 if call_limit_value > 0 else 0
cost_usage_pct = (usage.total_cost / max(cost_limit, 1)) * 100 if cost_limit > 0 else 0
result = (True, "Within limits", {
'current_calls': current_call_count,
'call_limit': call_limit_value,
'call_usage_percentage': call_usage_pct,
'current_cost': usage.total_cost,
'cost_limit': cost_limit,
'cost_usage_percentage': cost_usage_pct
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
except Exception as e:
logger.error(f"Error calculating usage percentages: {e}")
# Return basic success
return True, "Within limits", {}
except Exception as e: Returns:
logger.error(f"Unexpected error in check_usage_limits for {user_id}: {e}") (can_proceed, error_message, usage_info)
# STRICT: Fail closed - deny requests if subscription system fails """
return False, f"Subscription check error: {str(e)}", {} from .limit_validation import LimitValidator
validator = LimitValidator(self)
return validator.check_usage_limits(user_id, provider, tokens_requested, actual_provider_name)
def estimate_tokens(self, text: str, provider: APIProvider) -> int: def estimate_tokens(self, text: str, provider: APIProvider) -> int:
"""Estimate token count for text based on provider.""" """Estimate token count for text based on provider."""
@@ -827,6 +586,16 @@ class PricingService:
if not pricing: if not pricing:
return None return None
# Return pricing info as dict
return {
'provider': pricing.provider.value,
'model_name': pricing.model_name,
'cost_per_input_token': pricing.cost_per_input_token,
'cost_per_output_token': pricing.cost_per_output_token,
'cost_per_request': pricing.cost_per_request,
'description': pricing.description
}
def check_comprehensive_limits( def check_comprehensive_limits(
self, self,
user_id: str, user_id: str,
@@ -835,6 +604,7 @@ class PricingService:
""" """
Comprehensive pre-flight validation that checks ALL limits before making ANY API calls. Comprehensive pre-flight validation that checks ALL limits before making ANY API calls.
Delegates to LimitValidator for actual validation logic.
This prevents wasteful API calls by validating that ALL subsequent operations will succeed This prevents wasteful API calls by validating that ALL subsequent operations will succeed
before making the first external API call. before making the first external API call.
@@ -850,202 +620,9 @@ class PricingService:
(can_proceed, error_message, error_details) (can_proceed, error_message, error_details)
If can_proceed is False, error_message explains which limit would be exceeded If can_proceed is False, error_message explains which limit would be exceeded
""" """
try: from .limit_validation import LimitValidator
logger.info(f"[Pre-flight Check] 🔍 Starting comprehensive validation for user {user_id}") validator = LimitValidator(self)
logger.info(f"[Pre-flight Check] 📋 Validating {len(operations)} operation(s) before making any API calls") return validator.check_comprehensive_limits(user_id, operations)
# Get current usage and limits once
current_period = self.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not usage:
# First usage this period, create summary
try:
usage = UsageSummary(
user_id=user_id,
billing_period=current_period
)
self.db.add(usage)
self.db.commit()
except Exception as create_error:
logger.error(f"Error creating usage summary: {create_error}")
self.db.rollback()
return False, f"Failed to create usage summary: {str(create_error)}", {}
# Get user limits
limits_dict = self.get_user_limits(user_id)
if not limits_dict:
# No subscription found - check for free tier
free_plan = self.db.query(SubscriptionPlan).filter(
SubscriptionPlan.tier == SubscriptionTier.FREE,
SubscriptionPlan.is_active == True
).first()
if free_plan:
limits_dict = self._plan_to_limits_dict(free_plan)
else:
return False, "No subscription plan found. Please subscribe to a plan.", {}
limits = limits_dict.get('limits', {})
# Track cumulative usage across all operations
total_llm_calls = (
(usage.gemini_calls or 0) +
(usage.openai_calls or 0) +
(usage.anthropic_calls or 0) +
(usage.mistral_calls or 0)
)
total_llm_tokens = {}
total_images = usage.stability_calls or 0
# Log current usage summary
logger.info(f"[Pre-flight Check] 📊 Current Usage Summary:")
logger.info(f" └─ Total LLM Calls: {total_llm_calls}")
logger.info(f" └─ Gemini Tokens: {usage.gemini_tokens or 0}, Mistral/HF Tokens: {usage.mistral_tokens or 0}")
logger.info(f" └─ Image Calls: {total_images}")
# Validate each operation
for op_idx, operation in enumerate(operations):
provider = operation.get('provider')
provider_name = provider.value if hasattr(provider, 'value') else str(provider)
tokens_requested = operation.get('tokens_requested', 0)
actual_provider_name = operation.get('actual_provider_name')
operation_type = operation.get('operation_type', 'unknown')
display_provider_name = actual_provider_name or provider_name
logger.info(f"[Pre-flight Check] ✅ Operation {op_idx + 1}/{len(operations)}: {operation_type}")
logger.info(f" ├─ Provider: {display_provider_name} (enum: {provider_name})")
logger.info(f" └─ Estimated Tokens: {tokens_requested}")
# Check if this is an LLM provider
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
is_llm_provider = provider_name in llm_providers
# Check unified AI text generation limit for LLM providers
if is_llm_provider:
ai_text_gen_limit = limits.get('ai_text_generation_calls', 0) or 0
if ai_text_gen_limit == 0:
# Fallback to provider-specific limit
ai_text_gen_limit = limits.get(f"{provider_name}_calls", 0) or 0
# Count this operation as an LLM call
projected_total_llm_calls = total_llm_calls + 1
if ai_text_gen_limit > 0 and projected_total_llm_calls > ai_text_gen_limit:
error_info = {
'current_calls': total_llm_calls,
'limit': ai_text_gen_limit,
'provider': display_provider_name,
'operation_type': operation_type,
'operation_index': op_idx
}
return False, f"AI text generation call limit would be exceeded. Would use {projected_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls.", {
'error_type': 'call_limit',
'usage_info': error_info
}
# Check token limits for this provider
# Use cumulative projected tokens from previous operations, or current from DB if first operation
provider_tokens_key = f"{provider_name}_tokens"
if provider_tokens_key in total_llm_tokens:
# Use cumulative projected tokens from previous operations
current_provider_tokens = total_llm_tokens[provider_tokens_key]
logger.info(f" └─ Using cumulative projected tokens: {current_provider_tokens}")
else:
# First operation for this provider - get current from database
current_provider_tokens = getattr(usage, provider_tokens_key, 0) or 0
total_llm_tokens[provider_tokens_key] = current_provider_tokens
logger.info(f" └─ Current tokens from DB: {current_provider_tokens}")
token_limit = limits.get(provider_tokens_key, 0) or 0
if token_limit > 0 and tokens_requested > 0:
projected_tokens = current_provider_tokens + tokens_requested
logger.info(f" └─ Token Check: {current_provider_tokens} (current) + {tokens_requested} (requested) = {projected_tokens} (total) / {token_limit} (limit)")
if projected_tokens > token_limit:
usage_percentage = (projected_tokens / token_limit) * 100 if token_limit > 0 else 0
error_info = {
'current_tokens': current_provider_tokens,
'requested_tokens': tokens_requested,
'limit': token_limit,
'provider': display_provider_name,
'operation_type': operation_type,
'operation_index': op_idx
}
error_msg = (
f"Token limit exceeded for {display_provider_name} "
f"({operation_type}). "
f"Current: {current_provider_tokens}/{token_limit}, "
f"Requested: {tokens_requested}, "
f"Would exceed by: {projected_tokens - token_limit} tokens "
f"({usage_percentage:.1f}% of limit)"
)
logger.error(f"[Pre-flight Check] ❌ BLOCKED: {error_msg}")
return False, error_msg, {
'error_type': 'token_limit',
'usage_info': error_info
}
else:
logger.info(f" └─ ✅ Token limit check passed: {projected_tokens} <= {token_limit}")
# Update cumulative counts for next operation
total_llm_calls = projected_total_llm_calls
total_llm_tokens[provider_tokens_key] += tokens_requested
logger.info(f" └─ Updated cumulative tokens for {display_provider_name}: {total_llm_tokens[provider_tokens_key]}")
# Check image generation limits
elif provider == APIProvider.STABILITY:
image_limit = limits.get('stability_calls', 0) or 0
projected_images = total_images + 1
if image_limit > 0 and projected_images > image_limit:
error_info = {
'current_images': total_images,
'limit': image_limit,
'provider': 'stability',
'operation_type': operation_type,
'operation_index': op_idx
}
return False, f"Image generation limit would be exceeded. Would use {projected_images} of {image_limit} images this billing period.", {
'error_type': 'image_limit',
'usage_info': error_info
}
total_images = projected_images
# Check other provider-specific limits
else:
provider_calls_key = f"{provider_name}_calls"
current_provider_calls = getattr(usage, provider_calls_key, 0) or 0
call_limit = limits.get(provider_calls_key, 0) or 0
if call_limit > 0:
projected_calls = current_provider_calls + 1
if projected_calls > call_limit:
error_info = {
'current_calls': current_provider_calls,
'limit': call_limit,
'provider': display_provider_name,
'operation_type': operation_type,
'operation_index': op_idx
}
return False, f"API call limit would be exceeded for {display_provider_name}. Would use {projected_calls} of {call_limit} calls this billing period.", {
'error_type': 'call_limit',
'usage_info': error_info
}
# All checks passed
logger.info(f"[Pre-flight Check] ✅ All {len(operations)} operation(s) validated successfully")
logger.info(f"[Pre-flight Check] ✅ User {user_id} is cleared to proceed with API calls")
return True, None, None
except Exception as e:
logger.error(f"[Pre-flight Check] Error during comprehensive limit check: {e}", exc_info=True)
return False, f"Failed to validate limits: {str(e)}", {}
def get_pricing_for_provider_model(self, provider: APIProvider, model_name: str) -> Optional[Dict[str, Any]]: def get_pricing_for_provider_model(self, provider: APIProvider, model_name: str) -> Optional[Dict[str, Any]]:
"""Get pricing configuration for a specific provider and model.""" """Get pricing configuration for a specific provider and model."""

View File

@@ -0,0 +1,39 @@
from typing import Set
from sqlalchemy.orm import Session
_checked_subscription_plan_columns: bool = False
def ensure_subscription_plan_columns(db: Session) -> None:
"""Ensure required columns exist on subscription_plans for runtime safety.
This is a defensive guard for environments where migrations have not yet
been applied. If columns are missing (e.g., exa_calls_limit), we add them
with a safe default so ORM queries do not fail.
"""
global _checked_subscription_plan_columns
if _checked_subscription_plan_columns:
return
try:
# Discover existing columns
result = db.execute("PRAGMA table_info(subscription_plans)")
cols: Set[str] = {row[1] for row in result}
# Columns we may reference in models but might be missing in older DBs
required_columns = {
"exa_calls_limit": "INTEGER DEFAULT 0",
}
for col_name, ddl in required_columns.items():
if col_name not in cols:
db.execute(f"ALTER TABLE subscription_plans ADD COLUMN {col_name} {ddl}")
db.commit()
except Exception:
# Do not block app if pragma/alter fails; let normal errors surface
db.rollback()
finally:
_checked_subscription_plan_columns = True

View File

@@ -7,6 +7,7 @@ Handles authentication, permission checking, and blog publishing to Wix websites
import os import os
import json import json
import requests import requests
import uuid
from typing import Dict, Any, Optional, List from typing import Dict, Any, Optional, List
from loguru import logger from loguru import logger
from datetime import datetime, timedelta from datetime import datetime, timedelta
@@ -19,6 +20,9 @@ from services.integrations.wix.media import WixMediaService
from services.integrations.wix.utils import extract_meta_from_token, normalize_token_string, extract_member_id_from_access_token as utils_extract_member from services.integrations.wix.utils import extract_meta_from_token, normalize_token_string, extract_member_id_from_access_token as utils_extract_member
from services.integrations.wix.content import convert_content_to_ricos as ricos_builder from services.integrations.wix.content import convert_content_to_ricos as ricos_builder
from services.integrations.wix.auth import WixAuthService from services.integrations.wix.auth import WixAuthService
from services.integrations.wix.seo import build_seo_data
from services.integrations.wix.ricos_converter import markdown_to_html, convert_via_wix_api
from services.integrations.wix.blog_publisher import create_blog_post as publish_blog_post
class WixService: class WixService:
"""Service for interacting with Wix APIs""" """Service for interacting with Wix APIs"""
@@ -237,13 +241,35 @@ class WixService:
logger.error(f"Failed to import image to Wix: {e}") logger.error(f"Failed to import image to Wix: {e}")
raise raise
def convert_content_to_ricos(self, content: str, images: List[str] = None) -> Dict[str, Any]: def convert_content_to_ricos(self, content: str, images: List[str] = None,
use_wix_api: bool = False, access_token: str = None) -> Dict[str, Any]:
"""
Convert markdown content to Ricos JSON format.
Args:
content: Markdown content to convert
images: Optional list of image URLs
use_wix_api: If True, use Wix's official Ricos Documents API (requires access_token)
access_token: Wix access token (required if use_wix_api=True)
Returns:
Ricos JSON document
"""
if use_wix_api and access_token:
try:
return convert_via_wix_api(content, access_token, self.base_url)
except Exception as e:
logger.warning(f"Failed to convert via Wix API, falling back to custom parser: {e}")
# Fall back to custom parser
# Use custom parser (current implementation)
return ricos_builder(content, images) return ricos_builder(content, images)
def create_blog_post(self, access_token: str, title: str, content: str, def create_blog_post(self, access_token: str, title: str, content: str,
cover_image_url: str = None, category_ids: List[str] = None, cover_image_url: str = None, category_ids: List[str] = None,
tag_ids: List[str] = None, publish: bool = True, tag_ids: List[str] = None, publish: bool = True,
member_id: str = None) -> Dict[str, Any]: member_id: str = None, seo_metadata: Dict[str, Any] = None) -> Dict[str, Any]:
""" """
Create and optionally publish a blog post on Wix Create and optionally publish a blog post on Wix
@@ -256,101 +282,33 @@ class WixService:
tag_ids: Optional list of tag IDs tag_ids: Optional list of tag IDs
publish: Whether to publish immediately or save as draft publish: Whether to publish immediately or save as draft
member_id: Required for third-party apps - the member ID of the post author member_id: Required for third-party apps - the member ID of the post author
seo_metadata: Optional SEO metadata dict with fields like:
- seo_title: SEO optimized title
- meta_description: Meta description
- focus_keyword: Main keyword
- blog_tags: List of tag strings (for keywords)
- open_graph: Open Graph data
- canonical_url: Canonical URL
Returns: Returns:
Created blog post information Created blog post information
""" """
if not member_id: return publish_blog_post(
raise ValueError("memberId is required for third-party apps creating blog posts") blog_service=self.blog_service,
access_token=access_token,
headers = { title=title,
'Authorization': f'Bearer {access_token}', content=content,
'Content-Type': 'application/json' member_id=member_id,
} cover_image_url=cover_image_url,
category_ids=category_ids,
# Build valid Ricos rich content (minimum: one paragraph with text) tag_ids=tag_ids,
ricos_content = self.convert_content_to_ricos(content or "This is a post from ALwrity.", None) publish=publish,
seo_metadata=seo_metadata,
# Minimal payload per Wix docs: title, memberId, and richContent import_image_func=self.import_image_to_wix,
blog_data = { lookup_categories_func=self.lookup_or_create_categories,
'draftPost': { lookup_tags_func=self.lookup_or_create_tags,
'title': title, base_url=self.base_url
'memberId': member_id, # Required for third-party apps )
'richContent': ricos_content,
'excerpt': (content or '').strip()[:200]
},
'publish': publish,
'fieldsets': ['URL'] # Simplified fieldsets
}
# Add cover image if provided
if cover_image_url:
try:
media_id = self.import_image_to_wix(access_token, cover_image_url, f'Cover: {title}')
blog_data['draftPost']['media'] = {
'wixMedia': {
'image': {'id': media_id}
},
'displayed': True,
'custom': True
}
except Exception as e:
logger.warning(f"Failed to import cover image: {e}")
# Add categories if provided
if category_ids:
blog_data['draftPost']['categoryIds'] = category_ids
# Add tags if provided
if tag_ids:
blog_data['draftPost']['tagIds'] = tag_ids
try:
# Check what permissions we have in the token
logger.info("DEBUG: Checking token permissions...")
try:
import jwt
# Extract token string manually since _normalize_access_token doesn't exist
token_str = str(access_token)
if token_str and token_str.startswith('OauthNG.JWS.'):
jwt_part = token_str[12:]
payload = jwt.decode(jwt_part, options={"verify_signature": False, "verify_aud": False})
logger.info(f"DEBUG: Full token payload: {payload}")
# Check for permissions in various possible locations
data_payload = payload.get('data', {})
if isinstance(data_payload, str):
try:
data_payload = json.loads(data_payload)
except:
pass
instance_data = data_payload.get('instance', {})
permissions = instance_data.get('permissions', '')
scopes = instance_data.get('scopes', [])
meta_site_id = instance_data.get('metaSiteId')
if isinstance(meta_site_id, str) and meta_site_id:
headers['wix-site-id'] = meta_site_id
logger.info(f"DEBUG: Added wix-site-id header: {meta_site_id}")
logger.info(f"DEBUG: Token permissions: {permissions}")
logger.info(f"DEBUG: Token scopes: {scopes}")
else:
logger.info("DEBUG: Could not decode token for permission check")
except Exception as perm_e:
logger.warning(f"DEBUG: Failed to check permissions: {perm_e}")
logger.info(f"DEBUG: Sending simplified blog data: {json.dumps(blog_data, indent=2)}")
extra_headers = {}
if 'wix-site-id' in headers:
extra_headers['wix-site-id'] = headers['wix-site-id']
result = self.blog_service.create_draft_post(access_token, blog_data, extra_headers or None)
logger.info(f"DEBUG: Create draft result: {result}")
return result
except requests.RequestException as e:
logger.error(f"Failed to create blog post: {e}")
if hasattr(e, 'response') and e.response is not None:
logger.error(f"Response body: {e.response.text}")
raise
def get_blog_categories(self, access_token: str) -> List[Dict[str, Any]]: def get_blog_categories(self, access_token: str) -> List[Dict[str, Any]]:
""" """
@@ -383,6 +341,138 @@ class WixService:
except requests.RequestException as e: except requests.RequestException as e:
logger.error(f"Failed to get blog tags: {e}") logger.error(f"Failed to get blog tags: {e}")
raise raise
def lookup_or_create_categories(self, access_token: str, category_names: List[str],
extra_headers: Optional[Dict[str, str]] = None) -> List[str]:
"""
Lookup existing categories by name or create new ones, return their IDs.
Args:
access_token: Valid access token
category_names: List of category name strings
extra_headers: Optional extra headers (e.g., wix-site-id)
Returns:
List of category UUIDs
"""
if not category_names:
return []
try:
# Get existing categories
existing_categories = self.blog_service.list_categories(access_token, extra_headers)
# Create name -> ID mapping (case-insensitive)
category_map = {}
for cat in existing_categories:
cat_label = cat.get('label', '').strip()
cat_id = cat.get('id')
if cat_label and cat_id:
category_map[cat_label.lower()] = cat_id
category_ids = []
for category_name in category_names:
category_name_clean = str(category_name).strip()
if not category_name_clean:
continue
# Lookup existing category (case-insensitive)
category_id = category_map.get(category_name_clean.lower())
if not category_id:
# Create new category
try:
logger.info(f"Creating new category: {category_name_clean}")
result = self.blog_service.create_category(
access_token,
label=category_name_clean,
extra_headers=extra_headers
)
new_category = result.get('category', {})
category_id = new_category.get('id')
if category_id:
category_ids.append(category_id)
# Update map to avoid duplicate creates
category_map[category_name_clean.lower()] = category_id
logger.info(f"Created category '{category_name_clean}' with ID: {category_id}")
except Exception as create_error:
logger.warning(f"Failed to create category '{category_name_clean}': {create_error}")
# Continue with other categories
else:
category_ids.append(category_id)
logger.info(f"Found existing category '{category_name_clean}' with ID: {category_id}")
return category_ids
except requests.RequestException as e:
logger.error(f"Failed to lookup/create categories: {e}")
return []
def lookup_or_create_tags(self, access_token: str, tag_names: List[str],
extra_headers: Optional[Dict[str, str]] = None) -> List[str]:
"""
Lookup existing tags by name or create new ones, return their IDs.
Args:
access_token: Valid access token
tag_names: List of tag name strings
extra_headers: Optional extra headers (e.g., wix-site-id)
Returns:
List of tag UUIDs
"""
if not tag_names:
return []
try:
# Get existing tags
existing_tags = self.blog_service.list_tags(access_token, extra_headers)
# Create name -> ID mapping (case-insensitive)
tag_map = {}
for tag in existing_tags:
tag_label = tag.get('label', '').strip()
tag_id = tag.get('id')
if tag_label and tag_id:
tag_map[tag_label.lower()] = tag_id
tag_ids = []
for tag_name in tag_names:
tag_name_clean = str(tag_name).strip()
if not tag_name_clean:
continue
# Lookup existing tag (case-insensitive)
tag_id = tag_map.get(tag_name_clean.lower())
if not tag_id:
# Create new tag
try:
logger.info(f"Creating new tag: {tag_name_clean}")
result = self.blog_service.create_tag(
access_token,
label=tag_name_clean,
extra_headers=extra_headers
)
new_tag = result.get('tag', {})
tag_id = new_tag.get('id')
if tag_id:
tag_ids.append(tag_id)
# Update map to avoid duplicate creates
tag_map[tag_name_clean.lower()] = tag_id
logger.info(f"Created tag '{tag_name_clean}' with ID: {tag_id}")
except Exception as create_error:
logger.warning(f"Failed to create tag '{tag_name_clean}': {create_error}")
# Continue with other tags
else:
tag_ids.append(tag_id)
logger.info(f"Found existing tag '{tag_name_clean}' with ID: {tag_id}")
return tag_ids
except requests.RequestException as e:
logger.error(f"Failed to lookup/create tags: {e}")
return []
def publish_draft_post(self, access_token: str, draft_post_id: str) -> Dict[str, Any]: def publish_draft_post(self, access_token: str, draft_post_id: str) -> Dict[str, Any]:
""" """

View File

@@ -0,0 +1,335 @@
# Research Component Integration Guide
## Overview
The modular Research component has been implemented as a standalone, testable wizard that can be integrated into the blog writer or used independently. This document outlines the architecture, usage, and integration steps.
## Architecture
### Backend Strategy Pattern
The research service now supports multiple research modes through a strategy pattern:
```python
# Research modes
- Basic: Quick keyword-focused analysis
- Comprehensive: Full analysis with all components
- Targeted: Customizable components based on config
# Strategy implementation
backend/services/blog_writer/research/research_strategies.py
- ResearchStrategy (base class)
- BasicResearchStrategy
- ComprehensiveResearchStrategy
- TargetedResearchStrategy
```
### Frontend Component Structure
```
frontend/src/components/Research/
├── index.tsx # Main exports
├── ResearchWizard.tsx # Main wizard container
├── steps/
│ ├── StepKeyword.tsx # Step 1: Keyword input
│ ├── StepOptions.tsx # Step 2: Mode selection
│ ├── StepProgress.tsx # Step 3: Progress display
│ └── StepResults.tsx # Step 4: Results display
├── hooks/
│ ├── useResearchWizard.ts # Wizard state management
│ └── useResearchExecution.ts # API calls and polling
├── types/
│ └── research.types.ts # TypeScript interfaces
└── utils/
└── researchUtils.ts # Utility functions
```
## Test Page
A dedicated test page is available at `/research-test` for testing the research wizard independently.
**Features:**
- Quick preset keywords for testing
- Debug panel with JSON export
- Performance metrics display
- Cache state visualization
## Usage
### Standalone Usage
```typescript
import { ResearchWizard } from '../components/Research';
<ResearchWizard
onComplete={(results) => {
console.log('Research complete:', results);
}}
onCancel={() => {
console.log('Cancelled');
}}
initialKeywords={['AI', 'marketing']}
initialIndustry="Technology"
/>
```
### Integration with Blog Writer
The component is designed to be easily integrated into the BlogWriter research phase:
**Current Implementation:**
- Uses CopilotKit sidebar for research input
- Displays results in `ResearchResults` component
- Manual fallback via `ManualResearchForm`
**Proposed Integration:**
Replace the CopilotKit/manual form with the wizard:
```typescript
// In BlogWriter.tsx
{currentPhase === 'research' && (
<ResearchWizard
onComplete={(results) => setResearch(results)}
onCancel={() => navigate('blog-writer')}
/>
)}
```
## Backend API Changes
### New Models
The `BlogResearchRequest` model now supports:
```python
class BlogResearchRequest(BaseModel):
keywords: List[str]
topic: Optional[str] = None
industry: Optional[str] = None
target_audience: Optional[str] = None
tone: Optional[str] = None
word_count_target: Optional[int] = 1500
persona: Optional[PersonaInfo] = None
research_mode: Optional[ResearchMode] = ResearchMode.BASIC # NEW
config: Optional[ResearchConfig] = None # NEW
```
### Backward Compatibility
The API remains backward compatible:
- If `research_mode` is not provided, defaults to `BASIC`
- If `config` is not provided, defaults to standard configuration
- Existing requests continue to work unchanged
## Research Modes
### Basic Mode
- Quick keyword analysis
- Primary & secondary keywords
- Current trends overview
- Top 5 content angles
- Key statistics
### Comprehensive Mode
- All basic features plus:
- Expert quotes & opinions
- Competitor analysis
- Market forecasts
- Best practices & case studies
- Content gaps identification
### Targeted Mode
- Selectable components:
- Statistics
- Expert quotes
- Competitors
- Trends
- Always includes: Keywords & content angles
## Configuration Options
### ResearchConfig Model
```python
class ResearchConfig(BaseModel):
mode: ResearchMode = ResearchMode.BASIC
date_range: Optional[DateRange] = None
source_types: List[SourceType] = []
max_sources: int = 10
include_statistics: bool = True
include_expert_quotes: bool = True
include_competitors: bool = True
include_trends: bool = True
```
### Date Range Options
- `last_week`
- `last_month`
- `last_3_months`
- `last_6_months`
- `last_year`
- `all_time`
### Source Types
- `web` - Web articles
- `academic` - Academic papers
- `news` - News articles
- `industry` - Industry reports
- `expert` - Expert opinions
## Caching
The research component uses the existing cache infrastructure:
- Cache keys include research mode
- Cache is shared across basic/comprehensive/targeted modes
- Cache invalidation handled automatically
## Testing
### Test the Wizard
1. Navigate to `/research-test`
2. Use quick presets or enter custom keywords
3. Select research mode
4. Monitor progress
5. Review results
6. Export JSON for analysis
### Integration Testing
To test integration with BlogWriter:
1. Start backend: `python start_alwrity_backend.py`
2. Navigate to `/blog-writer` (current implementation)
3. Or navigate to `/research-test` (new wizard)
4. Compare results and UI
## Migration Path
### Phase 1: Parallel Testing (Current)
- `/research-test` - New wizard available
- `/blog-writer` - Current implementation unchanged
- Users can test both
### Phase 2: Integration
1. Add wizard as option in BlogWriter
2. A/B test user preference
3. Monitor performance metrics
### Phase 3: Replacement (Optional)
1. Replace CopilotKit/manual form with wizard
2. Remove old implementation
3. Update documentation
## API Endpoints
All existing endpoints remain unchanged:
```
POST /api/blog/research/start
- Supports new research_mode and config parameters
- Backward compatible with existing requests
GET /api/blog/research/status/{task_id}
- No changes required
```
## Benefits
1. **Modularity**: Component works standalone
2. **Testability**: Dedicated test page for experimentation
3. **Backward Compatibility**: Existing functionality unchanged
4. **Progressive Enhancement**: Can add features incrementally
5. **Reusability**: Can be used in other parts of the app
## Future Enhancements
Potential future improvements:
1. **Multi-stage Research**: Sequential research with refinement
2. **Source Quality Validation**: Advanced credibility scoring
3. **Interactive Query Builder**: Dynamic search refinement
4. **Advanced Prompting**: Few-shot examples, reasoning chains
5. **Custom Strategy Plugins**: User-defined research strategies
## Troubleshooting
### Research Results Not Showing
Check:
1. Backend logs for API errors
2. Network tab for failed requests
3. Browser console for JavaScript errors
4. Verify user authentication
### Cache Issues
Clear cache:
```typescript
import { researchCache } from '../services/researchCache';
researchCache.clearCache();
```
### Type Errors
Ensure all imports are correct:
```typescript
import {
ResearchWizard,
useResearchWizard,
WizardState
} from '../components/Research';
import {
BlogResearchRequest,
BlogResearchResponse,
ResearchMode,
ResearchConfig
} from '../services/blogWriterApi';
```
## Examples
### Basic Integration
```typescript
import { ResearchWizard } from './components/Research';
import { BlogResearchResponse } from './services/blogWriterApi';
const MyComponent: React.FC = () => {
const [results, setResults] = useState<BlogResearchResponse | null>(null);
return (
<ResearchWizard
onComplete={(res) => setResults(res)}
onCancel={() => console.log('Cancelled')}
/>
);
};
```
### Advanced Integration with Custom Config
```typescript
const request: BlogResearchRequest = {
keywords: ['AI', 'automation'],
industry: 'Technology',
research_mode: 'targeted',
config: {
mode: 'targeted',
include_statistics: true,
include_competitors: true,
include_trends: false,
max_sources: 20,
}
};
```
## Support
For issues or questions:
1. Check this documentation
2. Review test page examples
3. Inspect backend logs
4. Check frontend console

View File

@@ -0,0 +1,346 @@
# Research Wizard Implementation Summary
## Implementation Complete
A modular, pluggable research component has been successfully implemented with wizard-based UI that can be tested independently and integrated into the blog writer.
---
## Backend Implementation
### 1. Research Models (blog_models.py)
**New Enums:**
- `ResearchMode`: `BASIC`, `COMPREHENSIVE`, `TARGETED`
- `SourceType`: `WEB`, `ACADEMIC`, `NEWS`, `INDUSTRY`, `EXPERT`
- `DateRange`: `LAST_WEEK` through `ALL_TIME`
**New Models:**
```python
class ResearchConfig(BaseModel):
mode: ResearchMode = ResearchMode.BASIC
date_range: Optional[DateRange] = None
source_types: List[SourceType] = []
max_sources: int = 10
include_statistics: bool = True
include_expert_quotes: bool = True
include_competitors: bool = True
include_trends: bool = True
```
**Enhanced BlogResearchRequest:**
- Added `research_mode: Optional[ResearchMode]`
- Added `config: Optional[ResearchConfig]`
- **Backward compatible** - defaults to existing behavior
### 2. Strategy Pattern (research_strategies.py)
**New file:** `backend/services/blog_writer/research/research_strategies.py`
**Three Strategy Classes:**
1. **BasicResearchStrategy**: Quick keyword-focused analysis
2. **ComprehensiveResearchStrategy**: Full analysis with all components
3. **TargetedResearchStrategy**: Customizable components based on config
**Factory Function:**
```python
get_strategy_for_mode(mode: ResearchMode) -> ResearchStrategy
```
### 3. Service Integration (research_service.py)
**Key Changes:**
- Imports strategy factory and models
- Uses strategy pattern in both `research()` and `research_with_progress()` methods
- Automatically selects strategy based on `research_mode`
- Backward compatible - defaults to BASIC if not specified
**Line Changes:**
```python
# Lines 88-96: Determine research mode and get appropriate strategy
research_mode = request.research_mode or ResearchMode.BASIC
config = request.config or ResearchConfig(mode=research_mode)
strategy = get_strategy_for_mode(research_mode)
logger.info(f"Using research mode: {research_mode.value}")
# Build research prompt based on strategy
research_prompt = strategy.build_research_prompt(topic, industry, target_audience, config)
```
---
## Frontend Implementation
### 4. Component Structure
**New Directory:** `frontend/src/components/Research/`
```
Research/
├── index.tsx # Main exports
├── ResearchWizard.tsx # Main wizard container
├── steps/
│ ├── StepKeyword.tsx # Step 1: Keyword input
│ ├── StepOptions.tsx # Step 2: Mode selection (3 cards)
│ ├── StepProgress.tsx # Step 3: Progress display
│ └── StepResults.tsx # Step 4: Results display
├── hooks/
│ ├── useResearchWizard.ts # Wizard state management
│ └── useResearchExecution.ts # API calls and polling
├── types/
│ └── research.types.ts # TypeScript interfaces
├── utils/
│ └── researchUtils.ts # Utility functions
└── integrations/
└── BlogWriterAdapter.tsx # Blog writer integration adapter
```
### 5. Wizard Components
**ResearchWizard.tsx:**
- Main container with progress bar
- Step indicators (Setup → Options → Research → Results)
- Navigation footer with Back/Next buttons
- Responsive layout
**StepKeyword.tsx:**
- Keywords textarea
- Industry dropdown (16 options)
- Target audience input
- Validation for keyword requirements
**StepOptions.tsx:**
- Three mode cards (Basic, Comprehensive, Targeted)
- Visual selection feedback
- Feature lists per mode
- Hover effects
**StepProgress.tsx:**
- Real-time progress updates
- Progress messages display
- Cancel button
- Auto-advance to results on completion
**StepResults.tsx:**
- Displays research results using existing `ResearchResults` component
- Export JSON button
- Start new research button
### 6. Hooks
**useResearchWizard.ts:**
- State management for wizard steps
- localStorage persistence
- Step navigation (next/back)
- Validation per step
- Reset functionality
**useResearchExecution.ts:**
- Research execution via API
- Cache checking
- Polling integration
- Error handling
- Progress tracking
### 7. Test Page (ResearchTest.tsx)
**Location:** `frontend/src/pages/ResearchTest.tsx`
**Route:** `/research-test`
**Features:**
- Quick preset buttons (3 samples)
- Debug panel with JSON export
- Performance metrics display
- Cache state visualization
- Research statistics summary
**Sample Presets:**
1. AI Marketing Tools
2. Small Business SEO
3. Content Strategy
### 8. Type Definitions
**research.types.ts:**
- `WizardState`
- `WizardStepProps`
- `ResearchWizardProps`
- `ModeCardInfo`
**blogWriterApi.ts:**
- `ResearchMode` type union
- `SourceType` type union
- `DateRange` type union
- `ResearchConfig` interface
- Updated `BlogResearchRequest` interface
---
## Integration
### 9. Blog Writer API (blogWriterApi.ts)
**Enhanced Interface:**
```typescript
export interface BlogResearchRequest {
keywords: string[];
topic?: string;
industry?: string;
target_audience?: string;
tone?: string;
word_count_target?: number;
persona?: PersonaInfo;
research_mode?: ResearchMode; // NEW
config?: ResearchConfig; // NEW
}
```
### 10. App Routing (App.tsx)
**New Route:**
```typescript
<Route path="/research-test" element={<ResearchTest />} />
```
### 11. Integration Adapter
**BlogWriterAdapter.tsx:**
- Wrapper component for easy integration
- Usage examples included
- Clean interface for BlogWriter
---
## Documentation
### 12. Integration Guide
**File:** `docs/RESEARCH_COMPONENT_INTEGRATION.md`
**Contents:**
- Architecture overview
- Usage examples
- Backend API details
- Research modes explained
- Configuration options
- Testing instructions
- Migration path
- Troubleshooting guide
---
## Key Features
### Research Modes
**Basic Mode:**
- Quick keyword analysis
- Primary & secondary keywords
- Trends overview
- Top 5 content angles
- Key statistics
**Comprehensive Mode:**
- All basic features
- Expert quotes & opinions
- Competitor analysis
- Market forecasts
- Best practices & case studies
- Content gaps identification
**Targeted Mode:**
- Selectable components
- Customizable filters
- Date range options
- Source type filtering
### User Experience
1. **Step-by-step wizard** with clear progress
2. **Visual mode selection** with cards
3. **Real-time progress** with live updates
4. **Comprehensive results** with export capability
5. **Error handling** with retry options
6. **Cache integration** for instant results
### Developer Experience
1. **Modular architecture** - standalone components
2. **Type safety** - full TypeScript interfaces
3. **Reusable hooks** - state and execution management
4. **Test page** - isolated testing environment
5. **Documentation** - comprehensive guides
---
## Testing
### Quick Test
1. Navigate to `http://localhost:3000/research-test`
2. Click "AI Marketing Tools" preset
3. Select "Comprehensive" mode
4. Watch progress updates
5. Review results with export
### Integration Test
1. Compare `/research-test` wizard UI
2. Compare `/blog-writer` current UI
3. Test both research workflows
4. Verify caching works across both
---
## Backward Compatibility
- Existing API calls continue working
- No breaking changes to BlogWriter
- Optional parameters default to current behavior
- Cache infrastructure shared
- All existing features preserved
---
## File Summary
**Backend (4 files):**
- Modified: `blog_models.py`, `research_service.py`
- Created: `research_strategies.py`
**Frontend (13 files):**
- Created: `ResearchWizard.tsx`, 4 step components, 2 hooks, types, utils, adapter, test page
- Modified: `App.tsx`, `blogWriterApi.ts`
**Documentation (2 files):**
- Created: `RESEARCH_COMPONENT_INTEGRATION.md`, `RESEARCH_WIZARD_IMPLEMENTATION.md`
---
## Next Steps
1.**Test the wizard** at `/research-test`
2.**Review integration guide** in docs
3.**Integrate into BlogWriter** using adapter (optional)
4.**Gather user feedback** on wizard vs CopilotKit UI
5.**Add more presets** if needed
---
## Benefits Delivered
- Modular & Pluggable: Standalone component
- Testable: Dedicated test page
- Backward Compatible: No breaking changes
- Reusable: Can be used anywhere in the app
- Extensible: Easy to add new modes or features
- Documented: Comprehensive guides
- Type Safe: Full TypeScript support
- Production Ready: No linting errors
---
Implementation Date: Current Session
Status: Complete & Ready for Testing

View File

@@ -0,0 +1,150 @@
# Complete Wix SEO Metadata Implementation
## 📊 SEO Metadata Generated vs Posted
### ✅ FULLY POSTED TO WIX
#### 1. **SEO Keywords** (in `seoData.settings.keywords`)
-`focus_keyword` → Main keyword (`isMain: true`)
-`blog_tags` → Additional keywords (`isMain: false`)
-`social_hashtags` → Additional keywords (`isMain: false`)
#### 2. **Meta Tags** (in `seoData.tags`)
-`meta_description``<meta name="description">`
-`seo_title``<meta name="title">`
#### 3. **Open Graph Tags** (in `seoData.tags`)
-`open_graph.title``og:title`
-`open_graph.description``og:description`
-`open_graph.image``og:image` (HTTP/HTTPS URLs only)
-`og:type` → Always set to `article`
-`open_graph.url` or `canonical_url``og:url`
#### 4. **Twitter Card Tags** (in `seoData.tags`)
-`twitter_card.title``twitter:title`
-`twitter_card.description``twitter:description`
-`twitter_card.image``twitter:image` (HTTP/HTTPS URLs only)
-`twitter_card.card``twitter:card` (default: `summary_large_image`)
#### 5. **Canonical URL** (in `seoData.tags`)
-`canonical_url``<link rel="canonical">`
#### 6. **Blog Categories** (in `draftPost.categoryIds`)
-`blog_categories` → Lookup/create categories → `categoryIds` (UUIDs)
- **Implementation**: `lookup_or_create_categories()` method
- **Behavior**: Case-insensitive lookup, auto-create if missing
#### 7. **Blog Tags** (in `draftPost.tagIds`)
-`blog_tags` → Lookup/create tags → `tagIds` (UUIDs)
- **Implementation**: `lookup_or_create_tags()` method
- **Behavior**: Case-insensitive lookup, auto-create if missing
- **Note**: `blog_tags` are also used in SEO keywords, but separately as post tags
### ❌ NOT POSTED (Optional/Future)
1. **JSON-LD Structured Data** (`json_ld_schema`)
- **Reason**: Wix doesn't support JSON-LD in backend API
- **Solution**: Would require frontend implementation using `@wix/site-seo` package
- **Status**: Not implemented (would need to be added to Wix site code)
2. **URL Slug** (`url_slug`)
- **Reason**: Wix auto-generates URLs from title
- **Status**: Could be implemented if Wix API supports custom slugs
3. **Reading Time** (`reading_time`)
- **Reason**: Metadata only, not part of Wix blog post structure
- **Status**: Not applicable
4. **Optimization Score** (`optimization_score`)
- **Reason**: Internal metadata for ALwrity, not Wix field
- **Status**: Not applicable
## 🔄 Conversion Methods
### Markdown to Ricos Conversion
**Primary Method**: Wix Official Ricos Documents API
- **Endpoint**: Tries multiple paths to find correct endpoint
- **Benefits**: Official conversion, handles all edge cases
- **Fallback**: Custom parser if API unavailable
**Fallback Method**: Custom Markdown Parser
- **Location**: `backend/services/integrations/wix/content.py`
- **Supports**: Headings, paragraphs, lists, bold, italic, links, images, blockquotes
## 📋 Complete Post Structure
When publishing to Wix, the blog post includes:
```json
{
"draftPost": {
"title": "SEO optimized title",
"memberId": "author-member-id",
"richContent": { /* Ricos JSON document */ },
"excerpt": "First 200 chars of content",
"categoryIds": ["uuid1", "uuid2"], // From blog_categories
"tagIds": ["uuid1", "uuid2"], // From blog_tags
"media": { /* Cover image if provided */ },
"seoData": {
"settings": {
"keywords": [
{ "term": "main keyword", "isMain": true },
{ "term": "tag1", "isMain": false },
{ "term": "tag2", "isMain": false }
]
},
"tags": [
{ "type": "meta", "props": { "name": "description", "content": "..." } },
{ "type": "meta", "props": { "name": "title", "content": "..." } },
{ "type": "meta", "props": { "property": "og:title", "content": "..." } },
{ "type": "meta", "props": { "property": "og:description", "content": "..." } },
{ "type": "meta", "props": { "property": "og:image", "content": "..." } },
{ "type": "meta", "props": { "property": "og:type", "content": "article" } },
{ "type": "meta", "props": { "property": "og:url", "content": "..." } },
{ "type": "meta", "props": { "name": "twitter:title", "content": "..." } },
{ "type": "meta", "props": { "name": "twitter:description", "content": "..." } },
{ "type": "meta", "props": { "name": "twitter:image", "content": "..." } },
{ "type": "meta", "props": { "name": "twitter:card", "content": "summary_large_image" } },
{ "type": "link", "props": { "rel": "canonical", "href": "..." } }
]
}
},
"publish": true
}
```
## ✅ Implementation Status
### Fully Implemented ✅
- SEO keywords (main + additional)
- Meta description and title
- Open Graph tags (all standard fields)
- Twitter Card tags (all standard fields)
- Canonical URL
- **Blog categories** (lookup/create)
- **Blog tags** (lookup/create)
- Wix Ricos API integration (with fallback)
### Partially Implemented ⚠️
- Image handling (only HTTP/HTTPS URLs, base64 skipped)
### Not Implemented ❌
- JSON-LD structured data (requires frontend)
- URL slug customization
- Reading time (not applicable)
- Optimization score (not applicable)
## 🎯 Summary
**All major SEO metadata fields are now being posted to Wix:**
- ✅ Keywords
- ✅ Meta tags
- ✅ Open Graph
- ✅ Twitter Cards
- ✅ Canonical URL
- ✅ Categories (auto-lookup/create)
- ✅ Tags (auto-lookup/create)
The only missing piece is JSON-LD structured data, which requires frontend implementation in the Wix site code using the `@wix/site-seo` package.

View File

@@ -0,0 +1,102 @@
# Wix SEO Metadata Review
## SEO Metadata We Generate (`BlogSEOMetadataResponse`)
### Available Fields:
1.**seo_title** - SEO optimized title
2.**meta_description** - Meta description
3.**url_slug** - URL slug for the blog post
4.**blog_tags** - Array of tag strings (NOW being used for Wix post tags via lookup/create)
5.**blog_categories** - Array of category strings (NOW being used for Wix post categories via lookup/create)
6.**social_hashtags** - Hashtags for social media
7.**open_graph** - Open Graph metadata object:
- title
- description
- image
- url
- type
8.**twitter_card** - Twitter Card metadata object:
- title
- description
- image
- card (type)
9.**canonical_url** - Canonical URL
10.**focus_keyword** - Main SEO keyword
11.**json_ld_schema** - JSON-LD structured data (NOT being posted - would need frontend implementation)
12.**schema** - Legacy schema field (NOT being used)
13.**reading_time** - Estimated reading time (NOT being posted)
14.**optimization_score** - SEO optimization score (NOT being posted)
15.**generated_at** - Generation timestamp (NOT being posted)
## What We're Currently Posting to Wix
### ✅ Posted via `seoData`:
- **Keywords** (from `focus_keyword`, `blog_tags`, `social_hashtags`)
- Main keyword: `focus_keyword``isMain: true`
- Additional keywords: `blog_tags` and `social_hashtags``isMain: false`
- **Meta Tags**:
- `meta description``<meta name="description">`
- `seo_title``<meta name="title">`
- **Open Graph Tags**:
- `og:title`, `og:description`, `og:image`, `og:type`, `og:url`
- **Twitter Card Tags**:
- `twitter:title`, `twitter:description`, `twitter:image`, `twitter:card`
- **Canonical URL**:
- `<link rel="canonical">`
### ✅ NOW Being Posted (Recently Implemented):
1. **Blog Categories** (`blog_categories`)
-**Implemented**: `lookup_or_create_categories()` method
-**Behavior**: Case-insensitive lookup, auto-create if missing
-**Result**: Categories from SEO metadata are posted as `categoryIds` (UUIDs)
2. **Blog Tags** (`blog_tags` for post organization)
-**Implemented**: `lookup_or_create_tags()` method
-**Behavior**: Case-insensitive lookup, auto-create if missing
-**Result**: Tags from SEO metadata are posted as `tagIds` (UUIDs)
- **Note**: `blog_tags` are used BOTH for SEO keywords AND for Wix post tags
3. **JSON-LD Structured Data** (`json_ld_schema`)
- **Issue**: Wix doesn't support JSON-LD in backend API
- **Solution**: Would need frontend implementation using `@wix/site-seo` package
- **Status**: Not implemented
4. **URL Slug** (`url_slug`)
- **Issue**: Not being passed to Wix
- **Status**: Wix generates URL automatically, but we could potentially set it
## Implementation Status
### ✅ Fully Implemented:
- SEO keywords in `seoData.settings.keywords`
- Meta description tag
- SEO title tag
- Open Graph tags (title, description, image, type, url)
- Twitter Card tags (title, description, image, card type)
- Canonical URL link tag
### ✅ Fully Implemented:
- **Blog Categories**: Auto-lookup/create from `blog_categories`
- **Blog Tags**: Auto-lookup/create from `blog_tags`
- **Wix Ricos API Integration**: Uses official Wix API with fallback to custom parser
### ❌ Not Implemented (Optional):
- JSON-LD structured data (frontend only - requires `@wix/site-seo` package)
- URL slug setting (Wix auto-generates URLs)
- Reading time (metadata only, not applicable)
- Optimization score (metadata only, not applicable)
## Summary
**All major SEO metadata is now being posted to Wix:**
- SEO keywords (main + additional)
- Meta tags (description, title)
- Open Graph tags (title, description, image, type, url)
- Twitter Card tags (title, description, image, card type)
- Canonical URL
- **Blog Categories** (auto-lookup/create)
- **Blog Tags** (auto-lookup/create)
The only missing piece is JSON-LD structured data, which requires frontend implementation in the Wix site code using `@wix/site-seo` package (not a backend concern).

View File

@@ -17,13 +17,16 @@ import WixCallbackPage from './components/WixCallbackPage/WixCallbackPage';
import WordPressCallbackPage from './components/WordPressCallbackPage/WordPressCallbackPage'; import WordPressCallbackPage from './components/WordPressCallbackPage/WordPressCallbackPage';
import BingCallbackPage from './components/BingCallbackPage/BingCallbackPage'; import BingCallbackPage from './components/BingCallbackPage/BingCallbackPage';
import BingAnalyticsStorage from './components/BingAnalyticsStorage/BingAnalyticsStorage'; import BingAnalyticsStorage from './components/BingAnalyticsStorage/BingAnalyticsStorage';
import ResearchTest from './pages/ResearchTest';
import ProtectedRoute from './components/shared/ProtectedRoute'; import ProtectedRoute from './components/shared/ProtectedRoute';
import GSCAuthCallback from './components/SEODashboard/components/GSCAuthCallback'; import GSCAuthCallback from './components/SEODashboard/components/GSCAuthCallback';
import Landing from './components/Landing/Landing'; import Landing from './components/Landing/Landing';
import ErrorBoundary from './components/shared/ErrorBoundary'; import ErrorBoundary from './components/shared/ErrorBoundary';
import ErrorBoundaryTest from './components/shared/ErrorBoundaryTest'; import ErrorBoundaryTest from './components/shared/ErrorBoundaryTest';
import CopilotKitDegradedBanner from './components/shared/CopilotKitDegradedBanner';
import { OnboardingProvider } from './contexts/OnboardingContext'; import { OnboardingProvider } from './contexts/OnboardingContext';
import { SubscriptionProvider, useSubscription } from './contexts/SubscriptionContext'; import { SubscriptionProvider, useSubscription } from './contexts/SubscriptionContext';
import { CopilotKitHealthProvider } from './contexts/CopilotKitHealthContext';
import { setAuthTokenGetter } from './api/client'; import { setAuthTokenGetter } from './api/client';
import { useOnboarding } from './contexts/OnboardingContext'; import { useOnboarding } from './contexts/OnboardingContext';
@@ -397,6 +400,7 @@ const App: React.FC = () => {
<Route path="/linkedin-writer" element={<ProtectedRoute><LinkedInWriter /></ProtectedRoute>} /> <Route path="/linkedin-writer" element={<ProtectedRoute><LinkedInWriter /></ProtectedRoute>} />
<Route path="/blog-writer" element={<ProtectedRoute><BlogWriter /></ProtectedRoute>} /> <Route path="/blog-writer" element={<ProtectedRoute><BlogWriter /></ProtectedRoute>} />
<Route path="/pricing" element={<PricingPage />} /> <Route path="/pricing" element={<PricingPage />} />
<Route path="/research-test" element={<ResearchTest />} />
<Route path="/wix-test" element={<WixTestPage />} /> <Route path="/wix-test" element={<WixTestPage />} />
<Route path="/wix-test-direct" element={<WixTestPage />} /> <Route path="/wix-test-direct" element={<WixTestPage />} />
<Route path="/wix/callback" element={<WixCallbackPage />} /> <Route path="/wix/callback" element={<WixCallbackPage />} />
@@ -411,14 +415,57 @@ const App: React.FC = () => {
// Only wrap with CopilotKit if we have a valid key // Only wrap with CopilotKit if we have a valid key
if (copilotApiKey && copilotApiKey.trim()) { if (copilotApiKey && copilotApiKey.trim()) {
// Enhanced error handler that updates health context
const handleCopilotKitError = (e: any) => {
console.error("CopilotKit Error:", e);
// Try to get health context if available
// We'll use a custom event to notify health context since we can't access it directly here
const errorMessage = e?.error?.message || e?.message || 'CopilotKit error occurred';
const errorType = errorMessage.toLowerCase();
// Differentiate between fatal and transient errors
const isFatalError =
errorType.includes('cors') ||
errorType.includes('ssl') ||
errorType.includes('certificate') ||
errorType.includes('403') ||
errorType.includes('forbidden') ||
errorType.includes('ERR_CERT_COMMON_NAME_INVALID');
// Dispatch event for health context to listen to
window.dispatchEvent(new CustomEvent('copilotkit-error', {
detail: {
error: e,
errorMessage,
isFatal: isFatalError,
}
}));
};
return ( return (
<CopilotKit <ErrorBoundary
publicApiKey={copilotApiKey} context="CopilotKit"
showDevConsole={false} showDetails={process.env.NODE_ENV === 'development'}
onError={(e) => console.error("CopilotKit Error:", e)} fallback={
<Box sx={{ p: 3, textAlign: 'center' }}>
<Typography variant="h6" color="warning" gutterBottom>
Chat Unavailable
</Typography>
<Typography variant="body2" color="textSecondary">
CopilotKit encountered an error. The app continues to work with manual controls.
</Typography>
</Box>
}
> >
{appContent} <CopilotKit
</CopilotKit> publicApiKey={copilotApiKey}
showDevConsole={false}
onError={handleCopilotKitError}
>
{appContent}
</CopilotKit>
</ErrorBoundary>
); );
} }
@@ -426,6 +473,9 @@ const App: React.FC = () => {
return appContent; return appContent;
}; };
// Determine initial health status based on whether CopilotKit key is available
const hasCopilotKitKey = copilotApiKey && copilotApiKey.trim();
return ( return (
<ErrorBoundary <ErrorBoundary
context="Application Root" context="Application Root"
@@ -439,7 +489,10 @@ const App: React.FC = () => {
<ClerkProvider publishableKey={clerkPublishableKey}> <ClerkProvider publishableKey={clerkPublishableKey}>
<SubscriptionProvider> <SubscriptionProvider>
<OnboardingProvider> <OnboardingProvider>
{renderApp()} <CopilotKitHealthProvider initialHealthStatus={!!hasCopilotKitKey}>
<CopilotKitDegradedBanner />
{renderApp()}
</CopilotKitHealthProvider>
</OnboardingProvider> </OnboardingProvider>
</SubscriptionProvider> </SubscriptionProvider>
</ClerkProvider> </ClerkProvider>

View File

@@ -357,13 +357,23 @@ pollingApiClient.interceptors.response.use(
} }
// Check if it's a subscription-related error and handle it globally // Check if it's a subscription-related error and handle it globally
if (error.response?.status === 429 || error.response?.status === 402) { if (error.response?.status === 429 || error.response?.status === 402) {
console.log('Polling API Client: Detected subscription error, triggering global handler'); console.log('Polling API Client: Detected subscription error, triggering global handler', {
status: error.response?.status,
data: error.response?.data,
hasHandler: !!globalSubscriptionErrorHandler
});
if (globalSubscriptionErrorHandler) { if (globalSubscriptionErrorHandler) {
const wasHandled = globalSubscriptionErrorHandler(error); const wasHandled = globalSubscriptionErrorHandler(error);
console.log('Polling API Client: Global handler returned', wasHandled);
if (wasHandled) { if (wasHandled) {
console.log('Polling API Client: Subscription error handled by global handler'); console.log('Polling API Client: Subscription error handled by global handler - modal should be showing');
return Promise.reject(error); } else {
console.warn('Polling API Client: Global handler did not handle subscription error');
} }
// Always reject so the polling hook can also handle it
return Promise.reject(error);
} else {
console.warn('Polling API Client: No global subscription error handler registered');
} }
} }

View File

@@ -1,49 +1,39 @@
import React, { useState, useEffect, useRef, useCallback } from 'react'; import React, { useRef, useCallback } from 'react';
import { debug } from '../../utils/debug'; import { debug } from '../../utils/debug';
import { CopilotSidebar } from '@copilotkit/react-ui';
import { useCopilotChatHeadless_c } from '@copilotkit/react-core';
import { useCopilotAction } from '@copilotkit/react-core';
import '@copilotkit/react-ui/styles.css';
import WriterCopilotSidebar from './BlogWriterUtils/WriterCopilotSidebar'; import WriterCopilotSidebar from './BlogWriterUtils/WriterCopilotSidebar';
import { blogWriterApi, BlogSEOActionableRecommendation } from '../../services/blogWriterApi'; import { blogWriterApi } from '../../services/blogWriterApi';
import { useOutlinePolling, useMediumGenerationPolling, useResearchPolling, useRewritePolling } from '../../hooks/usePolling';
import { useClaimFixer } from '../../hooks/useClaimFixer'; import { useClaimFixer } from '../../hooks/useClaimFixer';
import { useMarkdownProcessor } from '../../hooks/useMarkdownProcessor'; import { useMarkdownProcessor } from '../../hooks/useMarkdownProcessor';
import { useBlogWriterState } from '../../hooks/useBlogWriterState'; import { useBlogWriterState } from '../../hooks/useBlogWriterState';
import { useSuggestions } from './SuggestionsGenerator';
import EnhancedOutlineEditor from './EnhancedOutlineEditor';
import ContinuityBadge from './ContinuityBadge';
import EnhancedTitleSelector from './EnhancedTitleSelector';
import SEOMiniPanel from './SEOMiniPanel';
import ResearchResults from './ResearchResults';
import KeywordInputForm from './KeywordInputForm';
import ResearchAction from './ResearchAction';
import { CustomOutlineForm } from './CustomOutlineForm';
import { ResearchDataActions } from './ResearchDataActions';
import { EnhancedOutlineActions } from './EnhancedOutlineActions';
import HallucinationChecker from './HallucinationChecker'; import HallucinationChecker from './HallucinationChecker';
import { RewriteFeedbackForm } from './RewriteFeedbackForm';
import Publisher from './Publisher'; import Publisher from './Publisher';
import OutlineGenerator from './OutlineGenerator'; import OutlineGenerator from './OutlineGenerator';
import OutlineRefiner from './OutlineRefiner'; import OutlineRefiner from './OutlineRefiner';
import { SEOProcessor } from './SEO'; import { SEOProcessor } from './SEO';
import BlogWriterLanding from './BlogWriterLanding';
import { OutlineProgressModal } from './OutlineProgressModal';
import TaskProgressModals from './BlogWriterUtils/TaskProgressModals'; import TaskProgressModals from './BlogWriterUtils/TaskProgressModals';
import OutlineFeedbackForm from './OutlineFeedbackForm';
import { BlogEditor } from './WYSIWYG';
import { SEOAnalysisModal } from './SEOAnalysisModal'; import { SEOAnalysisModal } from './SEOAnalysisModal';
import { SEOMetadataModal } from './SEOMetadataModal'; import { SEOMetadataModal } from './SEOMetadataModal';
import PhaseNavigation from './PhaseNavigation';
import { usePhaseNavigation } from '../../hooks/usePhaseNavigation'; import { usePhaseNavigation } from '../../hooks/usePhaseNavigation';
import HeaderBar from './BlogWriterUtils/HeaderBar'; import HeaderBar from './BlogWriterUtils/HeaderBar';
import PhaseContent from './BlogWriterUtils/PhaseContent'; import PhaseContent from './BlogWriterUtils/PhaseContent';
import useBlogWriterCopilotActions from './BlogWriterUtils/useBlogWriterCopilotActions'; import useBlogWriterCopilotActions from './BlogWriterUtils/useBlogWriterCopilotActions';
import { useCopilotKitHealth } from '../../hooks/useCopilotKitHealth';
// Type assertion for CopilotKit action import { useSEOManager } from './BlogWriterUtils/useSEOManager';
const useCopilotActionTyped = useCopilotAction as any; import { usePhaseActionHandlers } from './BlogWriterUtils/usePhaseActionHandlers';
import { useBlogWriterPolling } from './BlogWriterUtils/useBlogWriterPolling';
import { useCopilotSuggestions } from './BlogWriterUtils/useCopilotSuggestions';
import { usePhaseRestoration } from './BlogWriterUtils/usePhaseRestoration';
import { useModalVisibility } from './BlogWriterUtils/useModalVisibility';
import { useBlogWriterRefs } from './BlogWriterUtils/useBlogWriterRefs';
import { BlogWriterLandingSection } from './BlogWriterUtils/BlogWriterLandingSection';
import { CopilotKitComponents } from './BlogWriterUtils/CopilotKitComponents';
export const BlogWriter: React.FC = () => { export const BlogWriter: React.FC = () => {
// Check CopilotKit health status
const { isAvailable: copilotKitAvailable } = useCopilotKitHealth({
enabled: true, // Enable health checking
});
// Use custom hook for all state management // Use custom hook for all state management
const { const {
research, research,
@@ -91,17 +81,64 @@ export const BlogWriter: React.FC = () => {
handleContentSave handleContentSave
} = useBlogWriterState(); } = useBlogWriterState();
const [isSEOAnalysisModalOpen, setIsSEOAnalysisModalOpen] = useState(false); // SEO Manager - handles all SEO-related logic
const [isSEOMetadataModalOpen, setIsSEOMetadataModalOpen] = useState(false); // Initialize phase navigation with temporary false value for seoRecommendationsApplied
const [seoRecommendationsApplied, setSeoRecommendationsApplied] = useState(false); const [tempSeoRecommendationsApplied] = React.useState(false);
const lastSEOModalOpenRef = useRef<number>(0); const {
phases: tempPhases,
currentPhase: tempCurrentPhase,
navigateToPhase: tempNavigateToPhase,
setCurrentPhase: tempSetCurrentPhase,
resetUserSelection
} = usePhaseNavigation(
research,
outline,
outlineConfirmed,
Object.keys(sections).length > 0,
contentConfirmed,
seoAnalysis,
seoMetadata,
tempSeoRecommendationsApplied
);
// Phase navigation hook const {
isSEOAnalysisModalOpen,
setIsSEOAnalysisModalOpen,
isSEOMetadataModalOpen,
setIsSEOMetadataModalOpen,
seoRecommendationsApplied,
setSeoRecommendationsApplied,
lastSEOModalOpenRef,
runSEOAnalysisDirect,
handleApplySeoRecommendations,
handleSEOAnalysisComplete,
handleSEOModalClose,
confirmBlogContent,
} = useSEOManager({
sections,
research,
outline,
selectedTitle,
contentConfirmed,
seoAnalysis,
currentPhase: tempCurrentPhase,
navigateToPhase: tempNavigateToPhase,
setContentConfirmed,
setSeoAnalysis,
setSeoMetadata,
setSections,
setSelectedTitle: setSelectedTitle as (title: string | null) => void,
setContinuityRefresh,
setFlowAnalysisCompleted,
setFlowAnalysisResults,
});
// Phase navigation hook with correct seoRecommendationsApplied
const { const {
phases, phases,
currentPhase, currentPhase,
navigateToPhase, navigateToPhase,
resetUserSelection setCurrentPhase,
} = usePhaseNavigation( } = usePhaseNavigation(
research, research,
outline, outline,
@@ -113,204 +150,17 @@ export const BlogWriter: React.FC = () => {
seoRecommendationsApplied seoRecommendationsApplied
); );
// Helper: run same checks as analyzeSEO and open modal // Phase restoration logic
const runSEOAnalysisDirect = (): string => { usePhaseRestoration({
const hasSections = !!sections && Object.keys(sections).length > 0; copilotKitAvailable,
const hasResearch = !!research && !!(research as any).keyword_analysis; research,
if (!hasSections) return "No blog content available for SEO analysis. Please generate content first."; phases,
if (!hasResearch) return "Research data is required for SEO analysis. Please run research first."; currentPhase,
// Prevent rapid re-opens navigateToPhase,
const now = Date.now(); setCurrentPhase,
if (isSEOAnalysisModalOpen && now - lastSEOModalOpenRef.current < 1000) { });
return "SEO analysis is already open.";
}
// Mark content phase as done when user clicks "Next: Run SEO Analysis"
if (!contentConfirmed) {
setContentConfirmed(true);
debug.log('[BlogWriter] Content phase marked as done (SEO analysis triggered)');
}
setSeoRecommendationsApplied(false);
if (!isSEOAnalysisModalOpen) {
setIsSEOAnalysisModalOpen(true);
lastSEOModalOpenRef.current = now;
debug.log('[BlogWriter] SEO modal opened (direct)');
}
return "Running SEO analysis of your blog content. This will analyze content structure, keyword optimization, readability, and provide actionable recommendations.";
};
const handleApplySeoRecommendations = useCallback(async ( // All SEO management logic is now in useSEOManager hook above
recommendations: BlogSEOActionableRecommendation[]
) => {
if (!outline || outline.length === 0) {
throw new Error('An outline is required before applying recommendations.');
}
const sectionPayload = outline.map((section) => ({
id: section.id,
heading: section.heading,
content: sections[section.id] ?? '',
}));
const response = await blogWriterApi.applySeoRecommendations({
title: selectedTitle || outline[0]?.heading || 'Untitled Blog',
sections: sectionPayload,
outline,
research: (research as any) || {},
recommendations,
});
if (!response.success) {
throw new Error(response.error || 'Failed to apply recommendations.');
}
if (!response.sections || !Array.isArray(response.sections)) {
throw new Error('Recommendation response did not include updated sections.');
}
// Update sections - create new object reference to trigger React re-render
const newSections: Record<string, string> = {};
response.sections.forEach((section) => {
if (section.id && section.content) {
newSections[section.id] = section.content;
}
});
// Validate we have sections before updating
if (Object.keys(newSections).length === 0) {
throw new Error('No valid sections received from SEO recommendations application.');
}
// Validate sections have actual content
const sectionsWithContent = Object.values(newSections).filter(c => c && c.trim().length > 0);
if (sectionsWithContent.length === 0) {
throw new Error('SEO recommendations resulted in empty sections. Please try again.');
}
// Log detailed section info for debugging
const sectionIds = Object.keys(newSections);
const sectionSizes = sectionIds.map(id => ({ id, length: newSections[id]?.length || 0 }));
debug.log('[BlogWriter] Applied SEO recommendations: sections updated', {
sectionCount: sectionIds.length,
sectionsWithContent: sectionsWithContent.length,
sectionIds: sectionIds,
sectionSizes: sectionSizes,
totalContentLength: Object.values(newSections).reduce((sum, c) => sum + (c?.length || 0), 0)
});
// Update sections state
setSections(newSections);
// Force a delay to ensure React processes the state update before proceeding
// This gives React time to re-render with new sections before phase navigation checks
await new Promise(resolve => setTimeout(resolve, 200));
setContinuityRefresh(Date.now());
setFlowAnalysisCompleted(false);
setFlowAnalysisResults(null);
if (response.title && response.title !== selectedTitle) {
setSelectedTitle(response.title);
}
if (response.applied) {
setSeoAnalysis(prev => prev ? { ...prev, applied_recommendations: response.applied } : prev);
debug.log('[BlogWriter] SEO analysis state updated with applied recommendations');
}
// Mark recommendations as applied (this will trigger phase navigation check)
// But we'll stay in SEO phase to show updated content
setSeoRecommendationsApplied(true);
debug.log('[BlogWriter] seoRecommendationsApplied set to true');
// Ensure we stay in SEO phase to show updated content
// Force navigation to SEO phase if we're not already there (safeguard)
if (currentPhase !== 'seo') {
navigateToPhase('seo');
debug.log('[BlogWriter] Forced navigation to SEO phase after applying recommendations');
} else {
debug.log('[BlogWriter] Already in SEO phase, staying to show updated content');
}
}, [outline, sections, selectedTitle, research, setSections, setSelectedTitle, setContinuityRefresh, setFlowAnalysisCompleted, setFlowAnalysisResults, setSeoAnalysis, currentPhase, navigateToPhase]);
// Handle SEO analysis completion
const handleSEOAnalysisComplete = useCallback((analysis: any) => {
setSeoAnalysis(analysis);
debug.log('[BlogWriter] SEO analysis completed', { hasAnalysis: !!analysis });
}, [setSeoAnalysis]);
// Handle SEO modal close - mark SEO phase as done if not already marked
const handleSEOModalClose = useCallback(() => {
// Mark SEO phase as done when modal closes (even without applying recommendations)
if (!seoAnalysis) {
// Set a minimal valid seoAnalysis object to mark phase as complete
setSeoAnalysis({
success: true,
overall_score: 0,
category_scores: {},
analysis_summary: {
overall_grade: 'N/A',
status: 'Skipped',
strongest_category: 'N/A',
weakest_category: 'N/A',
key_strengths: [],
key_weaknesses: [],
ai_summary: 'SEO analysis was skipped by user'
},
actionable_recommendations: [],
generated_at: new Date().toISOString()
});
debug.log('[BlogWriter] SEO phase marked as done (modal closed without analysis)');
}
setIsSEOAnalysisModalOpen(false);
debug.log('[BlogWriter] SEO modal closed');
}, [seoAnalysis, setSeoAnalysis, setIsSEOAnalysisModalOpen]);
// Mark SEO phase as completed when recommendations are applied
useEffect(() => {
if (seoRecommendationsApplied && seoAnalysis) {
// SEO phase is considered complete when recommendations are applied
// But stay in SEO phase to show updated content
debug.log('[BlogWriter] SEO recommendations applied, SEO phase marked as complete');
// Ensure we stay in SEO phase to show updated content (override auto-progression)
if (currentPhase !== 'seo' && Object.keys(sections).length > 0) {
navigateToPhase('seo');
debug.log('[BlogWriter] Forced stay in SEO phase to show updated content');
}
}
}, [seoRecommendationsApplied, seoAnalysis, currentPhase, sections, navigateToPhase]);
// Track when outlines/content become available for the first time
const prevOutlineLenRef = useRef<number>(outline.length);
const prevOutlineConfirmedRef = useRef<boolean>(outlineConfirmed);
const prevContentConfirmedRef = useRef<boolean>(contentConfirmed);
useEffect(() => {
const prevLen = prevOutlineLenRef.current;
if (research && prevLen === 0 && outline.length > 0) {
resetUserSelection();
}
prevOutlineLenRef.current = outline.length;
}, [research, outline.length, resetUserSelection]);
// Only reset user selection when transitioning from not-confirmed to confirmed
useEffect(() => {
const wasConfirmed = prevOutlineConfirmedRef.current;
if (!wasConfirmed && outlineConfirmed && Object.keys(sections).length > 0) {
resetUserSelection(); // Allow auto-progression to content phase
}
prevOutlineConfirmedRef.current = outlineConfirmed;
}, [outlineConfirmed, sections, resetUserSelection]);
useEffect(() => {
const wasConfirmed = prevContentConfirmedRef.current;
if (!wasConfirmed && contentConfirmed && seoAnalysis) {
resetUserSelection(); // Allow auto-progression to SEO phase
}
prevContentConfirmedRef.current = contentConfirmed;
}, [contentConfirmed, seoAnalysis, resetUserSelection]);
// Custom hooks for complex functionality // Custom hooks for complex functionality
const { buildFullMarkdown, buildUpdatedMarkdownForClaim, applyClaimFix } = useClaimFixer( const { buildFullMarkdown, buildUpdatedMarkdownForClaim, applyClaimFix } = useClaimFixer(
@@ -324,68 +174,48 @@ export const BlogWriter: React.FC = () => {
sections sections
); );
// Research polling hook (for context awareness) // Polling hooks - extracted to useBlogWriterPolling
const researchPolling = useResearchPolling({ const {
onComplete: handleResearchComplete, researchPolling,
onError: (error) => console.error('Research polling error:', error) outlinePolling,
mediumPolling,
rewritePolling,
researchPollingState,
outlinePollingState,
mediumPollingState,
} = useBlogWriterPolling({
onResearchComplete: handleResearchComplete,
onOutlineComplete: handleOutlineComplete,
onOutlineError: handleOutlineError,
onSectionsUpdate: setSections,
}); });
// Outline polling hook // Modal visibility management - extracted to useModalVisibility
const outlinePolling = useOutlinePolling({ const {
onComplete: handleOutlineComplete, showModal,
onError: handleOutlineError showOutlineModal,
setShowOutlineModal,
isMediumGenerationStarting,
setIsMediumGenerationStarting,
} = useModalVisibility({
mediumPolling,
rewritePolling,
outlinePolling,
}); });
// Medium generation polling (used after confirm if short blog) // CopilotKit suggestions management - extracted to useCopilotSuggestions
const mediumPolling = useMediumGenerationPolling({ const hasContent = React.useMemo(() => Object.keys(sections).length > 0, [sections]);
onComplete: (result: any) => { const {
try { suggestions,
if (result && result.sections) { setSuggestionsRef,
const newSections: Record<string, string> = {}; } = useCopilotSuggestions({
result.sections.forEach((s: any) => {
newSections[String(s.id)] = s.content || '';
});
setSections(newSections);
}
} catch (e) {
console.error('Failed to apply medium generation result:', e);
}
},
onError: (err) => console.error('Medium generation failed:', err)
});
// Rewrite polling hook (used for blog rewrite operations)
const rewritePolling = useRewritePolling({
onComplete: (result: any) => {
try {
if (result && result.sections) {
const newSections: Record<string, string> = {};
result.sections.forEach((s: any) => {
newSections[String(s.id)] = s.content || '';
});
setSections(newSections);
}
} catch (e) {
console.error('Failed to apply rewrite result:', e);
}
},
onError: (err) => console.error('Rewrite failed:', err)
});
// Add minimum display time for modal
const [showModal, setShowModal] = useState(false);
const [modalStartTime, setModalStartTime] = useState<number | null>(null);
const [isMediumGenerationStarting, setIsMediumGenerationStarting] = useState(false);
const [showOutlineModal, setShowOutlineModal] = useState(false);
const suggestions = useSuggestions({
research, research,
outline, outline,
outlineConfirmed, outlineConfirmed,
researchPolling: { isPolling: researchPolling.isPolling, currentStatus: researchPolling.currentStatus }, researchPollingState,
outlinePolling: { isPolling: outlinePolling.isPolling, currentStatus: outlinePolling.currentStatus }, outlinePollingState,
mediumPolling: { isPolling: mediumPolling.isPolling, currentStatus: mediumPolling.currentStatus }, mediumPollingState,
hasContent: Object.keys(sections).length > 0, hasContent,
flowAnalysisCompleted, flowAnalysisCompleted,
contentConfirmed, contentConfirmed,
seoAnalysis, seoAnalysis,
@@ -393,29 +223,17 @@ export const BlogWriter: React.FC = () => {
seoRecommendationsApplied, seoRecommendationsApplied,
}); });
// Drive CopilotKit suggestions programmatically // Refs and tracking logic - extracted to useBlogWriterRefs
const copilotHeadless = (useCopilotChatHeadless_c as any)?.(); useBlogWriterRefs({
const setSuggestionsRef = useRef<any>(null); research,
useEffect(() => { outline,
setSuggestionsRef.current = copilotHeadless?.setSuggestions; outlineConfirmed,
}, [copilotHeadless]); contentConfirmed,
sections,
const suggestionsPayload = React.useMemo( currentPhase,
() => (Array.isArray(suggestions) ? suggestions.map((s: any) => ({ title: s.title, message: s.message })) : []), isSEOAnalysisModalOpen,
[suggestions] resetUserSelection,
); });
const prevSuggestionsRef = useRef<string>("__init__");
const suggestionsJson = React.useMemo(() => JSON.stringify(suggestionsPayload), [suggestionsPayload]);
useEffect(() => {
try {
if (!setSuggestionsRef.current) return;
if (suggestionsJson !== prevSuggestionsRef.current) {
setSuggestionsRef.current(suggestionsPayload);
debug.log('[BlogWriter] Copilot suggestions pushed', { count: suggestionsPayload.length });
prevSuggestionsRef.current = suggestionsJson;
}
} catch {}
}, [suggestionsJson, suggestionsPayload]);
const handlePhaseClick = useCallback((phaseId: string) => { const handlePhaseClick = useCallback((phaseId: string) => {
navigateToPhase(phaseId); navigateToPhase(phaseId);
@@ -427,40 +245,50 @@ export const BlogWriter: React.FC = () => {
runSEOAnalysisDirect(); runSEOAnalysisDirect();
} }
} }
}, [navigateToPhase, seoAnalysis, runSEOAnalysisDirect]); }, [navigateToPhase, seoAnalysis, runSEOAnalysisDirect, setIsSEOAnalysisModalOpen]);
const outlineGenRef = useRef<any>(null); const outlineGenRef = useRef<any>(null);
useEffect(() => { // Callback to handle cached outline completion
if ((mediumPolling.isPolling || rewritePolling.isPolling || isMediumGenerationStarting) && !showModal) { const handleCachedOutlineComplete = useCallback((result: { outline: any[], title_options?: string[] }) => {
setShowModal(true); if (result.outline && Array.isArray(result.outline)) {
setModalStartTime(Date.now()); handleOutlineComplete(result);
} else if (!mediumPolling.isPolling && !rewritePolling.isPolling && !isMediumGenerationStarting && showModal) {
const elapsed = Date.now() - (modalStartTime || 0);
const minDisplayTime = 2000; // 2 seconds minimum
if (elapsed < minDisplayTime) {
setTimeout(() => {
setShowModal(false);
setModalStartTime(null);
}, minDisplayTime - elapsed);
} else {
setShowModal(false);
setModalStartTime(null);
}
} }
}, [mediumPolling.isPolling, rewritePolling.isPolling, isMediumGenerationStarting, showModal, modalStartTime]); }, [handleOutlineComplete]);
// Handle outline modal visibility // Callback to handle cached content completion
useEffect(() => { const handleCachedContentComplete = useCallback((cachedSections: Record<string, string>) => {
if (outlinePolling.isPolling && !showOutlineModal) { if (cachedSections && Object.keys(cachedSections).length > 0) {
setShowOutlineModal(true); setSections(cachedSections);
} else if (!outlinePolling.isPolling && showOutlineModal) { debug.log('[BlogWriter] Cached content loaded into state', { sections: Object.keys(cachedSections).length });
// Add a small delay to ensure user sees completion message
setTimeout(() => {
setShowOutlineModal(false);
}, 1000);
} }
}, [outlinePolling.isPolling, showOutlineModal]); }, [setSections]);
// Phase action handlers for when CopilotKit is unavailable - extracted to usePhaseActionHandlers
const {
handleResearchAction,
handleOutlineAction,
handleContentAction,
handleSEOAction,
handlePublishAction,
} = usePhaseActionHandlers({
research,
outline,
selectedTitle,
contentConfirmed,
sections,
navigateToPhase,
handleOutlineConfirmed,
setIsMediumGenerationStarting,
mediumPolling,
outlineGenRef,
setOutline,
setContentConfirmed,
setIsSEOMetadataModalOpen,
runSEOAnalysisDirect,
onOutlineComplete: handleCachedOutlineComplete,
onContentComplete: handleCachedContentComplete,
});
// Handle medium generation start from OutlineFeedbackForm // Handle medium generation start from OutlineFeedbackForm
const handleMediumGenerationStarted = (taskId: string) => { const handleMediumGenerationStarted = (taskId: string) => {
@@ -475,77 +303,11 @@ export const BlogWriter: React.FC = () => {
setIsMediumGenerationStarting(true); setIsMediumGenerationStarting(true);
}; };
// Debug medium polling state
console.log('Medium polling state:', {
isPolling: mediumPolling.isPolling,
status: mediumPolling.currentStatus,
progressCount: mediumPolling.progressMessages.length
});
// Log critical state changes only (reduce noise)
const lastPhaseRef = useRef<string>('');
const lastSeoOpenRef = useRef<boolean>(false);
const lastSectionsLenRef = useRef<number>(0);
useEffect(() => {
if (currentPhase !== lastPhaseRef.current) {
debug.log('[BlogWriter] Phase changed', { currentPhase });
lastPhaseRef.current = currentPhase;
}
}, [currentPhase]);
useEffect(() => {
const open = isSEOAnalysisModalOpen;
if (open !== lastSeoOpenRef.current) {
debug.log('[BlogWriter] SEO modal', { isOpen: open });
lastSeoOpenRef.current = open;
}
}, [isSEOAnalysisModalOpen]);
useEffect(() => {
const len = Object.keys(sections || {}).length;
if (len !== lastSectionsLenRef.current) {
debug.log('[BlogWriter] Sections updated', { count: len });
lastSectionsLenRef.current = len;
}
}, [sections]);
useEffect(() => {
debug.log('[BlogWriter] Suggestions updated', { suggestions });
}, [suggestions]);
// Force-sync Copilot suggestions right after SEO recommendations applied (guarded by previous suggestions key)
useEffect(() => {
if (!seoAnalysis || !seoRecommendationsApplied || !setSuggestionsRef.current) return;
try {
if (suggestionsJson !== prevSuggestionsRef.current) {
setSuggestionsRef.current(suggestionsPayload);
debug.log('[BlogWriter] Forced Copilot suggestions sync after SEO recommendations applied', { count: suggestionsPayload.length });
prevSuggestionsRef.current = suggestionsJson;
}
} catch (e) {
console.error('Failed to push Copilot suggestions after SEO apply:', e);
}
}, [seoAnalysis, seoRecommendationsApplied, suggestionsJson, suggestionsPayload]);
const confirmBlogContentCb = useCallback(() => {
debug.log('[BlogWriter] Blog content confirmed by user');
setContentConfirmed(true);
resetUserSelection();
setSeoRecommendationsApplied(false);
navigateToPhase('seo');
setTimeout(() => {
setIsSEOAnalysisModalOpen(true);
debug.log('[BlogWriter] SEO modal opened (confirm→direct)');
}, 0);
return "✅ Blog content has been confirmed! Running SEO analysis now.";
}, [setContentConfirmed, resetUserSelection, navigateToPhase, setIsSEOAnalysisModalOpen]);
useBlogWriterCopilotActions({ useBlogWriterCopilotActions({
isSEOAnalysisModalOpen, isSEOAnalysisModalOpen,
lastSEOModalOpenRef, lastSEOModalOpenRef,
runSEOAnalysisDirect, runSEOAnalysisDirect,
confirmBlogContent: confirmBlogContentCb, confirmBlogContent,
sections, sections,
research, research,
openSEOMetadata: () => setIsSEOMetadataModalOpen(true), openSEOMetadata: () => setIsSEOMetadataModalOpen(true),
@@ -557,47 +319,22 @@ export const BlogWriter: React.FC = () => {
return ( return (
<div style={{ height: '100vh', display: 'flex', flexDirection: 'column' }}> <div style={{ height: '100vh', display: 'flex', flexDirection: 'column' }}>
{/* Extracted Components */} {/* CopilotKit-dependent components - extracted to CopilotKitComponents */}
<KeywordInputForm {copilotKitAvailable && (
onResearchComplete={handleResearchComplete} <CopilotKitComponents
onTaskStart={(taskId) => researchPolling.startPolling(taskId)} research={research}
/>
<CustomOutlineForm onOutlineCreated={setOutline} />
<ResearchAction onResearchComplete={handleResearchComplete} />
<ResearchDataActions
research={research}
onOutlineCreated={setOutline}
onTitleOptionsSet={setTitleOptions}
/>
<EnhancedOutlineActions
outline={outline}
onOutlineUpdated={setOutline}
/>
<OutlineFeedbackForm
outline={outline}
research={research!}
onOutlineConfirmed={handleOutlineConfirmed}
onOutlineRefined={handleOutlineRefined}
onMediumGenerationStarted={handleMediumGenerationStarted}
onMediumGenerationTriggered={handleMediumGenerationTriggered}
sections={sections}
blogTitle={selectedTitle}
onFlowAnalysisComplete={(analysis) => {
console.log('Flow analysis completed:', analysis);
setFlowAnalysisCompleted(true);
setFlowAnalysisResults(analysis);
// Trigger a refresh of continuity badges
setContinuityRefresh((prev: number) => (prev || 0) + 1);
}}
/>
{/* Rewrite Feedback Form - Only show when content exists */}
{Object.keys(sections).length > 0 && (
<RewriteFeedbackForm
research={research!}
outline={outline} outline={outline}
outlineConfirmed={outlineConfirmed}
sections={sections} sections={sections}
blogTitle={selectedTitle} selectedTitle={selectedTitle}
onResearchComplete={handleResearchComplete}
onOutlineCreated={setOutline}
onOutlineUpdated={setOutline}
onTitleOptionsSet={setTitleOptions}
onOutlineConfirmed={handleOutlineConfirmed}
onOutlineRefined={(feedback?: string) => handleOutlineRefined(feedback || '')}
onMediumGenerationStarted={handleMediumGenerationStarted}
onMediumGenerationTriggered={handleMediumGenerationTriggered}
onRewriteStarted={(taskId) => { onRewriteStarted={(taskId) => {
console.log('Starting rewrite polling for task:', taskId); console.log('Starting rewrite polling for task:', taskId);
rewritePolling.startPolling(taskId); rewritePolling.startPolling(taskId);
@@ -606,6 +343,10 @@ export const BlogWriter: React.FC = () => {
console.log('Rewrite triggered - showing modal immediately'); console.log('Rewrite triggered - showing modal immediately');
setIsMediumGenerationStarting(true); setIsMediumGenerationStarting(true);
}} }}
setFlowAnalysisCompleted={setFlowAnalysisCompleted}
setFlowAnalysisResults={setFlowAnalysisResults}
setContinuityRefresh={setContinuityRefresh}
researchPolling={researchPolling}
/> />
)} )}
@@ -638,19 +379,41 @@ export const BlogWriter: React.FC = () => {
seoMetadata={seoMetadata} seoMetadata={seoMetadata}
/> />
{!research ? ( {/* Always show HeaderBar when CopilotKit is unavailable, or when research exists */}
<BlogWriterLanding {(!copilotKitAvailable || research) && (
onStartWriting={() => { <HeaderBar
// Trigger the copilot to start the research process phases={phases}
currentPhase={currentPhase}
onPhaseClick={handlePhaseClick}
copilotKitAvailable={copilotKitAvailable}
actionHandlers={{
onResearchAction: handleResearchAction,
onOutlineAction: handleOutlineAction,
onContentAction: handleContentAction,
onSEOAction: handleSEOAction,
onPublishAction: handlePublishAction,
}} }}
hasResearch={!!research}
hasOutline={outline.length > 0}
outlineConfirmed={outlineConfirmed}
hasContent={Object.keys(sections).length > 0}
contentConfirmed={contentConfirmed}
hasSEOAnalysis={!!seoAnalysis}
hasSEOMetadata={!!seoMetadata}
/> />
) : ( )}
<>
<HeaderBar {/* Landing section - extracted to BlogWriterLandingSection */}
phases={phases} <BlogWriterLandingSection
research={research}
copilotKitAvailable={copilotKitAvailable}
currentPhase={currentPhase} currentPhase={currentPhase}
onPhaseClick={handlePhaseClick} navigateToPhase={navigateToPhase}
onResearchComplete={handleResearchComplete}
/> />
{research && (
<>
<PhaseContent <PhaseContent
currentPhase={currentPhase} currentPhase={currentPhase}
research={research} research={research}
@@ -679,6 +442,14 @@ export const BlogWriter: React.FC = () => {
seoMetadata={seoMetadata} seoMetadata={seoMetadata}
onTitleSelect={handleTitleSelect} onTitleSelect={handleTitleSelect}
onCustomTitle={handleCustomTitle} onCustomTitle={handleCustomTitle}
copilotKitAvailable={copilotKitAvailable}
onResearchComplete={handleResearchComplete}
onOutlineGenerationStart={(taskId) => {
setOutlineTaskId(taskId);
outlinePolling.startPolling(taskId);
setShowOutlineModal(true);
}}
onContentGenerationStart={handleMediumGenerationStarted}
/> />
</> </>
)} )}

View File

@@ -0,0 +1,46 @@
import React from 'react';
import BlogWriterLanding from '../BlogWriterLanding';
import ManualResearchForm from '../ManualResearchForm';
interface BlogWriterLandingSectionProps {
research: any;
copilotKitAvailable: boolean;
currentPhase: string;
navigateToPhase: (phase: string) => void;
onResearchComplete: (research: any) => void;
}
export const BlogWriterLandingSection: React.FC<BlogWriterLandingSectionProps> = ({
research,
copilotKitAvailable,
currentPhase,
navigateToPhase,
onResearchComplete,
}) => {
if (!research) {
return (
<>
{!copilotKitAvailable && currentPhase === 'research' && (
<ManualResearchForm onResearchComplete={onResearchComplete} />
)}
{copilotKitAvailable && (
<BlogWriterLanding
onStartWriting={() => {
// Trigger the copilot to start the research process
}}
/>
)}
{!copilotKitAvailable && currentPhase !== 'research' && (
<BlogWriterLanding
onStartWriting={() => {
// Navigate to research phase when CopilotKit unavailable
navigateToPhase('research');
}}
/>
)}
</>
);
}
return null;
};

View File

@@ -0,0 +1,103 @@
import React from 'react';
import KeywordInputForm from '../KeywordInputForm';
import ResearchAction from '../ResearchAction';
import { CustomOutlineForm } from '../CustomOutlineForm';
import { ResearchDataActions } from '../ResearchDataActions';
import { EnhancedOutlineActions } from '../EnhancedOutlineActions';
import OutlineFeedbackForm from '../OutlineFeedbackForm';
import { RewriteFeedbackForm } from '../RewriteFeedbackForm';
interface CopilotKitComponentsProps {
research: any;
outline: any[];
outlineConfirmed: boolean;
sections: Record<string, string>;
selectedTitle: string | null;
onResearchComplete: (research: any) => void;
onOutlineCreated: (outline: any[]) => void;
onOutlineUpdated: (outline: any[]) => void;
onTitleOptionsSet: (titles: any[]) => void;
onOutlineConfirmed: () => void;
onOutlineRefined: (feedback?: string) => void;
onMediumGenerationStarted: (taskId: string) => void;
onMediumGenerationTriggered: () => void;
onRewriteStarted: (taskId: string) => void;
onRewriteTriggered: () => void;
setFlowAnalysisCompleted: (completed: boolean) => void;
setFlowAnalysisResults: (results: any) => void;
setContinuityRefresh: (refresh: number | ((prev: number) => number)) => void;
researchPolling: any;
}
export const CopilotKitComponents: React.FC<CopilotKitComponentsProps> = ({
research,
outline,
outlineConfirmed,
sections,
selectedTitle,
onResearchComplete,
onOutlineCreated,
onOutlineUpdated,
onTitleOptionsSet,
onOutlineConfirmed,
onOutlineRefined,
onMediumGenerationStarted,
onMediumGenerationTriggered,
onRewriteStarted,
onRewriteTriggered,
setFlowAnalysisCompleted,
setFlowAnalysisResults,
setContinuityRefresh,
researchPolling,
}) => {
return (
<>
<KeywordInputForm
onResearchComplete={onResearchComplete}
onTaskStart={(taskId) => researchPolling.startPolling(taskId)}
/>
<CustomOutlineForm onOutlineCreated={onOutlineCreated} />
<ResearchAction onResearchComplete={onResearchComplete} />
<ResearchDataActions
research={research}
onOutlineCreated={onOutlineCreated}
onTitleOptionsSet={onTitleOptionsSet}
/>
<EnhancedOutlineActions
outline={outline}
onOutlineUpdated={onOutlineUpdated}
/>
<OutlineFeedbackForm
outline={outline}
research={research!}
onOutlineConfirmed={onOutlineConfirmed}
onOutlineRefined={onOutlineRefined}
onMediumGenerationStarted={onMediumGenerationStarted}
onMediumGenerationTriggered={onMediumGenerationTriggered}
sections={sections}
blogTitle={selectedTitle ?? undefined}
onFlowAnalysisComplete={(analysis) => {
console.log('Flow analysis completed:', analysis);
setFlowAnalysisCompleted(true);
setFlowAnalysisResults(analysis);
// Trigger a refresh of continuity badges
setContinuityRefresh((prev: number) => (prev || 0) + 1);
}}
/>
{/* Rewrite Feedback Form - Only show when content exists */}
{Object.keys(sections).length > 0 && (
<RewriteFeedbackForm
research={research!}
outline={outline}
sections={sections}
blogTitle={selectedTitle || 'Untitled'}
onRewriteStarted={onRewriteStarted}
onRewriteTriggered={onRewriteTriggered}
/>
)}
</>
);
};

View File

@@ -1,13 +1,35 @@
import React from 'react'; import React from 'react';
import PhaseNavigation from '../PhaseNavigation'; import PhaseNavigation, { PhaseActionHandlers } from '../PhaseNavigation';
interface HeaderBarProps { interface HeaderBarProps {
phases: any[]; phases: any[];
currentPhase: string; currentPhase: string;
onPhaseClick: (phaseId: string) => void; onPhaseClick: (phaseId: string) => void;
copilotKitAvailable?: boolean;
actionHandlers?: PhaseActionHandlers;
hasResearch?: boolean;
hasOutline?: boolean;
outlineConfirmed?: boolean;
hasContent?: boolean;
contentConfirmed?: boolean;
hasSEOAnalysis?: boolean;
hasSEOMetadata?: boolean;
} }
export const HeaderBar: React.FC<HeaderBarProps> = ({ phases, currentPhase, onPhaseClick }) => { export const HeaderBar: React.FC<HeaderBarProps> = ({
phases,
currentPhase,
onPhaseClick,
copilotKitAvailable = true,
actionHandlers,
hasResearch = false,
hasOutline = false,
outlineConfirmed = false,
hasContent = false,
contentConfirmed = false,
hasSEOAnalysis = false,
hasSEOMetadata = false,
}) => {
return ( return (
<div style={{ padding: 16, borderBottom: '1px solid #eee' }}> <div style={{ padding: 16, borderBottom: '1px solid #eee' }}>
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 12 }}> <div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 12 }}>
@@ -31,6 +53,15 @@ export const HeaderBar: React.FC<HeaderBarProps> = ({ phases, currentPhase, onPh
phases={phases} phases={phases}
currentPhase={currentPhase} currentPhase={currentPhase}
onPhaseClick={onPhaseClick} onPhaseClick={onPhaseClick}
copilotKitAvailable={copilotKitAvailable}
actionHandlers={actionHandlers}
hasResearch={hasResearch}
hasOutline={hasOutline}
outlineConfirmed={outlineConfirmed}
hasContent={hasContent}
contentConfirmed={contentConfirmed}
hasSEOAnalysis={hasSEOAnalysis}
hasSEOMetadata={hasSEOMetadata}
/> />
</div> </div>
); );

View File

@@ -4,6 +4,9 @@ import EnhancedTitleSelector from '../EnhancedTitleSelector';
import EnhancedOutlineEditor from '../EnhancedOutlineEditor'; import EnhancedOutlineEditor from '../EnhancedOutlineEditor';
import { BlogEditor } from '../WYSIWYG'; import { BlogEditor } from '../WYSIWYG';
import OutlineCtaBanner from './OutlineCtaBanner'; import OutlineCtaBanner from './OutlineCtaBanner';
import ManualResearchForm from '../ManualResearchForm';
import ManualOutlineButton from '../ManualOutlineButton';
import ManualContentButton from '../ManualContentButton';
interface PhaseContentProps { interface PhaseContentProps {
currentPhase: string; currentPhase: string;
@@ -33,6 +36,10 @@ interface PhaseContentProps {
onCustomTitle: any; onCustomTitle: any;
sectionImages?: Record<string, string>; sectionImages?: Record<string, string>;
setSectionImages?: (images: Record<string, string> | ((prev: Record<string, string>) => Record<string, string>)) => void; setSectionImages?: (images: Record<string, string> | ((prev: Record<string, string>) => Record<string, string>)) => void;
copilotKitAvailable?: boolean; // Whether CopilotKit is available
onResearchComplete?: (research: any) => void; // Callback when research completes (for manual form)
onOutlineGenerationStart?: (taskId: string) => void; // Callback when outline generation starts
onContentGenerationStart?: (taskId: string) => void; // Callback when content generation starts
} }
export const PhaseContent: React.FC<PhaseContentProps> = ({ export const PhaseContent: React.FC<PhaseContentProps> = ({
@@ -62,7 +69,11 @@ export const PhaseContent: React.FC<PhaseContentProps> = ({
onTitleSelect, onTitleSelect,
onCustomTitle, onCustomTitle,
sectionImages, sectionImages,
setSectionImages setSectionImages,
copilotKitAvailable = true,
onResearchComplete,
onOutlineGenerationStart,
onContentGenerationStart,
}) => { }) => {
return ( return (
<div style={{ display: 'flex', flex: 1, overflow: 'hidden' }}> <div style={{ display: 'flex', flex: 1, overflow: 'hidden' }}>
@@ -72,10 +83,16 @@ export const PhaseContent: React.FC<PhaseContentProps> = ({
{research ? ( {research ? (
<ResearchResults research={research} /> <ResearchResults research={research} />
) : ( ) : (
<div style={{ padding: '20px', textAlign: 'center' }}> <>
<h3>Start Your Research</h3> {copilotKitAvailable ? (
<p>Use the copilot to begin researching your blog topic.</p> <div style={{ padding: '20px', textAlign: 'center' }}>
</div> <h3>Start Your Research</h3>
<p>Use the copilot to begin researching your blog topic.</p>
</div>
) : (
<ManualResearchForm onResearchComplete={onResearchComplete} />
)}
</>
)} )}
</> </>
)} )}
@@ -83,7 +100,17 @@ export const PhaseContent: React.FC<PhaseContentProps> = ({
{currentPhase === 'outline' && research && ( {currentPhase === 'outline' && research && (
<> <>
{outline.length === 0 && ( {outline.length === 0 && (
<OutlineCtaBanner onGenerate={() => outlineGenRef.current?.generateNow()} /> <>
{copilotKitAvailable ? (
<OutlineCtaBanner onGenerate={() => outlineGenRef.current?.generateNow()} />
) : (
<ManualOutlineButton
outlineGenRef={outlineGenRef}
hasResearch={!!research}
onGenerationStart={onOutlineGenerationStart}
/>
)}
</>
)} )}
{outline.length > 0 ? ( {outline.length > 0 ? (
<> <>
@@ -108,6 +135,12 @@ export const PhaseContent: React.FC<PhaseContentProps> = ({
setSectionImages={setSectionImages} setSectionImages={setSectionImages}
/> />
</> </>
) : !copilotKitAvailable ? (
<ManualOutlineButton
outlineGenRef={outlineGenRef}
hasResearch={!!research}
onGenerationStart={onOutlineGenerationStart}
/>
) : ( ) : (
<div style={{ padding: '20px', textAlign: 'center' }}> <div style={{ padding: '20px', textAlign: 'center' }}>
<h3>Create Your Outline</h3> <h3>Create Your Outline</h3>
@@ -135,10 +168,22 @@ export const PhaseContent: React.FC<PhaseContentProps> = ({
sectionImages={sectionImages} sectionImages={sectionImages}
/> />
) : ( ) : (
<div style={{ padding: '20px', textAlign: 'center' }}> <>
<h3>Confirm Your Outline</h3> {copilotKitAvailable ? (
<p>Review and confirm your outline before generating content.</p> <div style={{ padding: '20px', textAlign: 'center' }}>
</div> <h3>Confirm Your Outline</h3>
<p>Review and confirm your outline before generating content.</p>
</div>
) : (
<ManualContentButton
outline={outline}
research={research}
blogTitle={selectedTitle || undefined}
sections={sections}
onGenerationStart={onContentGenerationStart}
/>
)}
</>
)} )}
</> </>
)} )}

View File

@@ -81,6 +81,16 @@ export const WixConnectModal: React.FC<WixConnectModalProps> = ({
try { try {
setIsConnecting(true); setIsConnecting(true);
setError(null); setError(null);
// Store current page URL so we can redirect back after OAuth completes
// This MUST be stored before calling handleConnect to ensure it's available after redirect
// We ALWAYS override any existing redirect URL since we know the exact page we're on (Blog Writer)
const currentUrl = window.location.href;
try {
sessionStorage.setItem('wix_oauth_redirect', currentUrl);
console.log('[WixConnectModal] Stored redirect URL (overriding any existing):', currentUrl);
} catch (e) {
console.warn('[WixConnectModal] Failed to store redirect URL:', e);
}
await handleConnect('wix'); await handleConnect('wix');
// OAuth will redirect, so we don't need to do anything else here // OAuth will redirect, so we don't need to do anything else here
// The postMessage handler or URL param handler will close the modal // The postMessage handler or URL param handler will close the modal

View File

@@ -0,0 +1,101 @@
import React from 'react';
import {
useResearchPolling,
useOutlinePolling,
useMediumGenerationPolling,
useRewritePolling,
} from '../../../hooks/usePolling';
import { blogWriterCache } from '../../../services/blogWriterCache';
interface UseBlogWriterPollingProps {
onResearchComplete: (research: any) => void;
onOutlineComplete: (outline: any) => void;
onOutlineError: (error: any) => void;
onSectionsUpdate: (sections: Record<string, string>) => void;
}
export const useBlogWriterPolling = ({
onResearchComplete,
onOutlineComplete,
onOutlineError,
onSectionsUpdate,
}: UseBlogWriterPollingProps) => {
// Research polling hook (for context awareness)
const researchPolling = useResearchPolling({
onComplete: onResearchComplete,
onError: (error) => console.error('Research polling error:', error)
});
// Outline polling hook
const outlinePolling = useOutlinePolling({
onComplete: onOutlineComplete,
onError: onOutlineError
});
// Medium generation polling (used after confirm if short blog)
const mediumPolling = useMediumGenerationPolling({
onComplete: (result: any) => {
try {
if (result && result.sections) {
const newSections: Record<string, string> = {};
result.sections.forEach((s: any) => {
newSections[String(s.id)] = s.content || '';
});
onSectionsUpdate(newSections);
// Cache the generated content (shared utility)
if (Object.keys(newSections).length > 0) {
const sectionIds = Object.keys(newSections);
blogWriterCache.cacheContent(newSections, sectionIds);
}
}
} catch (e) {
console.error('Failed to apply medium generation result:', e);
}
},
onError: (err) => console.error('Medium generation failed:', err)
});
// Rewrite polling hook (used for blog rewrite operations)
const rewritePolling = useRewritePolling({
onComplete: (result: any) => {
try {
if (result && result.sections) {
const newSections: Record<string, string> = {};
result.sections.forEach((s: any) => {
newSections[String(s.id)] = s.content || '';
});
onSectionsUpdate(newSections);
}
} catch (e) {
console.error('Failed to apply rewrite result:', e);
}
},
onError: (err) => console.error('Rewrite failed:', err)
});
// Memoize polling state objects to prevent unnecessary recalculations
const researchPollingState = React.useMemo(
() => ({ isPolling: researchPolling.isPolling, currentStatus: researchPolling.currentStatus }),
[researchPolling.isPolling, researchPolling.currentStatus]
);
const outlinePollingState = React.useMemo(
() => ({ isPolling: outlinePolling.isPolling, currentStatus: outlinePolling.currentStatus }),
[outlinePolling.isPolling, outlinePolling.currentStatus]
);
const mediumPollingState = React.useMemo(
() => ({ isPolling: mediumPolling.isPolling, currentStatus: mediumPolling.currentStatus }),
[mediumPolling.isPolling, mediumPolling.currentStatus]
);
return {
researchPolling,
outlinePolling,
mediumPolling,
rewritePolling,
researchPollingState,
outlinePollingState,
mediumPollingState,
};
};

View File

@@ -0,0 +1,83 @@
import { useRef, useEffect } from 'react';
import { debug } from '../../../utils/debug';
interface UseBlogWriterRefsProps {
research: any;
outline: any[];
outlineConfirmed: boolean;
contentConfirmed: boolean;
sections: Record<string, string>;
currentPhase: string;
isSEOAnalysisModalOpen: boolean;
resetUserSelection: () => void;
}
export const useBlogWriterRefs = ({
research,
outline,
outlineConfirmed,
contentConfirmed,
sections,
currentPhase,
isSEOAnalysisModalOpen,
resetUserSelection,
}: UseBlogWriterRefsProps) => {
// Track when outlines/content become available for the first time
const prevOutlineLenRef = useRef<number>(outline.length);
const prevOutlineConfirmedRef = useRef<boolean>(outlineConfirmed);
const prevContentConfirmedRef = useRef<boolean>(contentConfirmed);
useEffect(() => {
const prevLen = prevOutlineLenRef.current;
if (research && prevLen === 0 && outline.length > 0) {
resetUserSelection();
}
prevOutlineLenRef.current = outline.length;
}, [research, outline.length, resetUserSelection]);
// Only reset user selection when transitioning from not-confirmed to confirmed
useEffect(() => {
const wasConfirmed = prevOutlineConfirmedRef.current;
if (!wasConfirmed && outlineConfirmed && Object.keys(sections).length > 0) {
resetUserSelection(); // Allow auto-progression to content phase
}
prevOutlineConfirmedRef.current = outlineConfirmed;
}, [outlineConfirmed, sections, resetUserSelection]);
useEffect(() => {
const wasConfirmed = prevContentConfirmedRef.current;
if (!wasConfirmed && contentConfirmed) {
resetUserSelection(); // Allow auto-progression to SEO phase
}
prevContentConfirmedRef.current = contentConfirmed;
}, [contentConfirmed, resetUserSelection]);
// Log critical state changes only (reduce noise)
const lastPhaseRef = useRef<string>('');
const lastSeoOpenRef = useRef<boolean>(false);
const lastSectionsLenRef = useRef<number>(0);
useEffect(() => {
if (currentPhase !== lastPhaseRef.current) {
debug.log('[BlogWriter] Phase changed', { currentPhase });
lastPhaseRef.current = currentPhase;
}
}, [currentPhase]);
useEffect(() => {
const open = isSEOAnalysisModalOpen;
if (open !== lastSeoOpenRef.current) {
debug.log('[BlogWriter] SEO modal', { isOpen: open });
lastSeoOpenRef.current = open;
}
}, [isSEOAnalysisModalOpen]);
useEffect(() => {
const len = Object.keys(sections || {}).length;
if (len !== lastSectionsLenRef.current) {
debug.log('[BlogWriter] Sections updated', { count: len });
lastSectionsLenRef.current = len;
}
}, [sections]);
};

View File

@@ -0,0 +1,94 @@
import React, { useRef, useEffect, useMemo } from 'react';
import { useCopilotChatHeadless_c } from '@copilotkit/react-core';
import { debug } from '../../../utils/debug';
import { useSuggestions } from '../SuggestionsGenerator';
interface UseCopilotSuggestionsProps {
research: any;
outline: any[];
outlineConfirmed: boolean;
researchPollingState: { isPolling: boolean; currentStatus: any };
outlinePollingState: { isPolling: boolean; currentStatus: any };
mediumPollingState: { isPolling: boolean; currentStatus: any };
hasContent: boolean;
flowAnalysisCompleted: boolean;
contentConfirmed: boolean;
seoAnalysis: any;
seoMetadata: any;
seoRecommendationsApplied: boolean;
}
export const useCopilotSuggestions = ({
research,
outline,
outlineConfirmed,
researchPollingState,
outlinePollingState,
mediumPollingState,
hasContent,
flowAnalysisCompleted,
contentConfirmed,
seoAnalysis,
seoMetadata,
seoRecommendationsApplied,
}: UseCopilotSuggestionsProps) => {
const suggestions = useSuggestions({
research,
outline,
outlineConfirmed,
researchPolling: researchPollingState,
outlinePolling: outlinePollingState,
mediumPolling: mediumPollingState,
hasContent,
flowAnalysisCompleted,
contentConfirmed,
seoAnalysis,
seoMetadata,
seoRecommendationsApplied,
});
// Drive CopilotKit suggestions programmatically
const copilotHeadless = (useCopilotChatHeadless_c as any)?.();
const setSuggestionsRef = useRef<any>(null);
useEffect(() => {
setSuggestionsRef.current = copilotHeadless?.setSuggestions;
}, [copilotHeadless]);
const suggestionsPayload = useMemo(
() => (Array.isArray(suggestions) ? suggestions.map((s: any) => ({ title: s.title, message: s.message })) : []),
[suggestions]
);
const prevSuggestionsRef = useRef<string>("__init__");
const suggestionsJson = useMemo(() => JSON.stringify(suggestionsPayload), [suggestionsPayload]);
useEffect(() => {
try {
if (!setSuggestionsRef.current) return;
if (suggestionsJson !== prevSuggestionsRef.current) {
setSuggestionsRef.current(suggestionsPayload);
debug.log('[BlogWriter] Copilot suggestions pushed', { count: suggestionsPayload.length });
prevSuggestionsRef.current = suggestionsJson;
}
} catch {}
}, [suggestionsJson, suggestionsPayload]);
// Force-sync Copilot suggestions right after SEO recommendations applied
useEffect(() => {
if (!seoAnalysis || !seoRecommendationsApplied || !setSuggestionsRef.current) return;
try {
if (suggestionsJson !== prevSuggestionsRef.current) {
setSuggestionsRef.current(suggestionsPayload);
debug.log('[BlogWriter] Forced Copilot suggestions sync after SEO recommendations applied', { count: suggestionsPayload.length });
prevSuggestionsRef.current = suggestionsJson;
}
} catch (e) {
console.error('Failed to push Copilot suggestions after SEO apply:', e);
}
}, [seoAnalysis, seoRecommendationsApplied, suggestionsJson, suggestionsPayload]);
return {
suggestions,
setSuggestionsRef,
};
};

View File

@@ -0,0 +1,61 @@
import { useState, useEffect } from 'react';
interface UseModalVisibilityProps {
mediumPolling: { isPolling: boolean };
rewritePolling: { isPolling: boolean };
outlinePolling: { isPolling: boolean };
}
export const useModalVisibility = ({
mediumPolling,
rewritePolling,
outlinePolling,
}: UseModalVisibilityProps) => {
const [showModal, setShowModal] = useState(false);
const [modalStartTime, setModalStartTime] = useState<number | null>(null);
const [isMediumGenerationStarting, setIsMediumGenerationStarting] = useState(false);
const [showOutlineModal, setShowOutlineModal] = useState(false);
// Add minimum display time for modal
useEffect(() => {
if ((mediumPolling.isPolling || rewritePolling.isPolling || isMediumGenerationStarting) && !showModal) {
setShowModal(true);
setModalStartTime(Date.now());
} else if (!mediumPolling.isPolling && !rewritePolling.isPolling && !isMediumGenerationStarting && showModal) {
const elapsed = Date.now() - (modalStartTime || 0);
const minDisplayTime = 2000; // 2 seconds minimum
if (elapsed < minDisplayTime) {
setTimeout(() => {
setShowModal(false);
setModalStartTime(null);
}, minDisplayTime - elapsed);
} else {
setShowModal(false);
setModalStartTime(null);
}
}
}, [mediumPolling.isPolling, rewritePolling.isPolling, isMediumGenerationStarting, showModal, modalStartTime]);
// Handle outline modal visibility
useEffect(() => {
if (outlinePolling.isPolling && !showOutlineModal) {
setShowOutlineModal(true);
} else if (!outlinePolling.isPolling && showOutlineModal) {
// Add a small delay to ensure user sees completion message
setTimeout(() => {
setShowOutlineModal(false);
}, 1000);
}
}, [outlinePolling.isPolling, showOutlineModal]);
return {
showModal,
setShowModal,
showOutlineModal,
setShowOutlineModal,
isMediumGenerationStarting,
setIsMediumGenerationStarting,
};
};

View File

@@ -0,0 +1,182 @@
import { useCallback } from 'react';
import { debug } from '../../../utils/debug';
import { mediumBlogApi } from '../../../services/blogWriterApi';
import { researchCache } from '../../../services/researchCache';
import { blogWriterCache } from '../../../services/blogWriterCache';
interface UsePhaseActionHandlersProps {
research: any;
outline: any[];
selectedTitle: string | null;
contentConfirmed: boolean;
sections: Record<string, string>;
navigateToPhase: (phase: string) => void;
handleOutlineConfirmed: () => void;
setIsMediumGenerationStarting: (starting: boolean) => void;
mediumPolling: any;
outlineGenRef: React.RefObject<any>;
setOutline: (outline: any[]) => void;
setContentConfirmed: (confirmed: boolean) => void;
setIsSEOMetadataModalOpen: (open: boolean) => void;
runSEOAnalysisDirect: () => string;
onOutlineComplete?: (outline: any) => void;
onContentComplete?: (sections: Record<string, string>) => void;
}
export const usePhaseActionHandlers = ({
research,
outline,
selectedTitle,
contentConfirmed,
sections,
navigateToPhase,
handleOutlineConfirmed,
setIsMediumGenerationStarting,
mediumPolling,
outlineGenRef,
setOutline,
setContentConfirmed,
setIsSEOMetadataModalOpen,
runSEOAnalysisDirect,
onOutlineComplete,
onContentComplete,
}: UsePhaseActionHandlersProps) => {
const handleResearchAction = useCallback(() => {
navigateToPhase('research');
debug.log('[BlogWriter] Research action triggered - navigating to research phase');
// Note: Research caching is handled by ManualResearchForm component
}, [navigateToPhase]);
const handleOutlineAction = useCallback(async () => {
if (!research) {
alert('Please complete research first before generating an outline.');
return;
}
// Check cache first (shared utility)
const researchKeywords = research.original_keywords || research.keyword_analysis?.primary || [];
const cachedOutline = blogWriterCache.getCachedOutline(researchKeywords);
if (cachedOutline) {
debug.log('[BlogWriter] Using cached outline from localStorage', { sections: cachedOutline.outline.length });
setOutline(cachedOutline.outline);
if (onOutlineComplete) {
onOutlineComplete({ outline: cachedOutline.outline, title_options: cachedOutline.title_options });
}
navigateToPhase('outline');
return;
}
navigateToPhase('outline');
if (outlineGenRef.current) {
try {
const result = await outlineGenRef.current.generateNow();
if (!result.success) {
alert(result.message || 'Failed to generate outline');
}
} catch (error) {
console.error('Outline generation failed:', error);
alert(`Outline generation failed: ${error instanceof Error ? error.message : 'Unknown error'}`);
}
}
debug.log('[BlogWriter] Outline action triggered');
}, [research, navigateToPhase, outlineGenRef, setOutline, onOutlineComplete]);
const handleContentAction = useCallback(async () => {
if (!outline || outline.length === 0) {
alert('Please generate and confirm an outline first.');
return;
}
if (!research) {
alert('Research data is required for content generation.');
return;
}
navigateToPhase('content');
// Confirm outline first
handleOutlineConfirmed();
// Check cache first (shared utility)
const outlineIds = outline.map(s => String(s.id));
const cachedContent = blogWriterCache.getCachedContent(outlineIds);
if (cachedContent) {
debug.log('[BlogWriter] Using cached content from localStorage', { sections: Object.keys(cachedContent).length });
if (onContentComplete) {
onContentComplete(cachedContent);
}
return;
}
// Also check if sections already exist in current state (shared utility)
if (blogWriterCache.contentExistsInState(sections || {}, outlineIds)) {
debug.log('[BlogWriter] Content already exists in state, skipping generation', { sections: Object.keys(sections || {}).length });
return;
}
// If short/medium blog (<=1000 words), trigger content generation automatically
const target = Number(
research?.keyword_analysis?.blog_length ||
(research as any)?.word_count_target ||
localStorage.getItem('blog_length_target') ||
0
);
if (target && target <= 1000) {
try {
setIsMediumGenerationStarting(true);
const payload = {
title: selectedTitle || (typeof window !== 'undefined' ? localStorage.getItem('blog_selected_title') : '') || outline[0]?.heading || 'Untitled',
sections: outline.map(s => ({
id: s.id,
heading: s.heading,
keyPoints: s.key_points,
subheadings: s.subheadings,
keywords: s.keywords,
targetWords: s.target_words,
references: s.references,
})),
globalTargetWords: target,
researchKeywords: research.original_keywords || research.keyword_analysis?.primary || [],
};
const { task_id } = await mediumBlogApi.startMediumGeneration(payload as any);
setIsMediumGenerationStarting(false);
mediumPolling.startPolling(task_id);
debug.log('[BlogWriter] Content action triggered - medium generation started', { task_id });
} catch (error) {
console.error('Content generation failed:', error);
setIsMediumGenerationStarting(false);
alert(`Content generation failed: ${error instanceof Error ? error.message : 'Unknown error'}`);
}
} else {
// For longer blogs, just confirm outline - user will use manual button
debug.log('[BlogWriter] Content action triggered - outline confirmed (manual content generation required)');
}
}, [outline, research, selectedTitle, sections, navigateToPhase, handleOutlineConfirmed, setIsMediumGenerationStarting, mediumPolling, onContentComplete]);
const handleSEOAction = useCallback(() => {
if (!contentConfirmed) {
// Mark content as confirmed when SEO action is clicked
setContentConfirmed(true);
}
navigateToPhase('seo');
runSEOAnalysisDirect();
debug.log('[BlogWriter] SEO action triggered');
}, [contentConfirmed, setContentConfirmed, navigateToPhase, runSEOAnalysisDirect]);
const handlePublishAction = useCallback(() => {
navigateToPhase('publish');
setIsSEOMetadataModalOpen(true);
debug.log('[BlogWriter] Publish action triggered - opening SEO metadata modal');
}, [navigateToPhase, setIsSEOMetadataModalOpen]);
return {
handleResearchAction,
handleOutlineAction,
handleContentAction,
handleSEOAction,
handlePublishAction,
};
};

View File

@@ -0,0 +1,67 @@
import { useEffect } from 'react';
import { debug } from '../../../utils/debug';
interface UsePhaseRestorationProps {
copilotKitAvailable: boolean;
research: any;
phases: any[];
currentPhase: string;
navigateToPhase: (phase: string) => void;
setCurrentPhase: (phase: string) => void;
}
export const usePhaseRestoration = ({
copilotKitAvailable,
research,
phases,
currentPhase,
navigateToPhase,
setCurrentPhase,
}: UsePhaseRestorationProps) => {
// When CopilotKit is unavailable and there's no research, ensure we're on research phase
useEffect(() => {
if (!copilotKitAvailable && !research && phases.length > 0 && currentPhase !== 'research') {
navigateToPhase('research');
debug.log('[BlogWriter] Auto-navigating to research phase (CopilotKit unavailable)');
}
}, [copilotKitAvailable, research, phases.length, currentPhase, navigateToPhase]);
// Restore phase from navigation state on mount (after subscription renewal)
// Note: The PricingPage restores the phase to localStorage before redirecting
// This effect ensures the phase is applied when BlogWriter loads
useEffect(() => {
try {
// Wait for phases to be initialized
if (phases.length === 0) {
return;
}
// Check if we just returned from pricing page (has restored phase in localStorage)
const restoredPhase = localStorage.getItem('blogwriter_current_phase');
const userSelectedPhase = localStorage.getItem('blogwriter_user_selected_phase') === 'true';
// Only restore if:
// 1. A phase was saved (restoredPhase exists)
// 2. User had manually selected a phase (indicates they were actively working)
// 3. The phase is different from current (to avoid unnecessary updates)
if (restoredPhase && userSelectedPhase && restoredPhase !== currentPhase) {
const targetPhase = phases.find(p => p.id === restoredPhase);
if (targetPhase && !targetPhase.disabled) {
console.log('[BlogWriter] Restoring phase from navigation state:', restoredPhase);
setCurrentPhase(restoredPhase);
// Phase restoration complete - the usePhaseNavigation hook will handle persistence
} else {
console.log('[BlogWriter] Restored phase is disabled or not found, keeping current phase:', {
restoredPhase,
currentPhase,
targetPhaseExists: !!targetPhase,
targetPhaseDisabled: targetPhase?.disabled
});
}
}
} catch (error) {
console.error('[BlogWriter] Failed to restore phase from navigation state:', error);
}
}, [phases, currentPhase, setCurrentPhase]);
};

View File

@@ -0,0 +1,245 @@
import { useState, useRef, useEffect, useCallback } from 'react';
import { debug } from '../../../utils/debug';
import { blogWriterApi, BlogSEOActionableRecommendation } from '../../../services/blogWriterApi';
interface UseSEOManagerProps {
sections: Record<string, string>;
research: any;
outline: any[];
selectedTitle: string | null;
contentConfirmed: boolean;
seoAnalysis: any;
currentPhase: string;
navigateToPhase: (phase: string) => void;
setContentConfirmed: (confirmed: boolean) => void;
setSeoAnalysis: (analysis: any) => void;
setSeoMetadata: (metadata: any) => void;
setSections: (sections: Record<string, string>) => void;
setSelectedTitle: (title: string | null) => void;
setContinuityRefresh: (timestamp: number) => void;
setFlowAnalysisCompleted: (completed: boolean) => void;
setFlowAnalysisResults: (results: any) => void;
}
export const useSEOManager = ({
sections,
research,
outline,
selectedTitle,
contentConfirmed,
seoAnalysis,
currentPhase,
navigateToPhase,
setContentConfirmed,
setSeoAnalysis,
setSeoMetadata,
setSections,
setSelectedTitle,
setContinuityRefresh,
setFlowAnalysisCompleted,
setFlowAnalysisResults,
}: UseSEOManagerProps) => {
const [isSEOAnalysisModalOpen, setIsSEOAnalysisModalOpen] = useState(false);
const [isSEOMetadataModalOpen, setIsSEOMetadataModalOpen] = useState(false);
const [seoRecommendationsApplied, setSeoRecommendationsApplied] = useState(false);
const lastSEOModalOpenRef = useRef<number>(0);
// Helper: run same checks as analyzeSEO and open modal
const runSEOAnalysisDirect = useCallback((): string => {
const hasSections = !!sections && Object.keys(sections).length > 0;
const hasResearch = !!research && !!(research as any).keyword_analysis;
if (!hasSections) return "No blog content available for SEO analysis. Please generate content first.";
if (!hasResearch) return "Research data is required for SEO analysis. Please run research first.";
// Prevent rapid re-opens
const now = Date.now();
if (isSEOAnalysisModalOpen && now - lastSEOModalOpenRef.current < 1000) {
return "SEO analysis is already open.";
}
// Mark content phase as done when user clicks "Next: Run SEO Analysis"
if (!contentConfirmed) {
setContentConfirmed(true);
debug.log('[BlogWriter] Content phase marked as done (SEO analysis triggered)');
}
setSeoRecommendationsApplied(false);
if (!isSEOAnalysisModalOpen) {
setIsSEOAnalysisModalOpen(true);
lastSEOModalOpenRef.current = now;
debug.log('[BlogWriter] SEO modal opened (direct)');
}
return "Running SEO analysis of your blog content. This will analyze content structure, keyword optimization, readability, and provide actionable recommendations.";
}, [sections, research, isSEOAnalysisModalOpen, contentConfirmed, setContentConfirmed]);
const handleApplySeoRecommendations = useCallback(async (
recommendations: BlogSEOActionableRecommendation[]
) => {
if (!outline || outline.length === 0) {
throw new Error('An outline is required before applying recommendations.');
}
const sectionPayload = outline.map((section) => ({
id: section.id,
heading: section.heading,
content: sections[section.id] ?? '',
}));
const response = await blogWriterApi.applySeoRecommendations({
title: selectedTitle || outline[0]?.heading || 'Untitled Blog',
sections: sectionPayload,
outline,
research: (research as any) || {},
recommendations,
});
if (!response.success) {
throw new Error(response.error || 'Failed to apply recommendations.');
}
if (!response.sections || !Array.isArray(response.sections)) {
throw new Error('Recommendation response did not include updated sections.');
}
// Update sections - create new object reference to trigger React re-render
const newSections: Record<string, string> = {};
response.sections.forEach((section) => {
if (section.id && section.content) {
newSections[section.id] = section.content;
}
});
// Validate we have sections before updating
if (Object.keys(newSections).length === 0) {
throw new Error('No valid sections received from SEO recommendations application.');
}
// Validate sections have actual content
const sectionsWithContent = Object.values(newSections).filter(c => c && c.trim().length > 0);
if (sectionsWithContent.length === 0) {
throw new Error('SEO recommendations resulted in empty sections. Please try again.');
}
// Log detailed section info for debugging
const sectionIds = Object.keys(newSections);
const sectionSizes = sectionIds.map(id => ({ id, length: newSections[id]?.length || 0 }));
debug.log('[BlogWriter] Applied SEO recommendations: sections updated', {
sectionCount: sectionIds.length,
sectionsWithContent: sectionsWithContent.length,
sectionIds: sectionIds,
sectionSizes: sectionSizes,
totalContentLength: Object.values(newSections).reduce((sum, c) => sum + (c?.length || 0), 0)
});
// Update sections state
setSections(newSections);
// Force a delay to ensure React processes the state update before proceeding
// This gives React time to re-render with new sections before phase navigation checks
await new Promise(resolve => setTimeout(resolve, 200));
setContinuityRefresh(Date.now());
setFlowAnalysisCompleted(false);
setFlowAnalysisResults(null);
if (response.title && response.title !== selectedTitle) {
setSelectedTitle(response.title);
}
if (response.applied) {
setSeoAnalysis((prev: any) => prev ? { ...prev, applied_recommendations: response.applied } : prev);
debug.log('[BlogWriter] SEO analysis state updated with applied recommendations');
}
// Mark recommendations as applied (this will trigger phase navigation check)
// But we'll stay in SEO phase to show updated content
setSeoRecommendationsApplied(true);
debug.log('[BlogWriter] seoRecommendationsApplied set to true');
// Ensure we stay in SEO phase to show updated content
// Force navigation to SEO phase if we're not already there (safeguard)
if (currentPhase !== 'seo') {
navigateToPhase('seo');
debug.log('[BlogWriter] Forced navigation to SEO phase after applying recommendations');
} else {
debug.log('[BlogWriter] Already in SEO phase, staying to show updated content');
}
}, [outline, sections, selectedTitle, research, setSections, setSelectedTitle, setContinuityRefresh, setFlowAnalysisCompleted, setFlowAnalysisResults, setSeoAnalysis, currentPhase, navigateToPhase]);
// Handle SEO analysis completion
const handleSEOAnalysisComplete = useCallback((analysis: any) => {
setSeoAnalysis(analysis);
debug.log('[BlogWriter] SEO analysis completed', { hasAnalysis: !!analysis });
}, [setSeoAnalysis]);
// Handle SEO modal close - mark SEO phase as done if not already marked
const handleSEOModalClose = useCallback(() => {
// Mark SEO phase as done when modal closes (even without applying recommendations)
if (!seoAnalysis) {
// Set a minimal valid seoAnalysis object to mark phase as complete
setSeoAnalysis({
success: true,
overall_score: 0,
category_scores: {},
analysis_summary: {
overall_grade: 'N/A',
status: 'Skipped',
strongest_category: 'N/A',
weakest_category: 'N/A',
key_strengths: [],
key_weaknesses: [],
ai_summary: 'SEO analysis was skipped by user'
},
actionable_recommendations: [],
generated_at: new Date().toISOString()
});
debug.log('[BlogWriter] SEO phase marked as done (modal closed without analysis)');
}
setIsSEOAnalysisModalOpen(false);
debug.log('[BlogWriter] SEO modal closed');
}, [seoAnalysis, setSeoAnalysis]);
// Mark SEO phase as completed when recommendations are applied
useEffect(() => {
if (seoRecommendationsApplied && seoAnalysis) {
// SEO phase is considered complete when recommendations are applied
// But stay in SEO phase to show updated content
debug.log('[BlogWriter] SEO recommendations applied, SEO phase marked as complete');
// Ensure we stay in SEO phase to show updated content (override auto-progression)
if (currentPhase !== 'seo' && Object.keys(sections).length > 0) {
navigateToPhase('seo');
debug.log('[BlogWriter] Forced stay in SEO phase to show updated content');
}
}
}, [seoRecommendationsApplied, seoAnalysis, currentPhase, sections, navigateToPhase]);
const confirmBlogContent = useCallback(() => {
debug.log('[BlogWriter] Blog content confirmed by user');
setContentConfirmed(true);
setSeoRecommendationsApplied(false);
navigateToPhase('seo');
setTimeout(() => {
setIsSEOAnalysisModalOpen(true);
debug.log('[BlogWriter] SEO modal opened (confirm→direct)');
}, 0);
return "✅ Blog content has been confirmed! Running SEO analysis now.";
}, [setContentConfirmed, navigateToPhase]);
return {
isSEOAnalysisModalOpen,
setIsSEOAnalysisModalOpen,
isSEOMetadataModalOpen,
setIsSEOMetadataModalOpen,
seoRecommendationsApplied,
setSeoRecommendationsApplied,
lastSEOModalOpenRef,
runSEOAnalysisDirect,
handleApplySeoRecommendations,
handleSEOAnalysisComplete,
handleSEOModalClose,
confirmBlogContent,
};
};
export type SEOManagerReturn = ReturnType<typeof useSEOManager>;

View File

@@ -0,0 +1,113 @@
import React, { useState } from 'react';
import { Button, CircularProgress } from '@mui/material';
import { mediumBlogApi } from '../../services/blogWriterApi';
import { BlogOutlineSection, BlogResearchResponse } from '../../services/blogWriterApi';
interface ManualContentButtonProps {
/**
* The confirmed outline sections
*/
outline: BlogOutlineSection[];
/**
* The research data
*/
research: BlogResearchResponse;
/**
* Blog title (optional)
*/
blogTitle?: string;
/**
* Existing sections content (optional)
*/
sections?: Record<string, string>;
/**
* Callback when content generation starts
*/
onGenerationStart?: (taskId: string) => void;
}
/**
* Manual content generation button that works independently of CopilotKit
* Triggers medium blog generation via mediumBlogApi
*/
export const ManualContentButton: React.FC<ManualContentButtonProps> = ({
outline,
research,
blogTitle,
sections,
onGenerationStart,
}) => {
const [isGenerating, setIsGenerating] = useState(false);
const [error, setError] = useState<string | null>(null);
const handleGenerate = async () => {
if (!outline || outline.length === 0) {
alert('Please confirm an outline first before generating content.');
return;
}
if (!research) {
alert('Research data is required for content generation.');
return;
}
setIsGenerating(true);
setError(null);
try {
const payload = {
outline,
research,
title: blogTitle || outline[0]?.heading || 'Blog Post',
existing_sections: sections || {},
};
const { task_id } = await mediumBlogApi.startMediumGeneration(payload as any);
if (task_id) {
onGenerationStart?.(task_id);
} else {
throw new Error('Failed to start content generation - no task ID returned');
}
} catch (error) {
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
setError(errorMessage);
alert(`Content generation failed: ${errorMessage}`);
setIsGenerating(false);
}
};
return (
<div style={{ padding: '20px', textAlign: 'center' }}>
<h3 style={{ margin: '0 0 16px 0', color: '#333' }}>Generate Blog Content</h3>
<p style={{ margin: '0 0 20px 0', color: '#666', fontSize: '14px' }}>
Generate full content for all sections in your confirmed outline.
</p>
<Button
variant="contained"
color="primary"
size="large"
onClick={handleGenerate}
disabled={!outline || outline.length === 0 || !research || isGenerating}
startIcon={isGenerating ? <CircularProgress size={20} color="inherit" /> : null}
sx={{
minWidth: 200,
py: 1.5,
px: 4,
}}
>
{isGenerating ? 'Generating Content...' : '📝 Generate Content'}
</Button>
{error && (
<p style={{ margin: '12px 0 0 0', color: '#d32f2f', fontSize: '14px' }}>
{error}
</p>
)}
</div>
);
};
export default ManualContentButton;

View File

@@ -0,0 +1,111 @@
import React, { useState } from 'react';
import { Button, CircularProgress } from '@mui/material';
interface ManualOutlineButtonProps {
/**
* Ref to OutlineGenerator component with generateNow() method
*/
outlineGenRef: React.RefObject<{
generateNow: () => Promise<{
success: boolean;
message?: string;
task_id?: string;
cached?: boolean;
outline?: any[];
title_options?: string[];
}>
}>;
/**
* Whether research is available (required for outline generation)
*/
hasResearch: boolean;
/**
* Callback when outline generation starts
*/
onGenerationStart?: (taskId: string) => void;
}
/**
* Manual outline generation button that works independently of CopilotKit
* Calls the generateNow() method from OutlineGenerator ref
*/
export const ManualOutlineButton: React.FC<ManualOutlineButtonProps> = ({
outlineGenRef,
hasResearch,
onGenerationStart,
}) => {
const [isGenerating, setIsGenerating] = useState(false);
const [error, setError] = useState<string | null>(null);
const handleGenerate = async () => {
if (!hasResearch) {
alert('Please complete research first before generating an outline.');
return;
}
if (!outlineGenRef.current) {
alert('Outline generator is not available. Please refresh the page.');
return;
}
setIsGenerating(true);
setError(null);
try {
const result = await outlineGenRef.current.generateNow();
if (result.success) {
if (result.cached && result.outline) {
// Handle cached result - outline is already available, no need to poll
console.log('[ManualOutlineButton] Cached outline used', { sections: result.outline.length });
// The outline should be set by the parent component handling the cache
} else if (result.task_id) {
onGenerationStart?.(result.task_id);
}
} else {
setError(result.message || 'Failed to generate outline');
alert(result.message || 'Failed to generate outline. Please try again.');
}
} catch (error) {
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
setError(errorMessage);
alert(`Outline generation failed: ${errorMessage}`);
} finally {
setIsGenerating(false);
}
};
return (
<div style={{ padding: '20px', textAlign: 'center' }}>
<h3 style={{ margin: '0 0 16px 0', color: '#333' }}>Create Your Outline</h3>
<p style={{ margin: '0 0 20px 0', color: '#666', fontSize: '14px' }}>
Generate an AI-powered outline based on your research.
</p>
<Button
variant="contained"
color="primary"
size="large"
onClick={handleGenerate}
disabled={!hasResearch || isGenerating}
startIcon={isGenerating ? <CircularProgress size={20} color="inherit" /> : null}
sx={{
minWidth: 200,
py: 1.5,
px: 4,
}}
>
{isGenerating ? 'Generating Outline...' : '🧩 Generate Outline'}
</Button>
{error && (
<p style={{ margin: '12px 0 0 0', color: '#d32f2f', fontSize: '14px' }}>
{error}
</p>
)}
</div>
);
};
export default ManualOutlineButton;

View File

@@ -0,0 +1,184 @@
import React, { useState, useRef } from 'react';
import { blogWriterApi, BlogResearchRequest, BlogResearchResponse } from '../../services/blogWriterApi';
import { useResearchPolling } from '../../hooks/usePolling';
import ResearchProgressModal from './ResearchProgressModal';
import { researchCache } from '../../services/researchCache';
interface ManualResearchFormProps {
onResearchComplete?: (research: BlogResearchResponse) => void;
}
/**
* Manual research form component that works independently of CopilotKit
* Extracted from ResearchAction.tsx for use when CopilotKit is unavailable
*/
export const ManualResearchForm: React.FC<ManualResearchFormProps> = ({ onResearchComplete }) => {
const [currentTaskId, setCurrentTaskId] = useState<string | null>(null);
const [currentMessage, setCurrentMessage] = useState<string>('');
const [showProgressModal, setShowProgressModal] = useState<boolean>(false);
const [isSubmitting, setIsSubmitting] = useState(false);
// Refs for form inputs (uncontrolled, avoids typing issues)
const keywordsRef = useRef<HTMLInputElement | null>(null);
const blogLengthRef = useRef<HTMLSelectElement | null>(null);
const polling = useResearchPolling({
onProgress: (message) => {
setCurrentMessage(message);
},
onComplete: (result) => {
if (result && result.keywords) {
researchCache.cacheResult(
result.keywords,
result.industry || 'General',
result.target_audience || 'General',
result
);
}
onResearchComplete?.(result);
setCurrentTaskId(null);
setCurrentMessage('');
setShowProgressModal(false);
setIsSubmitting(false);
},
onError: (error) => {
console.error('Research polling error:', error);
setCurrentTaskId(null);
setCurrentMessage('');
setShowProgressModal(false);
setIsSubmitting(false);
}
});
const handleSubmit = async () => {
const keywords = (keywordsRef.current?.value || '').trim();
const blogLength = blogLengthRef.current?.value || '1000';
if (!keywords) {
alert('Please enter keywords or a topic for research.');
return;
}
setIsSubmitting(true);
try {
const keywordList = keywords.includes(',')
? keywords.split(',').map(k => k.trim()).filter(Boolean)
: [keywords];
// Check cache first
const cachedResult = researchCache.getCachedResult(keywordList, 'General', 'General');
if (cachedResult) {
onResearchComplete?.(cachedResult);
setIsSubmitting(false);
return;
}
const payload: BlogResearchRequest = {
keywords: keywordList,
industry: 'General',
target_audience: 'General',
word_count_target: parseInt(blogLength)
};
const { task_id } = await blogWriterApi.startResearch(payload);
setCurrentTaskId(task_id);
setShowProgressModal(true);
polling.startPolling(task_id);
} catch (error) {
console.error('Research failed:', error);
alert(`Research failed: ${error instanceof Error ? error.message : 'Unknown error'}`);
setIsSubmitting(false);
}
};
return (
<>
<div style={{ padding: '20px', backgroundColor: '#f8f9fa', borderRadius: '12px', border: '1px solid #e0e0e0', margin: '8px 0' }}>
<h4 style={{ margin: '0 0 16px 0', color: '#333' }}>🔍 Let's Research Your Blog Topic</h4>
<p style={{ margin: '0 0 16px 0', color: '#666', fontSize: '14px' }}>
What keywords and information would you like to use for your research? Please also specify the desired length of the blog post.
</p>
<div style={{ marginBottom: '16px' }}>
<label style={{ display: 'block', marginBottom: '8px', fontWeight: '500', color: '#333' }}>Keywords or Topic *</label>
<input
type="text"
id="research-keywords-input"
placeholder="e.g., artificial intelligence, machine learning, AI trends"
ref={keywordsRef}
disabled={isSubmitting}
style={{
width: '100%',
padding: '12px',
border: '1px solid #ddd',
borderRadius: '6px',
fontSize: '14px',
boxSizing: 'border-box',
opacity: isSubmitting ? 0.6 : 1
}}
/>
</div>
<div style={{ marginBottom: '16px' }}>
<label style={{ display: 'block', marginBottom: '8px', fontWeight: '500', color: '#333' }}>Blog Length (words)</label>
<select
id="research-blog-length-select"
defaultValue="1000"
ref={blogLengthRef}
disabled={isSubmitting}
style={{
width: '100%',
padding: '12px',
border: '1px solid #ddd',
borderRadius: '6px',
fontSize: '14px',
boxSizing: 'border-box',
opacity: isSubmitting ? 0.6 : 1
}}
>
<option value="500">500 words (Short blog)</option>
<option value="1000">1000 words (Medium blog)</option>
<option value="1500">1500 words (Long blog)</option>
<option value="2000">2000 words (Comprehensive blog)</option>
</select>
</div>
<div style={{ display: 'flex', gap: '12px', justifyContent: 'flex-end' }}>
<button
onClick={handleSubmit}
disabled={isSubmitting}
style={{
padding: '12px 24px',
backgroundColor: isSubmitting ? '#ccc' : '#1976d2',
color: 'white',
border: 'none',
borderRadius: '6px',
fontSize: '14px',
fontWeight: '500',
cursor: isSubmitting ? 'not-allowed' : 'pointer',
opacity: isSubmitting ? 0.7 : 1
}}
>
{isSubmitting ? ' Starting Research...' : '🚀 Start Research'}
</button>
</div>
</div>
{showProgressModal && (
<ResearchProgressModal
open={showProgressModal}
title="Research in progress"
status={polling.currentStatus}
messages={polling.progressMessages}
error={polling.error}
onClose={() => setShowProgressModal(false)}
/>
)}
</>
);
};
export default ManualResearchForm;

View File

@@ -363,6 +363,21 @@ export const OutlineFeedbackForm: React.FC<OutlineFeedbackFormProps> = ({
); );
if (target && target <= 1000) { if (target && target <= 1000) {
// Check cache first (shared utility)
const { blogWriterCache } = await import('../../services/blogWriterCache');
const outlineIds = outline.map(s => String(s.id));
const cachedContent = blogWriterCache.getCachedContent(outlineIds);
if (cachedContent) {
console.log('[OutlineFeedbackForm] Using cached content', { sections: Object.keys(cachedContent).length });
// Content is already cached, skip API call
return {
success: true,
message: 'Content is already available from cache.',
cached: true
};
}
// Show modal immediately when medium generation is triggered // Show modal immediately when medium generation is triggered
onMediumGenerationTriggered?.(); onMediumGenerationTriggered?.();
// Build payload for medium generation // Build payload for medium generation
@@ -386,13 +401,61 @@ export const OutlineFeedbackForm: React.FC<OutlineFeedbackFormProps> = ({
// Notify parent to start polling for the medium generation task // Notify parent to start polling for the medium generation task
onMediumGenerationStarted?.(task_id); onMediumGenerationStarted?.(task_id);
// Return message so the copilot shows feedback; UI will display modal via BlogWriter // Poll once immediately to check for immediate failures (e.g., subscription errors)
return { try {
success: true, const initialStatus = await mediumBlogApi.pollMediumGeneration(task_id);
message: `✅ Outline confirmed. Medium generation started (Task: ${task_id}). You can monitor progress in the modal.`,
task_id, // Check if task already failed with subscription error
action_taken: 'outline_confirmed_medium_generation_started' if (initialStatus.status === 'failed' && (initialStatus.error_status === 429 || initialStatus.error_status === 402)) {
}; const errorData = initialStatus.error_data || {};
const errorMessage = errorData.message || errorData.error || initialStatus.error || 'Subscription limit exceeded';
// Return error to CopilotKit so it shows in chat
return {
success: false,
message: `❌ Medium generation failed: ${errorMessage}`,
error: errorMessage,
error_type: 'subscription_limit',
provider: errorData.provider || 'unknown',
suggestion: 'Please renew your subscription to continue generating content.',
action_taken: 'outline_confirmed_medium_generation_failed'
};
}
// Task started successfully, continue polling in background
return {
success: true,
message: `✅ Outline confirmed. Medium generation started (Task: ${task_id}). You can monitor progress in the modal.`,
task_id,
action_taken: 'outline_confirmed_medium_generation_started'
};
} catch (pollError: any) {
// Check if polling error is a subscription error (HTTP 429/402)
if (pollError?.response?.status === 429 || pollError?.response?.status === 402) {
const errorData = pollError.response?.data || {};
const errorMessage = errorData.message || errorData.error || 'Subscription limit exceeded';
return {
success: false,
message: `❌ Medium generation failed: ${errorMessage}`,
error: errorMessage,
error_type: 'subscription_limit',
provider: errorData.provider || 'unknown',
suggestion: 'Please renew your subscription to continue generating content.',
action_taken: 'outline_confirmed_medium_generation_failed'
};
}
// Other polling errors - still return success since task was started
// The polling will handle the error in the background
console.warn('Initial poll check failed, but task was started:', pollError);
return {
success: true,
message: `✅ Outline confirmed. Medium generation started (Task: ${task_id}). You can monitor progress in the modal.`,
task_id,
action_taken: 'outline_confirmed_medium_generation_started'
};
}
} }
return { return {

View File

@@ -1,6 +1,7 @@
import React, { forwardRef, useImperativeHandle } from 'react'; import React, { forwardRef, useImperativeHandle } from 'react';
import { useCopilotAction } from '@copilotkit/react-core'; import { useCopilotAction } from '@copilotkit/react-core';
import { blogWriterApi, BlogResearchResponse } from '../../services/blogWriterApi'; import { blogWriterApi, BlogResearchResponse } from '../../services/blogWriterApi';
import { blogWriterCache } from '../../services/blogWriterCache';
interface OutlineGeneratorProps { interface OutlineGeneratorProps {
research: BlogResearchResponse | null; research: BlogResearchResponse | null;
@@ -23,6 +24,22 @@ export const OutlineGenerator = forwardRef<any, OutlineGeneratorProps>(({
if (!research) { if (!research) {
return { success: false, message: 'No research yet. Please research a topic first.' }; return { success: false, message: 'No research yet. Please research a topic first.' };
} }
// Check cache first (shared utility)
const researchKeywords = research.original_keywords || research.keyword_analysis?.primary || [];
const cachedOutline = blogWriterCache.getCachedOutline(researchKeywords);
if (cachedOutline) {
console.log('[OutlineGenerator] Using cached outline', { sections: cachedOutline.outline.length });
// Return cached result - caller should handle setting outline state
return {
success: true,
cached: true,
outline: cachedOutline.outline,
title_options: cachedOutline.title_options
};
}
try { try {
onModalShow?.(); onModalShow?.();
const { task_id } = await blogWriterApi.startOutlineGeneration({ research }); const { task_id } = await blogWriterApi.startOutlineGeneration({ research });
@@ -44,6 +61,21 @@ export const OutlineGenerator = forwardRef<any, OutlineGeneratorProps>(({
return { success: false, message: 'No research yet. Please research a topic first.' }; return { success: false, message: 'No research yet. Please research a topic first.' };
} }
// Check cache first (shared utility)
const researchKeywords = research.original_keywords || research.keyword_analysis?.primary || [];
const cachedOutline = blogWriterCache.getCachedOutline(researchKeywords);
if (cachedOutline) {
console.log('[OutlineGenerator] Using cached outline from CopilotKit action', { sections: cachedOutline.outline.length });
return {
success: true,
message: `✅ Outline already exists! ${cachedOutline.outline.length} sections loaded from cache.`,
cached: true,
outline: cachedOutline.outline,
title_options: cachedOutline.title_options
};
}
try { try {
// Show progress modal immediately when user clicks "Create outline" // Show progress modal immediately when user clicks "Create outline"
onModalShow?.(); onModalShow?.();

View File

@@ -10,17 +10,80 @@ export interface Phase {
disabled: boolean; disabled: boolean;
} }
export interface PhaseActionHandlers {
onResearchAction?: () => void; // Show research form
onOutlineAction?: () => void; // Generate outline
onContentAction?: () => void; // Confirm outline + generate content
onSEOAction?: () => void; // Run SEO analysis
onPublishAction?: () => void; // Generate SEO metadata or publish
}
interface PhaseNavigationProps { interface PhaseNavigationProps {
phases: Phase[]; phases: Phase[];
onPhaseClick: (phaseId: string) => void; onPhaseClick: (phaseId: string) => void;
currentPhase: string; currentPhase: string;
copilotKitAvailable?: boolean;
actionHandlers?: PhaseActionHandlers;
// State for determining which actions to show
hasResearch?: boolean;
hasOutline?: boolean;
outlineConfirmed?: boolean;
hasContent?: boolean;
contentConfirmed?: boolean;
hasSEOAnalysis?: boolean;
hasSEOMetadata?: boolean;
} }
export const PhaseNavigation: React.FC<PhaseNavigationProps> = ({ export const PhaseNavigation: React.FC<PhaseNavigationProps> = ({
phases, phases,
onPhaseClick, onPhaseClick,
currentPhase currentPhase,
copilotKitAvailable = true,
actionHandlers,
hasResearch = false,
hasOutline = false,
outlineConfirmed = false,
hasContent = false,
contentConfirmed = false,
hasSEOAnalysis = false,
hasSEOMetadata = false,
}) => { }) => {
// Determine which action to show for each phase when CopilotKit is unavailable
const getActionForPhase = (phaseId: string): { label: string; handler: (() => void) | null } => {
if (copilotKitAvailable || !actionHandlers) {
return { label: '', handler: null };
}
switch (phaseId) {
case 'research':
if (!hasResearch) {
return { label: 'Start Research', handler: actionHandlers.onResearchAction || null };
}
break;
case 'outline':
if (hasResearch && !hasOutline) {
return { label: 'Create Outline', handler: actionHandlers.onOutlineAction || null };
}
break;
case 'content':
if (hasOutline && !outlineConfirmed) {
return { label: 'Confirm & Generate Content', handler: actionHandlers.onContentAction || null };
}
break;
case 'seo':
if (hasContent && contentConfirmed && !hasSEOAnalysis) {
return { label: 'Run SEO Analysis', handler: actionHandlers.onSEOAction || null };
}
break;
case 'publish':
if (hasSEOAnalysis && !hasSEOMetadata) {
return { label: 'Generate SEO Metadata', handler: actionHandlers.onPublishAction || null };
}
break;
}
return { label: '', handler: null };
};
return ( return (
<div style={{ <div style={{
display: 'flex', display: 'flex',
@@ -33,53 +96,103 @@ export const PhaseNavigation: React.FC<PhaseNavigationProps> = ({
const isCurrent = phase.current; const isCurrent = phase.current;
const isCompleted = phase.completed; const isCompleted = phase.completed;
const isDisabled = phase.disabled; const isDisabled = phase.disabled;
const action = getActionForPhase(phase.id);
// Show action button when:
// 1. CopilotKit is unavailable
// 2. Action handler exists
// 3. Phase is not disabled
// 4. Show for current phase OR next actionable phase (not completed)
// For research phase specifically, always show if no research exists
const isResearchPhase = phase.id === 'research' && !hasResearch;
const showAction = !copilotKitAvailable && action.handler && !isDisabled && (
isCurrent ||
(!isCompleted && !isDisabled) ||
isResearchPhase
);
return ( return (
<button <div key={phase.id} style={{ display: 'flex', alignItems: 'center', gap: '6px' }}>
key={phase.id} <button
onClick={() => !isDisabled && onPhaseClick(phase.id)} onClick={() => !isDisabled && onPhaseClick(phase.id)}
disabled={isDisabled} disabled={isDisabled}
style={{ style={{
display: 'flex', display: 'flex',
alignItems: 'center', alignItems: 'center',
gap: '6px', gap: '6px',
padding: '8px 12px', padding: '8px 12px',
borderRadius: '20px', borderRadius: '20px',
border: 'none', border: 'none',
fontSize: '14px', fontSize: '14px',
fontWeight: '500', fontWeight: '500',
cursor: isDisabled ? 'not-allowed' : 'pointer', cursor: isDisabled ? 'not-allowed' : 'pointer',
transition: 'all 0.2s ease', transition: 'all 0.2s ease',
backgroundColor: isCurrent backgroundColor: isCurrent
? '#1976d2' ? '#1976d2'
: isCompleted : isCompleted
? '#4caf50' ? '#4caf50'
: isDisabled : isDisabled
? '#f5f5f5' ? '#f5f5f5'
: '#e3f2fd', : '#e3f2fd',
color: isCurrent color: isCurrent
? 'white'
: isCompleted
? 'white' ? 'white'
: isDisabled : isCompleted
? '#999' ? 'white'
: '#1976d2', : isDisabled
opacity: isDisabled ? 0.6 : 1, ? '#999'
boxShadow: isCurrent ? '0 2px 4px rgba(25, 118, 210, 0.3)' : 'none', : '#1976d2',
transform: isCurrent ? 'translateY(-1px)' : 'none' opacity: isDisabled ? 0.6 : 1,
}} boxShadow: isCurrent ? '0 2px 4px rgba(25, 118, 210, 0.3)' : 'none',
title={phase.disabled ? `Complete ${phase.name} first` : phase.description} transform: isCurrent ? 'translateY(-1px)' : 'none'
> }}
<span style={{ fontSize: '16px' }}> title={phase.disabled ? `Complete ${phase.name} first` : phase.description}
{phase.icon} >
</span> <span style={{ fontSize: '16px' }}>
<span>{phase.name}</span> {phase.icon}
{isCompleted && !isCurrent && (
<span style={{ fontSize: '12px', marginLeft: '4px' }}>
</span> </span>
<span>{phase.name}</span>
{isCompleted && !isCurrent && (
<span style={{ fontSize: '12px', marginLeft: '4px' }}>
</span>
)}
</button>
{showAction && (
<button
onClick={(e) => {
e.stopPropagation();
action.handler?.();
}}
style={{
display: 'flex',
alignItems: 'center',
gap: '4px',
padding: '6px 12px',
borderRadius: '16px',
border: '1px solid #1976d2',
fontSize: '12px',
fontWeight: '600',
cursor: 'pointer',
backgroundColor: '#1976d2',
color: 'white',
transition: 'all 0.2s ease',
boxShadow: '0 2px 4px rgba(25, 118, 210, 0.2)',
}}
onMouseEnter={(e) => {
e.currentTarget.style.backgroundColor = '#1565c0';
e.currentTarget.style.transform = 'translateY(-1px)';
}}
onMouseLeave={(e) => {
e.currentTarget.style.backgroundColor = '#1976d2';
e.currentTarget.style.transform = 'none';
}}
title={`${action.label} (Chat unavailable - click to proceed)`}
>
<span style={{ fontSize: '12px' }}></span>
<span>{action.label}</span>
</button>
)} )}
</button> </div>
); );
})} })}
</div> </div>

View File

@@ -105,19 +105,27 @@ export const Publisher: React.FC<PublisherProps> = ({
try { try {
// Publish using same endpoint as WixTestPage // Publish using same endpoint as WixTestPage
// Note: Wix requires category/tag IDs (UUIDs), not names // Backend will lookup/create category and tag IDs from names if needed
// For now, skip categories/tags until we implement ID lookup/creation
const response = await apiClient.post('/api/wix/test/publish/real', { const response = await apiClient.post('/api/wix/test/publish/real', {
title: title, title: title,
content: md, // Use markdown, backend converts it content: md, // Use markdown, backend converts it
cover_image_url: coverImageUrl, cover_image_url: coverImageUrl,
// TODO: Lookup/create category IDs from metadata?.blog_categories // Pass category/tag names - backend will lookup existing or create new ones
// TODO: Lookup/create tag IDs from metadata?.blog_tags category_names: metadata?.blog_categories || [],
category_ids: undefined, tag_names: metadata?.blog_tags || [],
tag_ids: undefined,
publish: true, publish: true,
access_token: accessToken, access_token: accessToken,
member_id: undefined // Let backend derive from token member_id: undefined, // Let backend derive from token
seo_metadata: metadata ? {
seo_title: metadata.seo_title,
meta_description: metadata.meta_description,
focus_keyword: metadata.focus_keyword,
blog_tags: metadata.blog_tags || [], // Used for SEO keywords
social_hashtags: metadata.social_hashtags || [],
open_graph: metadata.open_graph || {},
twitter_card: metadata.twitter_card || {},
canonical_url: metadata.canonical_url
} : undefined
}); });
if (response.data.success) { if (response.data.success) {

View File

@@ -1,4 +1,4 @@
import React, { useEffect, useState } from 'react'; import React, { useState } from 'react';
import { BlogResearchResponse } from '../../../services/blogWriterApi'; import { BlogResearchResponse } from '../../../services/blogWriterApi';
interface ResearchSourcesProps { interface ResearchSourcesProps {
@@ -187,24 +187,6 @@ const KeywordChipGroup: React.FC<KeywordChipGroupProps> = ({
}; };
export const ResearchSources: React.FC<ResearchSourcesProps> = ({ research }) => { export const ResearchSources: React.FC<ResearchSourcesProps> = ({ research }) => {
const [showWebSearchHelp, setShowWebSearchHelp] = useState(false);
// Fix search widget overflow after render
useEffect(() => {
if (research.search_widget) {
const searchWidget = document.querySelector('[data-search-widget]');
if (searchWidget) {
const allElements = searchWidget.querySelectorAll('*');
allElements.forEach((el: any) => {
el.style.maxWidth = '100%';
el.style.overflow = 'hidden';
el.style.wordWrap = 'break-word';
el.style.whiteSpace = 'normal';
el.style.boxSizing = 'border-box';
});
}
}
}, [research.search_widget]);
const renderCredibilityScore = (score: number | undefined) => { const renderCredibilityScore = (score: number | undefined) => {
const safeScore = score ?? 0.8; // Default to 0.8 if undefined const safeScore = score ?? 0.8; // Default to 0.8 if undefined
@@ -454,135 +436,17 @@ export const ResearchSources: React.FC<ResearchSourcesProps> = ({ research }) =>
</div> </div>
)} )}
{/* Interactive Web Search - Moved from Header */} {/* Google Search Suggestions - Per Google Display Requirements */}
{research.search_widget && ( {research.search_widget && (
<div style={{ marginBottom: '20px', width: '100%', overflow: 'hidden' }}> <div style={{
<div style={{ display: 'flex', alignItems: 'center', gap: '8px', marginBottom: '12px', position: 'relative' }}> marginBottom: '24px',
<h4 style={{ margin: 0, color: '#555', fontSize: '16px' }}> width: '100%',
🔍 Explore More Research Topics position: 'relative'
</h4> }}>
{/* Help Icon for Web Search */} {/* Google Search Widget - Display exactly as provided without modifications */}
<span <div
onClick={() => setShowWebSearchHelp(!showWebSearchHelp)} dangerouslySetInnerHTML={{ __html: research.search_widget }}
style={{ />
fontSize: '14px',
color: '#9ca3af',
cursor: 'pointer',
padding: '4px',
borderRadius: '50%',
transition: 'all 0.2s ease',
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
minWidth: '24px',
minHeight: '24px'
}}
onMouseEnter={(e) => {
e.currentTarget.style.color = '#6b7280';
e.currentTarget.style.backgroundColor = '#f3f4f6';
}}
onMouseLeave={(e) => {
e.currentTarget.style.color = '#9ca3af';
e.currentTarget.style.backgroundColor = 'transparent';
}}
>
</span>
{/* Help Tooltip for Web Search */}
{showWebSearchHelp && (
<div style={{
position: 'absolute',
top: '100%',
left: '0',
marginTop: '8px',
backgroundColor: '#1f2937',
color: '#f9fafb',
padding: '12px 16px',
borderRadius: '8px',
fontSize: '12px',
lineHeight: '1.5',
maxWidth: '300px',
boxShadow: '0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -2px rgba(0, 0, 0, 0.05)',
zIndex: 1000,
border: '1px solid #374151'
}}>
<div style={{ fontWeight: '600', marginBottom: '4px', color: '#f3f4f6' }}>
Research Enhancement
</div>
<div style={{ color: '#d1d5db' }}>
Click on any search suggestion below to explore additional research topics and gather more insights for your blog. These searches will open in a new tab to help you discover trending topics, expert opinions, and current statistics.
</div>
{/* Tooltip arrow */}
<div style={{
position: 'absolute',
bottom: '100%',
left: '20px',
width: 0,
height: 0,
borderLeft: '6px solid transparent',
borderRight: '6px solid transparent',
borderBottom: '6px solid #1f2937'
}} />
</div>
)}
</div>
<div style={{
border: '1px solid #e0e0e0',
borderRadius: '8px',
padding: '16px',
backgroundColor: '#fafafa',
maxHeight: '400px',
overflow: 'auto',
width: '100%',
maxWidth: '100%',
boxSizing: 'border-box',
overflowX: 'hidden',
position: 'relative'
}}
onClick={(e) => {
// Make all links open in new tabs
const target = e.target as HTMLElement;
if (target.tagName === 'A' || target.closest('a')) {
const link = target.tagName === 'A' ? target as HTMLAnchorElement : target.closest('a') as HTMLAnchorElement;
if (link && link.href) {
link.target = '_blank';
link.rel = 'noopener noreferrer';
}
}
}}>
<div
data-search-widget
dangerouslySetInnerHTML={{ __html: research.search_widget }}
style={{
fontSize: '14px',
width: '100%',
maxWidth: '100%',
overflow: 'hidden',
overflowX: 'hidden',
wordWrap: 'break-word',
overflowWrap: 'break-word',
whiteSpace: 'normal',
display: 'block',
position: 'relative'
}}
/>
{/* Custom CSS to make Google icon larger */}
<style>
{`
[data-search-widget] svg {
width: 24px !important;
height: 24px !important;
}
[data-search-widget] .logo-light,
[data-search-widget] .logo-dark {
width: 24px !important;
height: 24px !important;
}
`}
</style>
</div>
</div> </div>
)} )}

View File

@@ -27,7 +27,7 @@ import {
Avatar, Avatar,
CircularProgress CircularProgress
} from '@mui/material'; } from '@mui/material';
import { apiClient } from '../../api/client'; import { apiClient, triggerSubscriptionError } from '../../api/client';
import { import {
CheckCircle, CheckCircle,
Cancel, Cancel,
@@ -308,7 +308,28 @@ export const SEOAnalysisModal: React.FC<SEOAnalysisModalProps> = ({
onAnalysisComplete(convertedResult); onAnalysisComplete(convertedResult);
} }
} catch (err) { } catch (err: any) {
console.error('SEO analysis failed:', err);
// Check if this is a subscription error (429/402) and trigger global subscription modal
const status = err?.response?.status;
if (status === 429 || status === 402) {
console.log('SEOAnalysisModal: Detected subscription error, triggering global handler', {
status,
data: err?.response?.data
});
const handled = triggerSubscriptionError(err);
if (handled) {
console.log('SEOAnalysisModal: Global subscription error handler triggered successfully');
// Don't set local error - let the global modal handle it
setIsAnalyzing(false);
return;
} else {
console.warn('SEOAnalysisModal: Global subscription error handler did not handle the error');
}
}
// For non-subscription errors, show local error message
setError(err instanceof Error ? err.message : 'Analysis failed'); setError(err instanceof Error ? err.message : 'Analysis failed');
setIsAnalyzing(false); setIsAnalyzing(false);
} }

View File

@@ -36,7 +36,7 @@ import {
Tag as TagIcon, Tag as TagIcon,
Refresh as RefreshIcon Refresh as RefreshIcon
} from '@mui/icons-material'; } from '@mui/icons-material';
import { apiClient } from '../../api/client'; import { apiClient, triggerSubscriptionError } from '../../api/client';
// Import metadata display components // Import metadata display components
import { CoreMetadataTab } from './SEO/MetadataDisplay/CoreMetadataTab'; import { CoreMetadataTab } from './SEO/MetadataDisplay/CoreMetadataTab';
@@ -219,8 +219,28 @@ export const SEOMetadataModal: React.FC<SEOMetadataModalProps> = ({
setEditableMetadata(result); setEditableMetadata(result);
console.log('📊 Metadata result set:', result); console.log('📊 Metadata result set:', result);
} catch (err) { } catch (err: any) {
console.error('❌ SEO metadata generation failed:', err); console.error('❌ SEO metadata generation failed:', err);
// Check if this is a subscription error (429/402) and trigger global subscription modal
const status = err?.response?.status;
if (status === 429 || status === 402) {
console.log('SEOMetadataModal: Detected subscription error, triggering global handler', {
status,
data: err?.response?.data
});
const handled = triggerSubscriptionError(err);
if (handled) {
console.log('SEOMetadataModal: Global subscription error handler triggered successfully');
// Don't set local error - let the global modal handle it
setIsGenerating(false);
return;
} else {
console.warn('SEOMetadataModal: Global subscription error handler did not handle the error');
}
}
// For non-subscription errors, show local error message
setError(err instanceof Error ? err.message : 'Failed to generate SEO metadata'); setError(err instanceof Error ? err.message : 'Failed to generate SEO metadata');
} finally { } finally {
setIsGenerating(false); setIsGenerating(false);

View File

@@ -53,6 +53,21 @@ export const usePlatformConnections = () => {
const handleWixConnect = async () => { const handleWixConnect = async () => {
try { try {
// Store current page URL BEFORE redirecting (critical for proper redirect back)
// This ensures we can redirect back to the correct page (e.g., Blog Writer) after OAuth
const currentUrl = window.location.href;
try {
// Only store if not already set (allows WixConnectModal to override if needed)
if (!sessionStorage.getItem('wix_oauth_redirect')) {
sessionStorage.setItem('wix_oauth_redirect', currentUrl);
console.log('[Wix OAuth] Stored redirect URL:', currentUrl);
} else {
console.log('[Wix OAuth] Redirect URL already set, keeping existing:', sessionStorage.getItem('wix_oauth_redirect'));
}
} catch (e) {
console.warn('[Wix OAuth] Failed to store redirect URL:', e);
}
// Use the working Wix OAuth flow from WixTestPage // Use the working Wix OAuth flow from WixTestPage
const wixClient = createClient({ const wixClient = createClient({
auth: OAuthStrategy({ clientId: '75d88e36-1c76-4009-b769-15f4654556df' }) auth: OAuthStrategy({ clientId: '75d88e36-1c76-4009-b769-15f4654556df' })

View File

@@ -51,6 +51,7 @@ import {
} from '@mui/icons-material'; } from '@mui/icons-material';
import { useNavigate } from 'react-router-dom'; import { useNavigate } from 'react-router-dom';
import { apiClient } from '../../api/client'; import { apiClient } from '../../api/client';
import { restoreNavigationState, saveCurrentPhaseForTool } from '../../utils/navigationState';
interface SubscriptionPlan { interface SubscriptionPlan {
id: number; id: number;
@@ -114,8 +115,15 @@ const PricingPage: React.FC = () => {
}; };
const handleSubscribe = async (planId: number) => { const handleSubscribe = async (planId: number) => {
console.log('[PricingPage] handleSubscribe called', { planId });
const plan = plans.find(p => p.id === planId); const plan = plans.find(p => p.id === planId);
if (!plan) return; if (!plan) {
console.error('[PricingPage] ❌ Plan not found for ID:', planId);
return;
}
console.log('[PricingPage] Selected plan:', { id: plan.id, name: plan.name, tier: plan.tier });
// Get user_id from localStorage (set by Clerk auth) // Get user_id from localStorage (set by Clerk auth)
const userId = localStorage.getItem('user_id'); const userId = localStorage.getItem('user_id');
@@ -123,18 +131,20 @@ const PricingPage: React.FC = () => {
// Check if user is signed in // Check if user is signed in
if (!userId || userId === 'anonymous' || userId === '') { if (!userId || userId === 'anonymous' || userId === '') {
// User not signed in, show sign-in prompt // User not signed in, show sign-in prompt
console.warn('PricingPage: User not signed in, showing prompt'); console.warn('[PricingPage] User not signed in, showing prompt');
setShowSignInPrompt(true); setShowSignInPrompt(true);
return; return;
} }
// For alpha testing, only allow Free and Basic plans (Pro features not ready) // For alpha testing, only allow Free and Basic plans (Pro features not ready)
if (plan.tier !== 'free' && plan.tier !== 'basic') { if (plan.tier !== 'free' && plan.tier !== 'basic') {
console.error('[PricingPage] Plan tier not available:', plan.tier);
setError('This plan is not available for alpha testing'); setError('This plan is not available for alpha testing');
return; return;
} }
if (plan.tier === 'free') { if (plan.tier === 'free') {
console.log('[PricingPage] Processing Free plan subscription directly');
// For free plan, just create subscription // For free plan, just create subscription
try { try {
setSubscribing(true); setSubscribing(true);
@@ -164,23 +174,38 @@ const PricingPage: React.FC = () => {
} }
} else { } else {
// For Basic plan, show payment modal // For Basic plan, show payment modal
console.log('[PricingPage] Opening payment modal for Basic plan', { planId, planName: plan.name });
setSelectedPlan(planId); // ✅ Set selected plan before opening modal
setPaymentModalOpen(true); setPaymentModalOpen(true);
} }
}; };
const handlePaymentConfirm = async () => { const handlePaymentConfirm = async () => {
if (!selectedPlan) return; console.log('[PricingPage] handlePaymentConfirm called', { selectedPlan, yearlyBilling });
if (!selectedPlan) {
console.error('[PricingPage] ❌ No selectedPlan set - cannot proceed with subscription');
setError('No plan selected. Please select a plan and try again.');
return;
}
try { try {
setSubscribing(true); setSubscribing(true);
const userId = localStorage.getItem('user_id') || 'anonymous'; const userId = localStorage.getItem('user_id') || 'anonymous';
console.log('[PricingPage] Making subscription API call:', {
url: `/api/subscription/subscribe/${userId}`,
plan_id: selectedPlan,
billing_cycle: yearlyBilling ? 'yearly' : 'monthly',
userId
});
const response = await apiClient.post(`/api/subscription/subscribe/${userId}`, { const response = await apiClient.post(`/api/subscription/subscribe/${userId}`, {
plan_id: selectedPlan, plan_id: selectedPlan,
billing_cycle: yearlyBilling ? 'yearly' : 'monthly' billing_cycle: yearlyBilling ? 'yearly' : 'monthly'
}); });
console.log('Subscription renewed successfully:', response.data); console.log('[PricingPage] ✅ Subscription renewed successfully:', response.data);
// Refresh subscription status immediately // Refresh subscription status immediately
window.dispatchEvent(new CustomEvent('subscription-updated')); window.dispatchEvent(new CustomEvent('subscription-updated'));
@@ -223,13 +248,26 @@ const PricingPage: React.FC = () => {
// If not complete, redirect to onboarding; otherwise to dashboard // If not complete, redirect to onboarding; otherwise to dashboard
const onboardingComplete = localStorage.getItem('onboarding_complete') === 'true'; const onboardingComplete = localStorage.getItem('onboarding_complete') === 'true';
if (onboardingComplete) { if (onboardingComplete) {
// Try to go back to where the user was (e.g., blog writer) // Restore navigation state (path, phase, tool) if available
// If no history, go to dashboard const navState = restoreNavigationState();
const referrer = sessionStorage.getItem('subscription_referrer');
if (referrer && referrer !== '/pricing') { if (navState && navState.path && navState.path !== '/pricing') {
navigate(referrer); // Restore phase if applicable (e.g., Blog Writer)
if (navState.tool === 'blog-writer' && navState.phase) {
saveCurrentPhaseForTool('blog-writer', navState.phase);
console.log('[PricingPage] Restored Blog Writer phase:', navState.phase);
}
console.log('[PricingPage] Redirecting to saved navigation state:', navState);
navigate(navState.path);
} else { } else {
navigate('/dashboard'); // Fallback: try legacy referrer
const referrer = sessionStorage.getItem('subscription_referrer');
if (referrer && referrer !== '/pricing') {
navigate(referrer);
} else {
navigate('/dashboard');
}
} }
} else { } else {
navigate('/onboarding'); navigate('/onboarding');

View File

@@ -0,0 +1,216 @@
import React, { useEffect } from 'react';
import { useResearchWizard } from './hooks/useResearchWizard';
import { useResearchExecution } from './hooks/useResearchExecution';
import { StepKeyword } from './steps/StepKeyword';
import { StepOptions } from './steps/StepOptions';
import { StepProgress } from './steps/StepProgress';
import { StepResults } from './steps/StepResults';
import { ResearchWizardProps } from './types/research.types';
export const ResearchWizard: React.FC<ResearchWizardProps> = ({
onComplete,
onCancel,
initialKeywords,
initialIndustry,
}) => {
const wizard = useResearchWizard(initialKeywords, initialIndustry);
const execution = useResearchExecution();
// Handle results from execution
useEffect(() => {
if (execution.result && !execution.isExecuting) {
wizard.updateState({ results: execution.result });
if (wizard.state.currentStep === 3) {
wizard.nextStep();
}
}
}, [execution.result, execution.isExecuting]);
// Handle completion callback
useEffect(() => {
if (wizard.state.results && onComplete) {
onComplete(wizard.state.results);
}
}, [wizard.state.results, onComplete]);
const renderStep = () => {
const stepProps = {
state: wizard.state,
onUpdate: wizard.updateState,
onNext: wizard.nextStep,
onBack: wizard.prevStep,
};
switch (wizard.state.currentStep) {
case 1:
return <StepKeyword {...stepProps} />;
case 2:
return <StepOptions {...stepProps} />;
case 3:
return <StepProgress {...stepProps} />;
case 4:
return <StepResults {...stepProps} />;
default:
return <StepKeyword {...stepProps} />;
}
};
return (
<div style={{
minHeight: '100vh',
backgroundColor: '#f5f5f5',
padding: '20px',
}}>
{/* Wizard Container */}
<div style={{
maxWidth: '1200px',
margin: '0 auto',
backgroundColor: 'white',
borderRadius: '12px',
boxShadow: '0 2px 8px rgba(0,0,0,0.1)',
overflow: 'hidden',
}}>
{/* Header */}
<div style={{
backgroundColor: '#1976d2',
color: 'white',
padding: '24px',
borderBottom: '1px solid #e0e0e0',
}}>
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center' }}>
<div>
<h1 style={{ margin: 0, fontSize: '24px' }}>Research Wizard</h1>
<p style={{ margin: '8px 0 0 0', fontSize: '14px', opacity: 0.9 }}>
Step {wizard.state.currentStep} of {wizard.maxSteps}
</p>
</div>
{onCancel && (
<button
onClick={onCancel}
style={{
padding: '8px 16px',
backgroundColor: 'rgba(255,255,255,0.2)',
color: 'white',
border: '1px solid rgba(255,255,255,0.3)',
borderRadius: '6px',
cursor: 'pointer',
fontSize: '14px',
}}
>
Cancel
</button>
)}
</div>
</div>
{/* Progress Bar */}
<div style={{
backgroundColor: '#f0f0f0',
height: '4px',
position: 'relative',
}}>
<div
style={{
backgroundColor: '#1976d2',
height: '100%',
width: `${(wizard.state.currentStep / wizard.maxSteps) * 100}%`,
transition: 'width 0.3s ease',
}}
/>
</div>
{/* Step Indicators */}
<div style={{
display: 'flex',
justifyContent: 'space-around',
padding: '20px 40px',
borderBottom: '1px solid #e0e0e0',
}}>
{[1, 2, 3, 4].map(step => (
<div key={step} style={{ display: 'flex', flexDirection: 'column', alignItems: 'center', position: 'relative' }}>
<div style={{
width: '40px',
height: '40px',
borderRadius: '50%',
backgroundColor: step <= wizard.state.currentStep ? '#1976d2' : '#e0e0e0',
color: step <= wizard.state.currentStep ? 'white' : '#999',
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
fontWeight: 'bold',
fontSize: '16px',
marginBottom: '8px',
transition: 'all 0.3s ease',
}}>
{step < wizard.state.currentStep ? '✓' : step}
</div>
<span style={{
fontSize: '12px',
color: step <= wizard.state.currentStep ? '#1976d2' : '#999',
fontWeight: step === wizard.state.currentStep ? '600' : 'normal',
}}>
{step === 1 && 'Setup'}
{step === 2 && 'Options'}
{step === 3 && 'Research'}
{step === 4 && 'Results'}
</span>
</div>
))}
</div>
{/* Content */}
<div style={{ padding: '24px' }}>
{renderStep()}
</div>
{/* Navigation Footer */}
{wizard.state.currentStep <= 2 && (
<div style={{
padding: '20px 24px',
borderTop: '1px solid #e0e0e0',
display: 'flex',
justifyContent: 'space-between',
alignItems: 'center',
backgroundColor: '#fafafa',
}}>
<button
onClick={wizard.prevStep}
disabled={wizard.isFirstStep}
style={{
padding: '10px 20px',
backgroundColor: wizard.isFirstStep ? '#f0f0f0' : 'white',
color: wizard.isFirstStep ? '#999' : '#333',
border: wizard.isFirstStep ? '1px solid #e0e0e0' : '1px solid #ddd',
borderRadius: '6px',
cursor: wizard.isFirstStep ? 'not-allowed' : 'pointer',
fontSize: '14px',
}}
>
Back
</button>
<button
onClick={wizard.nextStep}
disabled={!wizard.canGoNext()}
style={{
padding: '10px 24px',
backgroundColor: wizard.canGoNext() ? '#1976d2' : '#e0e0e0',
color: wizard.canGoNext() ? 'white' : '#999',
border: 'none',
borderRadius: '6px',
cursor: wizard.canGoNext() ? 'pointer' : 'not-allowed',
fontSize: '14px',
fontWeight: '600',
}}
>
{wizard.isLastStep ? 'Finish' : 'Next →'}
</button>
</div>
)}
</div>
</div>
);
};
export default ResearchWizard;

View File

@@ -0,0 +1,82 @@
import { useState, useCallback } from 'react';
import { blogWriterApi, BlogResearchRequest, BlogResearchResponse } from '../../../services/blogWriterApi';
import { useResearchPolling } from '../../../hooks/usePolling';
import { researchCache } from '../../../services/researchCache';
import { WizardState } from '../types/research.types';
export const useResearchExecution = () => {
const [isExecuting, setIsExecuting] = useState(false);
const [error, setError] = useState<string | null>(null);
const polling = useResearchPolling({
onComplete: (result) => {
if (result && result.keywords) {
researchCache.cacheResult(
result.keywords,
'General',
'General',
result
);
}
setIsExecuting(false);
},
onError: (error) => {
console.error('Research polling error:', error);
setError(error);
setIsExecuting(false);
}
});
const executeResearch = useCallback(async (state: WizardState): Promise<string | null> => {
setIsExecuting(true);
setError(null);
try {
// Check cache first
const cachedResult = researchCache.getCachedResult(
state.keywords,
state.industry,
state.targetAudience
);
if (cachedResult) {
setIsExecuting(false);
return 'cached';
}
const payload: BlogResearchRequest = {
keywords: state.keywords,
industry: state.industry,
target_audience: state.targetAudience,
research_mode: state.researchMode,
config: state.config,
};
const { task_id } = await blogWriterApi.startResearch(payload);
polling.startPolling(task_id);
return task_id;
} catch (err) {
const errorMessage = err instanceof Error ? err.message : 'Unknown error';
setError(errorMessage);
setIsExecuting(false);
return null;
}
}, [polling]);
const stopExecution = useCallback(() => {
polling.stopPolling();
setIsExecuting(false);
setError(null);
}, [polling]);
return {
executeResearch,
stopExecution,
isExecuting,
error,
progressMessages: polling.progressMessages,
currentStatus: polling.currentStatus,
result: polling.result,
};
};

View File

@@ -0,0 +1,116 @@
import { useState, useCallback, useEffect } from 'react';
import { WizardState, WizardStepProps } from '../types/research.types';
import { ResearchMode, ResearchConfig, BlogResearchResponse } from '../../../services/blogWriterApi';
const WIZARD_STATE_KEY = 'alwrity_research_wizard_state';
const MAX_STEPS = 4;
const defaultState: WizardState = {
currentStep: 1,
keywords: [],
industry: 'General',
targetAudience: 'General',
researchMode: 'basic' as ResearchMode,
config: {
mode: 'basic',
provider: 'google',
max_sources: 10,
include_statistics: true,
include_expert_quotes: true,
include_competitors: true,
include_trends: true,
},
results: null,
};
export const useResearchWizard = (initialKeywords?: string[], initialIndustry?: string) => {
const [state, setState] = useState<WizardState>(() => {
// Try to load from localStorage first
const saved = localStorage.getItem(WIZARD_STATE_KEY);
if (saved) {
try {
const parsed = JSON.parse(saved);
return parsed;
} catch (e) {
console.warn('Failed to parse saved wizard state', e);
}
}
// Use defaults or initial values
return {
...defaultState,
keywords: initialKeywords || [],
industry: initialIndustry || defaultState.industry,
};
});
// Persist state to localStorage
useEffect(() => {
if (state.currentStep > 1) {
localStorage.setItem(WIZARD_STATE_KEY, JSON.stringify(state));
}
}, [state]);
const updateState = useCallback((updates: Partial<WizardState>) => {
setState(prev => ({ ...prev, ...updates }));
}, []);
const nextStep = useCallback(() => {
setState(prev => {
if (prev.currentStep >= MAX_STEPS) return prev;
return { ...prev, currentStep: prev.currentStep + 1 };
});
}, []);
const prevStep = useCallback(() => {
setState(prev => {
if (prev.currentStep <= 1) return prev;
return { ...prev, currentStep: prev.currentStep - 1 };
});
}, []);
const reset = useCallback(() => {
const resetState = {
...defaultState,
keywords: initialKeywords || [],
industry: initialIndustry || defaultState.industry,
};
setState(resetState);
localStorage.removeItem(WIZARD_STATE_KEY);
}, [initialKeywords, initialIndustry]);
const clearResults = useCallback(() => {
setState(prev => ({ ...prev, results: null }));
}, []);
const canGoNext = useCallback((): boolean => {
switch (state.currentStep) {
case 1:
return state.keywords.length > 0 && state.keywords.every(k => k.trim().length > 0);
case 2:
return true; // Mode selection always allowed
case 3:
return false; // Progress can't be skipped
case 4:
return false; // Results can't be skipped
default:
return false;
}
}, [state]);
return {
state,
updateState,
nextStep,
prevStep,
reset,
clearResults,
canGoNext,
isFirstStep: state.currentStep === 1,
isLastStep: state.currentStep === MAX_STEPS,
maxSteps: MAX_STEPS,
};
};
export type UseResearchWizardReturn = ReturnType<typeof useResearchWizard>;

View File

@@ -0,0 +1,5 @@
export { ResearchWizard } from './ResearchWizard';
export { useResearchWizard } from './hooks/useResearchWizard';
export { useResearchExecution } from './hooks/useResearchExecution';
export * from './types/research.types';

View File

@@ -0,0 +1,72 @@
/**
* Blog Writer Integration Adapter for Research Component
*
* This adapter provides a simple way to integrate the ResearchWizard
* into the BlogWriter's research phase.
*/
import React from 'react';
import { ResearchWizard } from '../ResearchWizard';
import { BlogResearchResponse } from '../../../services/blogWriterApi';
interface BlogWriterResearchAdapterProps {
onResearchComplete: (research: BlogResearchResponse) => void;
onCancel?: () => void;
initialKeywords?: string[];
initialIndustry?: string;
}
/**
* Adapter component that wraps ResearchWizard for BlogWriter integration.
* Provides a clean interface for switching between CopilotKit and wizard-based research.
*/
export const BlogWriterResearchAdapter: React.FC<BlogWriterResearchAdapterProps> = ({
onResearchComplete,
onCancel,
initialKeywords,
initialIndustry,
}) => {
return (
<div style={{
height: '100%',
display: 'flex',
flexDirection: 'column',
backgroundColor: 'white',
}}>
<ResearchWizard
onComplete={onResearchComplete}
onCancel={onCancel}
initialKeywords={initialKeywords}
initialIndustry={initialIndustry}
/>
</div>
);
};
export default BlogWriterResearchAdapter;
/**
* USAGE EXAMPLE:
*
* In BlogWriter.tsx, replace the research phase content with:
*
* {currentPhase === 'research' && !research && (
* <BlogWriterResearchAdapter
* onResearchComplete={(res) => {
* handleResearchComplete(res);
* // Optionally auto-advance to outline phase
* navigateToPhase('outline');
* }}
* onCancel={() => {
* // Navigate back to dashboard
* navigateToPhase('research');
* }}
* initialKeywords={[]}
* initialIndustry="General"
* />
* )}
*
* Note: This maintains backward compatibility. The existing CopilotKit/manual
* research flow continues to work. This provides an alternative UI option.
*/

View File

@@ -0,0 +1,133 @@
import React, { useEffect } from 'react';
import { WizardStepProps } from '../types/research.types';
const industries = [
'General',
'Technology',
'Business',
'Marketing',
'Finance',
'Healthcare',
'Education',
'Real Estate',
'Entertainment',
'Food & Beverage',
'Travel',
'Fashion',
'Sports',
'Science',
'Law',
'Other',
];
export const StepKeyword: React.FC<WizardStepProps> = ({ state, onUpdate }) => {
const handleKeywordsChange = (e: React.ChangeEvent<HTMLTextAreaElement>) => {
const value = e.target.value;
const keywords = value.split(',').map(k => k.trim()).filter(Boolean);
onUpdate({ keywords });
};
const handleIndustryChange = (e: React.ChangeEvent<HTMLSelectElement>) => {
onUpdate({ industry: e.target.value });
};
const handleAudienceChange = (e: React.ChangeEvent<HTMLInputElement>) => {
onUpdate({ targetAudience: e.target.value });
};
const keywordText = state.keywords.join(', ');
return (
<div style={{ padding: '24px', maxWidth: '800px', margin: '0 auto' }}>
<h2 style={{ marginBottom: '8px', color: '#333' }}>🔍 Research Setup</h2>
<p style={{ marginBottom: '24px', color: '#666', fontSize: '15px' }}>
Enter your keywords, industry, and target audience to start research.
</p>
{/* Keywords Input */}
<div style={{ marginBottom: '20px' }}>
<label style={{ display: 'block', marginBottom: '8px', fontWeight: '600', color: '#555' }}>
Keywords *
</label>
<textarea
value={keywordText}
onChange={handleKeywordsChange}
placeholder="e.g., AI in marketing, automation tools, customer engagement"
rows={4}
style={{
width: '100%',
padding: '12px',
border: '1px solid #ddd',
borderRadius: '8px',
fontSize: '14px',
fontFamily: 'inherit',
resize: 'vertical',
boxSizing: 'border-box',
}}
/>
<p style={{ marginTop: '4px', fontSize: '12px', color: '#888' }}>
Separate multiple keywords with commas
</p>
</div>
{/* Industry Selection */}
<div style={{ marginBottom: '20px' }}>
<label style={{ display: 'block', marginBottom: '8px', fontWeight: '600', color: '#555' }}>
Industry
</label>
<select
value={state.industry}
onChange={handleIndustryChange}
style={{
width: '100%',
padding: '12px',
border: '1px solid #ddd',
borderRadius: '8px',
fontSize: '14px',
fontFamily: 'inherit',
backgroundColor: 'white',
cursor: 'pointer',
}}
>
{industries.map(ind => (
<option key={ind} value={ind}>{ind}</option>
))}
</select>
</div>
{/* Target Audience */}
<div style={{ marginBottom: '20px' }}>
<label style={{ display: 'block', marginBottom: '8px', fontWeight: '600', color: '#555' }}>
Target Audience
</label>
<input
type="text"
value={state.targetAudience}
onChange={handleAudienceChange}
placeholder="e.g., Digital marketers, Small business owners"
style={{
width: '100%',
padding: '12px',
border: '1px solid #ddd',
borderRadius: '8px',
fontSize: '14px',
fontFamily: 'inherit',
boxSizing: 'border-box',
}}
/>
</div>
<div style={{
padding: '12px',
backgroundColor: '#f0f7ff',
borderRadius: '8px',
border: '1px solid #b3d9ff',
fontSize: '13px',
color: '#004085',
}}>
💡 <strong>Tip:</strong> Be specific with your keywords. The more precise your keywords, the better your research results.
</div>
</div>
);
};

View File

@@ -0,0 +1,182 @@
import React from 'react';
import { WizardStepProps, ModeCardInfo } from '../types/research.types';
import { ResearchProvider } from '../../../services/blogWriterApi';
const modeCards: ModeCardInfo[] = [
{
mode: 'basic',
title: 'Basic Research',
description: 'Quick keyword-focused analysis for fast results',
features: [
'Primary & secondary keywords',
'Current trends overview',
'Top 5 content angles',
'Key statistics',
],
icon: '⚡',
},
{
mode: 'comprehensive',
title: 'Comprehensive Research',
description: 'Deep analysis with full competitive intelligence',
features: [
'All basic features',
'Expert quotes & opinions',
'Competitor analysis',
'Market forecasts',
'Best practices & case studies',
'Content gaps identification',
],
icon: '📊',
},
{
mode: 'targeted',
title: 'Targeted Research',
description: 'Customize what you need most',
features: [
'Select specific components',
'Choose date ranges',
'Filter source types',
'Control depth & scope',
],
icon: '🎯',
},
];
export const StepOptions: React.FC<WizardStepProps> = ({ state, onUpdate }) => {
const handleModeChange = (mode: any) => {
onUpdate({ researchMode: mode });
};
const handleProviderChange = (provider: ResearchProvider) => {
onUpdate({ config: { ...state.config, provider } });
};
return (
<div style={{ padding: '24px', maxWidth: '1000px', margin: '0 auto' }}>
<h2 style={{ marginBottom: '8px', color: '#333' }}>Choose Research Mode</h2>
<p style={{ marginBottom: '24px', color: '#666', fontSize: '15px' }}>
Select the type of research that best fits your needs.
</p>
<div style={{
display: 'grid',
gridTemplateColumns: 'repeat(auto-fit, minmax(280px, 1fr))',
gap: '20px',
marginBottom: '24px',
}}>
{modeCards.map(card => (
<div
key={card.mode}
onClick={() => handleModeChange(card.mode)}
style={{
border: state.researchMode === card.mode ? '2px solid #1976d2' : '2px solid #e0e0e0',
borderRadius: '12px',
padding: '24px',
cursor: 'pointer',
transition: 'all 0.2s ease',
backgroundColor: state.researchMode === card.mode ? '#f0f7ff' : 'white',
}}
onMouseEnter={(e) => {
if (state.researchMode !== card.mode) {
e.currentTarget.style.borderColor = '#90caf9';
e.currentTarget.style.backgroundColor = '#fafafa';
}
}}
onMouseLeave={(e) => {
if (state.researchMode !== card.mode) {
e.currentTarget.style.borderColor = '#e0e0e0';
e.currentTarget.style.backgroundColor = 'white';
}
}}
>
<div style={{ display: 'flex', alignItems: 'center', marginBottom: '12px' }}>
<span style={{ fontSize: '32px', marginRight: '12px' }}>{card.icon}</span>
<h3 style={{ margin: 0, color: '#333', fontSize: '18px' }}>{card.title}</h3>
</div>
<p style={{ marginBottom: '16px', color: '#666', fontSize: '14px' }}>
{card.description}
</p>
<ul style={{ margin: 0, paddingLeft: '20px', fontSize: '13px', color: '#555' }}>
{card.features.map((feature, idx) => (
<li key={idx} style={{ marginBottom: '4px' }}>{feature}</li>
))}
</ul>
{state.researchMode === card.mode && (
<div style={{
marginTop: '16px',
padding: '8px',
backgroundColor: '#1976d2',
color: 'white',
borderRadius: '6px',
textAlign: 'center',
fontSize: '13px',
fontWeight: '600',
}}>
Selected
</div>
)}
</div>
))}
</div>
{state.researchMode !== 'basic' && (
<div style={{ marginBottom: '24px', border: '1px solid #e0e0e0', borderRadius: '8px', padding: '15px', backgroundColor: '#f9f9f9' }}>
<h3 style={{ marginTop: 0, marginBottom: '12px', color: '#555', fontSize: '16px' }}>
🔍 Research Provider
</h3>
<div style={{ display: 'flex', gap: '12px' }}>
<div
onClick={() => handleProviderChange('google')}
style={{
flex: 1,
padding: '12px',
border: '2px solid',
borderColor: (state.config.provider === 'google' || !state.config.provider) ? '#1976d2' : '#ddd',
backgroundColor: (state.config.provider === 'google' || !state.config.provider) ? '#e3f2fd' : 'white',
borderRadius: '6px',
cursor: 'pointer',
transition: 'all 0.2s',
}}
>
<div style={{ fontWeight: '600', marginBottom: '4px' }}>Google Search</div>
<div style={{ fontSize: '11px', color: '#666' }}>
Fast, broad coverage, trending topics
</div>
</div>
<div
onClick={() => handleProviderChange('exa')}
style={{
flex: 1,
padding: '12px',
border: '2px solid',
borderColor: state.config.provider === 'exa' ? '#7c3aed' : '#ddd',
backgroundColor: state.config.provider === 'exa' ? '#f3e8ff' : 'white',
borderRadius: '6px',
cursor: 'pointer',
transition: 'all 0.2s',
}}
>
<div style={{ fontWeight: '600', marginBottom: '4px' }}>Exa Neural</div>
<div style={{ fontSize: '11px', color: '#666' }}>
Deep research, rich citations, semantic search
</div>
</div>
</div>
</div>
)}
<div style={{
padding: '12px',
backgroundColor: '#fff3e0',
borderRadius: '8px',
border: '1px solid #ffb74d',
fontSize: '13px',
color: '#e65100',
}}>
<strong>Note:</strong> You can always run additional research if you need more information later.
</div>
</div>
);
};

View File

@@ -0,0 +1,153 @@
import React, { useEffect } from 'react';
import { WizardStepProps } from '../types/research.types';
import { useResearchExecution } from '../hooks/useResearchExecution';
export const StepProgress: React.FC<WizardStepProps> = ({ state, onNext, onUpdate }) => {
const { executeResearch, stopExecution, isExecuting, error, progressMessages, currentStatus } = useResearchExecution();
useEffect(() => {
// Start research when this step is reached
const startResearch = async () => {
const taskId = await executeResearch(state);
if (taskId === 'cached') {
// If cached, move to results immediately
// The parent will handle this
}
};
startResearch();
return () => {
if (isExecuting) {
stopExecution();
}
};
}, []); // Run once on mount
// Move to next step when research completes
useEffect(() => {
if (!isExecuting && progressMessages.length > 0) {
// Small delay to show final message
const timer = setTimeout(() => {
onNext();
}, 1000);
return () => clearTimeout(timer);
}
}, [isExecuting, progressMessages.length, onNext]);
const getStatusIcon = () => {
if (error) return '❌';
if (!isExecuting && progressMessages.length > 0) return '✅';
if (currentStatus === 'completed') return '✅';
return '🔄';
};
const getStatusColor = () => {
if (error) return '#f44336';
if (!isExecuting && progressMessages.length > 0) return '#4caf50';
return '#1976d2';
};
return (
<div style={{ padding: '24px', maxWidth: '800px', margin: '0 auto' }}>
<h2 style={{ marginBottom: '8px', color: '#333' }}>Researching...</h2>
<p style={{ marginBottom: '24px', color: '#666', fontSize: '15px' }}>
Gathering insights from Google Search grounding
</p>
{/* Status Display */}
<div style={{
backgroundColor: '#f5f5f5',
borderRadius: '12px',
padding: '24px',
marginBottom: '24px',
textAlign: 'center',
}}>
<div style={{ fontSize: '48px', marginBottom: '16px' }}>{getStatusIcon()}</div>
{error ? (
<>
<h3 style={{ color: getStatusColor(), marginBottom: '8px' }}>Error</h3>
<p style={{ color: '#666', fontSize: '14px' }}>{error}</p>
<button
onClick={() => window.location.reload()}
style={{
marginTop: '16px',
padding: '8px 16px',
backgroundColor: '#1976d2',
color: 'white',
border: 'none',
borderRadius: '6px',
cursor: 'pointer',
fontSize: '14px',
}}
>
Retry
</button>
</>
) : (
<>
<h3 style={{ color: getStatusColor(), marginBottom: '8px' }}>
{currentStatus === 'completed' ? 'Complete!' : 'In Progress'}
</h3>
<p style={{ color: '#666', fontSize: '14px' }}>
{isExecuting ? 'Analyzing sources and generating insights...' : 'Finalizing results...'}
</p>
</>
)}
</div>
{/* Progress Messages */}
{progressMessages.length > 0 && (
<div style={{
backgroundColor: 'white',
border: '1px solid #e0e0e0',
borderRadius: '8px',
maxHeight: '300px',
overflow: 'auto',
}}>
<div style={{ padding: '16px', borderBottom: '1px solid #e0e0e0' }}>
<strong style={{ fontSize: '14px', color: '#333' }}>Progress Updates</strong>
</div>
{progressMessages.map((msg, idx) => (
<div
key={idx}
style={{
padding: '12px 16px',
borderBottom: idx < progressMessages.length - 1 ? '1px solid #f0f0f0' : 'none',
fontSize: '13px',
color: '#555',
}}
>
{idx === progressMessages.length - 1 && isExecuting && (
<span style={{ marginRight: '8px' }}>🔄</span>
)}
{msg.message}
</div>
))}
</div>
)}
{/* Cancel Button */}
{isExecuting && (
<div style={{ marginTop: '24px', textAlign: 'center' }}>
<button
onClick={stopExecution}
style={{
padding: '8px 16px',
backgroundColor: '#f44336',
color: 'white',
border: 'none',
borderRadius: '6px',
cursor: 'pointer',
fontSize: '14px',
}}
>
Cancel Research
</button>
</div>
)}
</div>
);
};

View File

@@ -0,0 +1,103 @@
import React from 'react';
import { WizardStepProps } from '../types/research.types';
import { ResearchResults } from '../../BlogWriter/ResearchResults';
import { BlogResearchResponse } from '../../../services/blogWriterApi';
export const StepResults: React.FC<WizardStepProps> = ({ state, onBack }) => {
if (!state.results) {
return (
<div style={{ padding: '24px', textAlign: 'center' }}>
<p style={{ color: '#666' }}>No results available</p>
</div>
);
}
const handleExport = () => {
const dataStr = JSON.stringify(state.results, null, 2);
const dataBlob = new Blob([dataStr], { type: 'application/json' });
const url = URL.createObjectURL(dataBlob);
const link = document.createElement('a');
link.href = url;
link.download = `research-${state.keywords.join('-')}-${Date.now()}.json`;
link.click();
URL.revokeObjectURL(url);
};
return (
<div style={{ maxWidth: '1200px', margin: '0 auto', padding: '24px' }}>
<div style={{
display: 'flex',
justifyContent: 'space-between',
alignItems: 'center',
marginBottom: '24px',
flexWrap: 'wrap',
gap: '16px',
}}>
<h2 style={{ margin: 0, color: '#333' }}>Research Results</h2>
<div style={{ display: 'flex', gap: '8px', flexWrap: 'wrap' }}>
<button
onClick={handleExport}
style={{
padding: '8px 16px',
backgroundColor: '#1976d2',
color: 'white',
border: 'none',
borderRadius: '6px',
cursor: 'pointer',
fontSize: '14px',
display: 'flex',
alignItems: 'center',
gap: '6px',
}}
>
📥 Export JSON
</button>
<button
onClick={onBack}
style={{
padding: '8px 16px',
backgroundColor: '#f5f5f5',
color: '#333',
border: '1px solid #ddd',
borderRadius: '6px',
cursor: 'pointer',
fontSize: '14px',
}}
>
Start New Research
</button>
</div>
</div>
{/* Results Display */}
<div style={{
backgroundColor: 'white',
borderRadius: '8px',
border: '1px solid #e0e0e0',
overflow: 'hidden',
}}>
<ResearchResults research={state.results} />
</div>
{/* Action Section */}
<div style={{
marginTop: '24px',
padding: '16px',
backgroundColor: '#f0f7ff',
borderRadius: '8px',
border: '1px solid #b3d9ff',
}}>
<h4 style={{ marginBottom: '8px', color: '#004085' }}>Next Steps</h4>
<ul style={{ margin: 0, paddingLeft: '20px', color: '#004085', fontSize: '14px' }}>
<li>Review the research insights and sources</li>
<li>Explore content angles and competitor analysis</li>
<li>Use this research to create your blog outline</li>
<li>Export the data for reference</li>
</ul>
</div>
</div>
);
};

View File

@@ -0,0 +1,34 @@
import { BlogResearchResponse, ResearchMode, ResearchConfig } from '../../../services/blogWriterApi';
export interface WizardState {
currentStep: number;
keywords: string[];
industry: string;
targetAudience: string;
researchMode: ResearchMode;
config: ResearchConfig;
results: BlogResearchResponse | null;
}
export interface WizardStepProps {
state: WizardState;
onUpdate: (updates: Partial<WizardState>) => void;
onNext: () => void;
onBack: () => void;
}
export interface ResearchWizardProps {
onComplete?: (results: BlogResearchResponse) => void;
onCancel?: () => void;
initialKeywords?: string[];
initialIndustry?: string;
}
export interface ModeCardInfo {
mode: ResearchMode;
title: string;
description: string;
features: string[];
icon: string;
}

View File

@@ -0,0 +1,17 @@
// Utility functions for research component
export const formatKeywords = (keywords: string[]): string => {
return keywords.join(', ');
};
export const parseKeywords = (keywordsString: string): string[] => {
return keywordsString
.split(',')
.map(k => k.trim())
.filter(Boolean);
};
export const validateKeywords = (keywords: string[]): boolean => {
return keywords.length > 0 && keywords.every(k => k.trim().length > 0);
};

View File

@@ -41,12 +41,20 @@ const SubscriptionExpiredModal: React.FC<SubscriptionExpiredModalProps> = ({
}) => { }) => {
// Debug logging to verify modal state // Debug logging to verify modal state
React.useEffect(() => { React.useEffect(() => {
console.log('SubscriptionExpiredModal: State update', {
open,
errorData,
hasUsageInfo: !!errorData?.usage_info,
errorDataKeys: errorData ? Object.keys(errorData) : null
});
if (open) { if (open) {
console.log('SubscriptionExpiredModal: Modal opened', { console.log('SubscriptionExpiredModal: Modal should be visible now', {
open, open,
errorData, errorData,
hasUsageInfo: !!errorData?.usage_info hasUsageInfo: !!errorData?.usage_info
}); });
} else {
console.log('SubscriptionExpiredModal: Modal is closed');
} }
}, [open, errorData]); }, [open, errorData]);

View File

@@ -57,8 +57,18 @@ const WixCallbackPage: React.FC = () => {
return; return;
} }
} catch {} } catch {}
// Fallback redirect for same-tab flow and let onboarding hook mark Wix as connected // Fallback redirect for same-tab flow - check if we have a stored redirect URL
window.location.replace('/onboarding?step=5&wix_connected=true'); const redirectUrl = sessionStorage.getItem('wix_oauth_redirect');
console.log('[Wix Callback] Checking redirect URL:', redirectUrl);
if (redirectUrl) {
console.log('[Wix Callback] Redirecting to stored URL:', redirectUrl);
sessionStorage.removeItem('wix_oauth_redirect');
window.location.replace(redirectUrl);
} else {
// Default to onboarding if no redirect URL stored
console.warn('[Wix Callback] No redirect URL found, defaulting to onboarding');
window.location.replace('/onboarding?step=5&wix_connected=true');
}
} catch (e: any) { } catch (e: any) {
setError(e?.message || 'OAuth callback failed'); setError(e?.message || 'OAuth callback failed');
try { try {

View File

@@ -0,0 +1,74 @@
import React, { useState } from 'react';
import { Alert, AlertTitle, IconButton, Collapse } from '@mui/material';
import { Close as CloseIcon, Warning as WarningIcon } from '@mui/icons-material';
import { useCopilotKitHealth } from '../../hooks/useCopilotKitHealth';
interface CopilotKitDegradedBannerProps {
/**
* Position of the banner
* @default 'top'
*/
position?: 'top' | 'bottom';
/**
* Whether the banner is dismissible
* @default true
*/
dismissible?: boolean;
}
/**
* Banner component that displays when CopilotKit is unavailable
* Non-intrusive notification that chat is unavailable but app continues to work
*/
export const CopilotKitDegradedBanner: React.FC<CopilotKitDegradedBannerProps> = ({
position = 'top',
dismissible = true,
}) => {
const { isAvailable, errorMessage, isChecking } = useCopilotKitHealth();
const [dismissed, setDismissed] = useState(false);
// Don't show if CopilotKit is available, checking, or dismissed
if (isAvailable || isChecking || dismissed) {
return null;
}
const handleDismiss = () => {
setDismissed(true);
};
return (
<Collapse in={!dismissed}>
<Alert
severity="warning"
icon={<WarningIcon />}
action={
dismissible ? (
<IconButton
aria-label="close"
color="inherit"
size="small"
onClick={handleDismiss}
>
<CloseIcon fontSize="inherit" />
</IconButton>
) : null
}
sx={{
position: 'fixed',
[position]: 0,
left: 0,
right: 0,
zIndex: 1300, // Above most content but below modals
borderRadius: 0, // Full width banner
boxShadow: 2,
}}
>
<AlertTitle>Chat Unavailable</AlertTitle>
{errorMessage || 'CopilotKit service is currently unavailable. You can still use all features with manual controls.'}
</Alert>
</Collapse>
);
};
export default CopilotKitDegradedBanner;

View File

@@ -0,0 +1,157 @@
import React, { createContext, useContext, useState, useCallback, ReactNode } from 'react';
interface CopilotKitHealthState {
isHealthy: boolean;
isChecking: boolean;
lastChecked: Date | null;
errorMessage: string | null;
retryCount: number;
isAvailable: boolean; // Alias for isHealthy, for clearer semantics
}
interface CopilotKitHealthContextType extends CopilotKitHealthState {
checkHealth: () => Promise<void>;
markUnhealthy: (errorMessage?: string) => void;
markHealthy: () => void;
resetHealth: () => void;
}
const CopilotKitHealthContext = createContext<CopilotKitHealthContextType | undefined>(undefined);
export const useCopilotKitHealthContext = () => {
const context = useContext(CopilotKitHealthContext);
if (!context) {
throw new Error('useCopilotKitHealthContext must be used within CopilotKitHealthProvider');
}
return context;
};
interface CopilotKitHealthProviderProps {
children: ReactNode;
initialHealthStatus?: boolean;
}
export const CopilotKitHealthProvider: React.FC<CopilotKitHealthProviderProps> = ({
children,
initialHealthStatus = true,
}) => {
const [state, setState] = useState<CopilotKitHealthState>({
isHealthy: initialHealthStatus,
isChecking: false,
lastChecked: null,
errorMessage: null,
retryCount: 0,
isAvailable: initialHealthStatus,
});
const markHealthy = useCallback(() => {
setState((prev) => ({
...prev,
isHealthy: true,
isAvailable: true,
errorMessage: null,
retryCount: 0,
lastChecked: new Date(),
}));
}, []);
const markUnhealthy = useCallback((errorMessage?: string) => {
setState((prev) => ({
...prev,
isHealthy: false,
isAvailable: false,
errorMessage: errorMessage || 'CopilotKit is unavailable',
lastChecked: new Date(),
retryCount: prev.retryCount + 1,
}));
}, []);
// Listen for CopilotKit error events from App.tsx
React.useEffect(() => {
const handleCopilotKitError = (event: Event) => {
const customEvent = event as CustomEvent;
const { errorMessage, isFatal } = customEvent.detail || {};
if (isFatal) {
markUnhealthy(errorMessage || 'CopilotKit fatal error');
} else {
// For transient errors, just log but don't mark as unhealthy immediately
// Let the health check determine if it's truly down
console.warn('CopilotKit transient error:', errorMessage);
}
};
window.addEventListener('copilotkit-error', handleCopilotKitError as EventListener);
return () => {
window.removeEventListener('copilotkit-error', handleCopilotKitError as EventListener);
};
}, [markUnhealthy]);
const checkHealth = useCallback(async () => {
setState((prev) => ({ ...prev, isChecking: true }));
try {
// Try to check CopilotKit status endpoint
// This is a lightweight check that doesn't require full CopilotKit initialization
const response = await fetch('https://api.cloud.copilotkit.ai/ciu', {
method: 'GET',
headers: {
'x-copilotcloud-public-api-key': process.env.REACT_APP_COPILOTKIT_PUBLIC_API_KEY || '',
},
// Use a short timeout to avoid blocking
signal: AbortSignal.timeout(3000),
});
if (response.ok) {
markHealthy();
} else {
markUnhealthy(`CopilotKit status check failed: ${response.status}`);
}
} catch (error: any) {
// Handle various error types
let errorMsg = 'CopilotKit health check failed';
if (error.name === 'AbortError' || error.name === 'TimeoutError') {
errorMsg = 'CopilotKit health check timed out';
} else if (error.message?.includes('CORS')) {
errorMsg = 'CopilotKit CORS error - service may be unavailable';
} else if (error.message?.includes('certificate') || error.message?.includes('SSL')) {
errorMsg = 'CopilotKit SSL certificate error';
} else if (error.message?.includes('network') || error.message?.includes('Failed to fetch')) {
errorMsg = 'CopilotKit network error - service may be down';
} else {
errorMsg = error.message || 'Unknown error checking CopilotKit health';
}
markUnhealthy(errorMsg);
} finally {
setState((prev) => ({ ...prev, isChecking: false }));
}
}, [markHealthy, markUnhealthy]);
const resetHealth = useCallback(() => {
setState({
isHealthy: initialHealthStatus,
isChecking: false,
lastChecked: null,
errorMessage: null,
retryCount: 0,
isAvailable: initialHealthStatus,
});
}, [initialHealthStatus]);
const value: CopilotKitHealthContextType = {
...state,
checkHealth,
markUnhealthy,
markHealthy,
resetHealth,
};
return (
<CopilotKitHealthContext.Provider value={value}>
{children}
</CopilotKitHealthContext.Provider>
);
};

View File

@@ -1,6 +1,7 @@
import React, { createContext, useContext, useState, useEffect, ReactNode, useCallback, useRef } from 'react'; import React, { createContext, useContext, useState, useEffect, ReactNode, useCallback, useRef } from 'react';
import { apiClient, setGlobalSubscriptionErrorHandler } from '../api/client'; import { apiClient, setGlobalSubscriptionErrorHandler } from '../api/client';
import SubscriptionExpiredModal from '../components/SubscriptionExpiredModal'; import SubscriptionExpiredModal from '../components/SubscriptionExpiredModal';
import { saveNavigationState, getCurrentPhaseForTool } from '../utils/navigationState';
export interface SubscriptionLimits { export interface SubscriptionLimits {
gemini_calls: number; gemini_calls: number;
@@ -221,11 +222,29 @@ export const SubscriptionProvider: React.FC<SubscriptionProviderProps> = ({ chil
}, []); }, []);
const handleRenewSubscription = useCallback(() => { const handleRenewSubscription = useCallback(() => {
// Save current location so we can return after renewal // Save current location and phase so we can return after renewal
const currentPath = window.location.pathname; const currentPath = window.location.pathname;
sessionStorage.setItem('subscription_referrer', currentPath);
console.log('SubscriptionContext: Navigating to pricing page, saved referrer:', currentPath); // Detect tool from path
let tool: string | undefined;
if (currentPath.includes('/blog-writer') || currentPath.includes('/blogwriter')) {
tool = 'blog-writer';
}
// Get current phase for the tool if applicable
let phase: string | null = null;
if (tool) {
phase = getCurrentPhaseForTool(tool);
}
// Save navigation state (path, phase, tool)
saveNavigationState(currentPath, phase || undefined, tool);
console.log('SubscriptionContext: Navigating to pricing page, saved navigation state:', {
path: currentPath,
phase,
tool
});
window.location.href = '/pricing'; window.location.href = '/pricing';
}, []); }, []);
@@ -258,13 +277,30 @@ export const SubscriptionProvider: React.FC<SubscriptionProviderProps> = ({ chil
errorData = errorData[0] || {}; errorData = errorData[0] || {};
} }
// Check for usage_info in various possible locations // CRITICAL: FastAPI wraps HTTPException detail in a 'detail' field
// If errorData has a 'detail' field, extract it (this is the actual error data)
if (errorData.detail && typeof errorData.detail === 'object') {
console.log('SubscriptionContext: Found FastAPI detail wrapper, extracting detail field');
errorData = errorData.detail;
}
// Check for usage_info in various possible locations (now that we've unwrapped FastAPI detail)
const usageInfo = errorData.usage_info || const usageInfo = errorData.usage_info ||
(errorData.current_calls !== undefined ? errorData : null) || (errorData.current_calls !== undefined ? errorData : null) ||
(errorData.requested_tokens !== undefined ? errorData : null) ||
(errorData.current_tokens !== undefined ? errorData : null) ||
null; null;
// Usage limit error: 429 status with usage info OR 429 status without explicit expiration // Usage limit error: 429 status with usage info OR provider OR message indicating token/call limits
const isUsageLimitError = status === 429 && (usageInfo || errorData.provider || errorData.message); const hasUsageIndicators = usageInfo ||
errorData.provider ||
errorData.message?.includes('limit') ||
errorData.error?.includes('limit') ||
errorData.requested_tokens !== undefined ||
errorData.current_tokens !== undefined ||
errorData.current_calls !== undefined;
const isUsageLimitError = status === 429 && hasUsageIndicators;
const isSubscriptionExpired = status === 402 || (status === 429 && !isUsageLimitError); const isSubscriptionExpired = status === 402 || (status === 429 && !isUsageLimitError);
console.log('SubscriptionContext: Error analysis', { console.log('SubscriptionContext: Error analysis', {
@@ -280,16 +316,30 @@ export const SubscriptionProvider: React.FC<SubscriptionProviderProps> = ({ chil
// For usage limit errors (429 with usage_info), always show modal - even for active subscriptions // For usage limit errors (429 with usage_info), always show modal - even for active subscriptions
// Ignore grace window and cooldown for usage limit errors (user needs to know immediately) // Ignore grace window and cooldown for usage limit errors (user needs to know immediately)
if (isUsageLimitError) { if (isUsageLimitError) {
// Build usage_info from various possible locations
const finalUsageInfo = usageInfo ||
(errorData.requested_tokens !== undefined ? {
provider: errorData.provider,
current_tokens: errorData.current_tokens,
requested_tokens: errorData.requested_tokens,
limit: errorData.limit,
type: 'tokens',
...errorData
} : null) ||
errorData;
const modalData = { const modalData = {
provider: errorData.provider || usageInfo?.provider || 'unknown', provider: errorData.provider || usageInfo?.provider || 'unknown',
usage_info: usageInfo || errorData, usage_info: finalUsageInfo || errorData,
message: errorData.message || errorData.error || 'You have reached your usage limit.' message: errorData.message || errorData.error || 'You have reached your usage limit.'
}; };
console.log('SubscriptionContext: Usage limit exceeded, showing modal (ignoring grace window/cooldown)', { console.log('SubscriptionContext: Usage limit exceeded, showing modal (ignoring grace window/cooldown)', {
modalData, modalData,
errorData: Object.keys(errorData), errorData: Object.keys(errorData),
usageInfo: usageInfo ? Object.keys(usageInfo) : null usageInfo: usageInfo ? Object.keys(usageInfo) : null,
currentShowModal: showModal,
currentModalErrorData: modalErrorData
}); });
// Set flag to mark this as a usage limit modal (should never be auto-closed) // Set flag to mark this as a usage limit modal (should never be auto-closed)
@@ -298,7 +348,17 @@ export const SubscriptionProvider: React.FC<SubscriptionProviderProps> = ({ chil
setShowModal(true); setShowModal(true);
setLastModalShowTime(now); setLastModalShowTime(now);
console.log('SubscriptionContext: Modal state updated - showModal should be true, isUsageLimitModal = true'); console.log('SubscriptionContext: Modal state updated - showModal should be true, isUsageLimitModal = true', {
showModal: true,
isUsageLimitModal: true,
modalErrorData: modalData
});
// Force a re-render check
setTimeout(() => {
console.log('SubscriptionContext: State check after timeout - showModal:', showModal, 'modalErrorData:', modalErrorData);
}, 100);
return true; return true;
} }

View File

@@ -116,10 +116,10 @@ export const useBlogWriterState = () => {
} }
} }
// Save to localStorage for persistence // Save to localStorage for persistence (using shared cache utility)
try { try {
localStorage.setItem('blog_outline', JSON.stringify(result.outline)); const { blogWriterCache } = require('../services/blogWriterCache');
localStorage.setItem('blog_title_options', JSON.stringify(result.title_options || [])); blogWriterCache.cacheOutline(result.outline, result.title_options);
localStorage.setItem('blog_selected_title', result.title_options?.[0] || ''); localStorage.setItem('blog_selected_title', result.title_options?.[0] || '');
console.log('Saved outline data to localStorage'); console.log('Saved outline data to localStorage');
} catch (error) { } catch (error) {

View File

@@ -0,0 +1,161 @@
import { useEffect, useRef, useCallback } from 'react';
import { useCopilotKitHealthContext } from '../contexts/CopilotKitHealthContext';
interface UseCopilotKitHealthOptions {
/**
* Initial delay before first health check (milliseconds)
* @default 1000
*/
initialDelay?: number;
/**
* Interval between health checks when healthy (milliseconds)
* @default 60000 (1 minute)
*/
healthyInterval?: number;
/**
* Exponential backoff intervals when unhealthy (milliseconds)
* @default [5000, 10000, 30000, 60000]
*/
unhealthyIntervals?: number[];
/**
* Enable automatic health checking
* @default true
*/
enabled?: boolean;
}
/**
* Hook to monitor CopilotKit health status with automatic polling
* Uses exponential backoff when unhealthy
*/
export const useCopilotKitHealth = (options: UseCopilotKitHealthOptions = {}) => {
const {
initialDelay = 1000,
healthyInterval = 60000, // 1 minute
unhealthyIntervals = [5000, 10000, 30000, 60000], // 5s, 10s, 30s, 60s
enabled = true,
} = options;
const {
isHealthy,
isChecking,
lastChecked,
errorMessage,
retryCount,
isAvailable,
checkHealth,
markUnhealthy,
} = useCopilotKitHealthContext();
const intervalRef = useRef<NodeJS.Timeout | null>(null);
const timeoutRef = useRef<NodeJS.Timeout | null>(null);
const scheduleNextCheck = useCallback(() => {
// Clear any existing timeouts/intervals
if (intervalRef.current) {
clearInterval(intervalRef.current);
intervalRef.current = null;
}
if (timeoutRef.current) {
clearTimeout(timeoutRef.current);
timeoutRef.current = null;
}
if (!enabled) return;
// Calculate next check interval based on health status
let nextInterval: number;
if (isHealthy) {
// When healthy, use standard interval
nextInterval = healthyInterval;
} else {
// When unhealthy, use exponential backoff
const intervalIndex = Math.min(retryCount, unhealthyIntervals.length - 1);
nextInterval = unhealthyIntervals[intervalIndex];
}
// Schedule next check
timeoutRef.current = setTimeout(() => {
checkHealth();
}, nextInterval);
}, [enabled, isHealthy, retryCount, healthyInterval, unhealthyIntervals, checkHealth]);
// Initial health check on mount
useEffect(() => {
if (!enabled) return;
// Initial delay before first check
const initialTimeout = setTimeout(() => {
checkHealth();
}, initialDelay);
return () => {
clearTimeout(initialTimeout);
};
}, [enabled, initialDelay, checkHealth]);
// Schedule next check after health status changes
useEffect(() => {
if (!enabled || isChecking) return;
scheduleNextCheck();
return () => {
if (intervalRef.current) {
clearInterval(intervalRef.current);
}
if (timeoutRef.current) {
clearTimeout(timeoutRef.current);
}
};
}, [enabled, isChecking, scheduleNextCheck]);
// Also handle CopilotKit runtime errors by listening to window events
useEffect(() => {
if (!enabled) return;
const handleCopilotKitError = (event: Event) => {
// Check if this is a CopilotKit-related error
const errorEvent = event as ErrorEvent;
if (
errorEvent.message?.includes('copilotkit') ||
errorEvent.message?.includes('CopilotKit') ||
errorEvent.filename?.includes('copilotkit')
) {
markUnhealthy(`Runtime error: ${errorEvent.message}`);
}
};
window.addEventListener('error', handleCopilotKitError);
window.addEventListener('unhandledrejection', (event) => {
const reason = event.reason;
if (
typeof reason === 'string' && (
reason.includes('copilotkit') ||
reason.includes('CopilotKit') ||
reason.includes('ERR_CERT_COMMON_NAME_INVALID') ||
reason.includes('CORS')
)
) {
markUnhealthy(`Unhandled promise rejection: ${reason}`);
}
});
return () => {
window.removeEventListener('error', handleCopilotKitError);
};
}, [enabled, markUnhealthy]);
return {
isHealthy,
isAvailable,
isChecking,
lastChecked,
errorMessage,
retryCount,
checkHealth,
markUnhealthy,
};
};

View File

@@ -170,8 +170,8 @@ export const usePhaseNavigation = (
// User is NOT in SEO phase - can progress to publish // User is NOT in SEO phase - can progress to publish
// This handles cases where user navigates away and comes back // This handles cases where user navigates away and comes back
// Only auto-progress if user is already in a different phase (not actively in SEO) // Only auto-progress if user is already in a different phase (not actively in SEO)
if (currentPhase !== 'publish') { if (currentPhase !== 'publish') {
setCurrentPhase('publish'); setCurrentPhase('publish');
} }
} }
} }

View File

@@ -153,7 +153,7 @@ export function usePolling(
attemptsRef.current++; attemptsRef.current++;
} catch (err) { } catch (err) {
const errorMessage = err instanceof Error ? err.message : 'Unknown error occurred'; const errorMessage = err instanceof Error ? err.message : 'Unknown error occurred';
console.error('Polling error:', errorMessage); console.error('Polling error:', errorMessage, err);
// Check if this is an axios error with subscription limit status // Check if this is an axios error with subscription limit status
// This is a fallback in case the interceptor doesn't catch it // This is a fallback in case the interceptor doesn't catch it
@@ -161,15 +161,17 @@ export function usePolling(
if (axiosError?.response?.status === 429 || axiosError?.response?.status === 402) { if (axiosError?.response?.status === 429 || axiosError?.response?.status === 402) {
console.log('usePolling: Detected subscription error in axios error response', { console.log('usePolling: Detected subscription error in axios error response', {
status: axiosError.response.status, status: axiosError.response.status,
data: axiosError.response.data data: axiosError.response.data,
errorDataKeys: axiosError.response.data ? Object.keys(axiosError.response.data) : null
}); });
// Trigger subscription error handler (modal will show) // Trigger subscription error handler (modal will show)
// Note: The interceptor may have already called this, but we call it again to be safe
const handled = triggerSubscriptionError(axiosError); const handled = triggerSubscriptionError(axiosError);
console.log('usePolling: triggerSubscriptionError returned', handled); console.log('usePolling: triggerSubscriptionError returned', handled);
if (handled) { if (handled) {
console.log('usePolling: Subscription error handled, stopping polling'); console.log('usePolling: Subscription error handled, stopping polling - modal should be visible');
const errorMsg = axiosError.response?.data?.message || const errorMsg = axiosError.response?.data?.message ||
axiosError.response?.data?.error || axiosError.response?.data?.error ||
'Subscription limit exceeded'; 'Subscription limit exceeded';

View File

@@ -0,0 +1,239 @@
import React, { useState } from 'react';
import { ResearchWizard } from '../components/Research';
import { BlogResearchResponse } from '../services/blogWriterApi';
const samplePresets = [
{
name: 'AI Marketing Tools',
keywords: 'AI in marketing, automation tools, customer engagement',
industry: 'Technology',
},
{
name: 'Small Business SEO',
keywords: 'local SEO, small business, Google My Business',
industry: 'Marketing',
},
{
name: 'Content Strategy',
keywords: 'content planning, editorial calendar, content creation',
industry: 'Marketing',
},
];
export const ResearchTest: React.FC = () => {
const [results, setResults] = useState<BlogResearchResponse | null>(null);
const [showDebug, setShowDebug] = useState(false);
const [presetKeywords, setPresetKeywords] = useState<string[] | undefined>();
const [presetIndustry, setPresetIndustry] = useState<string | undefined>();
const handleComplete = (researchResults: BlogResearchResponse) => {
setResults(researchResults);
};
const handlePresetClick = (preset: typeof samplePresets[0]) => {
setPresetKeywords(preset.keywords.split(',').map(k => k.trim()));
setPresetIndustry(preset.industry);
setResults(null);
};
const handleReset = () => {
setPresetKeywords(undefined);
setPresetIndustry(undefined);
setResults(null);
};
return (
<div style={{ minHeight: '100vh', backgroundColor: '#f5f5f5' }}>
{/* Header */}
<div style={{
backgroundColor: '#1976d2',
color: 'white',
padding: '20px',
marginBottom: '20px',
}}>
<div style={{ maxWidth: '1400px', margin: '0 auto' }}>
<h1 style={{ margin: 0, fontSize: '28px' }}>🔬 Research Component Test Page</h1>
<p style={{ margin: '8px 0 0 0', fontSize: '14px', opacity: 0.9 }}>
Test the modular research wizard component
</p>
</div>
</div>
<div style={{ maxWidth: '1400px', margin: '0 auto', padding: '0 20px', display: 'flex', gap: '20px', flexWrap: 'wrap' }}>
{/* Left Panel - Controls */}
<div style={{ flex: '1 1 300px', minWidth: '300px' }}>
<div style={{
backgroundColor: 'white',
borderRadius: '8px',
padding: '20px',
marginBottom: '20px',
boxShadow: '0 2px 4px rgba(0,0,0,0.1)',
}}>
<h3 style={{ margin: '0 0 16px 0', color: '#333', fontSize: '18px' }}>
🎯 Quick Presets
</h3>
<div style={{ display: 'flex', flexDirection: 'column', gap: '8px' }}>
{samplePresets.map((preset, idx) => (
<button
key={idx}
onClick={() => handlePresetClick(preset)}
style={{
padding: '12px',
backgroundColor: '#f0f7ff',
border: '1px solid #b3d9ff',
borderRadius: '6px',
cursor: 'pointer',
textAlign: 'left',
fontSize: '14px',
transition: 'all 0.2s ease',
}}
onMouseEnter={(e) => {
e.currentTarget.style.backgroundColor = '#e3f2fd';
e.currentTarget.style.borderColor = '#90caf9';
}}
onMouseLeave={(e) => {
e.currentTarget.style.backgroundColor = '#f0f7ff';
e.currentTarget.style.borderColor = '#b3d9ff';
}}
>
<div style={{ fontWeight: '600', color: '#1976d2', marginBottom: '4px' }}>
{preset.name}
</div>
<div style={{ fontSize: '12px', color: '#666' }}>
{preset.keywords}
</div>
</button>
))}
</div>
<button
onClick={handleReset}
style={{
marginTop: '12px',
padding: '8px 16px',
backgroundColor: '#f5f5f5',
border: '1px solid #ddd',
borderRadius: '6px',
cursor: 'pointer',
fontSize: '13px',
width: '100%',
}}
>
Reset Test
</button>
</div>
{/* Debug Panel */}
<div style={{
backgroundColor: 'white',
borderRadius: '8px',
padding: '20px',
boxShadow: '0 2px 4px rgba(0,0,0,0.1)',
}}>
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: '12px' }}>
<h3 style={{ margin: 0, color: '#333', fontSize: '18px' }}>
🐛 Debug Panel
</h3>
<label style={{ cursor: 'pointer', fontSize: '14px' }}>
<input
type="checkbox"
checked={showDebug}
onChange={(e) => setShowDebug(e.target.checked)}
style={{ marginRight: '6px' }}
/>
Show Debug
</label>
</div>
{showDebug && (
<div style={{
backgroundColor: '#f5f5f5',
borderRadius: '4px',
padding: '12px',
fontSize: '12px',
fontFamily: 'monospace',
maxHeight: '400px',
overflow: 'auto',
}}>
<pre style={{ margin: 0, whiteSpace: 'pre-wrap', wordBreak: 'break-word' }}>
{JSON.stringify(results, null, 2)}
</pre>
</div>
)}
</div>
</div>
{/* Main Content - Wizard */}
<div style={{ flex: '2 1 800px' }}>
<ResearchWizard
initialKeywords={presetKeywords}
initialIndustry={presetIndustry}
onComplete={handleComplete}
/>
</div>
</div>
{/* Footer Stats */}
{results && (
<div style={{
backgroundColor: 'white',
borderTop: '2px solid #e0e0e0',
padding: '20px',
marginTop: '40px',
}}>
<div style={{ maxWidth: '1400px', margin: '0 auto' }}>
<h3 style={{ margin: '0 0 16px 0', color: '#333', fontSize: '18px' }}>
📊 Research Statistics
</h3>
<div style={{ display: 'grid', gridTemplateColumns: 'repeat(auto-fit, minmax(200px, 1fr))', gap: '16px' }}>
<div style={{
backgroundColor: '#e3f2fd',
padding: '16px',
borderRadius: '8px',
border: '1px solid #90caf9',
}}>
<div style={{ fontSize: '12px', color: '#1976d2', fontWeight: '600', marginBottom: '4px' }}>
Sources Found
</div>
<div style={{ fontSize: '28px', fontWeight: 'bold', color: '#1976d2' }}>
{results.sources.length}
</div>
</div>
<div style={{
backgroundColor: '#f3e5f5',
padding: '16px',
borderRadius: '8px',
border: '1px solid #ce93d8',
}}>
<div style={{ fontSize: '12px', color: '#7b1fa2', fontWeight: '600', marginBottom: '4px' }}>
Content Angles
</div>
<div style={{ fontSize: '28px', fontWeight: 'bold', color: '#7b1fa2' }}>
{results.suggested_angles.length}
</div>
</div>
<div style={{
backgroundColor: '#e8f5e8',
padding: '16px',
borderRadius: '8px',
border: '1px solid #81c784',
}}>
<div style={{ fontSize: '12px', color: '#2e7d32', fontWeight: '600', marginBottom: '4px' }}>
Search Queries
</div>
<div style={{ fontSize: '28px', fontWeight: 'bold', color: '#2e7d32' }}>
{results.search_queries?.length || 0}
</div>
</div>
</div>
</div>
</div>
)}
</div>
);
};
export default ResearchTest;

View File

@@ -17,6 +17,23 @@ export interface ResearchSource {
source_type?: string; source_type?: string;
} }
export type ResearchMode = 'basic' | 'comprehensive' | 'targeted';
export type ResearchProvider = 'google' | 'exa';
export type SourceType = 'web' | 'academic' | 'news' | 'industry' | 'expert';
export type DateRange = 'last_week' | 'last_month' | 'last_3_months' | 'last_6_months' | 'last_year' | 'all_time';
export interface ResearchConfig {
mode?: ResearchMode;
provider?: ResearchProvider;
date_range?: DateRange;
source_types?: SourceType[];
max_sources?: number;
include_statistics?: boolean;
include_expert_quotes?: boolean;
include_competitors?: boolean;
include_trends?: boolean;
}
export interface BlogResearchRequest { export interface BlogResearchRequest {
keywords: string[]; keywords: string[];
topic?: string; topic?: string;
@@ -25,6 +42,8 @@ export interface BlogResearchRequest {
tone?: string; tone?: string;
word_count_target?: number; word_count_target?: number;
persona?: PersonaInfo; persona?: PersonaInfo;
research_mode?: ResearchMode;
config?: ResearchConfig;
} }
export interface GroundingChunk { export interface GroundingChunk {

View File

@@ -0,0 +1,158 @@
/**
* Blog Writer Cache Service
*
* Provides persistent caching for outline and content to survive page refreshes
* and avoid unnecessary API calls. Shared by both CopilotKit and manual flows.
*/
interface CachedOutlineEntry {
outline: any[];
title_options?: string[];
research_keywords: string[];
created_at: string;
}
interface CachedContentEntry {
sections: Record<string, string>;
outline_ids: string[];
research_keywords: string[];
created_at: string;
}
class BlogWriterCacheService {
private readonly OUTLINE_CACHE_KEY = 'blog_outline';
private readonly TITLE_OPTIONS_CACHE_KEY = 'blog_title_options';
private readonly CONTENT_CACHE_PREFIX = 'blog_content_';
/**
* Get cached outline for research keywords
*/
getCachedOutline(researchKeywords: string[]): { outline: any[]; title_options?: string[] } | null {
try {
if (typeof window === 'undefined') return null;
const savedOutline = localStorage.getItem(this.OUTLINE_CACHE_KEY);
const savedTitleOptions = localStorage.getItem(this.TITLE_OPTIONS_CACHE_KEY);
if (!savedOutline) {
return null;
}
const parsedOutline = JSON.parse(savedOutline);
if (!Array.isArray(parsedOutline) || parsedOutline.length === 0) {
return null;
}
// Basic validation: if we have an outline saved and it has sections, use it
// More sophisticated matching could compare research keywords if needed
const titleOptions = savedTitleOptions ? JSON.parse(savedTitleOptions) : undefined;
console.log(`Cache hit for outline (${parsedOutline.length} sections)`);
return {
outline: parsedOutline,
title_options: titleOptions
};
} catch (error) {
console.error('Error retrieving cached outline:', error);
return null;
}
}
/**
* Cache outline result
*/
cacheOutline(outline: any[], titleOptions?: string[]): void {
try {
if (typeof window === 'undefined') return;
localStorage.setItem(this.OUTLINE_CACHE_KEY, JSON.stringify(outline));
if (titleOptions) {
localStorage.setItem(this.TITLE_OPTIONS_CACHE_KEY, JSON.stringify(titleOptions));
}
console.log(`Cached outline (${outline.length} sections)`);
} catch (error) {
console.error('Error caching outline:', error);
}
}
/**
* Generate cache key for content based on outline section IDs
*/
private generateContentCacheKey(outlineIds: string[]): string {
const sortedIds = [...outlineIds].sort().join('|');
return `${this.CONTENT_CACHE_PREFIX}${sortedIds}`;
}
/**
* Get cached content for outline sections
*/
getCachedContent(outlineIds: string[]): Record<string, string> | null {
try {
if (typeof window === 'undefined') return null;
const cacheKey = this.generateContentCacheKey(outlineIds);
const cachedContent = localStorage.getItem(cacheKey);
if (!cachedContent) {
return null;
}
const parsedSections = JSON.parse(cachedContent);
if (!parsedSections || typeof parsedSections !== 'object' || Object.keys(parsedSections).length === 0) {
return null;
}
// Verify that cached sections match outline structure
const cachedIds = new Set(Object.keys(parsedSections));
const outlineIdsSet = new Set(outlineIds.map(id => String(id)));
const idsMatch = outlineIdsSet.size === cachedIds.size &&
Array.from(outlineIdsSet).every(id => cachedIds.has(id));
if (!idsMatch) {
console.log('Cached content does not match outline structure');
return null;
}
console.log(`Cache hit for content (${Object.keys(parsedSections).length} sections)`);
return parsedSections;
} catch (error) {
console.error('Error retrieving cached content:', error);
return null;
}
}
/**
* Cache content sections
*/
cacheContent(sections: Record<string, string>, outlineIds: string[]): void {
try {
if (typeof window === 'undefined') return;
if (!sections || Object.keys(sections).length === 0) return;
const cacheKey = this.generateContentCacheKey(outlineIds);
localStorage.setItem(cacheKey, JSON.stringify(sections));
console.log(`Cached content (${Object.keys(sections).length} sections)`);
} catch (error) {
console.error('Error caching content:', error);
}
}
/**
* Check if content exists in state (helper for manual flow)
*/
contentExistsInState(sections: Record<string, string>, outlineIds: string[]): boolean {
if (!sections || Object.keys(sections).length === 0) {
return false;
}
const existingIds = new Set(Object.keys(sections));
const outlineIdsSet = new Set(outlineIds.map(id => String(id)));
return outlineIdsSet.size === existingIds.size &&
Array.from(outlineIdsSet).every(id => existingIds.has(id));
}
}
// Export singleton instance
export const blogWriterCache = new BlogWriterCacheService();
export default blogWriterCache;

View File

@@ -0,0 +1,161 @@
/**
* Global Navigation State Utility
*
* Manages navigation state preservation across subscription renewals and redirects.
* Supports:
* - Page path preservation
* - Phase state (for tools with phases like Blog Writer)
* - Tool-specific context (extensible for future tools)
*/
export interface NavigationState {
path: string;
phase?: string; // Phase ID for tools with phases (e.g., 'research', 'outline', 'content')
tool?: string; // Tool identifier (e.g., 'blog-writer', 'other-tool')
context?: Record<string, any>; // Tool-specific context data
timestamp: number; // When this state was saved
}
const NAVIGATION_STATE_KEY = 'subscription_navigation_state';
/**
* Save navigation state before redirecting to pricing/subscription pages
*
* @param path - Current page path (e.g., '/blog-writer')
* @param phase - Current phase ID (optional, for tools with phases)
* @param tool - Tool identifier (optional, defaults to detecting from path)
* @param context - Additional tool-specific context (optional)
*/
export const saveNavigationState = (
path: string,
phase?: string,
tool?: string,
context?: Record<string, any>
): void => {
try {
// Auto-detect tool from path if not provided
const detectedTool = tool || detectToolFromPath(path);
const state: NavigationState = {
path,
phase,
tool: detectedTool,
context,
timestamp: Date.now()
};
sessionStorage.setItem(NAVIGATION_STATE_KEY, JSON.stringify(state));
console.log('[NavigationState] Saved navigation state:', state);
} catch (error) {
console.error('[NavigationState] Failed to save navigation state:', error);
}
};
/**
* Restore navigation state after returning from pricing/subscription pages
*
* @returns NavigationState or null if not found/invalid
*/
export const restoreNavigationState = (): NavigationState | null => {
try {
const stored = sessionStorage.getItem(NAVIGATION_STATE_KEY);
if (!stored) {
return null;
}
const state: NavigationState = JSON.parse(stored);
// Validate state (must have path and reasonable timestamp)
if (!state.path || !state.timestamp) {
console.warn('[NavigationState] Invalid navigation state:', state);
return null;
}
// Clear state after reading (one-time use)
sessionStorage.removeItem(NAVIGATION_STATE_KEY);
console.log('[NavigationState] Restored navigation state:', state);
return state;
} catch (error) {
console.error('[NavigationState] Failed to restore navigation state:', error);
sessionStorage.removeItem(NAVIGATION_STATE_KEY);
return null;
}
};
/**
* Get navigation state without clearing it (for inspection)
*/
export const peekNavigationState = (): NavigationState | null => {
try {
const stored = sessionStorage.getItem(NAVIGATION_STATE_KEY);
if (!stored) {
return null;
}
return JSON.parse(stored);
} catch (error) {
console.error('[NavigationState] Failed to peek navigation state:', error);
return null;
}
};
/**
* Clear navigation state (useful for cleanup)
*/
export const clearNavigationState = (): void => {
try {
sessionStorage.removeItem(NAVIGATION_STATE_KEY);
console.log('[NavigationState] Cleared navigation state');
} catch (error) {
console.error('[NavigationState] Failed to clear navigation state:', error);
}
};
/**
* Detect tool identifier from path
*/
const detectToolFromPath = (path: string): string | undefined => {
if (path.includes('/blog-writer') || path.includes('/blogwriter')) {
return 'blog-writer';
}
// Add more tool detection logic as needed
// if (path.includes('/other-tool')) {
// return 'other-tool';
// }
return undefined;
};
/**
* Get current phase from localStorage for a specific tool
* This is a helper for tools that store phases in localStorage
*/
export const getCurrentPhaseForTool = (tool: string): string | null => {
try {
if (tool === 'blog-writer') {
return localStorage.getItem('blogwriter_current_phase') || null;
}
// Add more tool-specific phase retrieval as needed
return null;
} catch (error) {
console.error(`[NavigationState] Failed to get phase for tool ${tool}:`, error);
return null;
}
};
/**
* Save current phase to localStorage for a specific tool
* This is a helper for tools that store phases in localStorage
*/
export const saveCurrentPhaseForTool = (tool: string, phase: string): void => {
try {
if (tool === 'blog-writer') {
localStorage.setItem('blogwriter_current_phase', phase);
console.log(`[NavigationState] Saved phase '${phase}' for ${tool}`);
}
// Add more tool-specific phase saving as needed
} catch (error) {
console.error(`[NavigationState] Failed to save phase for tool ${tool}:`, error);
}
};