Research component integration, Copilotkit implementation, SEO copilotkit implementation, Wix SEO metadata complete, Wix SEO metadata review
This commit is contained in:
@@ -8,6 +8,7 @@ import time
|
||||
import json
|
||||
from typing import Dict, Any, List
|
||||
from loguru import logger
|
||||
from fastapi import HTTPException
|
||||
|
||||
from models.blog_models import (
|
||||
MediumBlogGenerateRequest,
|
||||
@@ -25,8 +26,20 @@ class MediumBlogGenerator:
|
||||
def __init__(self):
|
||||
self.cache = persistent_content_cache
|
||||
|
||||
async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str) -> MediumBlogGenerateResult:
|
||||
"""Use Gemini structured JSON to generate a medium-length blog in one call."""
|
||||
async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str, user_id: str) -> MediumBlogGenerateResult:
|
||||
"""Use Gemini structured JSON to generate a medium-length blog in one call.
|
||||
|
||||
Args:
|
||||
req: Medium blog generation request
|
||||
task_id: Task ID for progress updates
|
||||
user_id: User ID (required for subscription checks and usage tracking)
|
||||
|
||||
Raises:
|
||||
ValueError: If user_id is not provided
|
||||
"""
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for medium blog generation (subscription checks and usage tracking)")
|
||||
|
||||
import time
|
||||
start = time.time()
|
||||
|
||||
@@ -156,7 +169,7 @@ class MediumBlogGenerator:
|
||||
- Use language that resonates with {audience}
|
||||
- Maintain consistent voice that reflects this persona's expertise
|
||||
"""
|
||||
|
||||
|
||||
prompt = (
|
||||
f"Write blog content for the following sections. Each section should be {req.globalTargetWords or 1000} words total, distributed across all sections.\n\n"
|
||||
f"Blog Title: {req.title}\n\n"
|
||||
@@ -176,11 +189,20 @@ class MediumBlogGenerator:
|
||||
f"Sections to write:\n{json.dumps(payload, ensure_ascii=False, indent=2)}"
|
||||
)
|
||||
|
||||
ai_resp = llm_text_gen(
|
||||
prompt=prompt,
|
||||
json_struct=schema,
|
||||
system_prompt=system,
|
||||
)
|
||||
try:
|
||||
ai_resp = llm_text_gen(
|
||||
prompt=prompt,
|
||||
json_struct=schema,
|
||||
system_prompt=system,
|
||||
user_id=user_id
|
||||
)
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions (e.g., 429 subscription limit) to preserve error details
|
||||
raise
|
||||
except Exception as llm_error:
|
||||
# Wrap other errors
|
||||
logger.error(f"AI generation failed: {llm_error}")
|
||||
raise Exception(f"AI generation failed: {str(llm_error)}")
|
||||
|
||||
# Check for errors in AI response
|
||||
if not ai_resp or ai_resp.get("error"):
|
||||
|
||||
@@ -105,13 +105,20 @@ class BlogWriterService:
|
||||
return await self.research_service.research_with_progress(request, task_id, user_id)
|
||||
|
||||
# Outline Methods
|
||||
async def generate_outline(self, request: BlogOutlineRequest) -> BlogOutlineResponse:
|
||||
"""Generate AI-powered outline from research data."""
|
||||
return await self.outline_service.generate_outline(request)
|
||||
async def generate_outline(self, request: BlogOutlineRequest, user_id: str) -> BlogOutlineResponse:
|
||||
"""Generate AI-powered outline from research data.
|
||||
|
||||
Args:
|
||||
request: Outline generation request with research data
|
||||
user_id: User ID (required for subscription checks and usage tracking)
|
||||
"""
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for outline generation (subscription checks and usage tracking)")
|
||||
return await self.outline_service.generate_outline(request, user_id)
|
||||
|
||||
async def generate_outline_with_progress(self, request: BlogOutlineRequest, task_id: str) -> BlogOutlineResponse:
|
||||
async def generate_outline_with_progress(self, request: BlogOutlineRequest, task_id: str, user_id: str) -> BlogOutlineResponse:
|
||||
"""Generate outline with real-time progress updates."""
|
||||
return await self.outline_service.generate_outline_with_progress(request, task_id)
|
||||
return await self.outline_service.generate_outline_with_progress(request, task_id, user_id)
|
||||
|
||||
async def refine_outline(self, request: BlogOutlineRefineRequest) -> BlogOutlineResponse:
|
||||
"""Refine outline with HITL operations."""
|
||||
@@ -334,9 +341,17 @@ class BlogWriterService:
|
||||
# TODO: Move to content module
|
||||
return BlogPublishResponse(success=True, platform=request.platform, url="https://example.com/post")
|
||||
|
||||
async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str) -> MediumBlogGenerateResult:
|
||||
"""Use Gemini structured JSON to generate a medium-length blog in one call."""
|
||||
return await self.medium_blog_generator.generate_medium_blog_with_progress(req, task_id)
|
||||
async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str, user_id: str) -> MediumBlogGenerateResult:
|
||||
"""Use Gemini structured JSON to generate a medium-length blog in one call.
|
||||
|
||||
Args:
|
||||
req: Medium blog generation request
|
||||
task_id: Task ID for progress updates
|
||||
user_id: User ID (required for subscription checks and usage tracking)
|
||||
"""
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for medium blog generation (subscription checks and usage tracking)")
|
||||
return await self.medium_blog_generator.generate_medium_blog_with_progress(req, task_id, user_id)
|
||||
|
||||
async def analyze_flow_basic(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Analyze flow metrics for entire blog using single AI call (cost-effective)."""
|
||||
|
||||
@@ -42,10 +42,20 @@ class OutlineGenerator:
|
||||
self.response_processor = ResponseProcessor()
|
||||
self.parallel_processor = ParallelProcessor(self.source_mapper, self.grounding_engine)
|
||||
|
||||
async def generate(self, request: BlogOutlineRequest) -> BlogOutlineResponse:
|
||||
async def generate(self, request: BlogOutlineRequest, user_id: str) -> BlogOutlineResponse:
|
||||
"""
|
||||
Generate AI-powered outline using research results
|
||||
Generate AI-powered outline using research results.
|
||||
|
||||
Args:
|
||||
request: Outline generation request with research data
|
||||
user_id: User ID (required for subscription checks and usage tracking)
|
||||
|
||||
Raises:
|
||||
ValueError: If user_id is not provided
|
||||
"""
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for outline generation (subscription checks and usage tracking)")
|
||||
|
||||
# Extract research insights
|
||||
research = request.research
|
||||
primary_keywords = research.keyword_analysis.get('primary', [])
|
||||
@@ -68,15 +78,15 @@ class OutlineGenerator:
|
||||
# Define schema with proper property ordering (critical for Gemini API)
|
||||
outline_schema = self.prompt_builder.get_outline_schema()
|
||||
|
||||
# Generate outline using structured JSON response with retry logic
|
||||
outline_data = await self.response_processor.generate_with_retry(outline_prompt, outline_schema)
|
||||
# Generate outline using structured JSON response with retry logic (user_id required)
|
||||
outline_data = await self.response_processor.generate_with_retry(outline_prompt, outline_schema, user_id)
|
||||
|
||||
# Convert to BlogOutlineSection objects
|
||||
outline_sections = self.response_processor.convert_to_sections(outline_data, sources)
|
||||
|
||||
# Run parallel processing for speed optimization
|
||||
# Run parallel processing for speed optimization (user_id required)
|
||||
mapped_sections, grounding_insights = await self.parallel_processor.run_parallel_processing_async(
|
||||
outline_sections, research
|
||||
outline_sections, research, user_id
|
||||
)
|
||||
|
||||
# Enhance sections with grounding insights
|
||||
@@ -85,9 +95,9 @@ class OutlineGenerator:
|
||||
mapped_sections, research.grounding_metadata, grounding_insights
|
||||
)
|
||||
|
||||
# Optimize outline for better flow, SEO, and engagement
|
||||
# Optimize outline for better flow, SEO, and engagement (user_id required)
|
||||
logger.info("Optimizing outline for better flow and engagement...")
|
||||
optimized_sections = await self.outline_optimizer.optimize(grounding_enhanced_sections, "comprehensive optimization")
|
||||
optimized_sections = await self.outline_optimizer.optimize(grounding_enhanced_sections, "comprehensive optimization", user_id)
|
||||
|
||||
# Rebalance word counts for optimal distribution
|
||||
target_words = request.word_count or 1500
|
||||
@@ -118,10 +128,21 @@ class OutlineGenerator:
|
||||
research_coverage=research_coverage
|
||||
)
|
||||
|
||||
async def generate_with_progress(self, request: BlogOutlineRequest, task_id: str) -> BlogOutlineResponse:
|
||||
async def generate_with_progress(self, request: BlogOutlineRequest, task_id: str, user_id: str) -> BlogOutlineResponse:
|
||||
"""
|
||||
Outline generation method with progress updates for real-time feedback.
|
||||
|
||||
Args:
|
||||
request: Outline generation request with research data
|
||||
task_id: Task ID for progress updates
|
||||
user_id: User ID (required for subscription checks and usage tracking)
|
||||
|
||||
Raises:
|
||||
ValueError: If user_id is not provided
|
||||
"""
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for outline generation (subscription checks and usage tracking)")
|
||||
|
||||
from api.blog_writer.task_manager import task_manager
|
||||
|
||||
# Extract research insights
|
||||
@@ -150,17 +171,17 @@ class OutlineGenerator:
|
||||
|
||||
await task_manager.update_progress(task_id, "🔄 Making AI request to generate structured outline...")
|
||||
|
||||
# Generate outline using structured JSON response with retry logic
|
||||
outline_data = await self.response_processor.generate_with_retry(outline_prompt, outline_schema, task_id)
|
||||
# Generate outline using structured JSON response with retry logic (user_id required for subscription checks)
|
||||
outline_data = await self.response_processor.generate_with_retry(outline_prompt, outline_schema, user_id, task_id)
|
||||
|
||||
await task_manager.update_progress(task_id, "📝 Processing outline structure and validating sections...")
|
||||
|
||||
# Convert to BlogOutlineSection objects
|
||||
outline_sections = self.response_processor.convert_to_sections(outline_data, sources)
|
||||
|
||||
# Run parallel processing for speed optimization
|
||||
# Run parallel processing for speed optimization (user_id required for subscription checks)
|
||||
mapped_sections, grounding_insights = await self.parallel_processor.run_parallel_processing(
|
||||
outline_sections, research, task_id
|
||||
outline_sections, research, user_id, task_id
|
||||
)
|
||||
|
||||
# Enhance sections with grounding insights (depends on both previous tasks)
|
||||
@@ -169,9 +190,9 @@ class OutlineGenerator:
|
||||
mapped_sections, research.grounding_metadata, grounding_insights
|
||||
)
|
||||
|
||||
# Optimize outline for better flow, SEO, and engagement
|
||||
# Optimize outline for better flow, SEO, and engagement (user_id required for subscription checks)
|
||||
await task_manager.update_progress(task_id, "🎯 Optimizing outline for better flow and engagement...")
|
||||
optimized_sections = await self.outline_optimizer.optimize(grounding_enhanced_sections, "comprehensive optimization")
|
||||
optimized_sections = await self.outline_optimizer.optimize(grounding_enhanced_sections, "comprehensive optimization", user_id)
|
||||
|
||||
# Rebalance word counts for optimal distribution
|
||||
await task_manager.update_progress(task_id, "⚖️ Rebalancing word count distribution...")
|
||||
|
||||
@@ -13,8 +13,23 @@ from models.blog_models import BlogOutlineSection
|
||||
class OutlineOptimizer:
|
||||
"""Optimizes outlines for better flow, SEO, and engagement."""
|
||||
|
||||
async def optimize(self, outline: List[BlogOutlineSection], focus: str = "general optimization") -> List[BlogOutlineSection]:
|
||||
"""Optimize entire outline for better flow, SEO, and engagement."""
|
||||
async def optimize(self, outline: List[BlogOutlineSection], focus: str, user_id: str) -> List[BlogOutlineSection]:
|
||||
"""Optimize entire outline for better flow, SEO, and engagement.
|
||||
|
||||
Args:
|
||||
outline: List of outline sections to optimize
|
||||
focus: Optimization focus (e.g., "general optimization")
|
||||
user_id: User ID (required for subscription checks and usage tracking)
|
||||
|
||||
Returns:
|
||||
List of optimized outline sections
|
||||
|
||||
Raises:
|
||||
ValueError: If user_id is not provided
|
||||
"""
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for outline optimization (subscription checks and usage tracking)")
|
||||
|
||||
outline_text = "\n".join([f"{i+1}. {s.heading}" for i, s in enumerate(outline)])
|
||||
|
||||
optimization_prompt = f"""Optimize this blog outline for better flow, engagement, and SEO:
|
||||
@@ -67,7 +82,8 @@ Return JSON format:
|
||||
optimized_data = llm_text_gen(
|
||||
prompt=optimization_prompt,
|
||||
json_struct=optimization_schema,
|
||||
system_prompt=None
|
||||
system_prompt=None,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Handle the new schema format with "outline" wrapper
|
||||
|
||||
@@ -29,11 +29,21 @@ class OutlineService:
|
||||
self.outline_optimizer = OutlineOptimizer()
|
||||
self.section_enhancer = SectionEnhancer()
|
||||
|
||||
async def generate_outline(self, request: BlogOutlineRequest) -> BlogOutlineResponse:
|
||||
async def generate_outline(self, request: BlogOutlineRequest, user_id: str) -> BlogOutlineResponse:
|
||||
"""
|
||||
Stage 2: Content Planning with AI-generated outline using research results
|
||||
Uses Gemini with research data to create comprehensive, SEO-optimized outline
|
||||
Stage 2: Content Planning with AI-generated outline using research results.
|
||||
Uses Gemini with research data to create comprehensive, SEO-optimized outline.
|
||||
|
||||
Args:
|
||||
request: Outline generation request with research data
|
||||
user_id: User ID (required for subscription checks and usage tracking)
|
||||
|
||||
Raises:
|
||||
ValueError: If user_id is not provided
|
||||
"""
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for outline generation (subscription checks and usage tracking)")
|
||||
|
||||
# Extract cache parameters - use original user keywords for consistent caching
|
||||
keywords = request.research.original_keywords or request.research.keyword_analysis.get('primary', [])
|
||||
industry = getattr(request.persona, 'industry', 'general') if request.persona else 'general'
|
||||
@@ -56,9 +66,9 @@ class OutlineService:
|
||||
logger.info(f"Using cached outline for keywords: {keywords}")
|
||||
return BlogOutlineResponse(**cached_result)
|
||||
|
||||
# Generate new outline if not cached
|
||||
# Generate new outline if not cached (user_id required)
|
||||
logger.info(f"Generating new outline for keywords: {keywords}")
|
||||
result = await self.outline_generator.generate(request)
|
||||
result = await self.outline_generator.generate(request, user_id)
|
||||
|
||||
# Cache the result
|
||||
persistent_outline_cache.cache_outline(
|
||||
@@ -73,7 +83,7 @@ class OutlineService:
|
||||
|
||||
return result
|
||||
|
||||
async def generate_outline_with_progress(self, request: BlogOutlineRequest, task_id: str) -> BlogOutlineResponse:
|
||||
async def generate_outline_with_progress(self, request: BlogOutlineRequest, task_id: str, user_id: str) -> BlogOutlineResponse:
|
||||
"""
|
||||
Outline generation method with progress updates for real-time feedback.
|
||||
"""
|
||||
@@ -104,7 +114,7 @@ class OutlineService:
|
||||
|
||||
# Generate new outline if not cached
|
||||
logger.info(f"Generating new outline for keywords: {keywords} (with progress updates)")
|
||||
result = await self.outline_generator.generate_with_progress(request, task_id)
|
||||
result = await self.outline_generator.generate_with_progress(request, task_id, user_id)
|
||||
|
||||
# Cache the result
|
||||
persistent_outline_cache.cache_outline(
|
||||
|
||||
@@ -17,18 +17,25 @@ class ParallelProcessor:
|
||||
self.source_mapper = source_mapper
|
||||
self.grounding_engine = grounding_engine
|
||||
|
||||
async def run_parallel_processing(self, outline_sections, research, task_id: str = None) -> Tuple[Any, Any]:
|
||||
async def run_parallel_processing(self, outline_sections, research, user_id: str, task_id: str = None) -> Tuple[Any, Any]:
|
||||
"""
|
||||
Run source mapping and grounding insights extraction in parallel.
|
||||
|
||||
Args:
|
||||
outline_sections: List of outline sections to process
|
||||
research: Research data object
|
||||
user_id: User ID (required for subscription checks and usage tracking)
|
||||
task_id: Optional task ID for progress updates
|
||||
|
||||
Returns:
|
||||
Tuple of (mapped_sections, grounding_insights)
|
||||
|
||||
Raises:
|
||||
ValueError: If user_id is not provided
|
||||
"""
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for parallel processing (subscription checks and usage tracking)")
|
||||
|
||||
if task_id:
|
||||
from api.blog_writer.task_manager import task_manager
|
||||
await task_manager.update_progress(task_id, "⚡ Running parallel processing for maximum speed...")
|
||||
@@ -37,7 +44,7 @@ class ParallelProcessor:
|
||||
|
||||
# Run these tasks in parallel to save time
|
||||
source_mapping_task = asyncio.create_task(
|
||||
self._run_source_mapping(outline_sections, research, task_id)
|
||||
self._run_source_mapping(outline_sections, research, task_id, user_id)
|
||||
)
|
||||
|
||||
grounding_insights_task = asyncio.create_task(
|
||||
@@ -52,22 +59,29 @@ class ParallelProcessor:
|
||||
|
||||
return mapped_sections, grounding_insights
|
||||
|
||||
async def run_parallel_processing_async(self, outline_sections, research) -> Tuple[Any, Any]:
|
||||
async def run_parallel_processing_async(self, outline_sections, research, user_id: str) -> Tuple[Any, Any]:
|
||||
"""
|
||||
Run parallel processing without progress updates (for non-progress methods).
|
||||
|
||||
Args:
|
||||
outline_sections: List of outline sections to process
|
||||
research: Research data object
|
||||
user_id: User ID (required for subscription checks and usage tracking)
|
||||
|
||||
Returns:
|
||||
Tuple of (mapped_sections, grounding_insights)
|
||||
|
||||
Raises:
|
||||
ValueError: If user_id is not provided
|
||||
"""
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for parallel processing (subscription checks and usage tracking)")
|
||||
|
||||
logger.info("Running parallel processing for maximum speed...")
|
||||
|
||||
# Run these tasks in parallel to save time
|
||||
source_mapping_task = asyncio.create_task(
|
||||
self._run_source_mapping_async(outline_sections, research)
|
||||
self._run_source_mapping_async(outline_sections, research, user_id)
|
||||
)
|
||||
|
||||
grounding_insights_task = asyncio.create_task(
|
||||
@@ -82,12 +96,12 @@ class ParallelProcessor:
|
||||
|
||||
return mapped_sections, grounding_insights
|
||||
|
||||
async def _run_source_mapping(self, outline_sections, research, task_id):
|
||||
async def _run_source_mapping(self, outline_sections, research, task_id, user_id: str):
|
||||
"""Run source mapping in parallel."""
|
||||
if task_id:
|
||||
from api.blog_writer.task_manager import task_manager
|
||||
await task_manager.update_progress(task_id, "🔗 Applying intelligent source-to-section mapping...")
|
||||
return self.source_mapper.map_sources_to_sections(outline_sections, research)
|
||||
return self.source_mapper.map_sources_to_sections(outline_sections, research, user_id)
|
||||
|
||||
async def _run_grounding_insights_extraction(self, research, task_id):
|
||||
"""Run grounding insights extraction in parallel."""
|
||||
@@ -96,10 +110,10 @@ class ParallelProcessor:
|
||||
await task_manager.update_progress(task_id, "🧠 Extracting grounding metadata insights...")
|
||||
return self.grounding_engine.extract_contextual_insights(research.grounding_metadata)
|
||||
|
||||
async def _run_source_mapping_async(self, outline_sections, research):
|
||||
async def _run_source_mapping_async(self, outline_sections, research, user_id: str):
|
||||
"""Run source mapping in parallel (async version without progress updates)."""
|
||||
logger.info("Applying intelligent source-to-section mapping...")
|
||||
return self.source_mapper.map_sources_to_sections(outline_sections, research)
|
||||
return self.source_mapper.map_sources_to_sections(outline_sections, research, user_id)
|
||||
|
||||
async def _run_grounding_insights_extraction_async(self, research):
|
||||
"""Run grounding insights extraction in parallel (async version without progress updates)."""
|
||||
|
||||
@@ -18,8 +18,21 @@ class ResponseProcessor:
|
||||
"""Initialize the response processor."""
|
||||
pass
|
||||
|
||||
async def generate_with_retry(self, prompt: str, schema: Dict[str, Any], task_id: str = None) -> Dict[str, Any]:
|
||||
"""Generate outline with retry logic for API failures."""
|
||||
async def generate_with_retry(self, prompt: str, schema: Dict[str, Any], user_id: str, task_id: str = None) -> Dict[str, Any]:
|
||||
"""Generate outline with retry logic for API failures.
|
||||
|
||||
Args:
|
||||
prompt: The prompt for outline generation
|
||||
schema: JSON schema for structured response
|
||||
user_id: User ID (required for subscription checks and usage tracking)
|
||||
task_id: Optional task ID for progress updates
|
||||
|
||||
Raises:
|
||||
ValueError: If user_id is not provided
|
||||
"""
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for outline generation (subscription checks and usage tracking)")
|
||||
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from api.blog_writer.task_manager import task_manager
|
||||
|
||||
@@ -34,7 +47,8 @@ class ResponseProcessor:
|
||||
outline_data = llm_text_gen(
|
||||
prompt=prompt,
|
||||
json_struct=schema,
|
||||
system_prompt=None
|
||||
system_prompt=None,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Log response for debugging
|
||||
|
||||
@@ -12,8 +12,23 @@ from models.blog_models import BlogOutlineSection
|
||||
class SectionEnhancer:
|
||||
"""Enhances individual outline sections using AI."""
|
||||
|
||||
async def enhance(self, section: BlogOutlineSection, focus: str = "general improvement") -> BlogOutlineSection:
|
||||
"""Enhance a section using AI with research context."""
|
||||
async def enhance(self, section: BlogOutlineSection, focus: str, user_id: str) -> BlogOutlineSection:
|
||||
"""Enhance a section using AI with research context.
|
||||
|
||||
Args:
|
||||
section: Outline section to enhance
|
||||
focus: Enhancement focus (e.g., "general improvement")
|
||||
user_id: User ID (required for subscription checks and usage tracking)
|
||||
|
||||
Returns:
|
||||
Enhanced outline section
|
||||
|
||||
Raises:
|
||||
ValueError: If user_id is not provided
|
||||
"""
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for section enhancement (subscription checks and usage tracking)")
|
||||
|
||||
enhancement_prompt = f"""
|
||||
Enhance the following blog section to make it more engaging, comprehensive, and valuable:
|
||||
|
||||
@@ -61,7 +76,8 @@ class SectionEnhancer:
|
||||
enhanced_data = llm_text_gen(
|
||||
prompt=enhancement_prompt,
|
||||
json_struct=enhancement_schema,
|
||||
system_prompt=None
|
||||
system_prompt=None,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if isinstance(enhanced_data, dict) and 'error' not in enhanced_data:
|
||||
|
||||
@@ -52,7 +52,8 @@ class SourceToSectionMapper:
|
||||
def map_sources_to_sections(
|
||||
self,
|
||||
sections: List[BlogOutlineSection],
|
||||
research_data: BlogResearchResponse
|
||||
research_data: BlogResearchResponse,
|
||||
user_id: str
|
||||
) -> List[BlogOutlineSection]:
|
||||
"""
|
||||
Map research sources to outline sections using intelligent algorithms.
|
||||
@@ -60,10 +61,17 @@ class SourceToSectionMapper:
|
||||
Args:
|
||||
sections: List of outline sections to map sources to
|
||||
research_data: Research data containing sources and metadata
|
||||
user_id: User ID (required for subscription checks and usage tracking)
|
||||
|
||||
Returns:
|
||||
List of outline sections with intelligently mapped sources
|
||||
|
||||
Raises:
|
||||
ValueError: If user_id is not provided
|
||||
"""
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for source mapping (subscription checks and usage tracking)")
|
||||
|
||||
if not sections or not research_data.sources:
|
||||
logger.warning("No sections or sources to map")
|
||||
return sections
|
||||
@@ -73,8 +81,8 @@ class SourceToSectionMapper:
|
||||
# Step 1: Algorithmic mapping
|
||||
mapping_results = self._algorithmic_source_mapping(sections, research_data)
|
||||
|
||||
# Step 2: AI validation and improvement (single prompt)
|
||||
validated_mapping = self._ai_validate_mapping(mapping_results, research_data)
|
||||
# Step 2: AI validation and improvement (single prompt, user_id required for subscription checks)
|
||||
validated_mapping = self._ai_validate_mapping(mapping_results, research_data, user_id)
|
||||
|
||||
# Step 3: Apply validated mapping to sections
|
||||
mapped_sections = self._apply_mapping_to_sections(sections, validated_mapping)
|
||||
@@ -261,7 +269,8 @@ class SourceToSectionMapper:
|
||||
def _ai_validate_mapping(
|
||||
self,
|
||||
mapping_results: Dict[str, List[Tuple[ResearchSource, float]]],
|
||||
research_data: BlogResearchResponse
|
||||
research_data: BlogResearchResponse,
|
||||
user_id: str
|
||||
) -> Dict[str, List[Tuple[ResearchSource, float]]]:
|
||||
"""
|
||||
Use AI to validate and improve the algorithmic mapping results.
|
||||
@@ -269,18 +278,25 @@ class SourceToSectionMapper:
|
||||
Args:
|
||||
mapping_results: Algorithmic mapping results
|
||||
research_data: Research data for context
|
||||
user_id: User ID (required for subscription checks and usage tracking)
|
||||
|
||||
Returns:
|
||||
AI-validated and improved mapping results
|
||||
|
||||
Raises:
|
||||
ValueError: If user_id is not provided
|
||||
"""
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for AI validation (subscription checks and usage tracking)")
|
||||
|
||||
try:
|
||||
logger.info("Starting AI validation of source-to-section mapping...")
|
||||
|
||||
# Build AI validation prompt
|
||||
validation_prompt = self._build_validation_prompt(mapping_results, research_data)
|
||||
|
||||
# Get AI validation response
|
||||
validation_response = self._get_ai_validation_response(validation_prompt)
|
||||
# Get AI validation response (user_id required for subscription checks)
|
||||
validation_response = self._get_ai_validation_response(validation_prompt, user_id)
|
||||
|
||||
# Parse and apply AI validation results
|
||||
validated_mapping = self._parse_validation_response(validation_response, mapping_results, research_data)
|
||||
@@ -548,23 +564,31 @@ Analyze the mapping and provide your recommendations.
|
||||
|
||||
return prompt
|
||||
|
||||
def _get_ai_validation_response(self, prompt: str) -> str:
|
||||
def _get_ai_validation_response(self, prompt: str, user_id: str) -> str:
|
||||
"""
|
||||
Get AI validation response using LLM provider.
|
||||
|
||||
Args:
|
||||
prompt: Validation prompt
|
||||
user_id: User ID (required for subscription checks and usage tracking)
|
||||
|
||||
Returns:
|
||||
AI validation response
|
||||
|
||||
Raises:
|
||||
ValueError: If user_id is not provided
|
||||
"""
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for AI validation response (subscription checks and usage tracking)")
|
||||
|
||||
try:
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
|
||||
response = llm_text_gen(
|
||||
prompt=prompt,
|
||||
json_struct=None,
|
||||
system_prompt=None
|
||||
system_prompt=None,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@@ -13,11 +13,17 @@ from .keyword_analyzer import KeywordAnalyzer
|
||||
from .competitor_analyzer import CompetitorAnalyzer
|
||||
from .content_angle_generator import ContentAngleGenerator
|
||||
from .data_filter import ResearchDataFilter
|
||||
from .base_provider import ResearchProvider as BaseResearchProvider
|
||||
from .google_provider import GoogleResearchProvider
|
||||
from .exa_provider import ExaResearchProvider
|
||||
|
||||
__all__ = [
|
||||
'ResearchService',
|
||||
'KeywordAnalyzer',
|
||||
'CompetitorAnalyzer',
|
||||
'ContentAngleGenerator',
|
||||
'ResearchDataFilter'
|
||||
'ResearchDataFilter',
|
||||
'BaseResearchProvider',
|
||||
'GoogleResearchProvider',
|
||||
'ExaResearchProvider',
|
||||
]
|
||||
|
||||
37
backend/services/blog_writer/research/base_provider.py
Normal file
37
backend/services/blog_writer/research/base_provider.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
Base Research Provider Interface
|
||||
|
||||
Abstract base class for research provider implementations.
|
||||
Ensures consistency across different research providers (Google, Exa, etc.)
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
class ResearchProvider(ABC):
|
||||
"""Abstract base class for research providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def search(
|
||||
self,
|
||||
prompt: str,
|
||||
topic: str,
|
||||
industry: str,
|
||||
target_audience: str,
|
||||
config: Any, # ResearchConfig
|
||||
user_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute research and return raw results."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_provider_enum(self):
|
||||
"""Return APIProvider enum for subscription tracking."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def estimate_tokens(self) -> int:
|
||||
"""Estimate token usage for pre-flight validation."""
|
||||
pass
|
||||
|
||||
188
backend/services/blog_writer/research/exa_provider.py
Normal file
188
backend/services/blog_writer/research/exa_provider.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
Exa Research Provider
|
||||
|
||||
Neural search implementation using Exa API for high-quality, citation-rich research.
|
||||
"""
|
||||
|
||||
from exa_py import Exa
|
||||
import os
|
||||
from loguru import logger
|
||||
from models.subscription_models import APIProvider
|
||||
from .base_provider import ResearchProvider as BaseProvider
|
||||
|
||||
|
||||
class ExaResearchProvider(BaseProvider):
|
||||
"""Exa neural search provider."""
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = os.getenv("EXA_API_KEY")
|
||||
if not self.api_key:
|
||||
raise RuntimeError("EXA_API_KEY not configured")
|
||||
self.exa = Exa(self.api_key)
|
||||
logger.info("✅ Exa Research Provider initialized")
|
||||
|
||||
async def search(self, prompt, topic, industry, target_audience, config, user_id):
|
||||
"""Execute Exa neural search and return standardized results."""
|
||||
# Build Exa query
|
||||
query = f"{topic} {industry} {target_audience}"
|
||||
|
||||
# Map source types to Exa categories
|
||||
category = self._map_source_type_to_category(config.source_types)
|
||||
|
||||
logger.info(f"[Exa Research] Executing search: {query}")
|
||||
|
||||
# Execute Exa search
|
||||
results = self.exa.search_and_contents(
|
||||
query,
|
||||
type="auto",
|
||||
category=category,
|
||||
num_results=min(config.max_sources, 25),
|
||||
contents={
|
||||
'text': {'max_characters': 1000},
|
||||
'summary': {'query': f"Key insights about {topic}"},
|
||||
'highlights': {
|
||||
'num_sentences': 2,
|
||||
'highlights_per_url': 3
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Transform to standardized format
|
||||
sources = self._transform_sources(results.results)
|
||||
content = self._aggregate_content(results.results)
|
||||
search_type = getattr(results, 'resolvedSearchType', 'neural') if hasattr(results, 'resolvedSearchType') else 'neural'
|
||||
|
||||
# Get cost if available
|
||||
cost = 0.005 # Default Exa cost for 1-25 results
|
||||
if hasattr(results, 'costDollars'):
|
||||
if hasattr(results.costDollars, 'total'):
|
||||
cost = results.costDollars.total
|
||||
|
||||
logger.info(f"[Exa Research] Search completed: {len(sources)} sources, type: {search_type}")
|
||||
|
||||
return {
|
||||
'sources': sources,
|
||||
'content': content,
|
||||
'search_type': search_type,
|
||||
'provider': 'exa',
|
||||
'search_queries': [query],
|
||||
'cost': {'total': cost}
|
||||
}
|
||||
|
||||
def get_provider_enum(self):
|
||||
"""Return EXA provider enum for subscription tracking."""
|
||||
return APIProvider.EXA
|
||||
|
||||
def estimate_tokens(self) -> int:
|
||||
"""Estimate token usage for Exa (not token-based)."""
|
||||
return 0 # Exa is per-search, not token-based
|
||||
|
||||
def _map_source_type_to_category(self, source_types):
|
||||
"""Map SourceType enum to Exa category parameter."""
|
||||
if not source_types:
|
||||
return None
|
||||
|
||||
category_map = {
|
||||
'research paper': 'research paper',
|
||||
'news': 'news',
|
||||
'web': 'personal site',
|
||||
'industry': 'company',
|
||||
'expert': 'linkedin profile'
|
||||
}
|
||||
|
||||
for st in source_types:
|
||||
if st.value in category_map:
|
||||
return category_map[st.value]
|
||||
|
||||
return None
|
||||
|
||||
def _transform_sources(self, results):
|
||||
"""Transform Exa results to ResearchSource format."""
|
||||
sources = []
|
||||
for idx, result in enumerate(results):
|
||||
source_type = self._determine_source_type(result.url if hasattr(result, 'url') else '')
|
||||
|
||||
sources.append({
|
||||
'title': result.title if hasattr(result, 'title') else '',
|
||||
'url': result.url if hasattr(result, 'url') else '',
|
||||
'excerpt': self._get_excerpt(result),
|
||||
'credibility_score': 0.85, # Exa results are high quality
|
||||
'published_at': result.publishedDate if hasattr(result, 'publishedDate') else None,
|
||||
'index': idx,
|
||||
'source_type': source_type,
|
||||
'content': result.text if hasattr(result, 'text') else '',
|
||||
'highlights': result.highlights if hasattr(result, 'highlights') else [],
|
||||
'summary': result.summary if hasattr(result, 'summary') else ''
|
||||
})
|
||||
|
||||
return sources
|
||||
|
||||
def _get_excerpt(self, result):
|
||||
"""Extract excerpt from Exa result."""
|
||||
if hasattr(result, 'text') and result.text:
|
||||
return result.text[:500]
|
||||
elif hasattr(result, 'summary') and result.summary:
|
||||
return result.summary
|
||||
return ''
|
||||
|
||||
def _determine_source_type(self, url):
|
||||
"""Determine source type from URL."""
|
||||
if not url:
|
||||
return 'web'
|
||||
|
||||
url_lower = url.lower()
|
||||
if 'arxiv.org' in url_lower or 'research' in url_lower:
|
||||
return 'academic'
|
||||
elif any(news in url_lower for news in ['cnn.com', 'bbc.com', 'reuters.com', 'theguardian.com']):
|
||||
return 'news'
|
||||
elif 'linkedin.com' in url_lower:
|
||||
return 'expert'
|
||||
else:
|
||||
return 'web'
|
||||
|
||||
def _aggregate_content(self, results):
|
||||
"""Aggregate content from Exa results for LLM analysis."""
|
||||
content_parts = []
|
||||
|
||||
for idx, result in enumerate(results):
|
||||
if hasattr(result, 'summary') and result.summary:
|
||||
content_parts.append(f"Source {idx + 1}: {result.summary}")
|
||||
elif hasattr(result, 'text') and result.text:
|
||||
content_parts.append(f"Source {idx + 1}: {result.text[:1000]}")
|
||||
|
||||
return "\n\n".join(content_parts)
|
||||
|
||||
def track_exa_usage(self, user_id: str, cost: float):
|
||||
"""Track Exa API usage after successful call."""
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from sqlalchemy import text
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
current_period = pricing_service.get_current_billing_period(user_id)
|
||||
|
||||
# Update exa_calls and exa_cost via SQL UPDATE
|
||||
update_query = text("""
|
||||
UPDATE usage_summaries
|
||||
SET exa_calls = COALESCE(exa_calls, 0) + 1,
|
||||
exa_cost = COALESCE(exa_cost, 0) + :cost,
|
||||
total_calls = total_calls + 1,
|
||||
total_cost = total_cost + :cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db.execute(update_query, {
|
||||
'cost': cost,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
db.commit()
|
||||
|
||||
logger.info(f"[Exa] Tracked usage: user={user_id}, cost=${cost}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Exa] Failed to track usage: {e}")
|
||||
db.rollback()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
40
backend/services/blog_writer/research/google_provider.py
Normal file
40
backend/services/blog_writer/research/google_provider.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
Google Research Provider
|
||||
|
||||
Wrapper for Gemini native Google Search grounding to match base provider interface.
|
||||
"""
|
||||
|
||||
from services.llm_providers.gemini_grounded_provider import GeminiGroundedProvider
|
||||
from models.subscription_models import APIProvider
|
||||
from .base_provider import ResearchProvider as BaseProvider
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class GoogleResearchProvider(BaseProvider):
|
||||
"""Google research provider using Gemini native grounding."""
|
||||
|
||||
def __init__(self):
|
||||
self.gemini = GeminiGroundedProvider()
|
||||
|
||||
async def search(self, prompt, topic, industry, target_audience, config, user_id):
|
||||
"""Call Gemini grounding with pre-flight validation."""
|
||||
logger.info(f"[Google Research] Executing search for topic: {topic}")
|
||||
|
||||
result = await self.gemini.generate_grounded_content(
|
||||
prompt=prompt,
|
||||
content_type="research",
|
||||
max_tokens=2000,
|
||||
user_id=user_id,
|
||||
validate_subsequent_operations=True
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def get_provider_enum(self):
|
||||
"""Return GEMINI provider enum for subscription tracking."""
|
||||
return APIProvider.GEMINI
|
||||
|
||||
def estimate_tokens(self) -> int:
|
||||
"""Estimate token usage for Google grounding."""
|
||||
return 1200 # Conservative estimate
|
||||
|
||||
@@ -16,6 +16,9 @@ from models.blog_models import (
|
||||
GroundingChunk,
|
||||
GroundingSupport,
|
||||
Citation,
|
||||
ResearchConfig,
|
||||
ResearchMode,
|
||||
ResearchProvider,
|
||||
)
|
||||
from services.blog_writer.logger_config import blog_writer_logger, log_function_call
|
||||
from fastapi import HTTPException
|
||||
@@ -24,6 +27,7 @@ from .keyword_analyzer import KeywordAnalyzer
|
||||
from .competitor_analyzer import CompetitorAnalyzer
|
||||
from .content_angle_generator import ContentAngleGenerator
|
||||
from .data_filter import ResearchDataFilter
|
||||
from .research_strategies import get_strategy_for_mode
|
||||
|
||||
|
||||
class ResearchService:
|
||||
@@ -44,7 +48,6 @@ class ResearchService:
|
||||
Includes intelligent caching for exact keyword matches.
|
||||
"""
|
||||
try:
|
||||
from services.llm_providers.gemini_grounded_provider import GeminiGroundedProvider
|
||||
from services.cache.research_cache import research_cache
|
||||
|
||||
topic = request.topic or ", ".join(request.keywords)
|
||||
@@ -79,62 +82,104 @@ class ResearchService:
|
||||
|
||||
# Cache miss - proceed with API call
|
||||
logger.info(f"Cache miss - making API call for keywords: {request.keywords}")
|
||||
blog_writer_logger.log_operation_start("gemini_api_call", api_name="gemini_grounded", operation="research")
|
||||
gemini = GeminiGroundedProvider()
|
||||
blog_writer_logger.log_operation_start("research_api_call", api_name="research", operation="research")
|
||||
|
||||
# Single comprehensive research prompt - Gemini handles Google Search automatically
|
||||
research_prompt = f"""
|
||||
Research the topic "{topic}" in the {industry} industry for {target_audience} audience. Provide a comprehensive analysis including:
|
||||
|
||||
1. Current trends and insights (2024-2025)
|
||||
2. Key statistics and data points with sources
|
||||
3. Industry expert opinions and quotes
|
||||
4. Recent developments and news
|
||||
5. Market analysis and forecasts
|
||||
6. Best practices and case studies
|
||||
7. Keyword analysis: primary, secondary, and long-tail opportunities
|
||||
8. Competitor analysis: top players and content gaps
|
||||
9. Content angle suggestions: 5 compelling angles for blog posts
|
||||
|
||||
Focus on factual, up-to-date information from credible sources.
|
||||
Include specific data points, percentages, and recent developments.
|
||||
Structure your response with clear sections for each analysis area.
|
||||
"""
|
||||
# Determine research mode and get appropriate strategy
|
||||
research_mode = request.research_mode or ResearchMode.BASIC
|
||||
config = request.config or ResearchConfig(mode=research_mode, provider=ResearchProvider.GOOGLE)
|
||||
strategy = get_strategy_for_mode(research_mode)
|
||||
|
||||
# Single Gemini call with native Google Search grounding - no fallbacks
|
||||
# Validation is handled inside generate_grounded_content when validate_subsequent_operations=True
|
||||
import time
|
||||
api_start_time = time.time()
|
||||
gemini_result = await gemini.generate_grounded_content(
|
||||
prompt=research_prompt,
|
||||
content_type="research",
|
||||
max_tokens=2000,
|
||||
user_id=user_id,
|
||||
validate_subsequent_operations=True # Validates Google Grounding + 3 LLM calls
|
||||
)
|
||||
api_duration_ms = (time.time() - api_start_time) * 1000
|
||||
logger.info(f"Research: mode={research_mode.value}, provider={config.provider.value}")
|
||||
|
||||
# Log API call performance
|
||||
blog_writer_logger.log_api_call(
|
||||
"gemini_grounded",
|
||||
"generate_grounded_content",
|
||||
api_duration_ms,
|
||||
token_usage=gemini_result.get("token_usage", {}),
|
||||
content_length=len(gemini_result.get("content", ""))
|
||||
)
|
||||
# Build research prompt based on strategy
|
||||
research_prompt = strategy.build_research_prompt(topic, industry, target_audience, config)
|
||||
|
||||
# Extract sources from grounding metadata
|
||||
sources = self._extract_sources_from_grounding(gemini_result)
|
||||
# Route to appropriate provider
|
||||
if config.provider == ResearchProvider.EXA:
|
||||
# Exa research workflow
|
||||
from .exa_provider import ExaResearchProvider
|
||||
from services.subscription.preflight_validator import validate_exa_research_operations
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
import os
|
||||
import time
|
||||
|
||||
# Pre-flight validation
|
||||
db_val = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db_val)
|
||||
gpt_provider = os.getenv("GPT_PROVIDER", "google")
|
||||
validate_exa_research_operations(pricing_service, user_id, gpt_provider)
|
||||
finally:
|
||||
db_val.close()
|
||||
|
||||
# Execute Exa search
|
||||
api_start_time = time.time()
|
||||
try:
|
||||
exa_provider = ExaResearchProvider()
|
||||
raw_result = await exa_provider.search(
|
||||
research_prompt, topic, industry, target_audience, config, user_id
|
||||
)
|
||||
api_duration_ms = (time.time() - api_start_time) * 1000
|
||||
|
||||
# Track usage
|
||||
cost = raw_result.get('cost', {}).get('total', 0.005) if isinstance(raw_result.get('cost'), dict) else 0.005
|
||||
exa_provider.track_exa_usage(user_id, cost)
|
||||
|
||||
# Log API call performance
|
||||
blog_writer_logger.log_api_call(
|
||||
"exa_search",
|
||||
"search_and_contents",
|
||||
api_duration_ms,
|
||||
token_usage={},
|
||||
content_length=len(raw_result.get('content', ''))
|
||||
)
|
||||
|
||||
# Extract content for downstream analysis
|
||||
content = raw_result.get('content', '')
|
||||
sources = raw_result.get('sources', [])
|
||||
search_widget = "" # Exa doesn't provide search widgets
|
||||
search_queries = raw_result.get('search_queries', [])
|
||||
grounding_metadata = None # Exa doesn't provide grounding metadata
|
||||
|
||||
except RuntimeError as e:
|
||||
if "EXA_API_KEY not configured" in str(e):
|
||||
logger.warning("Exa not configured, falling back to Google")
|
||||
config.provider = ResearchProvider.GOOGLE
|
||||
# Continue to Google flow below
|
||||
raw_result = None
|
||||
else:
|
||||
raise
|
||||
|
||||
if config.provider != ResearchProvider.EXA:
|
||||
# Google research (existing flow) or fallback from Exa
|
||||
from .google_provider import GoogleResearchProvider
|
||||
import time
|
||||
|
||||
api_start_time = time.time()
|
||||
google_provider = GoogleResearchProvider()
|
||||
gemini_result = await google_provider.search(
|
||||
research_prompt, topic, industry, target_audience, config, user_id
|
||||
)
|
||||
api_duration_ms = (time.time() - api_start_time) * 1000
|
||||
|
||||
# Log API call performance
|
||||
blog_writer_logger.log_api_call(
|
||||
"gemini_grounded",
|
||||
"generate_grounded_content",
|
||||
api_duration_ms,
|
||||
token_usage=gemini_result.get("token_usage", {}),
|
||||
content_length=len(gemini_result.get("content", ""))
|
||||
)
|
||||
|
||||
# Extract sources and content
|
||||
sources = self._extract_sources_from_grounding(gemini_result)
|
||||
content = gemini_result.get("content", "")
|
||||
search_widget = gemini_result.get("search_widget", "") or ""
|
||||
search_queries = gemini_result.get("search_queries", []) or []
|
||||
grounding_metadata = self._extract_grounding_metadata(gemini_result)
|
||||
|
||||
# Extract grounding metadata for detailed UI display
|
||||
grounding_metadata = self._extract_grounding_metadata(gemini_result)
|
||||
|
||||
# Extract search widget and queries for UI display
|
||||
search_widget = gemini_result.get("search_widget", "") or ""
|
||||
search_queries = gemini_result.get("search_queries", []) or []
|
||||
|
||||
# Parse the comprehensive response for different analysis components
|
||||
content = gemini_result.get("content", "")
|
||||
# Continue with common analysis (same for both providers)
|
||||
keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords, user_id=user_id)
|
||||
competitor_analysis = self.competitor_analyzer.analyze(content, user_id=user_id)
|
||||
suggested_angles = self.content_angle_generator.generate(content, topic, industry, user_id=user_id)
|
||||
@@ -261,7 +306,6 @@ class ResearchService:
|
||||
Research method with progress updates for real-time feedback.
|
||||
"""
|
||||
try:
|
||||
from services.llm_providers.gemini_grounded_provider import GeminiGroundedProvider
|
||||
from services.cache.research_cache import research_cache
|
||||
from services.cache.persistent_research_cache import persistent_research_cache
|
||||
from api.blog_writer.task_manager import task_manager
|
||||
@@ -293,66 +337,100 @@ class ResearchService:
|
||||
logger.info(f"Returning cached research result for keywords: {request.keywords}")
|
||||
return BlogResearchResponse(**cached_result)
|
||||
|
||||
# User ID validation (validation logic is now in Google Grounding provider)
|
||||
# User ID validation
|
||||
if not user_id:
|
||||
await task_manager.update_progress(task_id, "❌ Error: User ID is required for research operation")
|
||||
raise ValueError("user_id is required for research operation. Please provide Clerk user ID.")
|
||||
|
||||
# Cache miss - proceed with API call
|
||||
await task_manager.update_progress(task_id, "🌐 Cache miss - connecting to Google Search grounding...")
|
||||
logger.info(f"Cache miss - making API call for keywords: {request.keywords}")
|
||||
gemini = GeminiGroundedProvider()
|
||||
|
||||
# Single comprehensive research prompt - Gemini handles Google Search automatically
|
||||
research_prompt = f"""
|
||||
Research the topic "{topic}" in the {industry} industry for {target_audience} audience. Provide a comprehensive analysis including:
|
||||
|
||||
1. Current trends and insights (2024-2025)
|
||||
2. Key statistics and data points with sources
|
||||
3. Industry expert opinions and quotes
|
||||
4. Recent developments and news
|
||||
5. Market analysis and forecasts
|
||||
6. Best practices and case studies
|
||||
7. Keyword analysis: primary, secondary, and long-tail opportunities
|
||||
8. Competitor analysis: top players and content gaps
|
||||
9. Content angle suggestions: 5 compelling angles for blog posts
|
||||
|
||||
Focus on factual, up-to-date information from credible sources.
|
||||
Include specific data points, percentages, and recent developments.
|
||||
Structure your response with clear sections for each analysis area.
|
||||
"""
|
||||
# Determine research mode and get appropriate strategy
|
||||
research_mode = request.research_mode or ResearchMode.BASIC
|
||||
config = request.config or ResearchConfig(mode=research_mode, provider=ResearchProvider.GOOGLE)
|
||||
strategy = get_strategy_for_mode(research_mode)
|
||||
|
||||
await task_manager.update_progress(task_id, "🤖 Making AI request to Gemini with Google Search grounding...")
|
||||
# Single Gemini call with native Google Search grounding - no fallbacks
|
||||
# Validation is handled inside generate_grounded_content when validate_subsequent_operations=True
|
||||
try:
|
||||
gemini_result = await gemini.generate_grounded_content(
|
||||
prompt=research_prompt,
|
||||
content_type="research",
|
||||
max_tokens=2000,
|
||||
user_id=user_id,
|
||||
validate_subsequent_operations=True # Validates Google Grounding + 3 LLM calls
|
||||
)
|
||||
except HTTPException as http_error:
|
||||
# Re-raise HTTPException so it can be properly handled by task manager
|
||||
logger.error(f"Subscription limit exceeded for research: {http_error.detail}")
|
||||
await task_manager.update_progress(task_id, f"❌ Subscription limit exceeded: {http_error.detail.get('message', str(http_error.detail)) if isinstance(http_error.detail, dict) else str(http_error.detail)}")
|
||||
raise # Re-raise HTTPException to preserve status code and error details
|
||||
logger.info(f"Research: mode={research_mode.value}, provider={config.provider.value}")
|
||||
|
||||
await task_manager.update_progress(task_id, "📊 Processing research results and extracting insights...")
|
||||
# Extract sources from grounding metadata
|
||||
sources = self._extract_sources_from_grounding(gemini_result)
|
||||
# Build research prompt based on strategy
|
||||
research_prompt = strategy.build_research_prompt(topic, industry, target_audience, config)
|
||||
|
||||
# Extract grounding metadata for detailed UI display
|
||||
grounding_metadata = self._extract_grounding_metadata(gemini_result)
|
||||
|
||||
# Extract search widget and queries for UI display
|
||||
search_widget = gemini_result.get("search_widget", "") or ""
|
||||
search_queries = gemini_result.get("search_queries", []) or []
|
||||
# Route to appropriate provider
|
||||
if config.provider == ResearchProvider.EXA:
|
||||
# Exa research workflow
|
||||
from .exa_provider import ExaResearchProvider
|
||||
from services.subscription.preflight_validator import validate_exa_research_operations
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
import os
|
||||
|
||||
await task_manager.update_progress(task_id, "🌐 Connecting to Exa neural search...")
|
||||
|
||||
# Pre-flight validation
|
||||
db_val = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db_val)
|
||||
gpt_provider = os.getenv("GPT_PROVIDER", "google")
|
||||
validate_exa_research_operations(pricing_service, user_id, gpt_provider)
|
||||
except HTTPException as http_error:
|
||||
logger.error(f"Subscription limit exceeded for Exa research: {http_error.detail}")
|
||||
await task_manager.update_progress(task_id, f"❌ Subscription limit exceeded: {http_error.detail.get('message', str(http_error.detail)) if isinstance(http_error.detail, dict) else str(http_error.detail)}")
|
||||
raise
|
||||
finally:
|
||||
db_val.close()
|
||||
|
||||
# Execute Exa search
|
||||
await task_manager.update_progress(task_id, "🤖 Executing Exa neural search...")
|
||||
try:
|
||||
exa_provider = ExaResearchProvider()
|
||||
raw_result = await exa_provider.search(
|
||||
research_prompt, topic, industry, target_audience, config, user_id
|
||||
)
|
||||
|
||||
# Track usage
|
||||
cost = raw_result.get('cost', {}).get('total', 0.005) if isinstance(raw_result.get('cost'), dict) else 0.005
|
||||
exa_provider.track_exa_usage(user_id, cost)
|
||||
|
||||
# Extract content for downstream analysis
|
||||
content = raw_result.get('content', '')
|
||||
sources = raw_result.get('sources', [])
|
||||
search_widget = "" # Exa doesn't provide search widgets
|
||||
search_queries = raw_result.get('search_queries', [])
|
||||
grounding_metadata = None # Exa doesn't provide grounding metadata
|
||||
|
||||
except RuntimeError as e:
|
||||
if "EXA_API_KEY not configured" in str(e):
|
||||
logger.warning("Exa not configured, falling back to Google")
|
||||
await task_manager.update_progress(task_id, "⚠️ Exa not configured, falling back to Google Search")
|
||||
config.provider = ResearchProvider.GOOGLE
|
||||
# Continue to Google flow below
|
||||
else:
|
||||
raise
|
||||
|
||||
if config.provider != ResearchProvider.EXA:
|
||||
# Google research (existing flow)
|
||||
from .google_provider import GoogleResearchProvider
|
||||
|
||||
await task_manager.update_progress(task_id, "🌐 Connecting to Google Search grounding...")
|
||||
google_provider = GoogleResearchProvider()
|
||||
|
||||
await task_manager.update_progress(task_id, "🤖 Making AI request to Gemini with Google Search grounding...")
|
||||
try:
|
||||
gemini_result = await google_provider.search(
|
||||
research_prompt, topic, industry, target_audience, config, user_id
|
||||
)
|
||||
except HTTPException as http_error:
|
||||
logger.error(f"Subscription limit exceeded for Google research: {http_error.detail}")
|
||||
await task_manager.update_progress(task_id, f"❌ Subscription limit exceeded: {http_error.detail.get('message', str(http_error.detail)) if isinstance(http_error.detail, dict) else str(http_error.detail)}")
|
||||
raise
|
||||
|
||||
await task_manager.update_progress(task_id, "📊 Processing research results and extracting insights...")
|
||||
# Extract sources and content
|
||||
sources = self._extract_sources_from_grounding(gemini_result)
|
||||
content = gemini_result.get("content", "")
|
||||
search_widget = gemini_result.get("search_widget", "") or ""
|
||||
search_queries = gemini_result.get("search_queries", []) or []
|
||||
grounding_metadata = self._extract_grounding_metadata(gemini_result)
|
||||
|
||||
# Continue with common analysis (same for both providers)
|
||||
await task_manager.update_progress(task_id, "🔍 Analyzing keywords and content angles...")
|
||||
# Parse the comprehensive response for different analysis components
|
||||
content = gemini_result.get("content", "")
|
||||
keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords, user_id=user_id)
|
||||
competitor_analysis = self.competitor_analyzer.analyze(content, user_id=user_id)
|
||||
suggested_angles = self.content_angle_generator.generate(content, topic, industry, user_id=user_id)
|
||||
|
||||
234
backend/services/blog_writer/research/research_strategies.py
Normal file
234
backend/services/blog_writer/research/research_strategies.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""
|
||||
Research Strategy Pattern Implementation
|
||||
|
||||
Different strategies for executing research based on depth and focus.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any
|
||||
from loguru import logger
|
||||
|
||||
from models.blog_models import BlogResearchRequest, ResearchMode, ResearchConfig
|
||||
from .keyword_analyzer import KeywordAnalyzer
|
||||
from .competitor_analyzer import CompetitorAnalyzer
|
||||
from .content_angle_generator import ContentAngleGenerator
|
||||
|
||||
|
||||
class ResearchStrategy(ABC):
|
||||
"""Base class for research strategies."""
|
||||
|
||||
def __init__(self):
|
||||
self.keyword_analyzer = KeywordAnalyzer()
|
||||
self.competitor_analyzer = CompetitorAnalyzer()
|
||||
self.content_angle_generator = ContentAngleGenerator()
|
||||
|
||||
@abstractmethod
|
||||
def build_research_prompt(
|
||||
self,
|
||||
topic: str,
|
||||
industry: str,
|
||||
target_audience: str,
|
||||
config: ResearchConfig
|
||||
) -> str:
|
||||
"""Build the research prompt for the strategy."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_mode(self) -> ResearchMode:
|
||||
"""Return the research mode this strategy handles."""
|
||||
pass
|
||||
|
||||
|
||||
class BasicResearchStrategy(ResearchStrategy):
|
||||
"""Basic research strategy - keyword focused, minimal analysis."""
|
||||
|
||||
def get_mode(self) -> ResearchMode:
|
||||
return ResearchMode.BASIC
|
||||
|
||||
def build_research_prompt(
|
||||
self,
|
||||
topic: str,
|
||||
industry: str,
|
||||
target_audience: str,
|
||||
config: ResearchConfig
|
||||
) -> str:
|
||||
"""Build basic research prompt focused on keywords and quick insights."""
|
||||
prompt = f"""You are a professional blog content strategist researching for a {industry} blog targeting {target_audience}.
|
||||
|
||||
Research Topic: "{topic}"
|
||||
|
||||
Provide analysis in this EXACT format:
|
||||
|
||||
## CURRENT TRENDS (2024-2025)
|
||||
- [Trend 1 with specific data and source URL]
|
||||
- [Trend 2 with specific data and source URL]
|
||||
- [Trend 3 with specific data and source URL]
|
||||
|
||||
## KEY STATISTICS
|
||||
- [Statistic 1: specific number/percentage with source URL]
|
||||
- [Statistic 2: specific number/percentage with source URL]
|
||||
- [Statistic 3: specific number/percentage with source URL]
|
||||
- [Statistic 4: specific number/percentage with source URL]
|
||||
- [Statistic 5: specific number/percentage with source URL]
|
||||
|
||||
## PRIMARY KEYWORDS
|
||||
1. "{topic}" (main keyword)
|
||||
2. [Variation 1]
|
||||
3. [Variation 2]
|
||||
|
||||
## SECONDARY KEYWORDS
|
||||
[5 related keywords for blog content]
|
||||
|
||||
## CONTENT ANGLES (Top 5)
|
||||
1. [Angle 1: specific unique approach]
|
||||
2. [Angle 2: specific unique approach]
|
||||
3. [Angle 3: specific unique approach]
|
||||
4. [Angle 4: specific unique approach]
|
||||
5. [Angle 5: specific unique approach]
|
||||
|
||||
REQUIREMENTS:
|
||||
- Cite EVERY claim with authoritative source URLs
|
||||
- Use 2024-2025 data when available
|
||||
- Include specific numbers, dates, examples
|
||||
- Focus on actionable blog insights for {target_audience}"""
|
||||
return prompt.strip()
|
||||
|
||||
|
||||
class ComprehensiveResearchStrategy(ResearchStrategy):
|
||||
"""Comprehensive research strategy - full analysis with all components."""
|
||||
|
||||
def get_mode(self) -> ResearchMode:
|
||||
return ResearchMode.COMPREHENSIVE
|
||||
|
||||
def build_research_prompt(
|
||||
self,
|
||||
topic: str,
|
||||
industry: str,
|
||||
target_audience: str,
|
||||
config: ResearchConfig
|
||||
) -> str:
|
||||
"""Build comprehensive research prompt with all analysis components."""
|
||||
date_filter = f"\nDate Focus: {config.date_range.value.replace('_', ' ')}" if config.date_range else ""
|
||||
source_filter = f"\nPriority Sources: {', '.join([s.value for s in config.source_types])}" if config.source_types else ""
|
||||
|
||||
prompt = f"""You are a senior blog content strategist conducting comprehensive research for a {industry} blog targeting {target_audience}.
|
||||
|
||||
Research Topic: "{topic}"{date_filter}{source_filter}
|
||||
|
||||
Provide COMPLETE analysis in this EXACT format:
|
||||
|
||||
## TRENDS AND INSIGHTS (2024-2025)
|
||||
[5-7 trends with specific data, numbers, and source URLs]
|
||||
|
||||
## KEY STATISTICS
|
||||
[7-10 statistics with exact numbers, percentages, dates, and source URLs]
|
||||
|
||||
## EXPERT OPINIONS
|
||||
[4-5 expert quotes with full attribution and source URLs]
|
||||
|
||||
## RECENT DEVELOPMENTS
|
||||
[5-7 recent news/developments with dates and source URLs]
|
||||
|
||||
## MARKET ANALYSIS
|
||||
[3-5 market insights with data points and source URLs]
|
||||
|
||||
## BEST PRACTICES & CASE STUDIES
|
||||
[3-5 examples with specific outcomes/metrics and source URLs]
|
||||
|
||||
## KEYWORD ANALYSIS
|
||||
Primary Keywords: [3 main variations]
|
||||
Secondary Keywords: [7-10 related keywords]
|
||||
Long-Tail Opportunities: [5-7 specific search phrases]
|
||||
|
||||
## COMPETITOR ANALYSIS
|
||||
Top Competitors: [5 competitors with brief descriptions]
|
||||
Content Gaps: [5 topics competitors are missing]
|
||||
Competitive Advantages: [5 unique angles we can own]
|
||||
|
||||
## CONTENT ANGLES (Exactly 5)
|
||||
1. [Unique angle with reasoning and target benefit]
|
||||
2. [Unique angle with reasoning and target benefit]
|
||||
3. [Unique angle with reasoning and target benefit]
|
||||
4. [Unique angle with reasoning and target benefit]
|
||||
5. [Unique angle with reasoning and target benefit]
|
||||
|
||||
VERIFICATION REQUIREMENTS:
|
||||
- Minimum 2 authoritative sources per major claim
|
||||
- Prioritize: Industry publications > Research papers > News > Blogs
|
||||
- 2024-2025 data strongly preferred
|
||||
- All numbers must include context (timeframe, sample size, methodology)
|
||||
- Every recommendation must be actionable for {target_audience}"""
|
||||
return prompt.strip()
|
||||
|
||||
|
||||
class TargetedResearchStrategy(ResearchStrategy):
|
||||
"""Targeted research strategy - focused on specific aspects."""
|
||||
|
||||
def get_mode(self) -> ResearchMode:
|
||||
return ResearchMode.TARGETED
|
||||
|
||||
def build_research_prompt(
|
||||
self,
|
||||
topic: str,
|
||||
industry: str,
|
||||
target_audience: str,
|
||||
config: ResearchConfig
|
||||
) -> str:
|
||||
"""Build targeted research prompt based on config preferences."""
|
||||
sections = []
|
||||
|
||||
if config.include_trends:
|
||||
sections.append("""## CURRENT TRENDS
|
||||
[3-5 trends with data and source URLs]""")
|
||||
|
||||
if config.include_statistics:
|
||||
sections.append("""## KEY STATISTICS
|
||||
[5-7 statistics with numbers and source URLs]""")
|
||||
|
||||
if config.include_expert_quotes:
|
||||
sections.append("""## EXPERT OPINIONS
|
||||
[3-4 expert quotes with attribution and source URLs]""")
|
||||
|
||||
if config.include_competitors:
|
||||
sections.append("""## COMPETITOR ANALYSIS
|
||||
Top Competitors: [3-5]
|
||||
Content Gaps: [3-5]""")
|
||||
|
||||
# Always include keywords and angles
|
||||
sections.append("""## KEYWORD ANALYSIS
|
||||
Primary: [2-3 variations]
|
||||
Secondary: [5-7 keywords]
|
||||
Long-Tail: [3-5 phrases]""")
|
||||
|
||||
sections.append("""## CONTENT ANGLES (3-5)
|
||||
[Unique blog angles with reasoning]""")
|
||||
|
||||
sections_str = "\n\n".join(sections)
|
||||
|
||||
prompt = f"""You are a blog content strategist conducting targeted research for a {industry} blog targeting {target_audience}.
|
||||
|
||||
Research Topic: "{topic}"
|
||||
|
||||
Provide focused analysis in this EXACT format:
|
||||
|
||||
{sections_str}
|
||||
|
||||
REQUIREMENTS:
|
||||
- Cite all claims with authoritative source URLs
|
||||
- Include specific numbers, dates, examples
|
||||
- Focus on actionable insights for {target_audience}
|
||||
- Use 2024-2025 data when available"""
|
||||
return prompt.strip()
|
||||
|
||||
|
||||
def get_strategy_for_mode(mode: ResearchMode) -> ResearchStrategy:
|
||||
"""Factory function to get the appropriate strategy for a mode."""
|
||||
strategy_map = {
|
||||
ResearchMode.BASIC: BasicResearchStrategy,
|
||||
ResearchMode.COMPREHENSIVE: ComprehensiveResearchStrategy,
|
||||
ResearchMode.TARGETED: TargetedResearchStrategy,
|
||||
}
|
||||
|
||||
strategy_class = strategy_map.get(mode, BasicResearchStrategy)
|
||||
return strategy_class()
|
||||
|
||||
@@ -2,4 +2,14 @@
|
||||
Wix integration modular services package.
|
||||
"""
|
||||
|
||||
from services.integrations.wix.seo import build_seo_data
|
||||
from services.integrations.wix.ricos_converter import markdown_to_html, convert_via_wix_api
|
||||
from services.integrations.wix.blog_publisher import create_blog_post
|
||||
|
||||
__all__ = [
|
||||
'build_seo_data',
|
||||
'markdown_to_html',
|
||||
'convert_via_wix_api',
|
||||
'create_blog_post',
|
||||
]
|
||||
|
||||
|
||||
@@ -20,6 +20,40 @@ class WixBlogService:
|
||||
return h
|
||||
|
||||
def create_draft_post(self, access_token: str, payload: Dict[str, Any], extra_headers: Optional[Dict[str, str]] = None) -> Dict[str, Any]:
|
||||
# Log the exact payload being sent for debugging
|
||||
import json
|
||||
logger.warning(f"📤 Sending to Wix Blog API:")
|
||||
logger.warning(f" Endpoint: {self.base_url}/blog/v3/draft-posts")
|
||||
logger.warning(f" Payload top-level keys: {list(payload.keys())}")
|
||||
if 'draftPost' in payload:
|
||||
dp = payload['draftPost']
|
||||
logger.warning(f" draftPost keys: {list(dp.keys())}")
|
||||
if 'richContent' in dp:
|
||||
rc = dp['richContent']
|
||||
logger.warning(f" richContent keys: {list(rc.keys()) if isinstance(rc, dict) else 'N/A'}")
|
||||
if isinstance(rc, dict) and 'nodes' in rc:
|
||||
nodes = rc['nodes']
|
||||
logger.warning(f" richContent.nodes count: {len(nodes) if isinstance(nodes, list) else 'N/A'}")
|
||||
# Inspect first LIST_ITEM node if any
|
||||
for i, node in enumerate(nodes[:10]):
|
||||
if isinstance(node, dict) and node.get('type') == 'LIST_ITEM':
|
||||
logger.warning(f" Found LIST_ITEM at index {i}:")
|
||||
logger.warning(f" Keys: {list(node.keys())}")
|
||||
logger.warning(f" Has listItemData: {'listItemData' in node}")
|
||||
if 'listItemData' in node:
|
||||
logger.warning(f" listItemData type: {type(node['listItemData'])}, value: {node['listItemData']}")
|
||||
if 'nodes' in node:
|
||||
nested = node['nodes']
|
||||
logger.warning(f" Nested nodes count: {len(nested) if isinstance(nested, list) else 'N/A'}")
|
||||
for j, n_node in enumerate(nested[:3]):
|
||||
if isinstance(n_node, dict):
|
||||
logger.warning(f" Nested node {j}: type={n_node.get('type')}, keys={list(n_node.keys())}")
|
||||
if n_node.get('type') == 'PARAGRAPH' and 'paragraphData' in n_node:
|
||||
logger.warning(f" paragraphData type: {type(n_node['paragraphData'])}, value: {n_node['paragraphData']}")
|
||||
break # Only inspect first LIST_ITEM
|
||||
|
||||
logger.warning(f" Full Payload JSON (first 8000 chars):\n{json.dumps(payload, indent=2, ensure_ascii=False)[:8000]}...")
|
||||
|
||||
response = requests.post(f"{self.base_url}/blog/v3/draft-posts", headers=self.headers(access_token, extra_headers), json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
716
backend/services/integrations/wix/blog_publisher.py
Normal file
716
backend/services/integrations/wix/blog_publisher.py
Normal file
@@ -0,0 +1,716 @@
|
||||
"""
|
||||
Blog Post Publisher for Wix
|
||||
|
||||
Handles blog post creation, validation, and publishing to Wix.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
import requests
|
||||
import jwt
|
||||
from typing import Dict, Any, Optional, List
|
||||
from loguru import logger
|
||||
from services.integrations.wix.blog import WixBlogService
|
||||
from services.integrations.wix.content import convert_content_to_ricos
|
||||
from services.integrations.wix.ricos_converter import convert_via_wix_api
|
||||
from services.integrations.wix.seo import build_seo_data
|
||||
|
||||
|
||||
def validate_ricos_content(ricos_content: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate and normalize Ricos document structure.
|
||||
|
||||
Args:
|
||||
ricos_content: Ricos document dict
|
||||
|
||||
Returns:
|
||||
Validated and normalized Ricos document
|
||||
"""
|
||||
# Validate Ricos document structure before using
|
||||
if not ricos_content or not isinstance(ricos_content, dict):
|
||||
logger.error("Invalid Ricos content - not a dict")
|
||||
raise ValueError("Failed to convert content to valid Ricos format")
|
||||
|
||||
if 'type' not in ricos_content:
|
||||
ricos_content['type'] = 'DOCUMENT'
|
||||
logger.debug("Added missing richContent type 'DOCUMENT'")
|
||||
if ricos_content.get('type') != 'DOCUMENT':
|
||||
logger.warning(f"richContent type expected 'DOCUMENT', got {ricos_content.get('type')}, correcting")
|
||||
ricos_content['type'] = 'DOCUMENT'
|
||||
|
||||
if 'id' not in ricos_content or not isinstance(ricos_content.get('id'), str):
|
||||
ricos_content['id'] = str(uuid.uuid4())
|
||||
logger.debug("Added missing richContent id")
|
||||
|
||||
if 'nodes' not in ricos_content:
|
||||
logger.warning("Ricos document missing 'nodes' field, adding empty nodes array")
|
||||
ricos_content['nodes'] = []
|
||||
|
||||
logger.debug(f"Ricos document structure: nodes={len(ricos_content.get('nodes', []))}")
|
||||
|
||||
# Validate richContent is a proper object with nodes array
|
||||
# Per Wix API: richContent must be a RichContent object with nodes array
|
||||
if not isinstance(ricos_content, dict):
|
||||
raise ValueError(f"richContent must be a dict object, got {type(ricos_content)}")
|
||||
|
||||
# Ensure nodes array exists and is valid
|
||||
if 'nodes' not in ricos_content:
|
||||
logger.warning("richContent missing 'nodes', adding empty array")
|
||||
ricos_content['nodes'] = []
|
||||
|
||||
if not isinstance(ricos_content['nodes'], list):
|
||||
raise ValueError(f"richContent.nodes must be a list, got {type(ricos_content['nodes'])}")
|
||||
|
||||
# Recursive function to validate and fix nodes at any depth
|
||||
def validate_node_recursive(node: Dict[str, Any], path: str = "root") -> None:
|
||||
"""
|
||||
Recursively validate a node and all its nested children, ensuring:
|
||||
1. All required data fields exist for each node type
|
||||
2. All 'nodes' arrays are proper lists
|
||||
3. No None values in critical fields
|
||||
"""
|
||||
if not isinstance(node, dict):
|
||||
logger.error(f"{path}: Node is not a dict: {type(node)}")
|
||||
return
|
||||
|
||||
# Ensure type and id exist
|
||||
if 'type' not in node:
|
||||
logger.error(f"{path}: Missing 'type' field - REQUIRED")
|
||||
node['type'] = 'PARAGRAPH' # Default fallback
|
||||
if 'id' not in node:
|
||||
node['id'] = str(uuid.uuid4())
|
||||
logger.debug(f"{path}: Added missing 'id'")
|
||||
|
||||
node_type = node.get('type')
|
||||
|
||||
# CRITICAL: Per Wix API schema, data fields like paragraphData, bulletedListData, etc.
|
||||
# are OPTIONAL and should be OMITTED entirely when empty, not included as {}
|
||||
# Only validate fields that have required properties
|
||||
|
||||
# Special handling: Remove listItemData if it exists (not in Wix API schema)
|
||||
if node_type == 'LIST_ITEM' and 'listItemData' in node:
|
||||
logger.debug(f"{path}: Removing incorrect listItemData field from LIST_ITEM")
|
||||
del node['listItemData']
|
||||
|
||||
# Only validate HEADING nodes - they require headingData with level property
|
||||
if node_type == 'HEADING':
|
||||
if 'headingData' not in node or not isinstance(node.get('headingData'), dict):
|
||||
logger.warning(f"{path} (HEADING): Missing headingData, adding default level 1")
|
||||
node['headingData'] = {'level': 1}
|
||||
elif 'level' not in node['headingData']:
|
||||
logger.warning(f"{path} (HEADING): Missing level in headingData, adding default")
|
||||
node['headingData']['level'] = 1
|
||||
|
||||
# TEXT nodes must have textData
|
||||
if node_type == 'TEXT':
|
||||
if 'textData' not in node or not isinstance(node.get('textData'), dict):
|
||||
logger.error(f"{path} (TEXT): Missing/invalid textData - node will be problematic")
|
||||
node['textData'] = {'text': '', 'decorations': []}
|
||||
|
||||
# LINK and IMAGE nodes must have their data fields
|
||||
if node_type == 'LINK' and ('linkData' not in node or not isinstance(node.get('linkData'), dict)):
|
||||
logger.error(f"{path} (LINK): Missing/invalid linkData - node will be problematic")
|
||||
if node_type == 'IMAGE' and ('imageData' not in node or not isinstance(node.get('imageData'), dict)):
|
||||
logger.error(f"{path} (IMAGE): Missing/invalid imageData - node will be problematic")
|
||||
|
||||
# Remove None values from any data fields that exist (Wix API rejects None)
|
||||
for data_field in ['headingData', 'paragraphData', 'blockquoteData', 'bulletedListData',
|
||||
'orderedListData', 'textData', 'linkData', 'imageData']:
|
||||
if data_field in node and isinstance(node[data_field], dict):
|
||||
data_value = node[data_field]
|
||||
keys_to_remove = [k for k, v in data_value.items() if v is None]
|
||||
if keys_to_remove:
|
||||
logger.debug(f"{path} ({node_type}): Removing None values from {data_field}: {keys_to_remove}")
|
||||
for key in keys_to_remove:
|
||||
del data_value[key]
|
||||
|
||||
# Ensure 'nodes' field exists for container nodes
|
||||
container_types = ['HEADING', 'PARAGRAPH', 'BLOCKQUOTE', 'LIST_ITEM', 'LINK',
|
||||
'BULLETED_LIST', 'ORDERED_LIST']
|
||||
if node_type in container_types:
|
||||
if 'nodes' not in node:
|
||||
logger.warning(f"{path} ({node_type}): Missing 'nodes' field, adding empty array")
|
||||
node['nodes'] = []
|
||||
elif not isinstance(node['nodes'], list):
|
||||
logger.error(f"{path} ({node_type}): Invalid 'nodes' field (not a list), fixing")
|
||||
node['nodes'] = []
|
||||
|
||||
# Recursively validate all nested nodes
|
||||
for nested_idx, nested_node in enumerate(node['nodes']):
|
||||
nested_path = f"{path}.nodes[{nested_idx}]"
|
||||
validate_node_recursive(nested_node, nested_path)
|
||||
|
||||
# Validate all top-level nodes recursively
|
||||
for idx, node in enumerate(ricos_content['nodes']):
|
||||
validate_node_recursive(node, f"nodes[{idx}]")
|
||||
|
||||
# Ensure documentStyle exists and is a dict (required by Wix API when provided)
|
||||
if 'metadata' not in ricos_content or not isinstance(ricos_content.get('metadata'), dict):
|
||||
ricos_content['metadata'] = {'version': 1, 'id': str(uuid.uuid4())}
|
||||
logger.debug("Added default metadata to richContent")
|
||||
else:
|
||||
ricos_content['metadata'].setdefault('version', 1)
|
||||
ricos_content['metadata'].setdefault('id', str(uuid.uuid4()))
|
||||
|
||||
if 'documentStyle' not in ricos_content or not isinstance(ricos_content.get('documentStyle'), dict):
|
||||
ricos_content['documentStyle'] = {
|
||||
'paragraph': {
|
||||
'decorations': [],
|
||||
'nodeStyle': {},
|
||||
'lineHeight': '1.5'
|
||||
}
|
||||
}
|
||||
logger.debug("Added default documentStyle to richContent")
|
||||
|
||||
logger.debug(f"✅ Validated richContent: {len(ricos_content['nodes'])} nodes, has_metadata={bool(ricos_content.get('metadata'))}, has_documentStyle={bool(ricos_content.get('documentStyle'))}")
|
||||
|
||||
return ricos_content
|
||||
|
||||
|
||||
def validate_payload_no_none(obj, path=""):
|
||||
"""Recursively validate that no None values exist in the payload"""
|
||||
if obj is None:
|
||||
raise ValueError(f"Found None value at path: {path}")
|
||||
if isinstance(obj, dict):
|
||||
for key, value in obj.items():
|
||||
validate_payload_no_none(value, f"{path}.{key}" if path else key)
|
||||
elif isinstance(obj, list):
|
||||
for idx, item in enumerate(obj):
|
||||
validate_payload_no_none(item, f"{path}[{idx}]" if path else f"[{idx}]")
|
||||
|
||||
|
||||
def create_blog_post(
|
||||
blog_service: WixBlogService,
|
||||
access_token: str,
|
||||
title: str,
|
||||
content: str,
|
||||
member_id: str,
|
||||
cover_image_url: str = None,
|
||||
category_ids: List[str] = None,
|
||||
tag_ids: List[str] = None,
|
||||
publish: bool = True,
|
||||
seo_metadata: Dict[str, Any] = None,
|
||||
import_image_func = None,
|
||||
lookup_categories_func = None,
|
||||
lookup_tags_func = None,
|
||||
base_url: str = 'https://www.wixapis.com'
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create and optionally publish a blog post on Wix
|
||||
|
||||
Args:
|
||||
blog_service: WixBlogService instance
|
||||
access_token: Valid access token
|
||||
title: Blog post title
|
||||
content: Blog post content (markdown)
|
||||
member_id: Required for third-party apps - the member ID of the post author
|
||||
cover_image_url: Optional cover image URL
|
||||
category_ids: Optional list of category IDs or names
|
||||
tag_ids: Optional list of tag IDs or names
|
||||
publish: Whether to publish immediately or save as draft
|
||||
seo_metadata: Optional SEO metadata dict
|
||||
import_image_func: Function to import images (optional)
|
||||
lookup_categories_func: Function to lookup/create categories (optional)
|
||||
lookup_tags_func: Function to lookup/create tags (optional)
|
||||
base_url: Wix API base URL
|
||||
|
||||
Returns:
|
||||
Created blog post information
|
||||
"""
|
||||
if not member_id:
|
||||
raise ValueError("memberId is required for third-party apps creating blog posts")
|
||||
|
||||
headers = {
|
||||
'Authorization': f'Bearer {access_token}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
# Build valid Ricos rich content
|
||||
# Ensure content is not empty
|
||||
if not content or not content.strip():
|
||||
content = "This is a post from ALwrity."
|
||||
logger.warning("⚠️ Content was empty, using default text")
|
||||
|
||||
# Try Wix API first (more reliable), fall back to custom parser
|
||||
ricos_content = None
|
||||
try:
|
||||
logger.warning("🔄 Attempting to convert markdown to Ricos via Wix API...")
|
||||
ricos_content = convert_via_wix_api(content, access_token, base_url)
|
||||
logger.warning(f"✅ Wix API conversion successful. Ricos document has {len(ricos_content.get('nodes', []))} nodes")
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ Wix Ricos API conversion failed: {e}. Falling back to custom parser...")
|
||||
# Fall back to custom parser
|
||||
ricos_content = convert_content_to_ricos(content, None)
|
||||
logger.warning(f"✅ Custom parser conversion complete. Ricos document has {len(ricos_content.get('nodes', []))} nodes")
|
||||
|
||||
# Validate Ricos content
|
||||
ricos_content = validate_ricos_content(ricos_content)
|
||||
|
||||
# Minimal payload per Wix docs: title, memberId, and richContent
|
||||
# CRITICAL: Only include fields that have valid values (no None, no empty strings for required fields)
|
||||
blog_data = {
|
||||
'draftPost': {
|
||||
'title': str(title).strip() if title else "Untitled",
|
||||
'memberId': str(member_id).strip(), # Required for third-party apps (validated above)
|
||||
'richContent': ricos_content, # Must be a valid Ricos document object
|
||||
},
|
||||
'publish': bool(publish),
|
||||
'fieldsets': ['URL'] # Simplified fieldsets
|
||||
}
|
||||
|
||||
# Add excerpt only if content exists and is not empty (avoid None or empty strings)
|
||||
excerpt = (content or '').strip()[:200] if content else None
|
||||
if excerpt and len(excerpt) > 0:
|
||||
blog_data['draftPost']['excerpt'] = str(excerpt)
|
||||
|
||||
# Add cover image if provided
|
||||
if cover_image_url and import_image_func:
|
||||
try:
|
||||
media_id = import_image_func(access_token, cover_image_url, f'Cover: {title}')
|
||||
# Ensure media_id is a string and not None
|
||||
if media_id and isinstance(media_id, str):
|
||||
blog_data['draftPost']['media'] = {
|
||||
'wixMedia': {
|
||||
'image': {'id': str(media_id).strip()}
|
||||
},
|
||||
'displayed': True,
|
||||
'custom': True
|
||||
}
|
||||
else:
|
||||
logger.warning(f"Invalid media_id type or value: {type(media_id)}, skipping media")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to import cover image: {e}")
|
||||
|
||||
# Handle categories - can be either IDs (list of strings) or names (for lookup)
|
||||
category_ids_to_use = None
|
||||
if category_ids:
|
||||
# Check if these are IDs (UUIDs) or names
|
||||
if isinstance(category_ids, list) and len(category_ids) > 0:
|
||||
# Assume IDs if first item looks like UUID (has hyphens and is long)
|
||||
first_item = str(category_ids[0])
|
||||
if '-' in first_item and len(first_item) > 30:
|
||||
category_ids_to_use = category_ids
|
||||
elif lookup_categories_func:
|
||||
# These are names, need to lookup/create
|
||||
extra_headers = {}
|
||||
if 'wix-site-id' in headers:
|
||||
extra_headers['wix-site-id'] = headers['wix-site-id']
|
||||
category_ids_to_use = lookup_categories_func(
|
||||
access_token, category_ids, extra_headers if extra_headers else None
|
||||
)
|
||||
|
||||
# Handle tags - can be either IDs (list of strings) or names (for lookup)
|
||||
tag_ids_to_use = None
|
||||
if tag_ids:
|
||||
# Check if these are IDs (UUIDs) or names
|
||||
if isinstance(tag_ids, list) and len(tag_ids) > 0:
|
||||
# Assume IDs if first item looks like UUID (has hyphens and is long)
|
||||
first_item = str(tag_ids[0])
|
||||
if '-' in first_item and len(first_item) > 30:
|
||||
tag_ids_to_use = tag_ids
|
||||
elif lookup_tags_func:
|
||||
# These are names, need to lookup/create
|
||||
extra_headers = {}
|
||||
if 'wix-site-id' in headers:
|
||||
extra_headers['wix-site-id'] = headers['wix-site-id']
|
||||
tag_ids_to_use = lookup_tags_func(
|
||||
access_token, tag_ids, extra_headers if extra_headers else None
|
||||
)
|
||||
|
||||
# Add categories if we have IDs (must be non-empty list of strings)
|
||||
# CRITICAL: Wix API rejects empty arrays or arrays with None/empty strings
|
||||
if category_ids_to_use and isinstance(category_ids_to_use, list) and len(category_ids_to_use) > 0:
|
||||
# Filter out None, empty strings, and ensure all are valid UUID strings
|
||||
valid_category_ids = [str(cid).strip() for cid in category_ids_to_use if cid and str(cid).strip()]
|
||||
if valid_category_ids:
|
||||
blog_data['draftPost']['categoryIds'] = valid_category_ids
|
||||
logger.debug(f"Added {len(valid_category_ids)} category IDs")
|
||||
else:
|
||||
logger.warning("All category IDs were invalid, not including categoryIds in payload")
|
||||
|
||||
# Add tags if we have IDs (must be non-empty list of strings)
|
||||
# CRITICAL: Wix API rejects empty arrays or arrays with None/empty strings
|
||||
if tag_ids_to_use and isinstance(tag_ids_to_use, list) and len(tag_ids_to_use) > 0:
|
||||
# Filter out None, empty strings, and ensure all are valid UUID strings
|
||||
valid_tag_ids = [str(tid).strip() for tid in tag_ids_to_use if tid and str(tid).strip()]
|
||||
if valid_tag_ids:
|
||||
blog_data['draftPost']['tagIds'] = valid_tag_ids
|
||||
logger.debug(f"Added {len(valid_tag_ids)} tag IDs")
|
||||
else:
|
||||
logger.warning("All tag IDs were invalid, not including tagIds in payload")
|
||||
|
||||
# Build SEO data from metadata if provided
|
||||
# TESTING: Skip SEO data temporarily to confirm richContent fix
|
||||
test_skip_seo = True
|
||||
if test_skip_seo:
|
||||
logger.warning("🧪 TESTING: Skipping SEO data to isolate richContent vs seoData issue")
|
||||
seo_data = None
|
||||
elif seo_metadata:
|
||||
logger.warning(f"📊 Building SEO data from metadata. Keys: {list(seo_metadata.keys())}")
|
||||
seo_data = build_seo_data(seo_metadata, title)
|
||||
if seo_data:
|
||||
# Log detailed SEO structure
|
||||
logger.warning(f"📋 SEO data built: {len(seo_data.get('tags', []))} tags, {len(seo_data.get('settings', {}).get('keywords', []))} keywords")
|
||||
|
||||
# Log each SEO tag for debugging (key ones only to avoid too much output)
|
||||
if seo_data.get('tags'):
|
||||
for idx, tag in enumerate(seo_data['tags'][:3]): # First 3 tags only
|
||||
tag_type = tag.get('type')
|
||||
if tag_type == 'title':
|
||||
logger.warning(f" SEO tag {idx+1}: type={tag_type}, children={str(tag.get('children', ''))[:50]}...")
|
||||
else:
|
||||
props = tag.get('props', {})
|
||||
content_preview = str(props.get('content', props.get('href', props.get('name', ''))))[:50]
|
||||
logger.warning(f" SEO tag {idx+1}: type={tag_type}, props={list(props.keys())}, content={content_preview}...")
|
||||
if len(seo_data['tags']) > 3:
|
||||
logger.warning(f" ... and {len(seo_data['tags']) - 3} more SEO tags")
|
||||
|
||||
blog_data['draftPost']['seoData'] = seo_data
|
||||
logger.warning(f"✅ Added seoData to blog post with {len(seo_data.get('tags', []))} tags")
|
||||
else:
|
||||
logger.warning("⚠️ SEO data was empty after building - check build_seo_data function")
|
||||
|
||||
# Add SEO slug if provided (separate field from seoData)
|
||||
if seo_metadata and seo_metadata.get('url_slug'):
|
||||
blog_data['draftPost']['seoSlug'] = str(seo_metadata.get('url_slug')).strip()
|
||||
logger.warning(f"✅ Added SEO slug: {blog_data['draftPost']['seoSlug']}")
|
||||
|
||||
if test_skip_seo:
|
||||
logger.warning("⚠️ SEO data skipped for testing - will add back once richContent is confirmed working")
|
||||
elif not seo_metadata:
|
||||
logger.warning("⚠️ No SEO metadata provided to create_blog_post")
|
||||
|
||||
# Log the payload structure for debugging (without sensitive data)
|
||||
logger.warning(f"📝 Creating blog post with title: '{title}'")
|
||||
logger.warning(f"📋 Draft post fields: {list(blog_data['draftPost'].keys())}")
|
||||
|
||||
# Detailed SEO logging
|
||||
if 'seoData' in blog_data['draftPost']:
|
||||
seo_data_debug = blog_data['draftPost']['seoData']
|
||||
logger.warning(f"📊 SEO data in payload: {len(seo_data_debug.get('tags', []))} tags, {len(seo_data_debug.get('settings', {}).get('keywords', []))} keywords")
|
||||
|
||||
# Log sample SEO tags (first 2 only to avoid too much output)
|
||||
if seo_data_debug.get('tags'):
|
||||
logger.warning("📋 SEO Tags sample:")
|
||||
for i, tag in enumerate(seo_data_debug['tags'][:2]): # First 2 tags
|
||||
logger.warning(f" Tag {i+1}: type={tag.get('type')}, custom={tag.get('custom')}, disabled={tag.get('disabled')}")
|
||||
if len(seo_data_debug['tags']) > 2:
|
||||
logger.warning(f" ... and {len(seo_data_debug['tags']) - 2} more tags")
|
||||
|
||||
if seo_data_debug.get('settings', {}).get('keywords'):
|
||||
keywords_list = [k.get('term') for k in seo_data_debug['settings']['keywords'][:3]]
|
||||
logger.warning(f"🔑 Keywords: {keywords_list}")
|
||||
|
||||
# Log FULL seoData structure for debugging
|
||||
import json
|
||||
try:
|
||||
seo_json = json.dumps(seo_data_debug, indent=2, ensure_ascii=False)
|
||||
logger.warning(f"📄 FULL seoData JSON:\n{seo_json[:2000]}...") # First 2000 chars
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to serialize seoData: {e}")
|
||||
else:
|
||||
logger.warning("⚠️ No seoData in draft post payload!")
|
||||
|
||||
try:
|
||||
# Add wix-site-id header if we can extract it from token
|
||||
extra_headers = {}
|
||||
try:
|
||||
token_str = str(access_token)
|
||||
if token_str and token_str.startswith('OauthNG.JWS.'):
|
||||
jwt_part = token_str[12:]
|
||||
payload = jwt.decode(jwt_part, options={"verify_signature": False, "verify_aud": False})
|
||||
data_payload = payload.get('data', {})
|
||||
if isinstance(data_payload, str):
|
||||
try:
|
||||
data_payload = json.loads(data_payload)
|
||||
except:
|
||||
pass
|
||||
instance_data = data_payload.get('instance', {})
|
||||
meta_site_id = instance_data.get('metaSiteId')
|
||||
if isinstance(meta_site_id, str) and meta_site_id:
|
||||
extra_headers['wix-site-id'] = meta_site_id
|
||||
headers['wix-site-id'] = meta_site_id
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not extract site ID from token: {e}")
|
||||
|
||||
# Make the API call
|
||||
logger.warning(f"🚀 Calling Wix API: POST /blog/v3/draft-posts")
|
||||
logger.warning(f"📦 Payload: title='{blog_data['draftPost'].get('title')}', has_seoData={'seoData' in blog_data['draftPost']}, has_richContent={'richContent' in blog_data['draftPost']}")
|
||||
|
||||
# Validate payload structure before sending
|
||||
draft_post = blog_data.get('draftPost', {})
|
||||
if not isinstance(draft_post, dict):
|
||||
raise ValueError("draftPost must be a dict object")
|
||||
|
||||
# Validate richContent structure
|
||||
if 'richContent' in draft_post:
|
||||
rc = draft_post['richContent']
|
||||
if not isinstance(rc, dict):
|
||||
raise ValueError(f"richContent must be a dict, got {type(rc)}")
|
||||
if 'nodes' not in rc:
|
||||
raise ValueError("richContent missing 'nodes' field")
|
||||
if not isinstance(rc['nodes'], list):
|
||||
raise ValueError(f"richContent.nodes must be a list, got {type(rc['nodes'])}")
|
||||
logger.debug(f"✅ richContent validation passed: {len(rc.get('nodes', []))} nodes")
|
||||
|
||||
# Validate seoData structure if present
|
||||
if 'seoData' in draft_post:
|
||||
seo = draft_post['seoData']
|
||||
if not isinstance(seo, dict):
|
||||
raise ValueError(f"seoData must be a dict, got {type(seo)}")
|
||||
if 'tags' in seo and not isinstance(seo['tags'], list):
|
||||
raise ValueError(f"seoData.tags must be a list, got {type(seo.get('tags'))}")
|
||||
if 'settings' in seo and not isinstance(seo['settings'], dict):
|
||||
raise ValueError(f"seoData.settings must be a dict, got {type(seo.get('settings'))}")
|
||||
logger.debug(f"✅ seoData validation passed: {len(seo.get('tags', []))} tags")
|
||||
|
||||
# Final validation: Ensure no None values in any nested objects
|
||||
# Wix API rejects None values and expects proper types
|
||||
try:
|
||||
validate_payload_no_none(blog_data, "blog_data")
|
||||
logger.debug("✅ Payload validation passed: No None values found")
|
||||
except ValueError as e:
|
||||
logger.error(f"❌ Payload validation failed: {e}")
|
||||
raise
|
||||
|
||||
# Log full payload structure for debugging (sanitized)
|
||||
logger.warning(f"📦 Full payload structure validation:")
|
||||
logger.warning(f" - draftPost type: {type(draft_post)}")
|
||||
logger.warning(f" - draftPost keys: {list(draft_post.keys())}")
|
||||
logger.warning(f" - richContent type: {type(draft_post.get('richContent'))}")
|
||||
if 'richContent' in draft_post:
|
||||
rc = draft_post['richContent']
|
||||
logger.warning(f" - richContent keys: {list(rc.keys()) if isinstance(rc, dict) else 'N/A'}")
|
||||
logger.warning(f" - richContent.nodes type: {type(rc.get('nodes'))}, count: {len(rc.get('nodes', []))}")
|
||||
logger.warning(f" - richContent.metadata type: {type(rc.get('metadata'))}")
|
||||
logger.warning(f" - richContent.documentStyle type: {type(rc.get('documentStyle'))}")
|
||||
logger.warning(f" - seoData type: {type(draft_post.get('seoData'))}")
|
||||
if 'seoData' in draft_post:
|
||||
seo = draft_post['seoData']
|
||||
logger.warning(f" - seoData keys: {list(seo.keys()) if isinstance(seo, dict) else 'N/A'}")
|
||||
logger.warning(f" - seoData.tags type: {type(seo.get('tags'))}, count: {len(seo.get('tags', []))}")
|
||||
logger.warning(f" - seoData.settings type: {type(seo.get('settings'))}")
|
||||
if 'categoryIds' in draft_post:
|
||||
logger.warning(f" - categoryIds type: {type(draft_post.get('categoryIds'))}, count: {len(draft_post.get('categoryIds', []))}")
|
||||
if 'tagIds' in draft_post:
|
||||
logger.warning(f" - tagIds type: {type(draft_post.get('tagIds'))}, count: {len(draft_post.get('tagIds', []))}")
|
||||
|
||||
# Log a sample of the payload JSON to see exact structure (first 2000 chars)
|
||||
try:
|
||||
import json
|
||||
payload_json = json.dumps(blog_data, indent=2, ensure_ascii=False)
|
||||
logger.warning(f"📄 Payload JSON preview (first 3000 chars):\n{payload_json[:3000]}...")
|
||||
|
||||
# Also log a deep structure inspection of richContent.nodes (first few nodes)
|
||||
if 'richContent' in blog_data['draftPost']:
|
||||
nodes = blog_data['draftPost']['richContent'].get('nodes', [])
|
||||
if nodes:
|
||||
logger.warning(f"🔍 Inspecting first 5 richContent.nodes:")
|
||||
for i, node in enumerate(nodes[:5]):
|
||||
logger.warning(f" Node {i+1}: type={node.get('type')}, keys={list(node.keys())}")
|
||||
# Check for any None values in node
|
||||
for key, value in node.items():
|
||||
if value is None:
|
||||
logger.error(f" ⚠️ Node {i+1}.{key} is None!")
|
||||
elif isinstance(value, dict):
|
||||
for k, v in value.items():
|
||||
if v is None:
|
||||
logger.error(f" ⚠️ Node {i+1}.{key}.{k} is None!")
|
||||
# Deep check: if it's a list-type node, inspect list items
|
||||
if node.get('type') in ['BULLETED_LIST', 'ORDERED_LIST']:
|
||||
list_items = node.get('nodes', [])
|
||||
if list_items:
|
||||
logger.warning(f" List has {len(list_items)} items, checking first LIST_ITEM:")
|
||||
first_item = list_items[0]
|
||||
logger.warning(f" LIST_ITEM keys: {list(first_item.keys())}")
|
||||
# Verify listItemData is NOT present (correct per Wix API spec)
|
||||
if 'listItemData' in first_item:
|
||||
logger.error(f" ❌ LIST_ITEM incorrectly has listItemData!")
|
||||
else:
|
||||
logger.debug(f" ✅ LIST_ITEM correctly has no listItemData")
|
||||
# Check nested PARAGRAPH nodes
|
||||
nested_nodes = first_item.get('nodes', [])
|
||||
if nested_nodes:
|
||||
logger.warning(f" LIST_ITEM has {len(nested_nodes)} nested nodes")
|
||||
for n_idx, n_node in enumerate(nested_nodes[:2]):
|
||||
logger.warning(f" Nested node {n_idx+1}: type={n_node.get('type')}, keys={list(n_node.keys())}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not serialize payload for logging: {e}")
|
||||
|
||||
# Note: All node validation is done by validate_ricos_content() which runs earlier
|
||||
# The recursive validation ensures all required data fields are present at any depth
|
||||
|
||||
# Final deep validation: Serialize and deserialize to catch any JSON-serialization issues
|
||||
# This will raise an error if there are any objects that can't be serialized
|
||||
try:
|
||||
import json
|
||||
test_json = json.dumps(blog_data, ensure_ascii=False)
|
||||
test_parsed = json.loads(test_json)
|
||||
logger.debug("✅ Payload JSON serialization test passed")
|
||||
except (TypeError, ValueError) as e:
|
||||
logger.error(f"❌ Payload JSON serialization failed: {e}")
|
||||
raise ValueError(f"Payload contains non-serializable data: {e}")
|
||||
|
||||
# Final check: Ensure documentStyle and metadata are valid objects (not None, not empty strings)
|
||||
rc = blog_data['draftPost']['richContent']
|
||||
if 'documentStyle' in rc:
|
||||
doc_style = rc['documentStyle']
|
||||
if doc_style is None or doc_style == "":
|
||||
logger.warning("⚠️ documentStyle is None or empty string, removing it")
|
||||
del rc['documentStyle']
|
||||
elif not isinstance(doc_style, dict):
|
||||
logger.warning(f"⚠️ documentStyle is not a dict ({type(doc_style)}), removing it")
|
||||
del rc['documentStyle']
|
||||
|
||||
if 'metadata' in rc:
|
||||
metadata = rc['metadata']
|
||||
if metadata is None or metadata == "":
|
||||
logger.warning("⚠️ metadata is None or empty string, removing it")
|
||||
del rc['metadata']
|
||||
elif not isinstance(metadata, dict):
|
||||
logger.warning(f"⚠️ metadata is not a dict ({type(metadata)}), removing it")
|
||||
del rc['metadata']
|
||||
|
||||
# Check for any None values in critical nested structures
|
||||
def check_none_in_dict(d, path=""):
|
||||
"""Recursively check for None values that shouldn't be there"""
|
||||
issues = []
|
||||
if isinstance(d, dict):
|
||||
for key, value in d.items():
|
||||
current_path = f"{path}.{key}" if path else key
|
||||
if value is None:
|
||||
# Some fields can legitimately be None, but most shouldn't
|
||||
if key not in ['decorations', 'nodeStyle', 'props']:
|
||||
issues.append(current_path)
|
||||
elif isinstance(value, dict):
|
||||
issues.extend(check_none_in_dict(value, current_path))
|
||||
elif isinstance(value, list):
|
||||
for i, item in enumerate(value):
|
||||
if item is None:
|
||||
issues.append(f"{current_path}[{i}]")
|
||||
elif isinstance(item, dict):
|
||||
issues.extend(check_none_in_dict(item, f"{current_path}[{i}]"))
|
||||
return issues
|
||||
|
||||
none_issues = check_none_in_dict(blog_data['draftPost']['richContent'])
|
||||
if none_issues:
|
||||
logger.error(f"❌ Found None values in richContent at: {none_issues[:10]}") # Limit to first 10
|
||||
# Remove None values from critical paths
|
||||
for issue_path in none_issues[:5]: # Fix first 5
|
||||
parts = issue_path.split('.')
|
||||
try:
|
||||
obj = blog_data['draftPost']['richContent']
|
||||
for part in parts[:-1]:
|
||||
if '[' in part:
|
||||
key, idx = part.split('[')
|
||||
idx = int(idx.rstrip(']'))
|
||||
obj = obj[key][idx]
|
||||
else:
|
||||
obj = obj[part]
|
||||
final_key = parts[-1]
|
||||
if '[' in final_key:
|
||||
key, idx = final_key.split('[')
|
||||
idx = int(idx.rstrip(']'))
|
||||
obj[key][idx] = {}
|
||||
else:
|
||||
obj[final_key] = {}
|
||||
logger.warning(f"Fixed None value at {issue_path}")
|
||||
except:
|
||||
pass
|
||||
|
||||
# Log the final payload structure one more time before sending
|
||||
logger.warning(f"📤 Final payload ready - draftPost keys: {list(blog_data['draftPost'].keys())}")
|
||||
logger.warning(f"📤 RichContent nodes count: {len(blog_data['draftPost']['richContent'].get('nodes', []))}")
|
||||
logger.warning(f"📤 RichContent has metadata: {bool(blog_data['draftPost']['richContent'].get('metadata'))}")
|
||||
logger.warning(f"📤 RichContent has documentStyle: {bool(blog_data['draftPost']['richContent'].get('documentStyle'))}")
|
||||
|
||||
# Try sending WITHOUT SEO data first to isolate the issue
|
||||
test_without_seo = False # Disabled - listItemData issue fixed
|
||||
if test_without_seo and 'seoData' in blog_data['draftPost']:
|
||||
logger.warning("🧪 TESTING WITHOUT SEO DATA to isolate issue...")
|
||||
# Clone the payload without SEO data
|
||||
test_payload_no_seo = {
|
||||
'draftPost': {
|
||||
'title': blog_data['draftPost']['title'],
|
||||
'memberId': blog_data['draftPost']['memberId'],
|
||||
'richContent': blog_data['draftPost']['richContent'],
|
||||
'excerpt': blog_data['draftPost'].get('excerpt', '')
|
||||
},
|
||||
'publish': False,
|
||||
'fieldsets': ['URL']
|
||||
}
|
||||
try:
|
||||
logger.warning("🧪 Attempting without SEO data...")
|
||||
test_result = blog_service.create_draft_post(access_token, test_payload_no_seo, extra_headers or None)
|
||||
logger.warning(f"✅ WITHOUT SEO DATA SUCCEEDED! Post ID: {test_result.get('draftPost', {}).get('id')}")
|
||||
logger.error("⚠️⚠️⚠️ ISSUE IS WITH SEO DATA STRUCTURE!")
|
||||
# If this succeeds, don't send the full payload, just return this result
|
||||
return test_result
|
||||
except Exception as e:
|
||||
logger.warning(f"❌ WITHOUT SEO DATA ALSO FAILED: {e}")
|
||||
logger.warning("⚠️ Issue is NOT with SEO data, continuing with full payload...")
|
||||
|
||||
# Try sending with minimal structure first to isolate the issue
|
||||
# Create a test payload with just required fields
|
||||
minimal_test = False # Set to True to test with minimal payload
|
||||
if minimal_test:
|
||||
logger.warning("🧪 TESTING WITH MINIMAL PAYLOAD (title + memberId + simple richContent)")
|
||||
test_payload = {
|
||||
'draftPost': {
|
||||
'title': blog_data['draftPost']['title'],
|
||||
'memberId': blog_data['draftPost']['memberId'],
|
||||
'richContent': {
|
||||
'nodes': [
|
||||
{
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'PARAGRAPH',
|
||||
'nodes': [
|
||||
{
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'TEXT',
|
||||
'textData': {
|
||||
'text': 'Test paragraph',
|
||||
'decorations': []
|
||||
}
|
||||
}
|
||||
],
|
||||
'paragraphData': {}
|
||||
}
|
||||
],
|
||||
'metadata': {'version': 1, 'id': str(uuid.uuid4())},
|
||||
'documentStyle': {}
|
||||
}
|
||||
},
|
||||
'publish': False,
|
||||
'fieldsets': ['URL']
|
||||
}
|
||||
logger.warning("🧪 Attempting minimal payload first...")
|
||||
try:
|
||||
test_result = blog_service.create_draft_post(access_token, test_payload, extra_headers or None)
|
||||
logger.warning(f"✅ MINIMAL PAYLOAD SUCCEEDED! Post ID: {test_result.get('draftPost', {}).get('id')}")
|
||||
logger.warning("⚠️ Issue is with complex content, not basic structure")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ MINIMAL PAYLOAD ALSO FAILED: {e}")
|
||||
logger.error("⚠️ Issue is with basic structure or permissions")
|
||||
|
||||
result = blog_service.create_draft_post(access_token, blog_data, extra_headers or None)
|
||||
|
||||
# Log response
|
||||
draft_post = result.get('draftPost', {})
|
||||
logger.warning(f"✅ Blog post created successfully! Post ID: {draft_post.get('id', 'N/A')}")
|
||||
|
||||
# Check if SEO data was preserved in response
|
||||
if 'seoData' in draft_post:
|
||||
seo_response = draft_post['seoData']
|
||||
logger.warning(f"✅ SEO data confirmed in response: {len(seo_response.get('tags', []))} tags, {len(seo_response.get('settings', {}).get('keywords', []))} keywords")
|
||||
else:
|
||||
logger.warning("⚠️ No seoData in response - it may have been filtered out by Wix API")
|
||||
logger.warning(f"📋 Response fields: {list(draft_post.keys())}")
|
||||
|
||||
return result
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Failed to create blog post: {e}")
|
||||
if hasattr(e, 'response') and e.response is not None:
|
||||
logger.error(f"Response body: {e.response.text}")
|
||||
raise
|
||||
|
||||
@@ -1,58 +1,460 @@
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
def parse_markdown_inline(text: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Parse inline markdown formatting (bold, italic, links) into Ricos text nodes.
|
||||
Returns a list of text nodes with decorations.
|
||||
Handles: **bold**, *italic*, [links](url), `code`, and combinations.
|
||||
"""
|
||||
if not text:
|
||||
return [{
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'TEXT',
|
||||
'textData': {'text': '', 'decorations': []}
|
||||
}]
|
||||
|
||||
nodes = []
|
||||
|
||||
# Process text character by character to handle nested/adjacent formatting
|
||||
# This is more robust than regex for complex cases
|
||||
i = 0
|
||||
current_text = ''
|
||||
current_decorations = []
|
||||
|
||||
while i < len(text):
|
||||
# Check for bold **text** (must come before single * check)
|
||||
if i < len(text) - 1 and text[i:i+2] == '**':
|
||||
# Save any accumulated text
|
||||
if current_text:
|
||||
nodes.append({
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'TEXT',
|
||||
'textData': {
|
||||
'text': current_text,
|
||||
'decorations': current_decorations.copy()
|
||||
}
|
||||
})
|
||||
current_text = ''
|
||||
|
||||
# Find closing **
|
||||
end_bold = text.find('**', i + 2)
|
||||
if end_bold != -1:
|
||||
bold_text = text[i + 2:end_bold]
|
||||
# Recursively parse the bold text for nested formatting
|
||||
bold_nodes = parse_markdown_inline(bold_text)
|
||||
# Add BOLD decoration to all text nodes within
|
||||
for node in bold_nodes:
|
||||
if node['type'] == 'TEXT':
|
||||
node_decorations = node['textData'].get('decorations', []).copy()
|
||||
if 'BOLD' not in node_decorations:
|
||||
node_decorations.append('BOLD')
|
||||
node['textData']['decorations'] = node_decorations
|
||||
nodes.append(node)
|
||||
i = end_bold + 2
|
||||
continue
|
||||
|
||||
# Check for link [text](url)
|
||||
elif text[i] == '[':
|
||||
# Save any accumulated text
|
||||
if current_text:
|
||||
nodes.append({
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'TEXT',
|
||||
'textData': {
|
||||
'text': current_text,
|
||||
'decorations': current_decorations.copy()
|
||||
}
|
||||
})
|
||||
current_text = ''
|
||||
current_decorations = []
|
||||
|
||||
# Find matching ]
|
||||
link_end = text.find(']', i)
|
||||
if link_end != -1 and link_end < len(text) - 1 and text[link_end + 1] == '(':
|
||||
link_text = text[i + 1:link_end]
|
||||
url_start = link_end + 2
|
||||
url_end = text.find(')', url_start)
|
||||
if url_end != -1:
|
||||
url = text[url_start:url_end]
|
||||
# Create link node
|
||||
link_node_id = str(uuid.uuid4())
|
||||
text_node_id = str(uuid.uuid4())
|
||||
link_text_nodes = parse_markdown_inline(link_text)
|
||||
# Wrap link text in LINK node
|
||||
nodes.append({
|
||||
'id': link_node_id,
|
||||
'type': 'LINK',
|
||||
'nodes': link_text_nodes if link_text_nodes else [{
|
||||
'id': text_node_id,
|
||||
'type': 'TEXT',
|
||||
'textData': {'text': link_text, 'decorations': []}
|
||||
}],
|
||||
'linkData': {
|
||||
'link': {
|
||||
'url': url,
|
||||
'target': '_blank'
|
||||
}
|
||||
}
|
||||
})
|
||||
i = url_end + 1
|
||||
continue
|
||||
|
||||
# Check for code `text`
|
||||
elif text[i] == '`':
|
||||
# Save any accumulated text
|
||||
if current_text:
|
||||
nodes.append({
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'TEXT',
|
||||
'textData': {
|
||||
'text': current_text,
|
||||
'decorations': current_decorations.copy()
|
||||
}
|
||||
})
|
||||
current_text = ''
|
||||
current_decorations = []
|
||||
|
||||
# Find closing `
|
||||
code_end = text.find('`', i + 1)
|
||||
if code_end != -1:
|
||||
code_text = text[i + 1:code_end]
|
||||
nodes.append({
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'TEXT',
|
||||
'textData': {
|
||||
'text': code_text,
|
||||
'decorations': ['CODE']
|
||||
}
|
||||
})
|
||||
i = code_end + 1
|
||||
continue
|
||||
|
||||
# Check for italic *text* (only if not part of **)
|
||||
elif text[i] == '*' and (i == 0 or text[i-1] != '*') and (i == len(text) - 1 or text[i+1] != '*'):
|
||||
# Save any accumulated text
|
||||
if current_text:
|
||||
nodes.append({
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'TEXT',
|
||||
'textData': {
|
||||
'text': current_text,
|
||||
'decorations': current_decorations.copy()
|
||||
}
|
||||
})
|
||||
current_text = ''
|
||||
current_decorations = []
|
||||
|
||||
# Find closing * (but not **)
|
||||
italic_end = text.find('*', i + 1)
|
||||
if italic_end != -1:
|
||||
# Make sure it's not part of **
|
||||
if italic_end == len(text) - 1 or text[italic_end + 1] != '*':
|
||||
italic_text = text[i + 1:italic_end]
|
||||
italic_nodes = parse_markdown_inline(italic_text)
|
||||
# Add ITALIC decoration
|
||||
for node in italic_nodes:
|
||||
if node['type'] == 'TEXT':
|
||||
node_decorations = node['textData'].get('decorations', []).copy()
|
||||
if 'ITALIC' not in node_decorations:
|
||||
node_decorations.append('ITALIC')
|
||||
node['textData']['decorations'] = node_decorations
|
||||
nodes.append(node)
|
||||
i = italic_end + 1
|
||||
continue
|
||||
|
||||
# Regular character
|
||||
current_text += text[i]
|
||||
i += 1
|
||||
|
||||
# Add any remaining text
|
||||
if current_text:
|
||||
nodes.append({
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'TEXT',
|
||||
'textData': {
|
||||
'text': current_text,
|
||||
'decorations': current_decorations.copy()
|
||||
}
|
||||
})
|
||||
|
||||
# If no nodes created, return single plain text node
|
||||
if not nodes:
|
||||
nodes.append({
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'TEXT',
|
||||
'textData': {
|
||||
'text': text,
|
||||
'decorations': []
|
||||
}
|
||||
})
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert simple markdown-like text into minimal valid Ricos JSON.
|
||||
Convert markdown content into valid Ricos JSON format.
|
||||
Supports headings, paragraphs, lists, bold, italic, links, and images.
|
||||
"""
|
||||
paragraphs = content.split('\n\n')
|
||||
if not content:
|
||||
content = "This is a post from ALwrity."
|
||||
|
||||
nodes = []
|
||||
|
||||
import uuid
|
||||
|
||||
for paragraph in paragraphs:
|
||||
text = paragraph.strip()
|
||||
if not text:
|
||||
lines = content.split('\n')
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i].strip()
|
||||
|
||||
if not line:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
text_node_id = str(uuid.uuid4())
|
||||
|
||||
if text.startswith('#'):
|
||||
level = len(text) - len(text.lstrip('#'))
|
||||
heading_text = text.lstrip('# ').strip()
|
||||
|
||||
# Check for headings
|
||||
if line.startswith('#'):
|
||||
level = len(line) - len(line.lstrip('#'))
|
||||
heading_text = line.lstrip('# ').strip()
|
||||
text_nodes = parse_markdown_inline(heading_text)
|
||||
nodes.append({
|
||||
'id': node_id,
|
||||
'type': 'HEADING',
|
||||
'nodes': [{
|
||||
'id': text_node_id,
|
||||
'type': 'TEXT',
|
||||
'textData': {
|
||||
'text': heading_text,
|
||||
'decorations': []
|
||||
}
|
||||
}],
|
||||
'headingData': { 'level': min(level, 6) }
|
||||
'nodes': text_nodes,
|
||||
'headingData': {'level': min(level, 6)}
|
||||
})
|
||||
else:
|
||||
nodes.append({
|
||||
'id': node_id,
|
||||
i += 1
|
||||
|
||||
# Check for blockquotes
|
||||
elif line.startswith('>'):
|
||||
quote_text = line.lstrip('> ').strip()
|
||||
# Continue reading consecutive blockquote lines
|
||||
quote_lines = [quote_text]
|
||||
i += 1
|
||||
while i < len(lines) and lines[i].strip().startswith('>'):
|
||||
quote_lines.append(lines[i].strip().lstrip('> ').strip())
|
||||
i += 1
|
||||
quote_content = ' '.join(quote_lines)
|
||||
text_nodes = parse_markdown_inline(quote_content)
|
||||
# CRITICAL: TEXT nodes must be wrapped in PARAGRAPH nodes within BLOCKQUOTE
|
||||
paragraph_node = {
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'PARAGRAPH',
|
||||
'nodes': [{
|
||||
'id': text_node_id,
|
||||
'type': 'TEXT',
|
||||
'textData': {
|
||||
'text': text,
|
||||
'decorations': []
|
||||
}
|
||||
}],
|
||||
'nodes': text_nodes,
|
||||
'paragraphData': {}
|
||||
})
|
||||
|
||||
}
|
||||
blockquote_node = {
|
||||
'id': node_id,
|
||||
'type': 'BLOCKQUOTE',
|
||||
'nodes': [paragraph_node],
|
||||
'blockquoteData': {}
|
||||
}
|
||||
nodes.append(blockquote_node)
|
||||
|
||||
# Check for unordered lists (handle both '- ' and '* ' markers)
|
||||
elif (line.startswith('- ') or line.startswith('* ') or
|
||||
(line.startswith('-') and len(line) > 1 and line[1] != '-') or
|
||||
(line.startswith('*') and len(line) > 1 and line[1] != '*')):
|
||||
list_items = []
|
||||
list_marker = '- ' if line.startswith('-') else '* '
|
||||
# Process list items
|
||||
while i < len(lines):
|
||||
current_line = lines[i].strip()
|
||||
# Check if this is a list item
|
||||
is_list_item = (current_line.startswith('- ') or current_line.startswith('* ') or
|
||||
(current_line.startswith('-') and len(current_line) > 1 and current_line[1] != '-') or
|
||||
(current_line.startswith('*') and len(current_line) > 1 and current_line[1] != '*'))
|
||||
|
||||
if not is_list_item:
|
||||
break
|
||||
|
||||
# Extract item text (handle both '- ' and '-item' formats)
|
||||
if current_line.startswith('- ') or current_line.startswith('* '):
|
||||
item_text = current_line[2:].strip()
|
||||
elif current_line.startswith('-'):
|
||||
item_text = current_line[1:].strip()
|
||||
elif current_line.startswith('*'):
|
||||
item_text = current_line[1:].strip()
|
||||
else:
|
||||
item_text = current_line
|
||||
|
||||
list_items.append(item_text)
|
||||
i += 1
|
||||
|
||||
# Check for nested items (indented with 2+ spaces)
|
||||
while i < len(lines):
|
||||
next_line = lines[i]
|
||||
# Must be indented and be a list marker
|
||||
if next_line.startswith(' ') and (next_line.strip().startswith('- ') or
|
||||
next_line.strip().startswith('* ') or
|
||||
(next_line.strip().startswith('-') and len(next_line.strip()) > 1) or
|
||||
(next_line.strip().startswith('*') and len(next_line.strip()) > 1)):
|
||||
nested_text = next_line.strip()
|
||||
if nested_text.startswith('- ') or nested_text.startswith('* '):
|
||||
nested_text = nested_text[2:].strip()
|
||||
elif nested_text.startswith('-'):
|
||||
nested_text = nested_text[1:].strip()
|
||||
elif nested_text.startswith('*'):
|
||||
nested_text = nested_text[1:].strip()
|
||||
list_items.append(nested_text)
|
||||
i += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# Build list items with proper formatting
|
||||
# CRITICAL: TEXT nodes must be wrapped in PARAGRAPH nodes within LIST_ITEM
|
||||
# NOTE: LIST_ITEM nodes do NOT have a data field per Wix API schema
|
||||
# Wix API: omit empty data objects, don't include them as {}
|
||||
list_node_items = []
|
||||
for item_text in list_items:
|
||||
item_node_id = str(uuid.uuid4())
|
||||
text_nodes = parse_markdown_inline(item_text)
|
||||
paragraph_node = {
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'PARAGRAPH',
|
||||
'nodes': text_nodes,
|
||||
'paragraphData': {}
|
||||
}
|
||||
list_item_node = {
|
||||
'id': item_node_id,
|
||||
'type': 'LIST_ITEM',
|
||||
'nodes': [paragraph_node]
|
||||
}
|
||||
list_node_items.append(list_item_node)
|
||||
|
||||
bulleted_list_node = {
|
||||
'id': node_id,
|
||||
'type': 'BULLETED_LIST',
|
||||
'nodes': list_node_items,
|
||||
'bulletedListData': {}
|
||||
}
|
||||
nodes.append(bulleted_list_node)
|
||||
|
||||
# Check for ordered lists
|
||||
elif re.match(r'^\d+\.\s+', line):
|
||||
list_items = []
|
||||
while i < len(lines) and re.match(r'^\d+\.\s+', lines[i].strip()):
|
||||
item_text = re.sub(r'^\d+\.\s+', '', lines[i].strip())
|
||||
list_items.append(item_text)
|
||||
i += 1
|
||||
# Check for nested items
|
||||
while i < len(lines) and lines[i].strip().startswith(' ') and re.match(r'^\s+\d+\.\s+', lines[i].strip()):
|
||||
nested_text = re.sub(r'^\s+\d+\.\s+', '', lines[i].strip())
|
||||
list_items.append(nested_text)
|
||||
i += 1
|
||||
|
||||
# CRITICAL: TEXT nodes must be wrapped in PARAGRAPH nodes within LIST_ITEM
|
||||
# NOTE: LIST_ITEM nodes do NOT have a data field per Wix API schema
|
||||
# Wix API: omit empty data objects, don't include them as {}
|
||||
list_node_items = []
|
||||
for item_text in list_items:
|
||||
item_node_id = str(uuid.uuid4())
|
||||
text_nodes = parse_markdown_inline(item_text)
|
||||
paragraph_node = {
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'PARAGRAPH',
|
||||
'nodes': text_nodes,
|
||||
'paragraphData': {}
|
||||
}
|
||||
list_item_node = {
|
||||
'id': item_node_id,
|
||||
'type': 'LIST_ITEM',
|
||||
'nodes': [paragraph_node]
|
||||
}
|
||||
list_node_items.append(list_item_node)
|
||||
|
||||
ordered_list_node = {
|
||||
'id': node_id,
|
||||
'type': 'ORDERED_LIST',
|
||||
'nodes': list_node_items,
|
||||
'orderedListData': {}
|
||||
}
|
||||
nodes.append(ordered_list_node)
|
||||
|
||||
# Check for images
|
||||
elif line.startswith('!['):
|
||||
img_match = re.match(r'!\[([^\]]*)\]\(([^)]+)\)', line)
|
||||
if img_match:
|
||||
alt_text = img_match.group(1)
|
||||
img_url = img_match.group(2)
|
||||
nodes.append({
|
||||
'id': node_id,
|
||||
'type': 'IMAGE',
|
||||
'nodes': [],
|
||||
'imageData': {
|
||||
'image': {
|
||||
'src': {'url': img_url},
|
||||
'altText': alt_text
|
||||
},
|
||||
'containerData': {
|
||||
'alignment': 'CENTER',
|
||||
'width': {'size': 'CONTENT'}
|
||||
}
|
||||
}
|
||||
})
|
||||
i += 1
|
||||
|
||||
# Regular paragraph
|
||||
else:
|
||||
# Collect consecutive non-empty lines as paragraph content
|
||||
para_lines = [line]
|
||||
i += 1
|
||||
while i < len(lines):
|
||||
next_line = lines[i].strip()
|
||||
if not next_line:
|
||||
break
|
||||
# Stop if next line is a special markdown element
|
||||
if (next_line.startswith('#') or
|
||||
next_line.startswith('- ') or
|
||||
next_line.startswith('* ') or
|
||||
next_line.startswith('>') or
|
||||
next_line.startswith('![') or
|
||||
re.match(r'^\d+\.\s+', next_line)):
|
||||
break
|
||||
para_lines.append(next_line)
|
||||
i += 1
|
||||
|
||||
para_text = ' '.join(para_lines)
|
||||
text_nodes = parse_markdown_inline(para_text)
|
||||
|
||||
# Only add paragraph if there are text nodes
|
||||
if text_nodes:
|
||||
paragraph_node = {
|
||||
'id': node_id,
|
||||
'type': 'PARAGRAPH',
|
||||
'nodes': text_nodes,
|
||||
'paragraphData': {}
|
||||
}
|
||||
nodes.append(paragraph_node)
|
||||
|
||||
# Ensure at least one node exists
|
||||
# Wix API: omit empty data objects, don't include them as {}
|
||||
if not nodes:
|
||||
fallback_paragraph = {
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'PARAGRAPH',
|
||||
'nodes': [{
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'TEXT',
|
||||
'textData': {
|
||||
'text': content[:500] if content else "This is a post from ALwrity.",
|
||||
'decorations': []
|
||||
}
|
||||
}],
|
||||
'paragraphData': {}
|
||||
}
|
||||
nodes.append(fallback_paragraph)
|
||||
|
||||
return {
|
||||
'type': 'DOCUMENT',
|
||||
'id': str(uuid.uuid4()),
|
||||
'nodes': nodes,
|
||||
'metadata': { 'version': 1, 'id': str(uuid.uuid4()) },
|
||||
'metadata': {'version': 1, 'id': str(uuid.uuid4())},
|
||||
'documentStyle': {
|
||||
'paragraph': { 'decorations': [], 'nodeStyle': {}, 'lineHeight': '1.5' }
|
||||
'paragraph': {'decorations': [], 'nodeStyle': {}, 'lineHeight': '1.5'}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,12 @@ class WixMediaService:
|
||||
self.base_url = base_url
|
||||
|
||||
def import_image(self, access_token: str, image_url: str, display_name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Import external image to Wix Media Manager.
|
||||
|
||||
Official endpoint: https://www.wixapis.com/site-media/v1/files/import
|
||||
Reference: https://dev.wix.com/docs/rest/assets/media/media-manager/files/import-file
|
||||
"""
|
||||
headers = {
|
||||
'Authorization': f'Bearer {access_token}',
|
||||
'Content-Type': 'application/json',
|
||||
@@ -16,7 +22,9 @@ class WixMediaService:
|
||||
'mediaType': 'IMAGE',
|
||||
'displayName': display_name,
|
||||
}
|
||||
response = requests.post(f"{self.base_url}/media/v1/files/import", headers=headers, json=payload)
|
||||
# Correct endpoint per Wix API documentation
|
||||
endpoint = f"{self.base_url}/site-media/v1/files/import"
|
||||
response = requests.post(endpoint, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
277
backend/services/integrations/wix/ricos_converter.py
Normal file
277
backend/services/integrations/wix/ricos_converter.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""
|
||||
Ricos Document Converter for Wix
|
||||
|
||||
Converts markdown content to Wix Ricos JSON format using either:
|
||||
1. Wix's official Ricos Documents API (preferred)
|
||||
2. Custom markdown parser (fallback)
|
||||
"""
|
||||
|
||||
import json
|
||||
import requests
|
||||
import jwt
|
||||
from typing import Dict, Any, Optional
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def markdown_to_html(markdown_content: str) -> str:
|
||||
"""
|
||||
Convert markdown content to HTML.
|
||||
Uses a simple markdown parser for basic conversion.
|
||||
|
||||
Args:
|
||||
markdown_content: Markdown content to convert
|
||||
|
||||
Returns:
|
||||
HTML string
|
||||
"""
|
||||
try:
|
||||
# Try using markdown library if available
|
||||
import markdown
|
||||
html = markdown.markdown(markdown_content, extensions=['fenced_code', 'tables'])
|
||||
return html
|
||||
except ImportError:
|
||||
# Fallback: Simple regex-based conversion for basic markdown
|
||||
logger.warning("markdown library not available, using basic markdown-to-HTML conversion")
|
||||
import re
|
||||
|
||||
if not markdown_content or not markdown_content.strip():
|
||||
return "<p>This is a post from ALwrity.</p>"
|
||||
|
||||
lines = markdown_content.split('\n')
|
||||
result = []
|
||||
in_list = False
|
||||
list_type = None # 'ul' or 'ol'
|
||||
in_code_block = False
|
||||
code_block_content = []
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i].strip()
|
||||
|
||||
# Handle code blocks first
|
||||
if line.startswith('```'):
|
||||
if not in_code_block:
|
||||
in_code_block = True
|
||||
code_block_content = []
|
||||
i += 1
|
||||
continue
|
||||
else:
|
||||
in_code_block = False
|
||||
result.append(f'<pre><code>{"\n".join(code_block_content)}</code></pre>')
|
||||
code_block_content = []
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if in_code_block:
|
||||
code_block_content.append(lines[i])
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Close any open lists
|
||||
if in_list and not (line.startswith('- ') or line.startswith('* ') or re.match(r'^\d+\.\s+', line)):
|
||||
result.append(f'</{list_type}>')
|
||||
in_list = False
|
||||
list_type = None
|
||||
|
||||
if not line:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Headers
|
||||
if line.startswith('###'):
|
||||
result.append(f'<h3>{line[3:].strip()}</h3>')
|
||||
elif line.startswith('##'):
|
||||
result.append(f'<h2>{line[2:].strip()}</h2>')
|
||||
elif line.startswith('#'):
|
||||
result.append(f'<h1>{line[1:].strip()}</h1>')
|
||||
# Lists
|
||||
elif line.startswith('- ') or line.startswith('* '):
|
||||
if not in_list or list_type != 'ul':
|
||||
if in_list:
|
||||
result.append(f'</{list_type}>')
|
||||
result.append('<ul>')
|
||||
in_list = True
|
||||
list_type = 'ul'
|
||||
# Process inline formatting in list item
|
||||
item_text = line[2:].strip()
|
||||
item_text = re.sub(r'\*\*(.*?)\*\*', r'<strong>\1</strong>', item_text)
|
||||
item_text = re.sub(r'\*(.*?)\*', r'<em>\1</em>', item_text)
|
||||
result.append(f'<li>{item_text}</li>')
|
||||
elif re.match(r'^\d+\.\s+', line):
|
||||
if not in_list or list_type != 'ol':
|
||||
if in_list:
|
||||
result.append(f'</{list_type}>')
|
||||
result.append('<ol>')
|
||||
in_list = True
|
||||
list_type = 'ol'
|
||||
# Process inline formatting in list item
|
||||
match = re.match(r'^\d+\.\s+(.*)', line)
|
||||
if match:
|
||||
item_text = match.group(1)
|
||||
item_text = re.sub(r'\*\*(.*?)\*\*', r'<strong>\1</strong>', item_text)
|
||||
item_text = re.sub(r'\*(.*?)\*', r'<em>\1</em>', item_text)
|
||||
result.append(f'<li>{item_text}</li>')
|
||||
# Blockquotes
|
||||
elif line.startswith('>'):
|
||||
quote_text = line[1:].strip()
|
||||
quote_text = re.sub(r'\*\*(.*?)\*\*', r'<strong>\1</strong>', quote_text)
|
||||
quote_text = re.sub(r'\*(.*?)\*', r'<em>\1</em>', quote_text)
|
||||
result.append(f'<blockquote><p>{quote_text}</p></blockquote>')
|
||||
# Regular paragraphs
|
||||
else:
|
||||
para_text = line
|
||||
# Process inline formatting
|
||||
para_text = re.sub(r'\*\*(.*?)\*\*', r'<strong>\1</strong>', para_text)
|
||||
para_text = re.sub(r'\*(.*?)\*', r'<em>\1</em>', para_text)
|
||||
para_text = re.sub(r'\[([^\]]+)\]\(([^\)]+)\)', r'<a href="\2">\1</a>', para_text)
|
||||
para_text = re.sub(r'`([^`]+)`', r'<code>\1</code>', para_text)
|
||||
result.append(f'<p>{para_text}</p>')
|
||||
|
||||
i += 1
|
||||
|
||||
# Close any open lists
|
||||
if in_list:
|
||||
result.append(f'</{list_type}>')
|
||||
|
||||
# Ensure we have at least one paragraph
|
||||
if not result:
|
||||
result.append('<p>This is a post from ALwrity.</p>')
|
||||
|
||||
html = '\n'.join(result)
|
||||
|
||||
logger.debug(f"Converted {len(markdown_content)} chars markdown to {len(html)} chars HTML")
|
||||
return html
|
||||
|
||||
|
||||
def convert_via_wix_api(markdown_content: str, access_token: str, base_url: str = 'https://www.wixapis.com') -> Dict[str, Any]:
|
||||
"""
|
||||
Convert markdown to Ricos using Wix's official Ricos Documents API.
|
||||
Uses HTML format for better reliability (per Wix documentation, HTML is fully supported).
|
||||
|
||||
Reference: https://dev.wix.com/docs/api-reference/assets/rich-content/ricos-documents/convert-to-ricos-document
|
||||
|
||||
Args:
|
||||
markdown_content: Markdown content to convert (will be converted to HTML)
|
||||
access_token: Wix access token
|
||||
base_url: Wix API base URL (default: https://www.wixapis.com)
|
||||
|
||||
Returns:
|
||||
Ricos JSON document
|
||||
"""
|
||||
# Validate content is not empty
|
||||
markdown_stripped = markdown_content.strip() if markdown_content else ""
|
||||
if not markdown_stripped:
|
||||
logger.error("Markdown content is empty or whitespace-only")
|
||||
raise ValueError("Content cannot be empty for Wix Ricos API conversion")
|
||||
|
||||
logger.debug(f"Converting markdown to HTML: input_length={len(markdown_stripped)} chars")
|
||||
|
||||
# Convert markdown to HTML for better reliability with Wix API
|
||||
# HTML format is more structured and less prone to parsing errors
|
||||
html_content = markdown_to_html(markdown_stripped)
|
||||
|
||||
# Validate HTML content is not empty - CRITICAL for Wix API
|
||||
html_stripped = html_content.strip() if html_content else ""
|
||||
if not html_stripped or len(html_stripped) == 0:
|
||||
logger.error(f"HTML conversion produced empty content! Markdown length: {len(markdown_stripped)}")
|
||||
logger.error(f"Markdown sample: {markdown_stripped[:500]}...")
|
||||
logger.error(f"HTML result: '{html_content}' (type: {type(html_content)})")
|
||||
# Fallback: use a minimal valid HTML if conversion failed
|
||||
html_content = "<p>Content from ALwrity blog writer.</p>"
|
||||
logger.warning("Using fallback HTML due to empty conversion result")
|
||||
else:
|
||||
html_content = html_stripped
|
||||
|
||||
logger.debug(f"✅ Converted markdown to HTML: {len(html_content)} chars, preview: {html_content[:200]}...")
|
||||
|
||||
headers = {
|
||||
'Authorization': f'Bearer {access_token}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
# Add wix-site-id if available from token
|
||||
try:
|
||||
token_str = str(access_token)
|
||||
if token_str and token_str.startswith('OauthNG.JWS.'):
|
||||
jwt_part = token_str[12:]
|
||||
payload = jwt.decode(jwt_part, options={"verify_signature": False, "verify_aud": False})
|
||||
data_payload = payload.get('data', {})
|
||||
if isinstance(data_payload, str):
|
||||
try:
|
||||
data_payload = json.loads(data_payload)
|
||||
except:
|
||||
pass
|
||||
instance_data = data_payload.get('instance', {})
|
||||
meta_site_id = instance_data.get('metaSiteId')
|
||||
if isinstance(meta_site_id, str) and meta_site_id:
|
||||
headers['wix-site-id'] = meta_site_id
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not extract site ID from token: {e}")
|
||||
|
||||
# Call Wix Ricos Documents API: Convert to Ricos Document
|
||||
# Official endpoint: https://www.wixapis.com/ricos/v1/ricos-document/convert/to-ricos
|
||||
# Reference: https://dev.wix.com/docs/rest/assets/rich-content/ricos-documents/convert-to-ricos-document
|
||||
endpoint = f"{base_url}/ricos/v1/ricos-document/convert/to-ricos"
|
||||
|
||||
# Ensure HTML content is not empty or just whitespace
|
||||
html_stripped = html_content.strip() if html_content else ""
|
||||
if not html_stripped or len(html_stripped) == 0:
|
||||
logger.error(f"HTML content is empty after conversion. Markdown length: {len(markdown_content)}")
|
||||
logger.error(f"Markdown preview (first 500 chars): {markdown_content[:500] if markdown_content else 'N/A'}")
|
||||
raise ValueError(f"HTML content cannot be empty. Original markdown had {len(markdown_content)} characters.")
|
||||
|
||||
# Payload structure per Wix API: html/markdown/plainText field at root, optional plugins
|
||||
payload = {
|
||||
'html': html_stripped, # Direct field, not nested in options
|
||||
'plugins': [] # Optional: empty array uses default plugins
|
||||
}
|
||||
|
||||
logger.warning(f"📤 Sending to Wix Ricos API: html_length={len(payload['html'])}, plugins_count={len(payload['plugins'])}")
|
||||
logger.debug(f"HTML preview (first 300 chars): {html_stripped[:300]}...")
|
||||
|
||||
try:
|
||||
# Log the exact payload being sent (for debugging)
|
||||
logger.warning(f"📤 Wix Ricos API Request:")
|
||||
logger.warning(f" Endpoint: {endpoint}")
|
||||
logger.warning(f" Payload keys: {list(payload.keys())}")
|
||||
logger.warning(f" HTML length: {len(payload.get('html', ''))}")
|
||||
logger.warning(f" Plugins: {payload.get('plugins', [])}")
|
||||
logger.debug(f" Full payload (first 500 chars of HTML): {str(payload)[:500]}")
|
||||
|
||||
response = requests.post(
|
||||
endpoint,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=30
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
# Extract the ricos document from response
|
||||
# Response structure: { "document": { "nodes": [...], "metadata": {...}, "documentStyle": {...} } }
|
||||
ricos_document = result.get('document')
|
||||
if not ricos_document:
|
||||
# Fallback: try other possible response fields
|
||||
ricos_document = result.get('ricosDocument') or result.get('ricos') or result
|
||||
|
||||
if not ricos_document:
|
||||
logger.error(f"Unexpected response structure from Wix API: {list(result.keys())}")
|
||||
logger.error(f"Response: {result}")
|
||||
raise ValueError("Wix API did not return a valid Ricos document")
|
||||
|
||||
logger.warning(f"✅ Successfully converted HTML to Ricos via Wix API: {len(ricos_document.get('nodes', []))} nodes")
|
||||
return ricos_document
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"❌ Wix Ricos API conversion failed: {e}")
|
||||
if hasattr(e, 'response') and e.response is not None:
|
||||
logger.error(f" Response status: {e.response.status_code}")
|
||||
logger.error(f" Response headers: {dict(e.response.headers)}")
|
||||
try:
|
||||
error_body = e.response.json()
|
||||
logger.error(f" Response JSON: {error_body}")
|
||||
except:
|
||||
logger.error(f" Response text: {e.response.text}")
|
||||
logger.error(f" Request payload was: {json.dumps(payload, indent=2)[:1000]}...") # First 1000 chars
|
||||
raise
|
||||
|
||||
300
backend/services/integrations/wix/seo.py
Normal file
300
backend/services/integrations/wix/seo.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""
|
||||
SEO Data Builder for Wix Blog Posts
|
||||
|
||||
Builds Wix-compatible seoData objects from ALwrity SEO metadata.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def build_seo_data(seo_metadata: Dict[str, Any], default_title: str = None) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Build Wix seoData object from our SEO metadata format.
|
||||
|
||||
Args:
|
||||
seo_metadata: SEO metadata dict with fields like:
|
||||
- seo_title: SEO optimized title
|
||||
- meta_description: Meta description
|
||||
- focus_keyword: Main keyword
|
||||
- blog_tags: List of tag strings (for keywords)
|
||||
- open_graph: Open Graph data dict
|
||||
- canonical_url: Canonical URL
|
||||
default_title: Fallback title if seo_title not provided
|
||||
|
||||
Returns:
|
||||
Wix seoData object with settings.keywords and tags array, or None if empty
|
||||
"""
|
||||
seo_data = {
|
||||
'settings': {
|
||||
'keywords': []
|
||||
},
|
||||
'tags': []
|
||||
}
|
||||
|
||||
# Build keywords array
|
||||
keywords_list = []
|
||||
|
||||
# Add main keyword (focus_keyword) if provided
|
||||
focus_keyword = seo_metadata.get('focus_keyword')
|
||||
if focus_keyword:
|
||||
keywords_list.append({
|
||||
'term': str(focus_keyword),
|
||||
'isMain': True
|
||||
})
|
||||
|
||||
# Add additional keywords from blog_tags or other sources
|
||||
blog_tags = seo_metadata.get('blog_tags', [])
|
||||
if isinstance(blog_tags, list):
|
||||
for tag in blog_tags:
|
||||
tag_str = str(tag).strip()
|
||||
if tag_str and tag_str != focus_keyword: # Don't duplicate main keyword
|
||||
keywords_list.append({
|
||||
'term': tag_str,
|
||||
'isMain': False
|
||||
})
|
||||
|
||||
# Add social hashtags as keywords if available
|
||||
social_hashtags = seo_metadata.get('social_hashtags', [])
|
||||
if isinstance(social_hashtags, list):
|
||||
for hashtag in social_hashtags:
|
||||
# Remove # if present
|
||||
hashtag_str = str(hashtag).strip().lstrip('#')
|
||||
if hashtag_str and hashtag_str != focus_keyword:
|
||||
keywords_list.append({
|
||||
'term': hashtag_str,
|
||||
'isMain': False
|
||||
})
|
||||
|
||||
seo_data['settings']['keywords'] = keywords_list
|
||||
|
||||
# Validate keywords list is not empty (or ensure at least one keyword exists)
|
||||
if not seo_data['settings']['keywords']:
|
||||
logger.warning("No keywords found in SEO metadata, adding empty keywords array")
|
||||
|
||||
# Build tags array (meta tags, Open Graph, etc.)
|
||||
tags_list = []
|
||||
|
||||
# Meta description
|
||||
meta_description = seo_metadata.get('meta_description')
|
||||
if meta_description:
|
||||
tags_list.append({
|
||||
'type': 'meta',
|
||||
'props': {
|
||||
'name': 'description',
|
||||
'content': str(meta_description)
|
||||
},
|
||||
'custom': True,
|
||||
'disabled': False
|
||||
})
|
||||
|
||||
# SEO title - 'title' type uses 'children' field, not 'props.content'
|
||||
seo_title = seo_metadata.get('seo_title') or default_title
|
||||
if seo_title:
|
||||
tags_list.append({
|
||||
'type': 'title',
|
||||
'children': str(seo_title), # Title tags use 'children', not 'props.content'
|
||||
'custom': True,
|
||||
'disabled': False
|
||||
})
|
||||
|
||||
# Open Graph tags
|
||||
open_graph = seo_metadata.get('open_graph', {})
|
||||
if isinstance(open_graph, dict):
|
||||
# OG Title
|
||||
og_title = open_graph.get('title') or seo_title
|
||||
if og_title:
|
||||
tags_list.append({
|
||||
'type': 'meta',
|
||||
'props': {
|
||||
'property': 'og:title',
|
||||
'content': str(og_title)
|
||||
},
|
||||
'custom': True,
|
||||
'disabled': False
|
||||
})
|
||||
|
||||
# OG Description
|
||||
og_description = open_graph.get('description') or meta_description
|
||||
if og_description:
|
||||
tags_list.append({
|
||||
'type': 'meta',
|
||||
'props': {
|
||||
'property': 'og:description',
|
||||
'content': str(og_description)
|
||||
},
|
||||
'custom': True,
|
||||
'disabled': False
|
||||
})
|
||||
|
||||
# OG Image
|
||||
og_image = open_graph.get('image')
|
||||
if og_image:
|
||||
# Skip base64 images for OG tags (Wix needs URLs)
|
||||
if isinstance(og_image, str) and (og_image.startswith('http://') or og_image.startswith('https://')):
|
||||
tags_list.append({
|
||||
'type': 'meta',
|
||||
'props': {
|
||||
'property': 'og:image',
|
||||
'content': og_image
|
||||
},
|
||||
'custom': True,
|
||||
'disabled': False
|
||||
})
|
||||
|
||||
# OG Type
|
||||
tags_list.append({
|
||||
'type': 'meta',
|
||||
'props': {
|
||||
'property': 'og:type',
|
||||
'content': 'article'
|
||||
},
|
||||
'custom': True,
|
||||
'disabled': False
|
||||
})
|
||||
|
||||
# OG URL (canonical or provided URL)
|
||||
og_url = open_graph.get('url') or seo_metadata.get('canonical_url')
|
||||
if og_url:
|
||||
tags_list.append({
|
||||
'type': 'meta',
|
||||
'props': {
|
||||
'property': 'og:url',
|
||||
'content': str(og_url)
|
||||
},
|
||||
'custom': True,
|
||||
'disabled': False
|
||||
})
|
||||
|
||||
# Twitter Card tags
|
||||
twitter_card = seo_metadata.get('twitter_card', {})
|
||||
if isinstance(twitter_card, dict):
|
||||
twitter_title = twitter_card.get('title') or seo_title
|
||||
if twitter_title:
|
||||
tags_list.append({
|
||||
'type': 'meta',
|
||||
'props': {
|
||||
'name': 'twitter:title',
|
||||
'content': str(twitter_title)
|
||||
},
|
||||
'custom': True,
|
||||
'disabled': False
|
||||
})
|
||||
|
||||
twitter_description = twitter_card.get('description') or meta_description
|
||||
if twitter_description:
|
||||
tags_list.append({
|
||||
'type': 'meta',
|
||||
'props': {
|
||||
'name': 'twitter:description',
|
||||
'content': str(twitter_description)
|
||||
},
|
||||
'custom': True,
|
||||
'disabled': False
|
||||
})
|
||||
|
||||
twitter_image = twitter_card.get('image')
|
||||
if twitter_image and isinstance(twitter_image, str) and (twitter_image.startswith('http://') or twitter_image.startswith('https://')):
|
||||
tags_list.append({
|
||||
'type': 'meta',
|
||||
'props': {
|
||||
'name': 'twitter:image',
|
||||
'content': twitter_image
|
||||
},
|
||||
'custom': True,
|
||||
'disabled': False
|
||||
})
|
||||
|
||||
twitter_card_type = twitter_card.get('card', 'summary_large_image')
|
||||
tags_list.append({
|
||||
'type': 'meta',
|
||||
'props': {
|
||||
'name': 'twitter:card',
|
||||
'content': str(twitter_card_type)
|
||||
},
|
||||
'custom': True,
|
||||
'disabled': False
|
||||
})
|
||||
|
||||
# Canonical URL as link tag
|
||||
canonical_url = seo_metadata.get('canonical_url')
|
||||
if canonical_url:
|
||||
tags_list.append({
|
||||
'type': 'link',
|
||||
'props': {
|
||||
'rel': 'canonical',
|
||||
'href': str(canonical_url)
|
||||
},
|
||||
'custom': True,
|
||||
'disabled': False
|
||||
})
|
||||
|
||||
# Validate all tags have required fields before adding
|
||||
validated_tags = []
|
||||
for tag in tags_list:
|
||||
if not isinstance(tag, dict):
|
||||
logger.warning(f"Skipping invalid tag (not a dict): {type(tag)}")
|
||||
continue
|
||||
# Ensure required fields exist
|
||||
if 'type' not in tag:
|
||||
logger.warning("Skipping tag missing 'type' field")
|
||||
continue
|
||||
# Ensure 'custom' and 'disabled' fields exist
|
||||
if 'custom' not in tag:
|
||||
tag['custom'] = True
|
||||
if 'disabled' not in tag:
|
||||
tag['disabled'] = False
|
||||
# Validate tag structure based on type
|
||||
tag_type = tag.get('type')
|
||||
if tag_type == 'title':
|
||||
if 'children' not in tag or not tag['children']:
|
||||
logger.warning("Skipping title tag with missing/invalid 'children' field")
|
||||
continue
|
||||
elif tag_type == 'meta':
|
||||
if 'props' not in tag or not isinstance(tag['props'], dict):
|
||||
logger.warning("Skipping meta tag with missing/invalid 'props' field")
|
||||
continue
|
||||
if 'name' not in tag['props'] and 'property' not in tag['props']:
|
||||
logger.warning("Skipping meta tag with missing 'name' or 'property' in props")
|
||||
continue
|
||||
# Ensure 'content' exists and is not empty
|
||||
if 'content' not in tag['props'] or not str(tag['props'].get('content', '')).strip():
|
||||
logger.warning(f"Skipping meta tag with missing/empty 'content': {tag.get('props', {})}")
|
||||
continue
|
||||
elif tag_type == 'link':
|
||||
if 'props' not in tag or not isinstance(tag['props'], dict):
|
||||
logger.warning("Skipping link tag with missing/invalid 'props' field")
|
||||
continue
|
||||
# Ensure 'href' exists and is not empty for link tags
|
||||
if 'href' not in tag['props'] or not str(tag['props'].get('href', '')).strip():
|
||||
logger.warning(f"Skipping link tag with missing/empty 'href': {tag.get('props', {})}")
|
||||
continue
|
||||
validated_tags.append(tag)
|
||||
|
||||
seo_data['tags'] = validated_tags
|
||||
|
||||
# Final validation: ensure seoData structure is complete
|
||||
if not isinstance(seo_data['settings'], dict):
|
||||
logger.error("seoData.settings is not a dict, creating default")
|
||||
seo_data['settings'] = {'keywords': []}
|
||||
if not isinstance(seo_data['settings'].get('keywords'), list):
|
||||
logger.error("seoData.settings.keywords is not a list, creating empty list")
|
||||
seo_data['settings']['keywords'] = []
|
||||
if not isinstance(seo_data['tags'], list):
|
||||
logger.error("seoData.tags is not a list, creating empty list")
|
||||
seo_data['tags'] = []
|
||||
|
||||
# CRITICAL: Per Wix API patterns, omit empty structures instead of including them as {}
|
||||
# If keywords is empty, omit settings entirely
|
||||
if not seo_data['settings'].get('keywords'):
|
||||
logger.debug("No keywords found, omitting settings from seoData")
|
||||
seo_data.pop('settings', None)
|
||||
|
||||
logger.debug(f"Built SEO data: {len(validated_tags)} tags, {len(keywords_list)} keywords")
|
||||
|
||||
# Only return seoData if we have at least keywords or tags
|
||||
if keywords_list or validated_tags:
|
||||
return seo_data
|
||||
|
||||
return None
|
||||
|
||||
@@ -9,6 +9,7 @@ import json
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
from fastapi import HTTPException
|
||||
from ..onboarding.api_key_manager import APIKeyManager
|
||||
|
||||
from .gemini_provider import gemini_text_response, gemini_structured_json_response
|
||||
@@ -129,11 +130,16 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
# Estimate tokens from prompt (input tokens)
|
||||
# Note: We estimate output tokens conservatively (assume response is similar length to prompt)
|
||||
# This prevents underestimating total token usage
|
||||
# CRITICAL: Use worst-case scenario (input + max_tokens) for validation to prevent abuse
|
||||
# This ensures we block requests that would exceed limits even if response is longer than expected
|
||||
input_tokens = int(len(prompt.split()) * 1.3)
|
||||
# Conservative estimate: assume output tokens ≈ input tokens * 1.0 (can be up to max_tokens)
|
||||
estimated_output_tokens = min(input_tokens, max_tokens) if max_tokens else int(input_tokens * 0.8)
|
||||
# Worst-case estimate: assume maximum possible output tokens (max_tokens if specified)
|
||||
# This prevents abuse where actual response tokens exceed the estimate
|
||||
if max_tokens:
|
||||
estimated_output_tokens = max_tokens # Use maximum allowed output tokens
|
||||
else:
|
||||
# If max_tokens not specified, use conservative estimate (input * 1.5)
|
||||
estimated_output_tokens = int(input_tokens * 1.5)
|
||||
estimated_total_tokens = input_tokens + estimated_output_tokens
|
||||
|
||||
# Check limits using sync method from pricing service (strict enforcement)
|
||||
@@ -146,7 +152,14 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
|
||||
if not can_proceed:
|
||||
logger.warning(f"[llm_text_gen] Subscription limit exceeded for user {user_id}: {message}")
|
||||
raise RuntimeError(f"Subscription limit exceeded: {message}")
|
||||
# Raise HTTPException(429) with usage info so frontend can display subscription modal
|
||||
error_detail = {
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': actual_provider_name or provider_enum.value,
|
||||
'usage_info': usage_info if usage_info else {}
|
||||
}
|
||||
raise HTTPException(status_code=429, detail=error_detail)
|
||||
|
||||
# Get current usage for limit checking only
|
||||
current_period = pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
@@ -159,6 +172,9 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions (e.g., 429 subscription limit) - preserve error details
|
||||
raise
|
||||
except RuntimeError:
|
||||
# Re-raise subscription limit errors
|
||||
raise
|
||||
@@ -244,7 +260,8 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
db_track = next(get_db())
|
||||
try:
|
||||
# Estimate tokens from prompt and response
|
||||
tokens_input = estimated_tokens # Already calculated above
|
||||
# Recalculate input tokens from prompt (consistent with pre-flight estimation)
|
||||
tokens_input = int(len(prompt.split()) * 1.3)
|
||||
tokens_output = int(len(str(response_text).split()) * 1.3) # Estimate output tokens
|
||||
tokens_total = tokens_input + tokens_output
|
||||
|
||||
@@ -259,45 +276,186 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
|
||||
logger.debug(f"[llm_text_gen] Looking for usage summary: user_id={user_id}, period={current_period}")
|
||||
|
||||
# Get limits once for safety check (to prevent exceeding limits even if actual usage > estimate)
|
||||
provider_name = provider_enum.value
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
token_limit = 0
|
||||
if limits and limits.get('limits'):
|
||||
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
|
||||
|
||||
# CRITICAL: Use raw SQL to read current values directly from DB, bypassing SQLAlchemy cache
|
||||
# This ensures we always get the absolute latest committed values, even across different sessions
|
||||
from sqlalchemy import text
|
||||
current_calls_before = 0
|
||||
current_tokens_before = 0
|
||||
record_count = 0 # Initialize to ensure it's always defined
|
||||
|
||||
# CRITICAL: First check if record exists using COUNT query
|
||||
try:
|
||||
check_query = text("SELECT COUNT(*) FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period")
|
||||
record_count = db_track.execute(check_query, {'user_id': user_id, 'period': current_period}).scalar()
|
||||
logger.debug(f"[llm_text_gen] 🔍 DEBUG: Record count check - found {record_count} record(s) for user={user_id}, period={current_period}")
|
||||
except Exception as count_error:
|
||||
logger.error(f"[llm_text_gen] ❌ COUNT query failed: {count_error}", exc_info=True)
|
||||
record_count = 0
|
||||
|
||||
if record_count and record_count > 0:
|
||||
# Record exists - read current values with raw SQL
|
||||
try:
|
||||
# Validate provider_name to prevent SQL injection (whitelist approach)
|
||||
valid_providers = ['gemini', 'openai', 'anthropic', 'mistral']
|
||||
if provider_name not in valid_providers:
|
||||
raise ValueError(f"Invalid provider_name for SQL query: {provider_name}")
|
||||
|
||||
# Read current values directly from database using raw SQL
|
||||
# CRITICAL: This bypasses SQLAlchemy's session cache and gets absolute latest values
|
||||
sql_query = text(f"""
|
||||
SELECT {provider_name}_calls, {provider_name}_tokens
|
||||
FROM usage_summaries
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
LIMIT 1
|
||||
""")
|
||||
logger.debug(f"[llm_text_gen] 🔍 Executing raw SQL for EXISTING record: SELECT {provider_name}_calls, {provider_name}_tokens WHERE user_id={user_id}, period={current_period}")
|
||||
result = db_track.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
|
||||
if result:
|
||||
raw_calls = result[0] if result[0] is not None else 0
|
||||
raw_tokens = result[1] if result[1] is not None else 0
|
||||
current_calls_before = raw_calls
|
||||
current_tokens_before = raw_tokens
|
||||
logger.debug(f"[llm_text_gen] ✅ Raw SQL SUCCESS: Found EXISTING record - calls={current_calls_before}, tokens={current_tokens_before} (provider={provider_name}, column={provider_name}_calls/{provider_name}_tokens)")
|
||||
logger.debug(f"[llm_text_gen] 🔍 Raw SQL returned row: {result}, extracted calls={raw_calls}, tokens={raw_tokens}")
|
||||
else:
|
||||
logger.error(f"[llm_text_gen] ❌ CRITICAL BUG: Record EXISTS (count={record_count}) but SELECT query returned None! Query: {sql_query}")
|
||||
# Fallback: Use ORM to get values
|
||||
summary_fallback = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if summary_fallback:
|
||||
db_track.refresh(summary_fallback)
|
||||
current_calls_before = getattr(summary_fallback, f"{provider_name}_calls", 0) or 0
|
||||
current_tokens_before = getattr(summary_fallback, f"{provider_name}_tokens", 0) or 0
|
||||
logger.warning(f"[llm_text_gen] ⚠️ Using ORM fallback: calls={current_calls_before}, tokens={current_tokens_before}")
|
||||
except Exception as sql_error:
|
||||
logger.error(f"[llm_text_gen] ❌ Raw SQL query failed: {sql_error}", exc_info=True)
|
||||
# Fallback: Use ORM to get values
|
||||
summary_fallback = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if summary_fallback:
|
||||
db_track.refresh(summary_fallback)
|
||||
current_calls_before = getattr(summary_fallback, f"{provider_name}_calls", 0) or 0
|
||||
current_tokens_before = getattr(summary_fallback, f"{provider_name}_tokens", 0) or 0
|
||||
else:
|
||||
logger.debug(f"[llm_text_gen] ℹ️ No record exists yet (will create new) - user={user_id}, period={current_period}")
|
||||
|
||||
# Get or create usage summary object (needed for ORM update)
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
logger.info(f"[llm_text_gen] Creating new usage summary for user {user_id}, period {current_period}")
|
||||
logger.debug(f"[llm_text_gen] Creating NEW usage summary for user {user_id}, period {current_period}")
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush() # Ensure summary is persisted before updating
|
||||
# New record - values are already 0, no need to set
|
||||
logger.debug(f"[llm_text_gen] ✅ New summary created - starting from 0")
|
||||
else:
|
||||
# CRITICAL: Update the ORM object with values from raw SQL query
|
||||
# This ensures the ORM object reflects the actual database state before we increment
|
||||
logger.debug(f"[llm_text_gen] 🔄 Existing summary found - syncing with raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}")
|
||||
setattr(summary, f"{provider_name}_calls", current_calls_before)
|
||||
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
setattr(summary, f"{provider_name}_tokens", current_tokens_before)
|
||||
logger.debug(f"[llm_text_gen] ✅ Synchronized ORM object: {provider_name}_calls={current_calls_before}, {provider_name}_tokens={current_tokens_before}")
|
||||
|
||||
# Get "before" state for unified log
|
||||
provider_name = provider_enum.value
|
||||
current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0
|
||||
logger.debug(f"[llm_text_gen] Current {provider_name}_calls from DB (raw SQL): {current_calls_before}")
|
||||
|
||||
# Update provider-specific counters (sync operation)
|
||||
new_calls = current_calls_before + 1
|
||||
setattr(summary, f"{provider_name}_calls", new_calls)
|
||||
logger.debug(f"[llm_text_gen] Updated {provider_name}_calls: {current_calls_before} -> {new_calls}")
|
||||
|
||||
# Update token usage for LLM providers
|
||||
# CRITICAL: Use direct SQL UPDATE instead of ORM setattr for dynamic attributes
|
||||
# SQLAlchemy doesn't detect changes when using setattr() on dynamic attributes
|
||||
# Using raw SQL UPDATE ensures the change is persisted
|
||||
from sqlalchemy import text
|
||||
update_calls_query = text(f"""
|
||||
UPDATE usage_summaries
|
||||
SET {provider_name}_calls = :new_calls
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_calls_query, {
|
||||
'new_calls': new_calls,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
logger.debug(f"[llm_text_gen] Updated {provider_name}_calls via SQL: {current_calls_before} -> {new_calls}")
|
||||
|
||||
# Update token usage for LLM providers with safety check
|
||||
# CRITICAL: Use current_tokens_before from raw SQL query (NOT from ORM object)
|
||||
# The ORM object may have stale values, but raw SQL always has the latest committed values
|
||||
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0
|
||||
new_tokens = current_tokens_before + tokens_total
|
||||
setattr(summary, f"{provider_name}_tokens", new_tokens)
|
||||
logger.debug(f"[llm_text_gen] Updated {provider_name}_tokens: {current_tokens_before} -> {new_tokens}")
|
||||
logger.debug(f"[llm_text_gen] Current {provider_name}_tokens from DB (raw SQL): {current_tokens_before}")
|
||||
|
||||
# SAFETY CHECK: Prevent exceeding token limit even if actual usage exceeds estimate
|
||||
# This prevents abuse where actual response tokens exceed pre-flight validation estimate
|
||||
projected_new_tokens = current_tokens_before + tokens_total
|
||||
|
||||
# If limit is set (> 0) and would be exceeded, cap at limit
|
||||
if token_limit > 0 and projected_new_tokens > token_limit:
|
||||
logger.warning(
|
||||
f"[llm_text_gen] ⚠️ ACTUAL token usage ({tokens_total}) exceeded estimate. "
|
||||
f"Would exceed limit: {projected_new_tokens} > {token_limit}. "
|
||||
f"Capping tracked tokens at limit to prevent abuse."
|
||||
)
|
||||
# Cap at limit to prevent abuse
|
||||
new_tokens = token_limit
|
||||
# Adjust tokens_total for accurate total tracking
|
||||
tokens_total = token_limit - current_tokens_before
|
||||
if tokens_total < 0:
|
||||
tokens_total = 0
|
||||
else:
|
||||
new_tokens = projected_new_tokens
|
||||
|
||||
# CRITICAL: Use direct SQL UPDATE instead of ORM setattr for dynamic attributes
|
||||
update_tokens_query = text(f"""
|
||||
UPDATE usage_summaries
|
||||
SET {provider_name}_tokens = :new_tokens
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_tokens_query, {
|
||||
'new_tokens': new_tokens,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
logger.debug(f"[llm_text_gen] Updated {provider_name}_tokens via SQL: {current_tokens_before} -> {new_tokens}")
|
||||
else:
|
||||
current_tokens_before = 0
|
||||
new_tokens = 0
|
||||
|
||||
# Update totals
|
||||
# Update totals using SQL UPDATE
|
||||
old_total_calls = summary.total_calls or 0
|
||||
old_total_tokens = summary.total_tokens or 0
|
||||
summary.total_calls = old_total_calls + 1
|
||||
summary.total_tokens = old_total_tokens + tokens_total
|
||||
logger.debug(f"[llm_text_gen] Updated totals: calls {old_total_calls} -> {summary.total_calls}, tokens {old_total_tokens} -> {summary.total_tokens}")
|
||||
new_total_calls = old_total_calls + 1
|
||||
new_total_tokens = old_total_tokens + tokens_total
|
||||
|
||||
update_totals_query = text("""
|
||||
UPDATE usage_summaries
|
||||
SET total_calls = :total_calls, total_tokens = :total_tokens
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_totals_query, {
|
||||
'total_calls': new_total_calls,
|
||||
'total_tokens': new_total_tokens,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
logger.debug(f"[llm_text_gen] Updated totals via SQL: calls {old_total_calls} -> {new_total_calls}, tokens {old_total_tokens} -> {new_total_tokens}")
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
@@ -310,12 +468,52 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
current_images_before = getattr(summary, "stability_calls", 0) or 0
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"[llm_text_gen] ✅ Successfully tracked usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens")
|
||||
# CRITICAL DEBUG: Print diagnostic info BEFORE commit (always visible, flushed immediately)
|
||||
import sys
|
||||
debug_msg = f"[DEBUG] BEFORE COMMIT - Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}, Period: {current_period}, New calls will be: {new_calls}, New tokens will be: {new_tokens}"
|
||||
print(debug_msg, flush=True)
|
||||
sys.stdout.flush()
|
||||
logger.debug(f"[llm_text_gen] {debug_msg}")
|
||||
|
||||
# CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions
|
||||
db_track.flush() # Flush to ensure changes are in DB (not just in transaction)
|
||||
db_track.commit() # Commit transaction to make changes visible to other sessions
|
||||
logger.debug(f"[llm_text_gen] ✅ Successfully tracked usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens (COMMITTED to DB)")
|
||||
logger.debug(f"[llm_text_gen] Database state after commit: {provider_name}_calls={new_calls}, {provider_name}_tokens={new_tokens} (should be visible to next session)")
|
||||
|
||||
# CRITICAL: Verify commit worked by reading back from DB immediately after commit
|
||||
try:
|
||||
verify_query = text(f"SELECT {provider_name}_calls, {provider_name}_tokens FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1")
|
||||
verify_result = db_track.execute(verify_query, {'user_id': user_id, 'period': current_period}).first()
|
||||
if verify_result:
|
||||
verified_calls = verify_result[0] if verify_result[0] is not None else 0
|
||||
verified_tokens = verify_result[1] if verify_result[1] is not None else 0
|
||||
logger.debug(f"[llm_text_gen] ✅ VERIFICATION AFTER COMMIT: Read back calls={verified_calls}, tokens={verified_tokens} (expected: calls={new_calls}, tokens={new_tokens})")
|
||||
if verified_calls != new_calls or verified_tokens != new_tokens:
|
||||
logger.error(f"[llm_text_gen] ❌ CRITICAL: COMMIT VERIFICATION FAILED! Expected calls={new_calls}, tokens={new_tokens}, but DB has calls={verified_calls}, tokens={verified_tokens}")
|
||||
# Force another commit attempt
|
||||
db_track.commit()
|
||||
verify_result2 = db_track.execute(verify_query, {'user_id': user_id, 'period': current_period}).first()
|
||||
if verify_result2:
|
||||
verified_calls2 = verify_result2[0] if verify_result2[0] is not None else 0
|
||||
verified_tokens2 = verify_result2[1] if verify_result2[1] is not None else 0
|
||||
logger.debug(f"[llm_text_gen] 🔄 After second commit attempt: calls={verified_calls2}, tokens={verified_tokens2}")
|
||||
else:
|
||||
logger.debug(f"[llm_text_gen] ✅ COMMIT VERIFICATION PASSED: Values match expected values")
|
||||
else:
|
||||
logger.error(f"[llm_text_gen] ❌ CRITICAL: COMMIT VERIFICATION FAILED! Record not found after commit!")
|
||||
except Exception as verify_error:
|
||||
logger.error(f"[llm_text_gen] ❌ Error verifying commit: {verify_error}", exc_info=True)
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
|
||||
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")
|
||||
# Include image stats in the log
|
||||
# DEBUG: Log the actual values being used
|
||||
logger.debug(f"[llm_text_gen] 📊 FINAL VALUES FOR LOG: calls_before={current_calls_before}, calls_after={new_calls}, tokens_before={current_tokens_before}, tokens_after={new_tokens}, provider={provider_name}, enum={provider_enum}")
|
||||
|
||||
# CRITICAL DEBUG: Print diagnostic info to stdout (always visible)
|
||||
print(f"[DEBUG] Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}")
|
||||
|
||||
print(f"""
|
||||
[SUBSCRIPTION] LLM Text Generation
|
||||
├─ User: {user_id}
|
||||
@@ -407,7 +605,8 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
db_track = next(get_db())
|
||||
try:
|
||||
# Estimate tokens from prompt and response
|
||||
tokens_input = estimated_tokens
|
||||
# Recalculate input tokens from prompt (consistent with pre-flight estimation)
|
||||
tokens_input = int(len(prompt.split()) * 1.3)
|
||||
tokens_output = int(len(str(response_text).split()) * 1.3)
|
||||
tokens_total = tokens_input + tokens_output
|
||||
|
||||
@@ -418,6 +617,49 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get limits once for safety check (to prevent exceeding limits even if actual usage > estimate)
|
||||
provider_name = provider_enum.value
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
token_limit = 0
|
||||
if limits and limits.get('limits'):
|
||||
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
|
||||
|
||||
# CRITICAL: Use raw SQL to read current values directly from DB, bypassing SQLAlchemy cache
|
||||
from sqlalchemy import text
|
||||
current_calls_before = 0
|
||||
current_tokens_before = 0
|
||||
|
||||
try:
|
||||
# Validate provider_name to prevent SQL injection
|
||||
valid_providers = ['gemini', 'openai', 'anthropic', 'mistral']
|
||||
if provider_name not in valid_providers:
|
||||
raise ValueError(f"Invalid provider_name for SQL query: {provider_name}")
|
||||
|
||||
# Read current values directly from database using raw SQL
|
||||
sql_query = text(f"""
|
||||
SELECT {provider_name}_calls, {provider_name}_tokens
|
||||
FROM usage_summaries
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
LIMIT 1
|
||||
""")
|
||||
result = db_track.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
|
||||
if result:
|
||||
current_calls_before = result[0] if result[0] is not None else 0
|
||||
current_tokens_before = result[1] if result[1] is not None else 0
|
||||
logger.debug(f"[llm_text_gen] Raw SQL read current values (fallback): calls={current_calls_before}, tokens={current_tokens_before}")
|
||||
except Exception as sql_error:
|
||||
logger.warning(f"[llm_text_gen] Raw SQL query failed (fallback), falling back to ORM: {sql_error}")
|
||||
# Fallback to ORM query if raw SQL fails
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if summary:
|
||||
db_track.refresh(summary)
|
||||
current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0
|
||||
current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0
|
||||
|
||||
# Get or create usage summary object (needed for ORM update)
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
@@ -430,41 +672,68 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush() # Ensure summary is persisted before updating
|
||||
else:
|
||||
# CRITICAL: Update the ORM object with values from raw SQL query
|
||||
# This ensures the ORM object reflects the actual database state before we increment
|
||||
setattr(summary, f"{provider_name}_calls", current_calls_before)
|
||||
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
setattr(summary, f"{provider_name}_tokens", current_tokens_before)
|
||||
logger.debug(f"[llm_text_gen] Synchronized summary object with raw SQL values (fallback): calls={current_calls_before}, tokens={current_tokens_before}")
|
||||
|
||||
# Get "before" state for unified log
|
||||
provider_name = provider_enum.value
|
||||
current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0
|
||||
# Get "before" state for unified log (from raw SQL query)
|
||||
logger.debug(f"[llm_text_gen] Current {provider_name}_calls from DB (fallback, raw SQL): {current_calls_before}")
|
||||
|
||||
# Update provider-specific counters (sync operation)
|
||||
new_calls = current_calls_before + 1
|
||||
setattr(summary, f"{provider_name}_calls", new_calls)
|
||||
|
||||
# Update token usage for LLM providers
|
||||
# Update token usage for LLM providers with safety check
|
||||
# Use current_tokens_before from raw SQL query (most reliable)
|
||||
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0
|
||||
new_tokens = current_tokens_before + tokens_total
|
||||
logger.debug(f"[llm_text_gen] Current {provider_name}_tokens from DB (fallback, raw SQL): {current_tokens_before}")
|
||||
|
||||
# SAFETY CHECK: Prevent exceeding token limit even if actual usage exceeds estimate
|
||||
# This prevents abuse where actual response tokens exceed pre-flight validation estimate
|
||||
projected_new_tokens = current_tokens_before + tokens_total
|
||||
|
||||
# If limit is set (> 0) and would be exceeded, cap at limit
|
||||
if token_limit > 0 and projected_new_tokens > token_limit:
|
||||
logger.warning(
|
||||
f"[llm_text_gen] ⚠️ ACTUAL token usage ({tokens_total}) exceeded estimate in fallback provider. "
|
||||
f"Would exceed limit: {projected_new_tokens} > {token_limit}. "
|
||||
f"Capping tracked tokens at limit to prevent abuse."
|
||||
)
|
||||
# Cap at limit to prevent abuse
|
||||
new_tokens = token_limit
|
||||
# Adjust tokens_total for accurate total tracking
|
||||
tokens_total = token_limit - current_tokens_before
|
||||
if tokens_total < 0:
|
||||
tokens_total = 0
|
||||
else:
|
||||
new_tokens = projected_new_tokens
|
||||
|
||||
setattr(summary, f"{provider_name}_tokens", new_tokens)
|
||||
else:
|
||||
current_tokens_before = 0
|
||||
new_tokens = 0
|
||||
|
||||
# Update totals
|
||||
# Update totals (using potentially capped tokens_total from safety check)
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.total_tokens = (summary.total_tokens or 0) + tokens_total
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
# Get plan details for unified log (limits already retrieved above)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0
|
||||
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) if limits else 0
|
||||
|
||||
# Get image stats for unified log
|
||||
current_images_before = getattr(summary, "stability_calls", 0) or 0
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"[llm_text_gen] ✅ Successfully tracked fallback usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens")
|
||||
# CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions
|
||||
db_track.flush() # Flush to ensure changes are in DB (not just in transaction)
|
||||
db_track.commit() # Commit transaction to make changes visible to other sessions
|
||||
logger.info(f"[llm_text_gen] ✅ Successfully tracked fallback usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens (committed)")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG for fallback
|
||||
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")
|
||||
|
||||
725
backend/services/subscription/limit_validation.py
Normal file
725
backend/services/subscription/limit_validation.py
Normal file
@@ -0,0 +1,725 @@
|
||||
"""
|
||||
Limit Validation Module
|
||||
Handles subscription limit checking and validation logic.
|
||||
Extracted from pricing_service.py for better modularity.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List, Tuple, TYPE_CHECKING
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy import text
|
||||
from loguru import logger
|
||||
|
||||
from models.subscription_models import (
|
||||
UserSubscription, UsageSummary, SubscriptionPlan,
|
||||
APIProvider, SubscriptionTier
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .pricing_service import PricingService
|
||||
|
||||
|
||||
class LimitValidator:
|
||||
"""Validates subscription limits for API usage."""
|
||||
|
||||
def __init__(self, pricing_service: 'PricingService'):
|
||||
"""
|
||||
Initialize limit validator with reference to PricingService.
|
||||
|
||||
Args:
|
||||
pricing_service: Instance of PricingService to access helper methods and cache
|
||||
"""
|
||||
self.pricing_service = pricing_service
|
||||
self.db = pricing_service.db
|
||||
|
||||
def check_usage_limits(self, user_id: str, provider: APIProvider,
|
||||
tokens_requested: int = 0, actual_provider_name: Optional[str] = None) -> Tuple[bool, str, Dict[str, Any]]:
|
||||
"""Check if user can make an API call within their limits.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
provider: APIProvider enum (may be MISTRAL for HuggingFace)
|
||||
tokens_requested: Estimated tokens for the request
|
||||
actual_provider_name: Optional actual provider name (e.g., "huggingface" when provider is MISTRAL)
|
||||
|
||||
Returns:
|
||||
(can_proceed, error_message, usage_info)
|
||||
"""
|
||||
try:
|
||||
# Use actual_provider_name if provided, otherwise use enum value
|
||||
# This fixes cases where HuggingFace maps to MISTRAL enum but should show as "huggingface" in errors
|
||||
display_provider_name = actual_provider_name or provider.value
|
||||
|
||||
logger.debug(f"[Subscription Check] Starting limit check for user {user_id}, provider {display_provider_name}, tokens {tokens_requested}")
|
||||
|
||||
# Short TTL cache to reduce DB reads under sustained traffic
|
||||
cache_key = f"{user_id}:{provider.value}"
|
||||
now = datetime.utcnow()
|
||||
cached = self.pricing_service._limits_cache.get(cache_key)
|
||||
if cached and cached.get('expires_at') and cached['expires_at'] > now:
|
||||
logger.debug(f"[Subscription Check] Using cached result for {user_id}:{provider.value}")
|
||||
return tuple(cached['result']) # type: ignore
|
||||
|
||||
# Get user subscription first to check expiration
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id,
|
||||
UserSubscription.is_active == True
|
||||
).first()
|
||||
|
||||
if subscription:
|
||||
logger.debug(f"[Subscription Check] Found subscription for user {user_id}: plan_id={subscription.plan_id}, period_end={subscription.current_period_end}")
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] No active subscription found for user {user_id}")
|
||||
|
||||
# Check subscription expiration (STRICT: deny if expired)
|
||||
if subscription:
|
||||
if subscription.current_period_end < now:
|
||||
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}: period_end={subscription.current_period_end}, now={now}")
|
||||
# Subscription expired - check if auto_renew is enabled
|
||||
if not getattr(subscription, 'auto_renew', False):
|
||||
# Expired and no auto-renew - deny access
|
||||
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}, auto_renew=False, denying access")
|
||||
result = (False, "Subscription expired. Please renew your subscription to continue using the service.", {
|
||||
'expired': True,
|
||||
'period_end': subscription.current_period_end.isoformat()
|
||||
})
|
||||
self.pricing_service._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
else:
|
||||
# Try to auto-renew
|
||||
if not self.pricing_service._ensure_subscription_current(subscription):
|
||||
# Auto-renew failed - deny access
|
||||
result = (False, "Subscription expired and auto-renewal failed. Please renew manually.", {
|
||||
'expired': True,
|
||||
'auto_renew_failed': True
|
||||
})
|
||||
self.pricing_service._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
|
||||
# Get user limits with error handling (STRICT: fail on errors)
|
||||
# CRITICAL: Expire SQLAlchemy objects to ensure we get fresh plan data after renewal
|
||||
try:
|
||||
# Force expire subscription and plan objects to avoid stale cache
|
||||
if subscription and subscription.plan_id:
|
||||
plan_obj = self.db.query(SubscriptionPlan).filter(SubscriptionPlan.id == subscription.plan_id).first()
|
||||
if plan_obj:
|
||||
self.db.expire(plan_obj)
|
||||
logger.debug(f"[Subscription Check] Expired plan object to ensure fresh limits after renewal")
|
||||
|
||||
limits = self.pricing_service.get_user_limits(user_id)
|
||||
if limits:
|
||||
logger.debug(f"[Subscription Check] Retrieved limits for user {user_id}: plan={limits.get('plan_name')}, tier={limits.get('tier')}")
|
||||
# Log token limits for debugging
|
||||
token_limits = limits.get('limits', {})
|
||||
logger.debug(f"[Subscription Check] Token limits: gemini={token_limits.get('gemini_tokens')}, mistral={token_limits.get('mistral_tokens')}, openai={token_limits.get('openai_tokens')}, anthropic={token_limits.get('anthropic_tokens')}")
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] No limits found for user {user_id}, checking free tier")
|
||||
except Exception as e:
|
||||
logger.error(f"[Subscription Check] Error getting user limits for {user_id}: {e}", exc_info=True)
|
||||
# STRICT: Fail closed - deny request if we can't check limits
|
||||
return False, f"Failed to retrieve subscription limits: {str(e)}", {}
|
||||
|
||||
if not limits:
|
||||
# No subscription found - check for free tier
|
||||
free_plan = self.db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.tier == SubscriptionTier.FREE,
|
||||
SubscriptionPlan.is_active == True
|
||||
).first()
|
||||
if free_plan:
|
||||
logger.info(f"[Subscription Check] Assigning free tier to user {user_id}")
|
||||
limits = self.pricing_service._plan_to_limits_dict(free_plan)
|
||||
else:
|
||||
# No subscription and no free tier - deny access
|
||||
logger.warning(f"[Subscription Check] No subscription or free tier found for user {user_id}, denying access")
|
||||
return False, "No subscription plan found. Please subscribe to a plan.", {}
|
||||
|
||||
# Get current usage for this billing period with error handling
|
||||
# CRITICAL: Use fresh queries to avoid SQLAlchemy cache after renewal
|
||||
try:
|
||||
current_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Expire all objects to force fresh read from DB (critical after renewal)
|
||||
self.db.expire_all()
|
||||
|
||||
# Use raw SQL query first to bypass ORM cache, fallback to ORM if SQL fails
|
||||
usage = None
|
||||
try:
|
||||
from sqlalchemy import text
|
||||
sql_query = text("SELECT * FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1")
|
||||
result = self.db.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
|
||||
if result:
|
||||
# Map result to UsageSummary object
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if usage:
|
||||
self.db.refresh(usage) # Ensure fresh data
|
||||
except Exception as sql_error:
|
||||
logger.debug(f"[Subscription Check] Raw SQL query failed, using ORM: {sql_error}")
|
||||
# Fallback to ORM query
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if usage:
|
||||
self.db.refresh(usage) # Ensure fresh data
|
||||
|
||||
if not usage:
|
||||
# First usage this period, create summary
|
||||
try:
|
||||
usage = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
self.db.add(usage)
|
||||
self.db.commit()
|
||||
except Exception as create_error:
|
||||
logger.error(f"Error creating usage summary: {create_error}")
|
||||
self.db.rollback()
|
||||
# STRICT: Fail closed on DB error
|
||||
return False, f"Failed to create usage summary: {str(create_error)}", {}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting usage summary for {user_id}: {e}")
|
||||
self.db.rollback()
|
||||
# STRICT: Fail closed on DB error
|
||||
return False, f"Failed to retrieve usage summary: {str(e)}", {}
|
||||
|
||||
# Check call limits with error handling
|
||||
# NOTE: call_limit = 0 means UNLIMITED (Enterprise plans)
|
||||
try:
|
||||
# Use display_provider_name for error messages, but provider.value for DB queries
|
||||
provider_name = provider.value # For DB field names (e.g., "mistral_calls", "mistral_tokens")
|
||||
|
||||
# For LLM text generation providers, check against unified total_calls limit
|
||||
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
|
||||
is_llm_provider = provider_name in llm_providers
|
||||
|
||||
if is_llm_provider:
|
||||
# Use unified AI text generation limit (total_calls across all LLM providers)
|
||||
ai_text_gen_limit = limits['limits'].get('ai_text_generation_calls', 0) or 0
|
||||
|
||||
# If unified limit not set, fall back to provider-specific limit for backwards compatibility
|
||||
if ai_text_gen_limit == 0:
|
||||
ai_text_gen_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
|
||||
|
||||
# Calculate total LLM provider calls (sum of gemini + openai + anthropic + mistral)
|
||||
current_total_llm_calls = (
|
||||
(usage.gemini_calls or 0) +
|
||||
(usage.openai_calls or 0) +
|
||||
(usage.anthropic_calls or 0) +
|
||||
(usage.mistral_calls or 0)
|
||||
)
|
||||
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if ai_text_gen_limit > 0 and current_total_llm_calls >= ai_text_gen_limit:
|
||||
logger.error(f"[Subscription Check] AI text generation call limit exceeded for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit} (provider: {display_provider_name})")
|
||||
result = (False, f"AI text generation call limit reached. Used {current_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls this billing period.", {
|
||||
'current_calls': current_total_llm_calls,
|
||||
'limit': ai_text_gen_limit,
|
||||
'usage_percentage': (current_total_llm_calls / ai_text_gen_limit) * 100 if ai_text_gen_limit > 0 else 0,
|
||||
'provider': display_provider_name, # Use display name for consistency
|
||||
'usage_info': {
|
||||
'provider': display_provider_name, # Use display name for user-facing info
|
||||
'current_calls': current_total_llm_calls,
|
||||
'limit': ai_text_gen_limit,
|
||||
'type': 'ai_text_generation',
|
||||
'breakdown': {
|
||||
'gemini': usage.gemini_calls or 0,
|
||||
'openai': usage.openai_calls or 0,
|
||||
'anthropic': usage.anthropic_calls or 0,
|
||||
'mistral': usage.mistral_calls or 0 # DB field name (not display name)
|
||||
}
|
||||
}
|
||||
})
|
||||
self.pricing_service._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] AI text generation limit check passed for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit if ai_text_gen_limit > 0 else 'unlimited'} (provider: {display_provider_name})")
|
||||
else:
|
||||
# For non-LLM providers, check provider-specific limit
|
||||
current_calls = getattr(usage, f"{provider_name}_calls", 0) or 0
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
|
||||
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if call_limit > 0 and current_calls >= call_limit:
|
||||
logger.error(f"[Subscription Check] Call limit exceeded for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit}")
|
||||
result = (False, f"API call limit reached for {display_provider_name}. Used {current_calls} of {call_limit} calls this billing period.", {
|
||||
'current_calls': current_calls,
|
||||
'limit': call_limit,
|
||||
'usage_percentage': 100.0,
|
||||
'provider': display_provider_name # Use display name for consistency
|
||||
})
|
||||
self.pricing_service._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] Call limit check passed for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit if call_limit > 0 else 'unlimited'}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking call limits: {e}")
|
||||
# Continue to next check
|
||||
|
||||
# Check token limits for LLM providers with error handling
|
||||
# NOTE: token_limit = 0 means UNLIMITED (Enterprise plans)
|
||||
try:
|
||||
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
current_tokens = getattr(usage, f"{provider_name}_tokens", 0) or 0
|
||||
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
|
||||
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if token_limit > 0 and (current_tokens + tokens_requested) > token_limit:
|
||||
result = (False, f"Token limit would be exceeded for {display_provider_name}. Current: {current_tokens}, Requested: {tokens_requested}, Limit: {token_limit}", {
|
||||
'current_tokens': current_tokens,
|
||||
'requested_tokens': tokens_requested,
|
||||
'limit': token_limit,
|
||||
'usage_percentage': ((current_tokens + tokens_requested) / token_limit) * 100,
|
||||
'provider': display_provider_name, # Use display name in error details
|
||||
'usage_info': {
|
||||
'provider': display_provider_name,
|
||||
'current_tokens': current_tokens,
|
||||
'requested_tokens': tokens_requested,
|
||||
'limit': token_limit,
|
||||
'type': 'tokens'
|
||||
}
|
||||
})
|
||||
self.pricing_service._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking token limits: {e}")
|
||||
# Continue to next check
|
||||
|
||||
# Check cost limits with error handling
|
||||
# NOTE: cost_limit = 0 means UNLIMITED (Enterprise plans)
|
||||
try:
|
||||
cost_limit = limits['limits'].get('monthly_cost', 0) or 0
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if cost_limit > 0 and usage.total_cost >= cost_limit:
|
||||
result = (False, f"Monthly cost limit reached. Current cost: ${usage.total_cost:.2f}, Limit: ${cost_limit:.2f}", {
|
||||
'current_cost': usage.total_cost,
|
||||
'limit': cost_limit,
|
||||
'usage_percentage': 100.0
|
||||
})
|
||||
self.pricing_service._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking cost limits: {e}")
|
||||
# Continue to success case
|
||||
|
||||
# Calculate usage percentages for warnings
|
||||
try:
|
||||
# Determine which call variables to use based on provider type
|
||||
if is_llm_provider:
|
||||
# Use unified LLM call tracking
|
||||
current_call_count = current_total_llm_calls
|
||||
call_limit_value = ai_text_gen_limit
|
||||
else:
|
||||
# Use provider-specific call tracking
|
||||
current_call_count = current_calls
|
||||
call_limit_value = call_limit
|
||||
|
||||
call_usage_pct = (current_call_count / max(call_limit_value, 1)) * 100 if call_limit_value > 0 else 0
|
||||
cost_usage_pct = (usage.total_cost / max(cost_limit, 1)) * 100 if cost_limit > 0 else 0
|
||||
result = (True, "Within limits", {
|
||||
'current_calls': current_call_count,
|
||||
'call_limit': call_limit_value,
|
||||
'call_usage_percentage': call_usage_pct,
|
||||
'current_cost': usage.total_cost,
|
||||
'cost_limit': cost_limit,
|
||||
'cost_usage_percentage': cost_usage_pct
|
||||
})
|
||||
self.pricing_service._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating usage percentages: {e}")
|
||||
# Return basic success
|
||||
return True, "Within limits", {}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in check_usage_limits for {user_id}: {e}")
|
||||
# STRICT: Fail closed - deny requests if subscription system fails
|
||||
return False, f"Subscription check error: {str(e)}", {}
|
||||
|
||||
def check_comprehensive_limits(
|
||||
self,
|
||||
user_id: str,
|
||||
operations: List[Dict[str, Any]]
|
||||
) -> Tuple[bool, Optional[str], Optional[Dict[str, Any]]]:
|
||||
"""
|
||||
Comprehensive pre-flight validation that checks ALL limits before making ANY API calls.
|
||||
|
||||
This prevents wasteful API calls by validating that ALL subsequent operations will succeed
|
||||
before making the first external API call.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
operations: List of operations to validate, each with:
|
||||
- 'provider': APIProvider enum
|
||||
- 'tokens_requested': int (estimated tokens for LLM calls, 0 for non-LLM)
|
||||
- 'actual_provider_name': Optional[str] (e.g., "huggingface" when provider is MISTRAL)
|
||||
- 'operation_type': str (e.g., "google_grounding", "llm_call", "image_generation")
|
||||
|
||||
Returns:
|
||||
(can_proceed, error_message, error_details)
|
||||
If can_proceed is False, error_message explains which limit would be exceeded
|
||||
"""
|
||||
try:
|
||||
logger.info(f"[Pre-flight Check] 🔍 Starting comprehensive validation for user {user_id}")
|
||||
logger.info(f"[Pre-flight Check] 📋 Validating {len(operations)} operation(s) before making any API calls")
|
||||
|
||||
# Get current usage and limits once
|
||||
current_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
logger.info(f"[Pre-flight Check] 📅 Billing Period: {current_period} (for user {user_id})")
|
||||
|
||||
# Explicitly expire any cached objects and refresh from DB to ensure fresh data
|
||||
self.db.expire_all()
|
||||
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
# CRITICAL: Explicitly refresh from database to get latest values (clears SQLAlchemy cache)
|
||||
if usage:
|
||||
self.db.refresh(usage)
|
||||
|
||||
# Log what we actually read from database
|
||||
if usage:
|
||||
logger.info(f"[Pre-flight Check] 📊 Usage Summary from DB (Period: {current_period}):")
|
||||
logger.info(f" ├─ Gemini: {usage.gemini_tokens or 0} tokens / {usage.gemini_calls or 0} calls")
|
||||
logger.info(f" ├─ Mistral/HF: {usage.mistral_tokens or 0} tokens / {usage.mistral_calls or 0} calls")
|
||||
logger.info(f" ├─ Total Tokens: {usage.total_tokens or 0}")
|
||||
logger.info(f" └─ Usage Status: {usage.usage_status.value if usage.usage_status else 'N/A'}")
|
||||
else:
|
||||
logger.info(f"[Pre-flight Check] 📊 No usage summary found for period {current_period} (will create new)")
|
||||
|
||||
if not usage:
|
||||
# First usage this period, create summary
|
||||
try:
|
||||
usage = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
self.db.add(usage)
|
||||
self.db.commit()
|
||||
except Exception as create_error:
|
||||
logger.error(f"Error creating usage summary: {create_error}")
|
||||
self.db.rollback()
|
||||
return False, f"Failed to create usage summary: {str(create_error)}", {}
|
||||
|
||||
# Get user limits
|
||||
limits_dict = self.pricing_service.get_user_limits(user_id)
|
||||
if not limits_dict:
|
||||
# No subscription found - check for free tier
|
||||
free_plan = self.db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.tier == SubscriptionTier.FREE,
|
||||
SubscriptionPlan.is_active == True
|
||||
).first()
|
||||
if free_plan:
|
||||
limits_dict = self.pricing_service._plan_to_limits_dict(free_plan)
|
||||
else:
|
||||
return False, "No subscription plan found. Please subscribe to a plan.", {}
|
||||
|
||||
limits = limits_dict.get('limits', {})
|
||||
|
||||
# Track cumulative usage across all operations
|
||||
total_llm_calls = (
|
||||
(usage.gemini_calls or 0) +
|
||||
(usage.openai_calls or 0) +
|
||||
(usage.anthropic_calls or 0) +
|
||||
(usage.mistral_calls or 0)
|
||||
)
|
||||
total_llm_tokens = {}
|
||||
total_images = usage.stability_calls or 0
|
||||
|
||||
# Log current usage summary
|
||||
logger.info(f"[Pre-flight Check] 📊 Current Usage Summary:")
|
||||
logger.info(f" └─ Total LLM Calls: {total_llm_calls}")
|
||||
logger.info(f" └─ Gemini Tokens: {usage.gemini_tokens or 0}, Mistral/HF Tokens: {usage.mistral_tokens or 0}")
|
||||
logger.info(f" └─ Image Calls: {total_images}")
|
||||
|
||||
# Validate each operation
|
||||
for op_idx, operation in enumerate(operations):
|
||||
provider = operation.get('provider')
|
||||
provider_name = provider.value if hasattr(provider, 'value') else str(provider)
|
||||
tokens_requested = operation.get('tokens_requested', 0)
|
||||
actual_provider_name = operation.get('actual_provider_name')
|
||||
operation_type = operation.get('operation_type', 'unknown')
|
||||
|
||||
display_provider_name = actual_provider_name or provider_name
|
||||
|
||||
logger.error(f"[Pre-flight Check] ✅ Operation {op_idx + 1}/{len(operations)}: {operation_type}")
|
||||
logger.error(f" ├─ Provider: {display_provider_name} (enum: {provider_name})")
|
||||
logger.error(f" ├─ Operation Index: {op_idx}")
|
||||
logger.error(f" └─ Estimated Tokens Requested: {tokens_requested}")
|
||||
|
||||
# Check if this is an LLM provider
|
||||
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
|
||||
is_llm_provider = provider_name in llm_providers
|
||||
|
||||
# Check unified AI text generation limit for LLM providers
|
||||
if is_llm_provider:
|
||||
ai_text_gen_limit = limits.get('ai_text_generation_calls', 0) or 0
|
||||
if ai_text_gen_limit == 0:
|
||||
# Fallback to provider-specific limit
|
||||
ai_text_gen_limit = limits.get(f"{provider_name}_calls", 0) or 0
|
||||
|
||||
# Count this operation as an LLM call
|
||||
projected_total_llm_calls = total_llm_calls + 1
|
||||
|
||||
if ai_text_gen_limit > 0 and projected_total_llm_calls > ai_text_gen_limit:
|
||||
error_info = {
|
||||
'current_calls': total_llm_calls,
|
||||
'limit': ai_text_gen_limit,
|
||||
'provider': display_provider_name,
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
return False, f"AI text generation call limit would be exceeded. Would use {projected_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls.", {
|
||||
'error_type': 'call_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
|
||||
# Check token limits for this provider
|
||||
# CRITICAL: Always query fresh from DB for each operation to avoid SQLAlchemy cache issues
|
||||
# This ensures we get the latest values after subscription renewal, even for cumulative tracking
|
||||
provider_tokens_key = f"{provider_name}_tokens"
|
||||
|
||||
# Try to get fresh value from DB with comprehensive error handling
|
||||
base_current_tokens = 0
|
||||
query_succeeded = False
|
||||
|
||||
try:
|
||||
# Validate column name is safe (only allow known provider token columns)
|
||||
valid_token_columns = ['gemini_tokens', 'openai_tokens', 'anthropic_tokens', 'mistral_tokens']
|
||||
|
||||
if provider_tokens_key not in valid_token_columns:
|
||||
logger.error(f" └─ Invalid provider tokens key: {provider_tokens_key}")
|
||||
query_succeeded = True # Treat as success with 0 value
|
||||
else:
|
||||
# Method 1: Try raw SQL query to completely bypass ORM cache
|
||||
try:
|
||||
logger.debug(f" └─ Attempting raw SQL query for {provider_tokens_key}")
|
||||
sql_query = text(f"""
|
||||
SELECT {provider_tokens_key}
|
||||
FROM usage_summaries
|
||||
WHERE user_id = :user_id
|
||||
AND billing_period = :period
|
||||
LIMIT 1
|
||||
""")
|
||||
|
||||
logger.debug(f" └─ SQL: SELECT {provider_tokens_key} FROM usage_summaries WHERE user_id={user_id} AND billing_period={current_period}")
|
||||
|
||||
result = self.db.execute(sql_query, {
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
}).first()
|
||||
|
||||
if result:
|
||||
base_current_tokens = result[0] if result[0] is not None else 0
|
||||
logger.error(f"[Pre-flight Check] ✅ Raw SQL query returned result: {result[0]} -> {base_current_tokens}")
|
||||
else:
|
||||
base_current_tokens = 0
|
||||
logger.error(f"[Pre-flight Check] ⚠️ Raw SQL query returned None (no rows found)")
|
||||
|
||||
query_succeeded = True
|
||||
logger.error(f"[Pre-flight Check] ✅ Raw SQL query succeeded for {provider_tokens_key}: {base_current_tokens}")
|
||||
|
||||
except Exception as sql_error:
|
||||
logger.error(f" └─ Raw SQL query failed for {provider_tokens_key}: {type(sql_error).__name__}: {sql_error}", exc_info=True)
|
||||
query_succeeded = False # Will try ORM fallback
|
||||
|
||||
# Method 2: Fallback to fresh ORM query if raw SQL fails
|
||||
if not query_succeeded:
|
||||
try:
|
||||
# Expire all cached objects and do fresh query
|
||||
self.db.expire_all()
|
||||
fresh_usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if fresh_usage:
|
||||
# Explicitly refresh to get latest from DB
|
||||
self.db.refresh(fresh_usage)
|
||||
base_current_tokens = getattr(fresh_usage, provider_tokens_key, 0) or 0
|
||||
else:
|
||||
base_current_tokens = 0
|
||||
|
||||
query_succeeded = True
|
||||
logger.info(f"[Pre-flight Check] ✅ ORM fallback query succeeded for {provider_tokens_key}: {base_current_tokens}")
|
||||
|
||||
except Exception as orm_error:
|
||||
logger.error(f" └─ ORM query also failed: {orm_error}", exc_info=True)
|
||||
query_succeeded = False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" └─ Unexpected error getting tokens from DB for {provider_tokens_key}: {e}", exc_info=True)
|
||||
base_current_tokens = 0 # Fail safe - assume 0 if we can't query
|
||||
|
||||
if not query_succeeded:
|
||||
logger.warning(f" └─ Both query methods failed, using 0 as fallback")
|
||||
|
||||
# CRITICAL LOG: Always log what we got from DB - this helps debug renewal issues
|
||||
# Use ERROR level to ensure it shows even if INFO is filtered
|
||||
logger.error(f"[Pre-flight Check] 🔍 Fresh DB Query for {display_provider_name}:")
|
||||
logger.error(f" ├─ Column: {provider_tokens_key}")
|
||||
logger.error(f" ├─ Billing Period: {current_period}")
|
||||
logger.error(f" ├─ User ID: {user_id}")
|
||||
logger.error(f" ├─ Method: {'Raw SQL' if query_succeeded and base_current_tokens >= 0 else 'ORM' if query_succeeded else 'Failed - using 0'}")
|
||||
logger.error(f" └─ Value from DB: {base_current_tokens}")
|
||||
|
||||
# Add any projected tokens from previous operations in this validation run
|
||||
# Note: total_llm_tokens tracks ONLY projected tokens from this run, not base DB value
|
||||
projected_from_previous = total_llm_tokens.get(provider_tokens_key, 0)
|
||||
|
||||
# Current tokens = base from DB + projected from previous operations in this run
|
||||
current_provider_tokens = base_current_tokens + projected_from_previous
|
||||
|
||||
# Use ERROR level to ensure visibility
|
||||
logger.error(f"[Pre-flight Check] 📊 Token Calculation for {display_provider_name}:")
|
||||
logger.error(f" ├─ Base from DB (fresh query): {base_current_tokens}")
|
||||
logger.error(f" ├─ Projected from previous ops in this run: {projected_from_previous}")
|
||||
logger.error(f" └─ Total current tokens (base + projected): {current_provider_tokens}")
|
||||
|
||||
# Also check the initial usage object to see if it's being used incorrectly
|
||||
if usage and hasattr(usage, provider_tokens_key):
|
||||
initial_usage_value = getattr(usage, provider_tokens_key, 0) or 0
|
||||
logger.error(f" ⚠️ Initial usage object value: {initial_usage_value} (this should NOT be used for fresh query)")
|
||||
|
||||
token_limit = limits.get(provider_tokens_key, 0) or 0
|
||||
|
||||
if token_limit > 0 and tokens_requested > 0:
|
||||
projected_tokens = current_provider_tokens + tokens_requested
|
||||
logger.info(f" └─ Token Check: {current_provider_tokens} (current) + {tokens_requested} (requested) = {projected_tokens} (total) / {token_limit} (limit)")
|
||||
|
||||
if projected_tokens > token_limit:
|
||||
usage_percentage = (projected_tokens / token_limit) * 100 if token_limit > 0 else 0
|
||||
error_info = {
|
||||
'current_tokens': current_provider_tokens,
|
||||
'base_tokens_from_db': base_current_tokens,
|
||||
'projected_from_previous_ops': projected_from_previous,
|
||||
'requested_tokens': tokens_requested,
|
||||
'limit': token_limit,
|
||||
'provider': display_provider_name,
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
# Make error message clearer: show actual DB usage vs projected
|
||||
if projected_from_previous > 0:
|
||||
error_msg = (
|
||||
f"Token limit exceeded for {display_provider_name} "
|
||||
f"({operation_type}). "
|
||||
f"Base usage: {base_current_tokens}/{token_limit}, "
|
||||
f"After previous operations in this workflow: {current_provider_tokens}/{token_limit}, "
|
||||
f"This operation would add: {tokens_requested}, "
|
||||
f"Total would be: {projected_tokens} (exceeds by {projected_tokens - token_limit} tokens)"
|
||||
)
|
||||
else:
|
||||
error_msg = (
|
||||
f"Token limit exceeded for {display_provider_name} "
|
||||
f"({operation_type}). "
|
||||
f"Current: {current_provider_tokens}/{token_limit}, "
|
||||
f"Requested: {tokens_requested}, "
|
||||
f"Would exceed by: {projected_tokens - token_limit} tokens "
|
||||
f"({usage_percentage:.1f}% of limit)"
|
||||
)
|
||||
logger.error(f"[Pre-flight Check] ❌ BLOCKED: {error_msg}")
|
||||
return False, error_msg, {
|
||||
'error_type': 'token_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
else:
|
||||
logger.info(f" └─ ✅ Token limit check passed: {projected_tokens} <= {token_limit}")
|
||||
|
||||
# Update cumulative counts for next operation
|
||||
total_llm_calls = projected_total_llm_calls
|
||||
# Update cumulative projected tokens from this validation run
|
||||
# This represents only projected tokens from previous operations in this run
|
||||
# Base DB value is always queried fresh, so we only track the projection delta
|
||||
old_projected = total_llm_tokens.get(provider_tokens_key, 0)
|
||||
if tokens_requested > 0:
|
||||
# Add this operation's tokens to cumulative projected tokens
|
||||
total_llm_tokens[provider_tokens_key] = projected_from_previous + tokens_requested
|
||||
logger.error(f"[Pre-flight Check] 📝 Updated cumulative projected tokens for {display_provider_name}:")
|
||||
logger.error(f" ├─ Previous projected: {projected_from_previous}")
|
||||
logger.error(f" ├─ This operation requested: {tokens_requested}")
|
||||
logger.error(f" ├─ New cumulative projected: {total_llm_tokens[provider_tokens_key]}")
|
||||
logger.error(f" └─ Old value in dict was: {old_projected}")
|
||||
else:
|
||||
# No tokens requested, keep existing projected tokens (or 0 if first operation)
|
||||
total_llm_tokens[provider_tokens_key] = projected_from_previous
|
||||
logger.error(f"[Pre-flight Check] 📝 No tokens requested, keeping projected at: {projected_from_previous}")
|
||||
|
||||
# Check image generation limits
|
||||
elif provider == APIProvider.STABILITY:
|
||||
image_limit = limits.get('stability_calls', 0) or 0
|
||||
projected_images = total_images + 1
|
||||
|
||||
if image_limit > 0 and projected_images > image_limit:
|
||||
error_info = {
|
||||
'current_images': total_images,
|
||||
'limit': image_limit,
|
||||
'provider': 'stability',
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
return False, f"Image generation limit would be exceeded. Would use {projected_images} of {image_limit} images this billing period.", {
|
||||
'error_type': 'image_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
|
||||
total_images = projected_images
|
||||
|
||||
# Check other provider-specific limits
|
||||
else:
|
||||
provider_calls_key = f"{provider_name}_calls"
|
||||
current_provider_calls = getattr(usage, provider_calls_key, 0) or 0
|
||||
call_limit = limits.get(provider_calls_key, 0) or 0
|
||||
|
||||
if call_limit > 0:
|
||||
projected_calls = current_provider_calls + 1
|
||||
if projected_calls > call_limit:
|
||||
error_info = {
|
||||
'current_calls': current_provider_calls,
|
||||
'limit': call_limit,
|
||||
'provider': display_provider_name,
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
return False, f"API call limit would be exceeded for {display_provider_name}. Would use {projected_calls} of {call_limit} calls this billing period.", {
|
||||
'error_type': 'call_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
|
||||
# All checks passed
|
||||
logger.info(f"[Pre-flight Check] ✅ All {len(operations)} operation(s) validated successfully")
|
||||
logger.info(f"[Pre-flight Check] ✅ User {user_id} is cleared to proceed with API calls")
|
||||
return True, None, None
|
||||
|
||||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
error_message = str(e)
|
||||
logger.error(f"[Pre-flight Check] ❌ Error during comprehensive limit check: {error_type}: {error_message}", exc_info=True)
|
||||
logger.error(f"[Pre-flight Check] ❌ User: {user_id}, Operations count: {len(operations) if operations else 0}")
|
||||
return False, f"Failed to validate limits: {error_type}: {error_message}", {}
|
||||
|
||||
@@ -44,15 +44,17 @@ def validate_research_operations(
|
||||
llm_provider_name = "gemini"
|
||||
|
||||
# Estimate tokens for each operation in research workflow
|
||||
# Google Grounding call: ~2000 tokens (input + output)
|
||||
# Google Grounding call: ~1200 tokens (input: ~500 tokens, output: ~700 tokens for research results)
|
||||
# Keyword analyzer: ~1000 tokens (input: 3000 chars research, output: structured JSON)
|
||||
# Competitor analyzer: ~1000 tokens (input: 3000 chars research, output: structured JSON)
|
||||
# Content angle generator: ~1000 tokens (input: 3000 chars research, output: list of angles)
|
||||
# Note: These are conservative estimates. Actual usage may be lower, but we use these for pre-flight validation
|
||||
# to prevent wasteful API calls if the workflow would exceed limits.
|
||||
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': APIProvider.GEMINI, # Google Grounding uses Gemini
|
||||
'tokens_requested': 2000,
|
||||
'tokens_requested': 1200, # Reduced from 2000 to more realistic estimate
|
||||
'actual_provider_name': 'gemini',
|
||||
'operation_type': 'google_grounding'
|
||||
},
|
||||
@@ -126,6 +128,120 @@ def validate_research_operations(
|
||||
)
|
||||
|
||||
|
||||
def validate_exa_research_operations(
|
||||
pricing_service: PricingService,
|
||||
user_id: str,
|
||||
gpt_provider: str = "google"
|
||||
) -> None:
|
||||
"""
|
||||
Validate all operations for an Exa research workflow before making ANY API calls.
|
||||
|
||||
This prevents wasteful external API calls (like Exa search) if subsequent
|
||||
LLM calls would fail due to token or call limits.
|
||||
|
||||
Args:
|
||||
pricing_service: PricingService instance
|
||||
user_id: User ID for subscription checking
|
||||
gpt_provider: GPT provider from env var (defaults to "google")
|
||||
|
||||
Returns:
|
||||
None
|
||||
If validation fails, raises HTTPException with 429 status
|
||||
"""
|
||||
try:
|
||||
# Determine actual provider for LLM calls based on GPT_PROVIDER env var
|
||||
gpt_provider_lower = gpt_provider.lower()
|
||||
if gpt_provider_lower == "huggingface":
|
||||
llm_provider_enum = APIProvider.MISTRAL # Maps to HuggingFace
|
||||
llm_provider_name = "huggingface"
|
||||
else:
|
||||
llm_provider_enum = APIProvider.GEMINI
|
||||
llm_provider_name = "gemini"
|
||||
|
||||
# Estimate tokens for each operation in Exa research workflow
|
||||
# Exa Search call: 1 Exa API call (not token-based)
|
||||
# Keyword analyzer: ~1000 tokens (input: research results, output: structured JSON)
|
||||
# Competitor analyzer: ~1000 tokens (input: research results, output: structured JSON)
|
||||
# Content angle generator: ~1000 tokens (input: research results, output: list of angles)
|
||||
# Note: These are conservative estimates for pre-flight validation
|
||||
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': APIProvider.EXA, # Exa API call
|
||||
'tokens_requested': 0,
|
||||
'actual_provider_name': 'exa',
|
||||
'operation_type': 'exa_neural_search'
|
||||
},
|
||||
{
|
||||
'provider': llm_provider_enum,
|
||||
'tokens_requested': 1000,
|
||||
'actual_provider_name': llm_provider_name,
|
||||
'operation_type': 'keyword_analysis'
|
||||
},
|
||||
{
|
||||
'provider': llm_provider_enum,
|
||||
'tokens_requested': 1000,
|
||||
'actual_provider_name': llm_provider_name,
|
||||
'operation_type': 'competitor_analysis'
|
||||
},
|
||||
{
|
||||
'provider': llm_provider_enum,
|
||||
'tokens_requested': 1000,
|
||||
'actual_provider_name': llm_provider_name,
|
||||
'operation_type': 'content_angle_generation'
|
||||
}
|
||||
]
|
||||
|
||||
logger.info(f"[Pre-flight Validator] 🚀 Starting Exa Research Workflow Validation")
|
||||
logger.info(f" ├─ User: {user_id}")
|
||||
logger.info(f" ├─ LLM Provider: {llm_provider_name} (GPT_PROVIDER={gpt_provider})")
|
||||
logger.info(f" └─ Operations to validate: {len(operations_to_validate)}")
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
usage_info = error_details.get('usage_info', {}) if error_details else {}
|
||||
provider = usage_info.get('provider', llm_provider_name) if usage_info else llm_provider_name
|
||||
operation_type = usage_info.get('operation_type', 'unknown')
|
||||
|
||||
logger.error(f"[Pre-flight Validator] ❌ EXA RESEARCH WORKFLOW BLOCKED")
|
||||
logger.error(f" ├─ User: {user_id}")
|
||||
logger.error(f" ├─ Blocked at: {operation_type}")
|
||||
logger.error(f" ├─ Provider: {provider}")
|
||||
logger.error(f" └─ Reason: {message}")
|
||||
|
||||
# Raise HTTPException immediately - frontend gets immediate response, no API calls made
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': provider,
|
||||
'usage_info': usage_info if usage_info else error_details
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[Pre-flight Validator] ✅ EXA RESEARCH WORKFLOW APPROVED")
|
||||
logger.info(f" ├─ User: {user_id}")
|
||||
logger.info(f" └─ All {len(operations_to_validate)} operations validated - proceeding with API calls")
|
||||
# Validation passed - no return needed (function raises HTTPException if validation fails)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Validator] Error validating Exa research operations: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
'error': f"Failed to validate operations: {str(e)}",
|
||||
'message': f"Failed to validate operations: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def validate_image_generation_operations(
|
||||
pricing_service: PricingService,
|
||||
user_id: str
|
||||
|
||||
@@ -258,6 +258,12 @@ class PricingService:
|
||||
"model_name": "stable-diffusion",
|
||||
"cost_per_image": 0.04, # $0.04 per image
|
||||
"description": "Stability AI Image Generation"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.EXA,
|
||||
"model_name": "exa-search",
|
||||
"cost_per_request": 0.005, # $0.005 per search (1-25 results)
|
||||
"description": "Exa Neural Search API"
|
||||
}
|
||||
]
|
||||
|
||||
@@ -296,6 +302,7 @@ class PricingService:
|
||||
"metaphor_calls_limit": 10,
|
||||
"firecrawl_calls_limit": 10,
|
||||
"stability_calls_limit": 5,
|
||||
"exa_calls_limit": 100,
|
||||
"gemini_tokens_limit": 100000,
|
||||
"monthly_cost_limit": 0.0,
|
||||
"features": ["basic_content_generation", "limited_research"],
|
||||
@@ -316,10 +323,11 @@ class PricingService:
|
||||
"metaphor_calls_limit": 100,
|
||||
"firecrawl_calls_limit": 100,
|
||||
"stability_calls_limit": 5,
|
||||
"gemini_tokens_limit": 2000,
|
||||
"openai_tokens_limit": 2000,
|
||||
"anthropic_tokens_limit": 2000,
|
||||
"mistral_tokens_limit": 2000,
|
||||
"exa_calls_limit": 500,
|
||||
"gemini_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
"openai_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
"anthropic_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
"mistral_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
"monthly_cost_limit": 50.0,
|
||||
"features": ["full_content_generation", "advanced_research", "basic_analytics"],
|
||||
"description": "Great for individuals and small teams"
|
||||
@@ -338,6 +346,7 @@ class PricingService:
|
||||
"metaphor_calls_limit": 500,
|
||||
"firecrawl_calls_limit": 500,
|
||||
"stability_calls_limit": 200,
|
||||
"exa_calls_limit": 2000,
|
||||
"gemini_tokens_limit": 5000000,
|
||||
"openai_tokens_limit": 2500000,
|
||||
"anthropic_tokens_limit": 1000000,
|
||||
@@ -360,6 +369,7 @@ class PricingService:
|
||||
"metaphor_calls_limit": 0,
|
||||
"firecrawl_calls_limit": 0,
|
||||
"stability_calls_limit": 0,
|
||||
"exa_calls_limit": 0, # Unlimited
|
||||
"gemini_tokens_limit": 0,
|
||||
"openai_tokens_limit": 0,
|
||||
"anthropic_tokens_limit": 0,
|
||||
@@ -423,11 +433,14 @@ class PricingService:
|
||||
def get_user_limits(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get usage limits for a user based on their subscription."""
|
||||
|
||||
# CRITICAL: Expire all objects first to ensure fresh data after renewal
|
||||
self.db.expire_all()
|
||||
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id,
|
||||
UserSubscription.is_active == True
|
||||
).first()
|
||||
|
||||
|
||||
if not subscription:
|
||||
# Return free tier limits
|
||||
free_plan = self.db.query(SubscriptionPlan).filter(
|
||||
@@ -439,7 +452,23 @@ class PricingService:
|
||||
|
||||
# Ensure current period before returning limits
|
||||
self._ensure_subscription_current(subscription)
|
||||
return self._plan_to_limits_dict(subscription.plan)
|
||||
|
||||
# CRITICAL: Refresh subscription to get latest plan_id, then refresh plan relationship
|
||||
self.db.refresh(subscription)
|
||||
|
||||
# Re-query plan directly to ensure fresh data (bypass relationship cache)
|
||||
plan = self.db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.id == subscription.plan_id
|
||||
).first()
|
||||
|
||||
if not plan:
|
||||
logger.error(f"Plan not found for subscription plan_id={subscription.plan_id}")
|
||||
return None
|
||||
|
||||
# Refresh plan to ensure fresh limits
|
||||
self.db.refresh(plan)
|
||||
|
||||
return self._plan_to_limits_dict(plan)
|
||||
|
||||
def _ensure_ai_text_gen_column_detection(self) -> None:
|
||||
"""Detect at runtime whether ai_text_generation_calls_limit column exists and cache the result."""
|
||||
@@ -508,290 +537,20 @@ class PricingService:
|
||||
tokens_requested: int = 0, actual_provider_name: Optional[str] = None) -> Tuple[bool, str, Dict[str, Any]]:
|
||||
"""Check if user can make an API call within their limits.
|
||||
|
||||
Delegates to LimitValidator for actual validation logic.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
provider: APIProvider enum (may be MISTRAL for HuggingFace)
|
||||
tokens_requested: Estimated tokens for the request
|
||||
actual_provider_name: Optional actual provider name (e.g., "huggingface" when provider is MISTRAL)
|
||||
"""
|
||||
try:
|
||||
# Use actual_provider_name if provided, otherwise use enum value
|
||||
# This fixes cases where HuggingFace maps to MISTRAL enum but should show as "huggingface" in errors
|
||||
display_provider_name = actual_provider_name or provider.value
|
||||
|
||||
logger.debug(f"[Subscription Check] Starting limit check for user {user_id}, provider {display_provider_name}, tokens {tokens_requested}")
|
||||
|
||||
# Short TTL cache to reduce DB reads under sustained traffic
|
||||
cache_key = f"{user_id}:{provider.value}"
|
||||
now = datetime.utcnow()
|
||||
cached = self._limits_cache.get(cache_key)
|
||||
if cached and cached.get('expires_at') and cached['expires_at'] > now:
|
||||
logger.debug(f"[Subscription Check] Using cached result for {user_id}:{provider.value}")
|
||||
return tuple(cached['result']) # type: ignore
|
||||
|
||||
# Get user subscription first to check expiration
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id,
|
||||
UserSubscription.is_active == True
|
||||
).first()
|
||||
|
||||
if subscription:
|
||||
logger.debug(f"[Subscription Check] Found subscription for user {user_id}: plan_id={subscription.plan_id}, period_end={subscription.current_period_end}")
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] No active subscription found for user {user_id}")
|
||||
|
||||
# Check subscription expiration (STRICT: deny if expired)
|
||||
if subscription:
|
||||
if subscription.current_period_end < now:
|
||||
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}: period_end={subscription.current_period_end}, now={now}")
|
||||
# Subscription expired - check if auto_renew is enabled
|
||||
if not getattr(subscription, 'auto_renew', False):
|
||||
# Expired and no auto-renew - deny access
|
||||
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}, auto_renew=False, denying access")
|
||||
result = (False, "Subscription expired. Please renew your subscription to continue using the service.", {
|
||||
'expired': True,
|
||||
'period_end': subscription.current_period_end.isoformat()
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
else:
|
||||
# Try to auto-renew
|
||||
if not self._ensure_subscription_current(subscription):
|
||||
# Auto-renew failed - deny access
|
||||
result = (False, "Subscription expired and auto-renewal failed. Please renew manually.", {
|
||||
'expired': True,
|
||||
'auto_renew_failed': True
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
|
||||
# Get user limits with error handling (STRICT: fail on errors)
|
||||
try:
|
||||
limits = self.get_user_limits(user_id)
|
||||
if limits:
|
||||
logger.debug(f"[Subscription Check] Retrieved limits for user {user_id}: plan={limits.get('plan_name')}, tier={limits.get('tier')}")
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] No limits found for user {user_id}, checking free tier")
|
||||
except Exception as e:
|
||||
logger.error(f"[Subscription Check] Error getting user limits for {user_id}: {e}", exc_info=True)
|
||||
# STRICT: Fail closed - deny request if we can't check limits
|
||||
return False, f"Failed to retrieve subscription limits: {str(e)}", {}
|
||||
|
||||
if not limits:
|
||||
# No subscription found - check for free tier
|
||||
free_plan = self.db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.tier == SubscriptionTier.FREE,
|
||||
SubscriptionPlan.is_active == True
|
||||
).first()
|
||||
if free_plan:
|
||||
logger.info(f"[Subscription Check] Assigning free tier to user {user_id}")
|
||||
limits = self._plan_to_limits_dict(free_plan)
|
||||
else:
|
||||
# No subscription and no free tier - deny access
|
||||
logger.warning(f"[Subscription Check] No subscription or free tier found for user {user_id}, denying access")
|
||||
return False, "No subscription plan found. Please subscribe to a plan.", {}
|
||||
|
||||
# Get current usage for this billing period with error handling
|
||||
try:
|
||||
current_period = self.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not usage:
|
||||
# First usage this period, create summary
|
||||
try:
|
||||
usage = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
self.db.add(usage)
|
||||
self.db.commit()
|
||||
except Exception as create_error:
|
||||
logger.error(f"Error creating usage summary: {create_error}")
|
||||
self.db.rollback()
|
||||
# STRICT: Fail closed on DB error
|
||||
return False, f"Failed to create usage summary: {str(create_error)}", {}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting usage summary for {user_id}: {e}")
|
||||
self.db.rollback()
|
||||
# STRICT: Fail closed on DB error
|
||||
return False, f"Failed to retrieve usage summary: {str(e)}", {}
|
||||
|
||||
# Check call limits with error handling
|
||||
# NOTE: call_limit = 0 means UNLIMITED (Enterprise plans)
|
||||
try:
|
||||
# Use display_provider_name for error messages, but provider.value for DB queries
|
||||
provider_name = provider.value # For DB field names (e.g., "mistral_calls", "mistral_tokens")
|
||||
|
||||
# For LLM text generation providers, check against unified total_calls limit
|
||||
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
|
||||
is_llm_provider = provider_name in llm_providers
|
||||
|
||||
if is_llm_provider:
|
||||
# Use unified AI text generation limit (total_calls across all LLM providers)
|
||||
ai_text_gen_limit = limits['limits'].get('ai_text_generation_calls', 0) or 0
|
||||
|
||||
# If unified limit not set, fall back to provider-specific limit for backwards compatibility
|
||||
if ai_text_gen_limit == 0:
|
||||
ai_text_gen_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
|
||||
|
||||
# Calculate total LLM provider calls (sum of gemini + openai + anthropic + mistral)
|
||||
current_total_llm_calls = (
|
||||
(usage.gemini_calls or 0) +
|
||||
(usage.openai_calls or 0) +
|
||||
(usage.anthropic_calls or 0) +
|
||||
(usage.mistral_calls or 0)
|
||||
)
|
||||
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if ai_text_gen_limit > 0 and current_total_llm_calls >= ai_text_gen_limit:
|
||||
logger.error(f"[Subscription Check] AI text generation call limit exceeded for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit} (provider: {display_provider_name})")
|
||||
result = (False, f"AI text generation call limit reached. Used {current_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls this billing period.", {
|
||||
'current_calls': current_total_llm_calls,
|
||||
'limit': ai_text_gen_limit,
|
||||
'usage_percentage': (current_total_llm_calls / ai_text_gen_limit) * 100 if ai_text_gen_limit > 0 else 0,
|
||||
'provider': display_provider_name, # Use display name for consistency
|
||||
'usage_info': {
|
||||
'provider': display_provider_name, # Use display name for user-facing info
|
||||
'current_calls': current_total_llm_calls,
|
||||
'limit': ai_text_gen_limit,
|
||||
'type': 'ai_text_generation',
|
||||
'breakdown': {
|
||||
'gemini': usage.gemini_calls or 0,
|
||||
'openai': usage.openai_calls or 0,
|
||||
'anthropic': usage.anthropic_calls or 0,
|
||||
'mistral': usage.mistral_calls or 0 # DB field name (not display name)
|
||||
}
|
||||
}
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] AI text generation limit check passed for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit if ai_text_gen_limit > 0 else 'unlimited'} (provider: {display_provider_name})")
|
||||
else:
|
||||
# For non-LLM providers, check provider-specific limit
|
||||
current_calls = getattr(usage, f"{provider_name}_calls", 0) or 0
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
|
||||
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if call_limit > 0 and current_calls >= call_limit:
|
||||
logger.error(f"[Subscription Check] Call limit exceeded for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit}")
|
||||
result = (False, f"API call limit reached for {display_provider_name}. Used {current_calls} of {call_limit} calls this billing period.", {
|
||||
'current_calls': current_calls,
|
||||
'limit': call_limit,
|
||||
'usage_percentage': 100.0,
|
||||
'provider': display_provider_name # Use display name for consistency
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] Call limit check passed for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit if call_limit > 0 else 'unlimited'}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking call limits: {e}")
|
||||
# Continue to next check
|
||||
|
||||
# Check token limits for LLM providers with error handling
|
||||
# NOTE: token_limit = 0 means UNLIMITED (Enterprise plans)
|
||||
try:
|
||||
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
current_tokens = getattr(usage, f"{provider_name}_tokens", 0) or 0
|
||||
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
|
||||
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if token_limit > 0 and (current_tokens + tokens_requested) > token_limit:
|
||||
result = (False, f"Token limit would be exceeded for {display_provider_name}. Current: {current_tokens}, Requested: {tokens_requested}, Limit: {token_limit}", {
|
||||
'current_tokens': current_tokens,
|
||||
'requested_tokens': tokens_requested,
|
||||
'limit': token_limit,
|
||||
'usage_percentage': ((current_tokens + tokens_requested) / token_limit) * 100,
|
||||
'provider': display_provider_name, # Use display name in error details
|
||||
'usage_info': {
|
||||
'provider': display_provider_name,
|
||||
'current_tokens': current_tokens,
|
||||
'requested_tokens': tokens_requested,
|
||||
'limit': token_limit,
|
||||
'type': 'tokens'
|
||||
}
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking token limits: {e}")
|
||||
# Continue to next check
|
||||
|
||||
# Check cost limits with error handling
|
||||
# NOTE: cost_limit = 0 means UNLIMITED (Enterprise plans)
|
||||
try:
|
||||
cost_limit = limits['limits'].get('monthly_cost', 0) or 0
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if cost_limit > 0 and usage.total_cost >= cost_limit:
|
||||
result = (False, f"Monthly cost limit reached. Current cost: ${usage.total_cost:.2f}, Limit: ${cost_limit:.2f}", {
|
||||
'current_cost': usage.total_cost,
|
||||
'limit': cost_limit,
|
||||
'usage_percentage': 100.0
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking cost limits: {e}")
|
||||
# Continue to success case
|
||||
|
||||
# Calculate usage percentages for warnings
|
||||
try:
|
||||
# Determine which call variables to use based on provider type
|
||||
if is_llm_provider:
|
||||
# Use unified LLM call tracking
|
||||
current_call_count = current_total_llm_calls
|
||||
call_limit_value = ai_text_gen_limit
|
||||
else:
|
||||
# Use provider-specific call tracking
|
||||
current_call_count = current_calls
|
||||
call_limit_value = call_limit
|
||||
|
||||
call_usage_pct = (current_call_count / max(call_limit_value, 1)) * 100 if call_limit_value > 0 else 0
|
||||
cost_usage_pct = (usage.total_cost / max(cost_limit, 1)) * 100 if cost_limit > 0 else 0
|
||||
result = (True, "Within limits", {
|
||||
'current_calls': current_call_count,
|
||||
'call_limit': call_limit_value,
|
||||
'call_usage_percentage': call_usage_pct,
|
||||
'current_cost': usage.total_cost,
|
||||
'cost_limit': cost_limit,
|
||||
'cost_usage_percentage': cost_usage_pct
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating usage percentages: {e}")
|
||||
# Return basic success
|
||||
return True, "Within limits", {}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in check_usage_limits for {user_id}: {e}")
|
||||
# STRICT: Fail closed - deny requests if subscription system fails
|
||||
return False, f"Subscription check error: {str(e)}", {}
|
||||
Returns:
|
||||
(can_proceed, error_message, usage_info)
|
||||
"""
|
||||
from .limit_validation import LimitValidator
|
||||
validator = LimitValidator(self)
|
||||
return validator.check_usage_limits(user_id, provider, tokens_requested, actual_provider_name)
|
||||
|
||||
def estimate_tokens(self, text: str, provider: APIProvider) -> int:
|
||||
"""Estimate token count for text based on provider."""
|
||||
@@ -827,6 +586,16 @@ class PricingService:
|
||||
if not pricing:
|
||||
return None
|
||||
|
||||
# Return pricing info as dict
|
||||
return {
|
||||
'provider': pricing.provider.value,
|
||||
'model_name': pricing.model_name,
|
||||
'cost_per_input_token': pricing.cost_per_input_token,
|
||||
'cost_per_output_token': pricing.cost_per_output_token,
|
||||
'cost_per_request': pricing.cost_per_request,
|
||||
'description': pricing.description
|
||||
}
|
||||
|
||||
def check_comprehensive_limits(
|
||||
self,
|
||||
user_id: str,
|
||||
@@ -835,6 +604,7 @@ class PricingService:
|
||||
"""
|
||||
Comprehensive pre-flight validation that checks ALL limits before making ANY API calls.
|
||||
|
||||
Delegates to LimitValidator for actual validation logic.
|
||||
This prevents wasteful API calls by validating that ALL subsequent operations will succeed
|
||||
before making the first external API call.
|
||||
|
||||
@@ -850,202 +620,9 @@ class PricingService:
|
||||
(can_proceed, error_message, error_details)
|
||||
If can_proceed is False, error_message explains which limit would be exceeded
|
||||
"""
|
||||
try:
|
||||
logger.info(f"[Pre-flight Check] 🔍 Starting comprehensive validation for user {user_id}")
|
||||
logger.info(f"[Pre-flight Check] 📋 Validating {len(operations)} operation(s) before making any API calls")
|
||||
|
||||
# Get current usage and limits once
|
||||
current_period = self.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not usage:
|
||||
# First usage this period, create summary
|
||||
try:
|
||||
usage = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
self.db.add(usage)
|
||||
self.db.commit()
|
||||
except Exception as create_error:
|
||||
logger.error(f"Error creating usage summary: {create_error}")
|
||||
self.db.rollback()
|
||||
return False, f"Failed to create usage summary: {str(create_error)}", {}
|
||||
|
||||
# Get user limits
|
||||
limits_dict = self.get_user_limits(user_id)
|
||||
if not limits_dict:
|
||||
# No subscription found - check for free tier
|
||||
free_plan = self.db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.tier == SubscriptionTier.FREE,
|
||||
SubscriptionPlan.is_active == True
|
||||
).first()
|
||||
if free_plan:
|
||||
limits_dict = self._plan_to_limits_dict(free_plan)
|
||||
else:
|
||||
return False, "No subscription plan found. Please subscribe to a plan.", {}
|
||||
|
||||
limits = limits_dict.get('limits', {})
|
||||
|
||||
# Track cumulative usage across all operations
|
||||
total_llm_calls = (
|
||||
(usage.gemini_calls or 0) +
|
||||
(usage.openai_calls or 0) +
|
||||
(usage.anthropic_calls or 0) +
|
||||
(usage.mistral_calls or 0)
|
||||
)
|
||||
total_llm_tokens = {}
|
||||
total_images = usage.stability_calls or 0
|
||||
|
||||
# Log current usage summary
|
||||
logger.info(f"[Pre-flight Check] 📊 Current Usage Summary:")
|
||||
logger.info(f" └─ Total LLM Calls: {total_llm_calls}")
|
||||
logger.info(f" └─ Gemini Tokens: {usage.gemini_tokens or 0}, Mistral/HF Tokens: {usage.mistral_tokens or 0}")
|
||||
logger.info(f" └─ Image Calls: {total_images}")
|
||||
|
||||
# Validate each operation
|
||||
for op_idx, operation in enumerate(operations):
|
||||
provider = operation.get('provider')
|
||||
provider_name = provider.value if hasattr(provider, 'value') else str(provider)
|
||||
tokens_requested = operation.get('tokens_requested', 0)
|
||||
actual_provider_name = operation.get('actual_provider_name')
|
||||
operation_type = operation.get('operation_type', 'unknown')
|
||||
|
||||
display_provider_name = actual_provider_name or provider_name
|
||||
|
||||
logger.info(f"[Pre-flight Check] ✅ Operation {op_idx + 1}/{len(operations)}: {operation_type}")
|
||||
logger.info(f" ├─ Provider: {display_provider_name} (enum: {provider_name})")
|
||||
logger.info(f" └─ Estimated Tokens: {tokens_requested}")
|
||||
|
||||
# Check if this is an LLM provider
|
||||
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
|
||||
is_llm_provider = provider_name in llm_providers
|
||||
|
||||
# Check unified AI text generation limit for LLM providers
|
||||
if is_llm_provider:
|
||||
ai_text_gen_limit = limits.get('ai_text_generation_calls', 0) or 0
|
||||
if ai_text_gen_limit == 0:
|
||||
# Fallback to provider-specific limit
|
||||
ai_text_gen_limit = limits.get(f"{provider_name}_calls", 0) or 0
|
||||
|
||||
# Count this operation as an LLM call
|
||||
projected_total_llm_calls = total_llm_calls + 1
|
||||
|
||||
if ai_text_gen_limit > 0 and projected_total_llm_calls > ai_text_gen_limit:
|
||||
error_info = {
|
||||
'current_calls': total_llm_calls,
|
||||
'limit': ai_text_gen_limit,
|
||||
'provider': display_provider_name,
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
return False, f"AI text generation call limit would be exceeded. Would use {projected_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls.", {
|
||||
'error_type': 'call_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
|
||||
# Check token limits for this provider
|
||||
# Use cumulative projected tokens from previous operations, or current from DB if first operation
|
||||
provider_tokens_key = f"{provider_name}_tokens"
|
||||
if provider_tokens_key in total_llm_tokens:
|
||||
# Use cumulative projected tokens from previous operations
|
||||
current_provider_tokens = total_llm_tokens[provider_tokens_key]
|
||||
logger.info(f" └─ Using cumulative projected tokens: {current_provider_tokens}")
|
||||
else:
|
||||
# First operation for this provider - get current from database
|
||||
current_provider_tokens = getattr(usage, provider_tokens_key, 0) or 0
|
||||
total_llm_tokens[provider_tokens_key] = current_provider_tokens
|
||||
logger.info(f" └─ Current tokens from DB: {current_provider_tokens}")
|
||||
|
||||
token_limit = limits.get(provider_tokens_key, 0) or 0
|
||||
|
||||
if token_limit > 0 and tokens_requested > 0:
|
||||
projected_tokens = current_provider_tokens + tokens_requested
|
||||
logger.info(f" └─ Token Check: {current_provider_tokens} (current) + {tokens_requested} (requested) = {projected_tokens} (total) / {token_limit} (limit)")
|
||||
|
||||
if projected_tokens > token_limit:
|
||||
usage_percentage = (projected_tokens / token_limit) * 100 if token_limit > 0 else 0
|
||||
error_info = {
|
||||
'current_tokens': current_provider_tokens,
|
||||
'requested_tokens': tokens_requested,
|
||||
'limit': token_limit,
|
||||
'provider': display_provider_name,
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
error_msg = (
|
||||
f"Token limit exceeded for {display_provider_name} "
|
||||
f"({operation_type}). "
|
||||
f"Current: {current_provider_tokens}/{token_limit}, "
|
||||
f"Requested: {tokens_requested}, "
|
||||
f"Would exceed by: {projected_tokens - token_limit} tokens "
|
||||
f"({usage_percentage:.1f}% of limit)"
|
||||
)
|
||||
logger.error(f"[Pre-flight Check] ❌ BLOCKED: {error_msg}")
|
||||
return False, error_msg, {
|
||||
'error_type': 'token_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
else:
|
||||
logger.info(f" └─ ✅ Token limit check passed: {projected_tokens} <= {token_limit}")
|
||||
|
||||
# Update cumulative counts for next operation
|
||||
total_llm_calls = projected_total_llm_calls
|
||||
total_llm_tokens[provider_tokens_key] += tokens_requested
|
||||
logger.info(f" └─ Updated cumulative tokens for {display_provider_name}: {total_llm_tokens[provider_tokens_key]}")
|
||||
|
||||
# Check image generation limits
|
||||
elif provider == APIProvider.STABILITY:
|
||||
image_limit = limits.get('stability_calls', 0) or 0
|
||||
projected_images = total_images + 1
|
||||
|
||||
if image_limit > 0 and projected_images > image_limit:
|
||||
error_info = {
|
||||
'current_images': total_images,
|
||||
'limit': image_limit,
|
||||
'provider': 'stability',
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
return False, f"Image generation limit would be exceeded. Would use {projected_images} of {image_limit} images this billing period.", {
|
||||
'error_type': 'image_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
|
||||
total_images = projected_images
|
||||
|
||||
# Check other provider-specific limits
|
||||
else:
|
||||
provider_calls_key = f"{provider_name}_calls"
|
||||
current_provider_calls = getattr(usage, provider_calls_key, 0) or 0
|
||||
call_limit = limits.get(provider_calls_key, 0) or 0
|
||||
|
||||
if call_limit > 0:
|
||||
projected_calls = current_provider_calls + 1
|
||||
if projected_calls > call_limit:
|
||||
error_info = {
|
||||
'current_calls': current_provider_calls,
|
||||
'limit': call_limit,
|
||||
'provider': display_provider_name,
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
return False, f"API call limit would be exceeded for {display_provider_name}. Would use {projected_calls} of {call_limit} calls this billing period.", {
|
||||
'error_type': 'call_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
|
||||
# All checks passed
|
||||
logger.info(f"[Pre-flight Check] ✅ All {len(operations)} operation(s) validated successfully")
|
||||
logger.info(f"[Pre-flight Check] ✅ User {user_id} is cleared to proceed with API calls")
|
||||
return True, None, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Check] Error during comprehensive limit check: {e}", exc_info=True)
|
||||
return False, f"Failed to validate limits: {str(e)}", {}
|
||||
from .limit_validation import LimitValidator
|
||||
validator = LimitValidator(self)
|
||||
return validator.check_comprehensive_limits(user_id, operations)
|
||||
|
||||
def get_pricing_for_provider_model(self, provider: APIProvider, model_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get pricing configuration for a specific provider and model."""
|
||||
|
||||
39
backend/services/subscription/schema_utils.py
Normal file
39
backend/services/subscription/schema_utils.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import Set
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
_checked_subscription_plan_columns: bool = False
|
||||
|
||||
|
||||
def ensure_subscription_plan_columns(db: Session) -> None:
|
||||
"""Ensure required columns exist on subscription_plans for runtime safety.
|
||||
|
||||
This is a defensive guard for environments where migrations have not yet
|
||||
been applied. If columns are missing (e.g., exa_calls_limit), we add them
|
||||
with a safe default so ORM queries do not fail.
|
||||
"""
|
||||
global _checked_subscription_plan_columns
|
||||
if _checked_subscription_plan_columns:
|
||||
return
|
||||
|
||||
try:
|
||||
# Discover existing columns
|
||||
result = db.execute("PRAGMA table_info(subscription_plans)")
|
||||
cols: Set[str] = {row[1] for row in result}
|
||||
|
||||
# Columns we may reference in models but might be missing in older DBs
|
||||
required_columns = {
|
||||
"exa_calls_limit": "INTEGER DEFAULT 0",
|
||||
}
|
||||
|
||||
for col_name, ddl in required_columns.items():
|
||||
if col_name not in cols:
|
||||
db.execute(f"ALTER TABLE subscription_plans ADD COLUMN {col_name} {ddl}")
|
||||
db.commit()
|
||||
except Exception:
|
||||
# Do not block app if pragma/alter fails; let normal errors surface
|
||||
db.rollback()
|
||||
finally:
|
||||
_checked_subscription_plan_columns = True
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ Handles authentication, permission checking, and blog publishing to Wix websites
|
||||
import os
|
||||
import json
|
||||
import requests
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional, List
|
||||
from loguru import logger
|
||||
from datetime import datetime, timedelta
|
||||
@@ -19,6 +20,9 @@ from services.integrations.wix.media import WixMediaService
|
||||
from services.integrations.wix.utils import extract_meta_from_token, normalize_token_string, extract_member_id_from_access_token as utils_extract_member
|
||||
from services.integrations.wix.content import convert_content_to_ricos as ricos_builder
|
||||
from services.integrations.wix.auth import WixAuthService
|
||||
from services.integrations.wix.seo import build_seo_data
|
||||
from services.integrations.wix.ricos_converter import markdown_to_html, convert_via_wix_api
|
||||
from services.integrations.wix.blog_publisher import create_blog_post as publish_blog_post
|
||||
|
||||
class WixService:
|
||||
"""Service for interacting with Wix APIs"""
|
||||
@@ -237,13 +241,35 @@ class WixService:
|
||||
logger.error(f"Failed to import image to Wix: {e}")
|
||||
raise
|
||||
|
||||
def convert_content_to_ricos(self, content: str, images: List[str] = None) -> Dict[str, Any]:
|
||||
def convert_content_to_ricos(self, content: str, images: List[str] = None,
|
||||
use_wix_api: bool = False, access_token: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert markdown content to Ricos JSON format.
|
||||
|
||||
Args:
|
||||
content: Markdown content to convert
|
||||
images: Optional list of image URLs
|
||||
use_wix_api: If True, use Wix's official Ricos Documents API (requires access_token)
|
||||
access_token: Wix access token (required if use_wix_api=True)
|
||||
|
||||
Returns:
|
||||
Ricos JSON document
|
||||
"""
|
||||
if use_wix_api and access_token:
|
||||
try:
|
||||
return convert_via_wix_api(content, access_token, self.base_url)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to convert via Wix API, falling back to custom parser: {e}")
|
||||
# Fall back to custom parser
|
||||
|
||||
# Use custom parser (current implementation)
|
||||
return ricos_builder(content, images)
|
||||
|
||||
|
||||
def create_blog_post(self, access_token: str, title: str, content: str,
|
||||
cover_image_url: str = None, category_ids: List[str] = None,
|
||||
tag_ids: List[str] = None, publish: bool = True,
|
||||
member_id: str = None) -> Dict[str, Any]:
|
||||
member_id: str = None, seo_metadata: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Create and optionally publish a blog post on Wix
|
||||
|
||||
@@ -256,101 +282,33 @@ class WixService:
|
||||
tag_ids: Optional list of tag IDs
|
||||
publish: Whether to publish immediately or save as draft
|
||||
member_id: Required for third-party apps - the member ID of the post author
|
||||
seo_metadata: Optional SEO metadata dict with fields like:
|
||||
- seo_title: SEO optimized title
|
||||
- meta_description: Meta description
|
||||
- focus_keyword: Main keyword
|
||||
- blog_tags: List of tag strings (for keywords)
|
||||
- open_graph: Open Graph data
|
||||
- canonical_url: Canonical URL
|
||||
|
||||
Returns:
|
||||
Created blog post information
|
||||
"""
|
||||
if not member_id:
|
||||
raise ValueError("memberId is required for third-party apps creating blog posts")
|
||||
|
||||
headers = {
|
||||
'Authorization': f'Bearer {access_token}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
# Build valid Ricos rich content (minimum: one paragraph with text)
|
||||
ricos_content = self.convert_content_to_ricos(content or "This is a post from ALwrity.", None)
|
||||
|
||||
# Minimal payload per Wix docs: title, memberId, and richContent
|
||||
blog_data = {
|
||||
'draftPost': {
|
||||
'title': title,
|
||||
'memberId': member_id, # Required for third-party apps
|
||||
'richContent': ricos_content,
|
||||
'excerpt': (content or '').strip()[:200]
|
||||
},
|
||||
'publish': publish,
|
||||
'fieldsets': ['URL'] # Simplified fieldsets
|
||||
}
|
||||
|
||||
# Add cover image if provided
|
||||
if cover_image_url:
|
||||
try:
|
||||
media_id = self.import_image_to_wix(access_token, cover_image_url, f'Cover: {title}')
|
||||
blog_data['draftPost']['media'] = {
|
||||
'wixMedia': {
|
||||
'image': {'id': media_id}
|
||||
},
|
||||
'displayed': True,
|
||||
'custom': True
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to import cover image: {e}")
|
||||
|
||||
# Add categories if provided
|
||||
if category_ids:
|
||||
blog_data['draftPost']['categoryIds'] = category_ids
|
||||
|
||||
# Add tags if provided
|
||||
if tag_ids:
|
||||
blog_data['draftPost']['tagIds'] = tag_ids
|
||||
|
||||
try:
|
||||
# Check what permissions we have in the token
|
||||
logger.info("DEBUG: Checking token permissions...")
|
||||
try:
|
||||
import jwt
|
||||
# Extract token string manually since _normalize_access_token doesn't exist
|
||||
token_str = str(access_token)
|
||||
if token_str and token_str.startswith('OauthNG.JWS.'):
|
||||
jwt_part = token_str[12:]
|
||||
payload = jwt.decode(jwt_part, options={"verify_signature": False, "verify_aud": False})
|
||||
logger.info(f"DEBUG: Full token payload: {payload}")
|
||||
|
||||
# Check for permissions in various possible locations
|
||||
data_payload = payload.get('data', {})
|
||||
if isinstance(data_payload, str):
|
||||
try:
|
||||
data_payload = json.loads(data_payload)
|
||||
except:
|
||||
pass
|
||||
|
||||
instance_data = data_payload.get('instance', {})
|
||||
permissions = instance_data.get('permissions', '')
|
||||
scopes = instance_data.get('scopes', [])
|
||||
meta_site_id = instance_data.get('metaSiteId')
|
||||
if isinstance(meta_site_id, str) and meta_site_id:
|
||||
headers['wix-site-id'] = meta_site_id
|
||||
logger.info(f"DEBUG: Added wix-site-id header: {meta_site_id}")
|
||||
logger.info(f"DEBUG: Token permissions: {permissions}")
|
||||
logger.info(f"DEBUG: Token scopes: {scopes}")
|
||||
else:
|
||||
logger.info("DEBUG: Could not decode token for permission check")
|
||||
except Exception as perm_e:
|
||||
logger.warning(f"DEBUG: Failed to check permissions: {perm_e}")
|
||||
|
||||
logger.info(f"DEBUG: Sending simplified blog data: {json.dumps(blog_data, indent=2)}")
|
||||
extra_headers = {}
|
||||
if 'wix-site-id' in headers:
|
||||
extra_headers['wix-site-id'] = headers['wix-site-id']
|
||||
result = self.blog_service.create_draft_post(access_token, blog_data, extra_headers or None)
|
||||
logger.info(f"DEBUG: Create draft result: {result}")
|
||||
return result
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Failed to create blog post: {e}")
|
||||
if hasattr(e, 'response') and e.response is not None:
|
||||
logger.error(f"Response body: {e.response.text}")
|
||||
raise
|
||||
return publish_blog_post(
|
||||
blog_service=self.blog_service,
|
||||
access_token=access_token,
|
||||
title=title,
|
||||
content=content,
|
||||
member_id=member_id,
|
||||
cover_image_url=cover_image_url,
|
||||
category_ids=category_ids,
|
||||
tag_ids=tag_ids,
|
||||
publish=publish,
|
||||
seo_metadata=seo_metadata,
|
||||
import_image_func=self.import_image_to_wix,
|
||||
lookup_categories_func=self.lookup_or_create_categories,
|
||||
lookup_tags_func=self.lookup_or_create_tags,
|
||||
base_url=self.base_url
|
||||
)
|
||||
|
||||
def get_blog_categories(self, access_token: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
@@ -383,6 +341,138 @@ class WixService:
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Failed to get blog tags: {e}")
|
||||
raise
|
||||
|
||||
def lookup_or_create_categories(self, access_token: str, category_names: List[str],
|
||||
extra_headers: Optional[Dict[str, str]] = None) -> List[str]:
|
||||
"""
|
||||
Lookup existing categories by name or create new ones, return their IDs.
|
||||
|
||||
Args:
|
||||
access_token: Valid access token
|
||||
category_names: List of category name strings
|
||||
extra_headers: Optional extra headers (e.g., wix-site-id)
|
||||
|
||||
Returns:
|
||||
List of category UUIDs
|
||||
"""
|
||||
if not category_names:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Get existing categories
|
||||
existing_categories = self.blog_service.list_categories(access_token, extra_headers)
|
||||
|
||||
# Create name -> ID mapping (case-insensitive)
|
||||
category_map = {}
|
||||
for cat in existing_categories:
|
||||
cat_label = cat.get('label', '').strip()
|
||||
cat_id = cat.get('id')
|
||||
if cat_label and cat_id:
|
||||
category_map[cat_label.lower()] = cat_id
|
||||
|
||||
category_ids = []
|
||||
for category_name in category_names:
|
||||
category_name_clean = str(category_name).strip()
|
||||
if not category_name_clean:
|
||||
continue
|
||||
|
||||
# Lookup existing category (case-insensitive)
|
||||
category_id = category_map.get(category_name_clean.lower())
|
||||
|
||||
if not category_id:
|
||||
# Create new category
|
||||
try:
|
||||
logger.info(f"Creating new category: {category_name_clean}")
|
||||
result = self.blog_service.create_category(
|
||||
access_token,
|
||||
label=category_name_clean,
|
||||
extra_headers=extra_headers
|
||||
)
|
||||
new_category = result.get('category', {})
|
||||
category_id = new_category.get('id')
|
||||
if category_id:
|
||||
category_ids.append(category_id)
|
||||
# Update map to avoid duplicate creates
|
||||
category_map[category_name_clean.lower()] = category_id
|
||||
logger.info(f"Created category '{category_name_clean}' with ID: {category_id}")
|
||||
except Exception as create_error:
|
||||
logger.warning(f"Failed to create category '{category_name_clean}': {create_error}")
|
||||
# Continue with other categories
|
||||
else:
|
||||
category_ids.append(category_id)
|
||||
logger.info(f"Found existing category '{category_name_clean}' with ID: {category_id}")
|
||||
|
||||
return category_ids
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Failed to lookup/create categories: {e}")
|
||||
return []
|
||||
|
||||
def lookup_or_create_tags(self, access_token: str, tag_names: List[str],
|
||||
extra_headers: Optional[Dict[str, str]] = None) -> List[str]:
|
||||
"""
|
||||
Lookup existing tags by name or create new ones, return their IDs.
|
||||
|
||||
Args:
|
||||
access_token: Valid access token
|
||||
tag_names: List of tag name strings
|
||||
extra_headers: Optional extra headers (e.g., wix-site-id)
|
||||
|
||||
Returns:
|
||||
List of tag UUIDs
|
||||
"""
|
||||
if not tag_names:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Get existing tags
|
||||
existing_tags = self.blog_service.list_tags(access_token, extra_headers)
|
||||
|
||||
# Create name -> ID mapping (case-insensitive)
|
||||
tag_map = {}
|
||||
for tag in existing_tags:
|
||||
tag_label = tag.get('label', '').strip()
|
||||
tag_id = tag.get('id')
|
||||
if tag_label and tag_id:
|
||||
tag_map[tag_label.lower()] = tag_id
|
||||
|
||||
tag_ids = []
|
||||
for tag_name in tag_names:
|
||||
tag_name_clean = str(tag_name).strip()
|
||||
if not tag_name_clean:
|
||||
continue
|
||||
|
||||
# Lookup existing tag (case-insensitive)
|
||||
tag_id = tag_map.get(tag_name_clean.lower())
|
||||
|
||||
if not tag_id:
|
||||
# Create new tag
|
||||
try:
|
||||
logger.info(f"Creating new tag: {tag_name_clean}")
|
||||
result = self.blog_service.create_tag(
|
||||
access_token,
|
||||
label=tag_name_clean,
|
||||
extra_headers=extra_headers
|
||||
)
|
||||
new_tag = result.get('tag', {})
|
||||
tag_id = new_tag.get('id')
|
||||
if tag_id:
|
||||
tag_ids.append(tag_id)
|
||||
# Update map to avoid duplicate creates
|
||||
tag_map[tag_name_clean.lower()] = tag_id
|
||||
logger.info(f"Created tag '{tag_name_clean}' with ID: {tag_id}")
|
||||
except Exception as create_error:
|
||||
logger.warning(f"Failed to create tag '{tag_name_clean}': {create_error}")
|
||||
# Continue with other tags
|
||||
else:
|
||||
tag_ids.append(tag_id)
|
||||
logger.info(f"Found existing tag '{tag_name_clean}' with ID: {tag_id}")
|
||||
|
||||
return tag_ids
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Failed to lookup/create tags: {e}")
|
||||
return []
|
||||
|
||||
def publish_draft_post(self, access_token: str, draft_post_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user