Save local changes (GSC/Bing integrations) before merging PR #354

This commit is contained in:
ajaysi
2026-02-13 13:11:27 +05:30
parent 43e66835ac
commit 08a1f4a1d8
144 changed files with 8310 additions and 2748 deletions

View File

@@ -206,6 +206,13 @@ class RouterManager:
except Exception as e:
logger.warning(f"Persona router not mounted: {e}")
# Video Studio router
try:
from api.video_studio.router import router as video_studio_router
self.include_router_safely(video_studio_router, "video_studio")
except Exception as e:
logger.warning(f"Video Studio router not mounted: {e}")
# Stability AI routers
try:
from routers.stability import router as stability_router

View File

@@ -0,0 +1,52 @@
from fastapi import APIRouter, HTTPException
from fastapi.responses import FileResponse
import os
from pathlib import Path
from services.database import WORKSPACE_DIR, get_user_db_path
router = APIRouter(prefix="/api/assets", tags=["Assets Serving"])
@router.get("/{user_id}/avatars/{filename}")
async def serve_avatar(user_id: str, filename: str):
"""
Serve avatar images directly.
Public endpoint relying on unguessable filenames.
"""
# Sanitize user_id (simple check to prevent directory traversal)
safe_user_id = "".join(c for c in user_id if c.isalnum() or c in ('-', '_'))
if safe_user_id != user_id:
raise HTTPException(status_code=400, detail="Invalid user ID")
# Sanitize filename
safe_filename = os.path.basename(filename)
# Construct path
# workspace/workspace_{user_id}/assets/avatars/{filename}
file_path = Path(WORKSPACE_DIR) / f"workspace_{safe_user_id}" / "assets" / "avatars" / safe_filename
if not file_path.exists():
raise HTTPException(status_code=404, detail="Asset not found")
return FileResponse(file_path)
@router.get("/{user_id}/voice_samples/{filename}")
async def serve_voice_sample(user_id: str, filename: str):
"""
Serve voice sample audio files directly.
"""
# Sanitize user_id
safe_user_id = "".join(c for c in user_id if c.isalnum() or c in ('-', '_'))
if safe_user_id != user_id:
raise HTTPException(status_code=400, detail="Invalid user ID")
# Sanitize filename
safe_filename = os.path.basename(filename)
# Construct path
# workspace/workspace_{user_id}/assets/voice_samples/{filename}
file_path = Path(WORKSPACE_DIR) / f"workspace_{safe_user_id}" / "assets" / "voice_samples" / safe_filename
if not file_path.exists():
raise HTTPException(status_code=404, detail="Asset not found")
return FileResponse(file_path)

View File

@@ -0,0 +1,97 @@
# Brand Avatar API Documentation
## Overview
The Brand Avatar API provides endpoints for generating, varying, and enhancing brand avatars using WaveSpeed AI.
**Base URL**: `/api/onboarding/assets`
## Endpoints
### 1. Generate Avatar
Generate a new brand avatar from a text prompt.
- **URL**: `/generate-avatar`
- **Method**: `POST`
- **Body** (`application/json`):
```json
{
"prompt": "A professional tech entrepreneur, studio lighting",
"style_preset": "Cinematic",
"aspect_ratio": "1:1",
"model": "ideogram-v3-turbo",
"provider": "wavespeed"
}
```
- **Response**:
```json
{
"success": true,
"image_url": "/api/assets/{user_id}/avatars/{filename}.png",
"image_base64": "...",
"asset_id": 123
}
```
### 2. Create Variation
Create a variation of an existing avatar/image.
- **URL**: `/create-variation`
- **Method**: `POST`
- **Content-Type**: `multipart/form-data`
- **Form Data**:
- `prompt` (text): Description of the variation (e.g., "same person but smiling")
- `file` (file): The reference image file
- **Response**:
```json
{
"success": true,
"image_url": "/api/assets/{user_id}/avatars/{filename}.png",
"image_base64": "...",
"asset_id": 124
}
```
### 3. Enhance Avatar
Upscale and enhance an existing avatar image.
- **URL**: `/enhance-avatar`
- **Method**: `POST`
- **Content-Type**: `multipart/form-data`
- **Form Data**:
- `file` (file): The image file to enhance
- **Response**:
```json
{
"success": true,
"image_url": "/api/assets/{user_id}/avatars/{filename}.png",
"image_base64": "...",
"asset_id": 125
}
```
### 4. Enhance Prompt
Optimize a simple prompt into a detailed, high-quality prompt using WaveSpeed.
- **URL**: `/enhance-prompt`
- **Method**: `POST`
- **Body**:
```json
{
"prompt": "man in suit"
}
```
- **Response**:
```json
{
"success": true,
"original_prompt": "man in suit",
"optimized_prompt": "A professional portrait of a man in a tailored navy blue suit, confident expression, studio lighting, 4k resolution..."
}
```
## Providers
- **Default Provider**: `wavespeed`
- **Models**:
- Generation: `ideogram-v3-turbo` (default), `qwen-image`
- Editing/Variation: `qwen-edit-plus` (default)
- Enhancement: `nano-banana-pro-edit-ultra` (4K upscale)

View File

@@ -100,6 +100,8 @@ class OnboardingCompletionService:
except Exception as e:
logger.warning(f"Failed to schedule website analysis task creation for user {user_id}: {e}")
# Schedule onboarding full-site SEO audit (non-blocking) ~10 minutes after completion
try:
from services.database import SessionLocal

View File

@@ -10,22 +10,36 @@ from sqlalchemy.orm import Session
from pydantic import BaseModel
from loguru import logger
from .step4_persona_routes import _extract_user_id
from middleware.auth_middleware import get_current_user
import base64
import os
from pathlib import Path
from utils.file_storage import save_file_safely, generate_unique_filename
from services.database import get_db, WORKSPACE_DIR
from utils.asset_tracker import save_asset_to_library
from models.content_asset_models import ContentAsset, AssetType, AssetSource
from sqlalchemy import desc
from services.llm_providers.main_image_generation import (
generate_image_with_provider,
enhance_image_prompt,
generate_image_variation
generate_image_variation,
generate_image_enhance
)
from services.llm_providers.main_audio_generation import clone_voice, qwen3_voice_clone, cosyvoice_voice_clone, qwen3_voice_design
import asyncio
import random
import string
router = APIRouter()
router = APIRouter(prefix="/onboarding/assets")
# --- Models ---
class VoiceDesignRequest(BaseModel):
user_id: Optional[str] = None
text: str
voice_description: str
language: str = "auto"
class AvatarPromptRequest(BaseModel):
user_id: Optional[str] = None
prompt: str
@@ -34,6 +48,9 @@ class AvatarPromptRequest(BaseModel):
negative_prompt: Optional[str] = None
num_inference_steps: int = 30
guidance_scale: float = 7.5
model: Optional[str] = None
rendering_speed: Optional[str] = None
provider: Optional[str] = None
class AvatarEnhanceRequest(BaseModel):
user_id: Optional[str] = None
@@ -47,14 +64,108 @@ class VoiceCloneRequest(BaseModel):
# --- Routes ---
@router.get("/latest-avatar")
async def get_latest_avatar(
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Get the latest generated brand avatar for the user."""
try:
user_id = _extract_user_id(current_user)
# Search for assets that are either:
# 1. Saved with source_module=BRAND_AVATAR_GENERATOR (new)
# 2. Saved with source_module=STORY_WRITER but have metadata category='brand_avatar' (legacy)
# Fetch candidates (limit to recent 20 to avoid performance issues)
candidates = db.query(ContentAsset).filter(
ContentAsset.user_id == user_id,
ContentAsset.asset_type == AssetType.IMAGE,
ContentAsset.source_module.in_([
AssetSource.BRAND_AVATAR_GENERATOR,
AssetSource.STORY_WRITER
])
).order_by(desc(ContentAsset.created_at)).limit(50).all()
asset = None
for candidate in candidates:
# Check for direct match (new assets)
if candidate.source_module == AssetSource.BRAND_AVATAR_GENERATOR:
asset = candidate
break
# Check for legacy match (metadata category)
if candidate.source_module == AssetSource.STORY_WRITER:
meta = candidate.asset_metadata or {}
if meta.get('category') == 'brand_avatar':
asset = candidate
break
if not asset:
return {"success": False, "message": "No avatar found"}
# Fallback to metadata prompt if main column is empty (legacy support)
prompt = asset.prompt
if not prompt and asset.asset_metadata:
prompt = asset.asset_metadata.get('prompt', '')
return {
"success": True,
"image_url": asset.file_url,
"prompt": prompt,
"asset_id": asset.id,
"provider": asset.provider
}
except Exception as e:
logger.error(f"Failed to fetch latest avatar: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/latest-voice-clone")
async def get_latest_voice_clone(
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Get the latest generated voice clone for the user."""
try:
user_id = _extract_user_id(current_user)
# Fetch latest voice clone asset
asset = db.query(ContentAsset).filter(
ContentAsset.user_id == user_id,
ContentAsset.asset_type == AssetType.AUDIO,
ContentAsset.source_module == AssetSource.VOICE_CLONER
).order_by(desc(ContentAsset.created_at)).first()
if not asset:
# Try to find legacy assets or assets that might have been saved differently
# For example, voice designs might be saved as VOICE_CLONER too?
# Or check for 'voice_design' logic if needed, but 'voice_cloner' is primary
return {"success": False, "message": "No voice clone found"}
meta = asset.asset_metadata or {}
return {
"success": True,
"custom_voice_id": meta.get("custom_voice_id"),
"preview_audio_url": meta.get("preview_url") or asset.file_url,
"asset_id": asset.id,
"voice_name": meta.get("voice_name"),
"engine": meta.get("engine")
}
except Exception as e:
logger.error(f"Failed to fetch latest voice clone: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/generate-avatar")
async def generate_avatar(
request: AvatarPromptRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Generate a brand avatar using available image providers."""
try:
user_id = _extract_user_id(request.user_id)
user_id = _extract_user_id(current_user)
logger.info(f"Generating avatar for user {user_id} with prompt: {request.prompt}")
@@ -66,6 +177,9 @@ async def generate_avatar(
num_inference_steps=request.num_inference_steps,
guidance_scale=request.guidance_scale,
style_preset=request.style_preset,
model=request.model,
rendering_speed=request.rendering_speed,
provider=request.provider,
user_id=user_id
)
@@ -78,42 +192,66 @@ async def generate_avatar(
image_data = result.get("image_base64")
if not image_data and result.get("image_url"):
# TODO: Download image from URL if needed, or just store URL
pass
try:
import httpx
async with httpx.AsyncClient() as client:
response = await client.get(result["image_url"], timeout=30.0)
response.raise_for_status()
image_data = response.content
except ImportError:
# Fallback to requests if httpx is not installed
import requests
response = requests.get(result["image_url"], timeout=30.0)
response.raise_for_status()
image_data = response.content
except Exception as e:
logger.error(f"Failed to download image from URL: {e}")
raise HTTPException(status_code=500, detail=f"Failed to download generated image: {str(e)}")
if image_data:
# Decode if needed (usually it's already base64 string)
# Save file
filename = generate_unique_filename("avatar", "png")
file_path = save_file_safely(
base64.b64decode(image_data) if isinstance(image_data, str) else image_data,
user_id,
"avatars",
# If image_data is bytes (from URL download), pass it directly
# If it's base64 string (from API), decode it
content_to_save = base64.b64decode(image_data) if isinstance(image_data, str) else image_data
# Construct user assets directory
user_assets_dir = Path(WORKSPACE_DIR) / f"workspace_{user_id}" / "assets" / "avatars"
saved_path, error = save_file_safely(
content_to_save,
user_assets_dir,
filename
)
if error or not saved_path:
raise HTTPException(status_code=500, detail=f"Failed to save image file: {error}")
# Construct public URL
image_url = f"/api/assets/{user_id}/avatars/{filename}"
# Save to Asset Library
asset_id = save_asset_to_library(
db=db,
user_id=user_id,
file_path=file_path,
asset_type="image",
category="brand_avatar",
meta_data={
"prompt": request.prompt,
source_module="brand_avatar_generator",
filename=filename,
file_url=image_url,
file_path=str(saved_path),
prompt=request.prompt,
asset_metadata={
"provider": result.get("provider", "unknown"),
"style": request.style_preset
"style": request.style_preset,
"category": "brand_avatar"
}
)
# Construct public URL (this depends on your static file serving setup)
# Assuming /api/assets/{user_id}/avatars/{filename}
image_url = f"/api/assets/{user_id}/avatars/{filename}"
return {
"success": True,
"image_url": image_url,
"image_base64": image_data, # Optional: return base64 for immediate display
"image_base64": image_data if isinstance(image_data, str) else base64.b64encode(image_data).decode('utf-8'),
"asset_id": asset_id
}
@@ -126,14 +264,15 @@ async def generate_avatar(
@router.post("/enhance-prompt")
async def enhance_prompt_route(
request: AvatarEnhanceRequest
request: AvatarEnhanceRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Enhance a simple prompt into a detailed midjourney-style prompt."""
try:
user_id = _extract_user_id(request.user_id)
user_id = _extract_user_id(current_user)
logger.info(f"Enhancing prompt for user {user_id}: {request.prompt}")
enhanced_prompt = await enhance_image_prompt(request.prompt)
enhanced_prompt = await enhance_image_prompt(request.prompt, user_id=user_id)
return {
"success": True,
@@ -145,52 +284,347 @@ async def enhance_prompt_route(
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-variation")
async def create_variation_route(
prompt: str = Form(...),
file: UploadFile = File(...),
user_id: Optional[str] = Form(None), # Ignored in favor of authenticated user
current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Generate a variation of an existing avatar."""
try:
user_id = _extract_user_id(current_user)
logger.info(f"Creating variation for user {user_id} with prompt: {prompt}")
# Read file
file_content = await file.read()
result = await generate_image_variation(
image=file_content,
prompt=prompt,
user_id=user_id
)
if not result.get("success"):
raise HTTPException(status_code=500, detail=result.get("error", "Variation generation failed"))
# Save result
image_data = result.get("image_base64")
if image_data:
filename = generate_unique_filename("avatar_variation", "png")
content_to_save = base64.b64decode(image_data)
# Construct user assets directory
user_assets_dir = Path(WORKSPACE_DIR) / f"workspace_{user_id}" / "assets" / "avatars"
saved_path, error = save_file_safely(
content_to_save,
user_assets_dir,
filename
)
if error or not saved_path:
raise HTTPException(status_code=500, detail=f"Failed to save variation file: {error}")
# Construct public URL
image_url = f"/api/assets/{user_id}/avatars/{filename}"
# Save to Asset Library
asset_id = save_asset_to_library(
db=next(get_db()),
user_id=user_id,
asset_type="image",
source_module="brand_avatar_variation",
filename=filename,
file_url=image_url,
file_path=str(saved_path),
asset_metadata={
"prompt": prompt,
"provider": "wavespeed",
"original_filename": file.filename,
"category": "brand_avatar_variation"
}
)
return {
"success": True,
"image_url": image_url,
"image_base64": image_data,
"asset_id": asset_id
}
return {"success": False, "error": "No image data returned"}
except Exception as e:
logger.error(f"Variation generation failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/enhance-avatar")
async def enhance_avatar_route(
file: UploadFile = File(...),
user_id: Optional[str] = Form(None), # Ignored in favor of authenticated user
current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Enhance/Upscale an existing avatar."""
try:
user_id = _extract_user_id(current_user)
logger.info(f"Enhancing avatar for user {user_id}")
# Read file
file_content = await file.read()
result = await generate_image_enhance(
image=file_content,
user_id=user_id
)
if not result.get("success"):
raise HTTPException(status_code=500, detail=result.get("error", "Enhancement failed"))
# Save result
image_data = result.get("image_base64")
if image_data:
filename = generate_unique_filename("avatar_enhanced", "png")
content_to_save = base64.b64decode(image_data)
# Construct user assets directory
user_assets_dir = Path(WORKSPACE_DIR) / f"workspace_{user_id}" / "assets" / "avatars"
saved_path, error = save_file_safely(
content_to_save,
user_assets_dir,
filename
)
if error or not saved_path:
raise HTTPException(status_code=500, detail=f"Failed to save enhanced file: {error}")
# Construct public URL
image_url = f"/api/assets/{user_id}/avatars/{filename}"
# Save to Asset Library
asset_id = save_asset_to_library(
db=next(get_db()),
user_id=user_id,
asset_type="image",
source_module="brand_avatar_enhancer",
filename=filename,
file_url=image_url,
file_path=str(saved_path),
asset_metadata={
"provider": "wavespeed",
"category": "brand_avatar_enhanced",
"original_filename": file.filename
}
)
return {
"success": True,
"image_url": image_url,
"image_base64": image_data,
"asset_id": asset_id
}
return {"success": False, "error": "No image data returned"}
except Exception as e:
logger.error(f"Avatar enhancement failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-voice-clone")
async def create_voice_clone(
voice_name: str = Form(...),
description: str = Form(None),
engine: str = Form("qwen3"),
file: UploadFile = File(...),
user_id: Optional[str] = Form(None),
user_id: Optional[str] = Form(None), # Ignored in favor of authenticated user
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Create a voice clone from an audio file."""
try:
user_id = _extract_user_id(user_id)
logger.info(f"Creating voice clone '{voice_name}' for user {user_id}")
user_id = _extract_user_id(current_user)
logger.info(f"Creating voice clone '{voice_name}' (engine={engine}) for user {user_id}")
# 1. Save uploaded audio file
file_content = await file.read()
filename = generate_unique_filename("voice_sample", Path(file.filename).suffix.lstrip("."))
file_path = save_file_safely(file_content, user_id, "voice_samples", filename)
# 2. Call Voice Cloning API (Placeholder for actual implementation)
# TODO: Integrate with Minimax or CosyVoice API
# For now, we simulate success
user_voice_dir = Path(WORKSPACE_DIR) / f"workspace_{user_id}" / "assets" / "voice_samples"
saved_path, error = save_file_safely(file_content, user_voice_dir, filename)
# 3. Save to Asset Library
if error or not saved_path:
raise HTTPException(status_code=500, detail=f"Failed to save voice sample: {error}")
file_path = str(saved_path)
# 2. Call Voice Cloning API
preview_audio_bytes = None
custom_voice_id = None
loop = asyncio.get_event_loop()
# Default preview text
preview_text = "Hello! This is a preview of my cloned voice using AI technology. I hope you like it!"
if engine.lower() == "minimax":
# Generate valid voice ID for Minimax (alphanumeric, starts with letter, 8+ chars)
random_suffix = ''.join(random.choices(string.ascii_letters + string.digits, k=8))
custom_voice_id = f"vc_{random_suffix}"
logger.info(f"Cloning voice with Minimax, ID: {custom_voice_id}")
# Run blocking call in executor
result = await loop.run_in_executor(
None,
lambda: clone_voice(
audio_bytes=file_content,
custom_voice_id=custom_voice_id,
text=preview_text,
user_id=user_id
)
)
preview_audio_bytes = result.preview_audio_bytes
elif engine.lower() == "cosyvoice":
logger.info("Cloning voice with CosyVoice")
result = await loop.run_in_executor(
None,
lambda: cosyvoice_voice_clone(
audio_bytes=file_content,
text=preview_text,
user_id=user_id
)
)
preview_audio_bytes = result.preview_audio_bytes
# CosyVoice doesn't persist ID on provider side, but we need one for DB
asset_uuid = ''.join(random.choices(string.ascii_letters + string.digits, k=8))
custom_voice_id = f"vc_cosy_{asset_uuid}"
else: # qwen3 (default)
logger.info("Cloning voice with Qwen3")
result = await loop.run_in_executor(
None,
lambda: qwen3_voice_clone(
audio_bytes=file_content,
text=preview_text,
user_id=user_id
)
)
preview_audio_bytes = result.preview_audio_bytes
# Qwen3 doesn't persist ID on provider side
asset_uuid = ''.join(random.choices(string.ascii_letters + string.digits, k=8))
custom_voice_id = f"vc_qwen_{asset_uuid}"
# 3. Save Preview Audio (if generated)
preview_url = None
if preview_audio_bytes:
preview_filename = f"preview_{filename}"
# Ensure it ends with .wav
if not preview_filename.endswith(".wav"):
preview_filename = str(Path(preview_filename).with_suffix('.wav'))
user_voice_dir = Path(WORKSPACE_DIR) / f"workspace_{user_id}" / "assets" / "voice_samples"
saved_preview_path, error = save_file_safely(preview_audio_bytes, user_voice_dir, preview_filename)
if not error and saved_preview_path:
preview_url = f"/api/assets/{user_id}/voice_samples/{preview_filename}"
# 4. Save to Asset Library
asset_id = save_asset_to_library(
db=db,
user_id=user_id,
file_path=file_path,
asset_type="audio",
category="voice_clone",
meta_data={
source_module="voice_cloner",
filename=filename,
file_url=f"/api/assets/{user_id}/voice_samples/{filename}",
asset_metadata={
"voice_name": voice_name,
"engine": engine,
"description": description,
"original_filename": file.filename
"original_filename": file.filename,
"custom_voice_id": custom_voice_id,
"preview_url": preview_url,
"category": "voice_clone"
}
)
return {
"success": True,
"custom_voice_id": f"vc_{asset_id}", # Mock ID
"preview_audio_url": f"/api/assets/{user_id}/voice_samples/{filename}",
"custom_voice_id": custom_voice_id,
"preview_audio_url": preview_url or f"/api/assets/{user_id}/voice_samples/{filename}",
"asset_id": asset_id,
"message": "Voice clone created successfully (simulated)"
"message": "Voice clone created successfully"
}
except Exception as e:
logger.error(f"Voice cloning failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-voice-design")
async def create_voice_design(
request: VoiceDesignRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Create a voice from text description (Voice Design)."""
try:
user_id = _extract_user_id(current_user)
logger.info(f"Designing voice for user {user_id}")
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None,
lambda: qwen3_voice_design(
text=request.text,
voice_description=request.voice_description,
language=request.language,
user_id=user_id
)
)
# Save the result to a temporary file
filename = generate_unique_filename("voice_design_preview", "wav")
user_voice_dir = Path(WORKSPACE_DIR) / f"workspace_{user_id}" / "assets" / "voice_samples"
saved_path, error = save_file_safely(result.preview_audio_bytes, user_voice_dir, filename)
if error or not saved_path:
raise HTTPException(status_code=500, detail=f"Failed to save voice design: {error}")
# Generate URL
preview_url = f"/api/assets/{user_id}/voice_samples/{filename}"
# Save to Asset Library
asset_id = save_asset_to_library(
db=db,
user_id=user_id,
file_path=str(saved_path),
asset_type="audio",
source_module="voice_cloner",
filename=filename,
file_url=preview_url,
asset_metadata={
"voice_description": request.voice_description,
"text": request.text,
"language": request.language,
"engine": "qwen3-design",
"category": "voice_design",
"preview_url": preview_url
}
)
return {
"success": True,
"preview_audio_url": preview_url,
"asset_id": asset_id,
"message": "Voice generated successfully"
}
except Exception as e:
logger.error(f"Voice design failed: {e}")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -4,10 +4,10 @@ Podcast Audio Handlers
Audio generation, combining, and serving endpoints.
"""
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
from typing import Dict, Any
from typing import Dict, Any, Optional
from pathlib import Path
from urllib.parse import urlparse
import tempfile
@@ -31,6 +31,83 @@ from ..models import (
router = APIRouter()
@router.post("/audio/upload")
async def upload_podcast_audio(
file: UploadFile = File(...),
project_id: Optional[str] = Form(None),
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""
Upload an audio file (voice sample) for a podcast project.
Returns the audio URL for use in video generation.
"""
user_id = require_authenticated_user(current_user)
# Validate file type
if not file.content_type or not file.content_type.startswith('audio/'):
# Allow octet-stream if extension is audio
allowed_exts = ['.mp3', '.wav', '.m4a', '.aac']
file_ext = Path(file.filename).suffix.lower()
if file_ext not in allowed_exts and file.content_type != 'application/octet-stream':
raise HTTPException(status_code=400, detail="File must be an audio file")
# Validate file size (max 20MB)
file_content = await file.read()
if len(file_content) > 20 * 1024 * 1024:
raise HTTPException(status_code=400, detail="Audio file size must be less than 20MB")
try:
# Generate filename
file_ext = Path(file.filename).suffix or '.mp3'
unique_id = str(uuid.uuid4())[:8]
audio_filename = f"audio_{project_id or 'temp'}_{unique_id}{file_ext}"
audio_path = PODCAST_AUDIO_DIR / audio_filename
# Save file
with open(audio_path, "wb") as f:
f.write(file_content)
logger.info(f"[Podcast] Audio uploaded: {audio_path}")
# Create audio URL
audio_url = f"/api/podcast/audio/{audio_filename}"
# Save to asset library if project_id provided
if project_id:
try:
save_asset_to_library(
db=db,
user_id=user_id,
asset_type="audio",
source_module="podcast_maker",
filename=audio_filename,
file_url=audio_url,
file_path=str(audio_path),
file_size=len(file_content),
mime_type=file.content_type,
title=f"Uploaded Audio - {project_id}",
description="Uploaded podcast audio/voice sample",
tags=["podcast", "audio", "upload", project_id],
asset_metadata={
"project_id": project_id,
"type": "uploaded_audio",
"status": "completed",
},
)
except Exception as e:
logger.warning(f"[Podcast] Failed to save audio asset: {e}")
return {
"audio_url": audio_url,
"audio_filename": audio_filename,
"message": "Audio uploaded successfully"
}
except Exception as exc:
logger.error(f"[Podcast] Audio upload failed: {exc}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Audio upload failed: {str(exc)}")
@router.post("/audio", response_model=PodcastAudioResponse)
async def generate_podcast_audio(
request: PodcastAudioRequest,

View File

@@ -10,6 +10,7 @@ from fastapi import HTTPException
from loguru import logger
from .constants import PODCAST_AUDIO_DIR, PODCAST_IMAGES_DIR
from utils.media_utils import load_media_bytes
def load_podcast_audio_bytes(audio_url: str) -> bytes:
@@ -54,49 +55,23 @@ def load_podcast_audio_bytes(audio_url: str) -> bytes:
def load_podcast_image_bytes(image_url: str) -> bytes:
"""Load podcast image bytes from URL. Only handles /api/podcast/images/ URLs."""
"""Load podcast image bytes from URL. Uses centralized media loader."""
if not image_url:
raise HTTPException(status_code=400, detail="Image URL is required")
logger.info(f"[Podcast] Loading image from URL: {image_url}")
try:
parsed = urlparse(image_url)
path = parsed.path if parsed.scheme else image_url
# REUSE: Use centralized media loader which handles cross-module lookups
image_bytes = load_media_bytes(image_url)
# Only handle /api/podcast/images/ URLs
prefix = "/api/podcast/images/"
if prefix not in path:
logger.error(f"[Podcast] Unsupported image URL format: {image_url}")
raise HTTPException(
status_code=400,
detail=f"Unsupported image URL format: {image_url}. Only /api/podcast/images/ URLs are supported."
)
filename = path.split(prefix, 1)[1].split("?", 1)[0].strip()
if not filename:
logger.error(f"[Podcast] Could not extract filename from URL: {image_url}")
raise HTTPException(status_code=400, detail=f"Could not extract filename from URL: {image_url}")
logger.info(f"[Podcast] Extracted filename: {filename}")
logger.info(f"[Podcast] PODCAST_IMAGES_DIR: {PODCAST_IMAGES_DIR}")
# Podcast images are stored in podcast_images directory
image_path = (PODCAST_IMAGES_DIR / filename).resolve()
logger.info(f"[Podcast] Resolved image path: {image_path}")
# Security check: ensure path is within PODCAST_IMAGES_DIR
if not str(image_path).startswith(str(PODCAST_IMAGES_DIR)):
logger.error(f"[Podcast] Attempted path traversal when resolving image: {image_url} -> {image_path}")
raise HTTPException(status_code=403, detail="Invalid image path")
if not image_path.exists():
logger.error(f"[Podcast] Image file not found: {image_path}")
raise HTTPException(status_code=404, detail=f"Image file not found: {filename}")
image_bytes = image_path.read_bytes()
logger.info(f"[Podcast] ✅ Successfully loaded image: {len(image_bytes)} bytes from {image_path}")
if not image_bytes:
logger.error(f"[Podcast] Image file not found for URL: {image_url}")
raise HTTPException(status_code=404, detail=f"Image file not found: {image_url}")
logger.info(f"[Podcast] ✅ Successfully loaded image: {len(image_bytes)} bytes")
return image_bytes
except HTTPException:
raise
except Exception as exc:

View File

@@ -56,6 +56,8 @@ async def preflight_check(
provider_enum = APIProvider.MISTRAL # Maps to HuggingFace
elif provider_str == "video":
provider_enum = APIProvider.VIDEO
elif provider_str == "fal-ai" or provider_str == "fal":
provider_enum = APIProvider.VIDEO # Map fal-ai to VIDEO as it's primarily used for media gen
elif provider_str == "image_edit":
provider_enum = APIProvider.IMAGE_EDIT
elif provider_str == "stability":

View File

@@ -0,0 +1 @@
# Video Studio API Module

View File

@@ -0,0 +1,173 @@
from fastapi import APIRouter, UploadFile, File, Form, BackgroundTasks, HTTPException, Depends
from fastapi.responses import FileResponse
from typing import Optional, Dict, Any
import shutil
import os
from pathlib import Path
from services.wavespeed.infinitetalk import animate_scene_with_voiceover
from ..task_manager import task_manager
from middleware.auth_middleware import get_current_user
from loguru import logger
from services.database import get_engine_for_user
from sqlalchemy.orm import sessionmaker
from utils.asset_tracker import save_asset_to_library
router = APIRouter()
# Define storage directory
UPLOAD_DIR = Path("backend/data/video_studio/uploads")
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
def _process_avatar_generation(task_id: str, image_path: Path, audio_path: Path, user_id: str, resolution: str, model: str):
"""
Background task to process avatar generation using shared InfiniteTalk service.
"""
try:
task_manager.update_task(task_id, "processing", user_id=user_id)
# Read file bytes
with open(image_path, "rb") as f:
image_bytes = f.read()
with open(audio_path, "rb") as f:
audio_bytes = f.read()
# Dummy scene data required by the service (used for prompt generation)
scene_data = {
"title": "Test Persona",
"description": "A talking avatar video generated via Video Studio."
}
story_context = {}
# Call the common interface function
logger.info(f"[VideoStudio] Starting InfiniteTalk generation for task {task_id}")
result = animate_scene_with_voiceover(
image_bytes=image_bytes,
audio_bytes=audio_bytes,
scene_data=scene_data,
story_context=story_context,
user_id=user_id,
resolution=resolution
)
# Save the resulting video bytes to a file
video_filename = f"video_{task_id}.mp4"
video_path = UPLOAD_DIR / video_filename
with open(video_path, "wb") as f:
f.write(result["video_bytes"])
# Prepare result for frontend (remove raw bytes)
result.pop("video_bytes", None)
# Add local download URL
video_url = f"/api/video-studio/download/{video_filename}"
result["video_url"] = video_url
# Save asset to library
try:
engine = get_engine_for_user(user_id)
SessionLocal = sessionmaker(bind=engine)
db = SessionLocal()
try:
save_asset_to_library(
db=db,
user_id=user_id,
asset_type="video",
source_module="video_studio",
filename=video_filename,
file_url=video_url,
file_path=str(video_path),
file_size=video_path.stat().st_size,
mime_type="video/mp4",
title=f"Avatar Video {task_id}",
description=f"Generated avatar video using {model}",
model=model,
cost=result.get("cost", 0.0),
generation_time=result.get("generation_time", 0.0)
)
finally:
db.close()
except Exception as e:
logger.error(f"[VideoStudio] Failed to save asset to library: {e}")
logger.info(f"[VideoStudio] Task {task_id} completed successfully")
task_manager.update_task(task_id, "completed", result=result, user_id=user_id)
except Exception as e:
logger.error(f"[VideoStudio] Avatar generation failed for task {task_id}: {e}", exc_info=True)
task_manager.update_task(task_id, "failed", error=str(e), user_id=user_id)
finally:
# Cleanup temp upload files
try:
if image_path.exists(): image_path.unlink()
if audio_path.exists(): audio_path.unlink()
except Exception as e:
logger.warning(f"[VideoStudio] Failed to cleanup temp files: {e}")
@router.post("/avatar/create-async")
async def create_avatar_video(
background_tasks: BackgroundTasks,
image: UploadFile = File(...),
audio: UploadFile = File(...),
resolution: str = Form("720p"),
model: str = Form("infinitetalk"),
current_user: dict = Depends(get_current_user)
):
"""
Create a talking avatar video using InfiniteTalk (WaveSpeed).
Directly uses the common backend service without Podcast Maker dependencies.
"""
user_id = current_user.get("id", "anonymous")
# Validate file types roughly
if not image.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="Invalid image file type")
task_id = task_manager.create_task("avatar_generation", user_id=user_id)
# Generate temp paths
image_ext = Path(image.filename).suffix or ".png"
audio_ext = Path(audio.filename).suffix or ".mp3"
image_path = UPLOAD_DIR / f"img_{task_id}{image_ext}"
audio_path = UPLOAD_DIR / f"aud_{task_id}{audio_ext}"
try:
# Save uploaded files
with open(image_path, "wb") as f:
shutil.copyfileobj(image.file, f)
with open(audio_path, "wb") as f:
shutil.copyfileobj(audio.file, f)
# Start background task
background_tasks.add_task(
_process_avatar_generation,
task_id,
image_path,
audio_path,
user_id,
resolution,
model
)
return {"task_id": task_id, "status": "pending", "message": "Video generation started successfully."}
except Exception as e:
# Cleanup if immediate failure
if image_path.exists(): image_path.unlink()
if audio_path.exists(): audio_path.unlink()
logger.error(f"[VideoStudio] Failed to start generation: {e}")
raise HTTPException(status_code=500, detail=f"Failed to start generation: {str(e)}")
@router.get("/task/{task_id}")
async def get_task_status(task_id: str, current_user: dict = Depends(get_current_user)):
user_id = current_user.get("id", "anonymous")
task = task_manager.get_task(task_id, user_id=user_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return task
@router.get("/download/{filename}")
async def download_video(filename: str):
file_path = UPLOAD_DIR / filename
if not file_path.exists():
raise HTTPException(status_code=404, detail="File not found")
return FileResponse(file_path)

View File

@@ -0,0 +1,6 @@
from fastapi import APIRouter
from .handlers import avatar
router = APIRouter(prefix="/api/video-studio", tags=["Video Studio"])
router.include_router(avatar.router)

View File

@@ -0,0 +1,126 @@
from typing import Dict, Any, Optional
import uuid
from datetime import datetime
from loguru import logger
from sqlalchemy.orm import sessionmaker
from services.database import get_engine_for_user
from models.video_models import VideoGenerationTask, VideoTaskStatus, Base
class TaskManager:
def __init__(self):
pass
def create_task(self, task_type: str, user_id: str, request_data: Optional[Dict] = None) -> str:
"""Create a new persistent task."""
task_id = str(uuid.uuid4())
try:
engine = get_engine_for_user(user_id)
# Ensure table exists
Base.metadata.create_all(bind=engine)
SessionLocal = sessionmaker(bind=engine)
db = SessionLocal()
try:
task = VideoGenerationTask(
task_id=task_id,
user_id=user_id,
status=VideoTaskStatus.PENDING,
request_data=request_data
)
db.add(task)
db.commit()
logger.info(f"[VideoStudio] Created persistent task {task_id} for user {user_id}")
return task_id
finally:
db.close()
except Exception as e:
logger.error(f"[VideoStudio] Failed to create task: {e}")
raise
def update_task(self, task_id: str, status: str, result: Optional[Dict] = None, error: Optional[str] = None, user_id: str = None, progress: float = None, message: str = None):
"""Update an existing task."""
if not user_id:
logger.error(f"[VideoStudio] Cannot update task {task_id} without user_id")
return
try:
engine = get_engine_for_user(user_id)
SessionLocal = sessionmaker(bind=engine)
db = SessionLocal()
try:
task = db.query(VideoGenerationTask).filter(VideoGenerationTask.task_id == task_id).first()
if not task:
logger.warning(f"[VideoStudio] Task {task_id} not found in DB for update")
return
# Map string status to Enum
try:
# Handle case-insensitive status mapping
status_upper = status.upper()
if status_upper == "RUNNING":
status_upper = "PROCESSING"
enum_status = VideoTaskStatus[status_upper]
except KeyError:
logger.warning(f"[VideoStudio] Invalid status {status}, defaulting to PROCESSING")
enum_status = VideoTaskStatus.PROCESSING
task.status = enum_status
task.updated_at = datetime.utcnow()
if result:
task.result = result
if error:
task.error = error
if progress is not None:
task.progress = progress
if message:
task.message = message
db.commit()
logger.debug(f"[VideoStudio] Updated task {task_id} to {status}")
finally:
db.close()
except Exception as e:
logger.error(f"[VideoStudio] Failed to update task {task_id}: {e}")
def get_task(self, task_id: str, user_id: str = None) -> Optional[Dict[str, Any]]:
"""Retrieve task status."""
if not user_id:
logger.error(f"[VideoStudio] Cannot get task {task_id} without user_id")
return None
try:
engine = get_engine_for_user(user_id)
SessionLocal = sessionmaker(bind=engine)
db = SessionLocal()
try:
task = db.query(VideoGenerationTask).filter(VideoGenerationTask.task_id == task_id).first()
if not task:
return None
# Map internal status to frontend status
status_val = task.status.value
if status_val == "processing":
status_val = "running"
return {
"task_id": task.task_id,
"status": status_val,
"result": task.result,
"error": task.error,
"progress": task.progress,
"message": task.message,
"created_at": task.created_at,
"updated_at": task.updated_at
}
finally:
db.close()
except Exception as e:
logger.error(f"[VideoStudio] Failed to get task {task_id}: {e}")
return None
task_manager = TaskManager()

View File

@@ -17,6 +17,7 @@ from services.subscription.preflight_validator import validate_image_generation_
from services.llm_providers.main_image_generation import generate_image, generate_character_image
from utils.asset_tracker import save_asset_to_library
from utils.logger_utils import get_service_logger
from utils.media_utils import load_media_bytes
from ..task_manager import task_manager
router = APIRouter(tags=["youtube-image"])
@@ -59,36 +60,15 @@ def require_authenticated_user(current_user: Dict[str, Any]) -> str:
def _load_base_avatar_bytes(avatar_url: str) -> Optional[bytes]:
"""Load base avatar bytes for character consistency."""
try:
# Handle different avatar URL formats
if avatar_url.startswith("/api/youtube/avatars/"):
# YouTube avatar
filename = avatar_url.split("/")[-1].split("?")[0]
avatar_path = YOUTUBE_AVATARS_DIR / filename
elif avatar_url.startswith("/api/podcast/avatars/"):
# Podcast avatar (cross-module usage)
filename = avatar_url.split("/")[-1].split("?")[0]
from pathlib import Path
podcast_avatars_dir = Path(__file__).parent.parent.parent.parent / "podcast_avatars"
avatar_path = podcast_avatars_dir / filename
else:
# Try to extract filename and check YouTube avatars first
filename = avatar_url.split("/")[-1].split("?")[0]
avatar_path = YOUTUBE_AVATARS_DIR / filename
if not avatar_path.exists():
# Fallback to podcast avatars
podcast_avatars_dir = Path(__file__).parent.parent.parent.parent / "podcast_avatars"
avatar_path = podcast_avatars_dir / filename
if not avatar_path.exists() or not avatar_path.is_file():
logger.warning(f"[YouTube] Avatar file not found: {avatar_path}")
return None
logger.info(f"[YouTube] Successfully loaded avatar: {avatar_path}")
return avatar_path.read_bytes()
except Exception as e:
logger.error(f"[YouTube] Error loading avatar from {avatar_url}: {e}")
return None
# REUSE: Use centralized media loader
avatar_bytes = load_media_bytes(avatar_url)
if avatar_bytes:
logger.info(f"[YouTube] Successfully loaded avatar from: {avatar_url}")
return avatar_bytes
logger.warning(f"[YouTube] Avatar file not found for URL: {avatar_url}")
return None
def _save_scene_image(image_bytes: bytes, scene_id: str) -> Dict[str, str]:

View File

@@ -74,6 +74,7 @@ from routers.linkedin import router as linkedin_router
from api.linkedin_image_generation import router as linkedin_image_router
from api.brainstorm import router as brainstorm_router
from api.images import router as images_router
from api.assets_serving import router as assets_serving_router
from routers.image_studio import router as image_studio_router
from routers.product_marketing import router as product_marketing_router
from routers.campaign_creator import router as campaign_creator_router
@@ -132,6 +133,7 @@ from api.seo_dashboard import (
get_semantic_health # Phase 2B: Semantic health monitoring
)
# Initialize FastAPI app
app = FastAPI(
title="ALwrity Backend API",
@@ -244,6 +246,9 @@ async def onboarding_status():
router_manager.include_core_routers()
router_manager.include_optional_routers()
# Include assets serving router (must be mounted to serve generated images)
app.include_router(assets_serving_router)
# SEO Dashboard endpoints
@app.get("/api/seo-dashboard/data")
async def seo_dashboard_data():

View File

@@ -38,6 +38,7 @@ class ClerkAuthMiddleware:
)
self.clerk_publishable_key = publishable_key.strip() if publishable_key else None
self.disable_auth = os.getenv('DISABLE_AUTH', 'false').lower() == 'true'
self.allow_unverified_dev = os.getenv('ALLOW_UNVERIFIED_JWT_DEV', 'false').lower() == 'true'
# Cache for PyJWKClient to avoid repeated JWKS fetches
self._jwks_client_cache = {}
@@ -67,6 +68,7 @@ class ClerkAuthMiddleware:
# Create ClerkHTTPBearer instance for dependency injection
self.clerk_bearer = ClerkHTTPBearer(clerk_config)
logger.info(f"fastapi-clerk-auth initialized successfully with JWKS URL: {jwks_url}")
self._jwks_url_cache = jwks_url
else:
logger.warning("Could not extract instance from publishable key")
self.clerk_bearer = None
@@ -113,7 +115,9 @@ class ClerkAuthMiddleware:
issuer = unverified_claims.get('iss', '')
# Construct JWKS URL from issuer
jwks_url = f"{issuer}/.well-known/jwks.json"
jwks_url = f"{issuer}/.well-known/jwks.json" if issuer else self._jwks_url_cache or ""
if not jwks_url:
raise Exception("Unable to resolve JWKS URL for Clerk verification")
# Use cached PyJWKClient to avoid repeated JWKS fetches
if jwks_url not in self._jwks_client_cache:
@@ -162,11 +166,37 @@ class ClerkAuthMiddleware:
if 'expired' in error_msg or 'signature has expired' in error_msg:
logger.debug(f"Token expired (expected): {e}")
else:
logger.warning(f"fastapi-clerk-auth verification error: {e}")
logger.warning(f"fastapi-clerk-auth verification error: {e}. Attempting fallback decoding.")
# Fallback to unverified decoding on verification failure (DEV MODE ONLY)
try:
import jwt
# Decode the JWT without verification to get claims
decoded_token = jwt.decode(token, options={"verify_signature": False}, leeway=300)
user_id = decoded_token.get('sub')
email = decoded_token.get('email')
first_name = decoded_token.get('first_name') or decoded_token.get('given_name')
last_name = decoded_token.get('last_name') or decoded_token.get('family_name')
if user_id and self.allow_unverified_dev:
logger.debug(f"Unverified token accepted (dev) for user: {email or 'unknown'} (ID: {user_id})")
return {
'id': user_id,
'email': email,
'first_name': first_name,
'last_name': last_name,
'clerk_user_id': user_id
}
elif user_id and not self.allow_unverified_dev:
logger.error("Unverified token rejected (production).")
return None
except Exception as fallback_e:
logger.warning(f"Fallback decoding failed: {fallback_e}")
return None
else:
# Fallback to custom implementation (not secure for production)
logger.warning("Using fallback JWT decoding without signature verification")
logger.debug("Using fallback JWT decoding without signature verification")
try:
import jwt
# Decode the JWT without verification to get claims
@@ -188,14 +218,17 @@ class ClerkAuthMiddleware:
logger.warning("No user ID found in token")
return None
logger.info(f"Token decoded successfully (fallback) for user: {email} (ID: {user_id})")
return {
'id': user_id,
'email': email,
'first_name': first_name,
'last_name': last_name,
'clerk_user_id': user_id
}
if self.allow_unverified_dev:
logger.debug(f"Token decoded successfully (fallback dev) for user: {email} (ID: {user_id})")
return {
'id': user_id,
'email': email,
'first_name': first_name,
'last_name': last_name,
'clerk_user_id': user_id
}
logger.error("Fallback decoding is disabled in production.")
return None
except Exception as e:
logger.warning(f"Fallback JWT decode error: {e}")

View File

@@ -55,6 +55,15 @@ class AssetSource(enum.Enum):
# YouTube Creator
YOUTUBE_CREATOR = "youtube_creator"
# Brand Avatar Generator
BRAND_AVATAR_GENERATOR = "brand_avatar_generator"
# Video Studio
VIDEO_STUDIO = "video_studio"
# Voice Cloner
VOICE_CLONER = "voice_cloner"
class ContentAsset(Base):
"""

View File

@@ -0,0 +1,36 @@
from sqlalchemy import Column, Integer, String, DateTime, JSON, Text, Float, Enum
from datetime import datetime
import enum
from models.subscription_models import Base
class VideoTaskStatus(enum.Enum):
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
class VideoGenerationTask(Base):
"""
Model for tracking video generation tasks (Video Studio).
"""
__tablename__ = "video_generation_tasks"
id = Column(Integer, primary_key=True, index=True)
task_id = Column(String(36), unique=True, index=True, nullable=False) # UUID
user_id = Column(String(255), nullable=False, index=True)
status = Column(Enum(VideoTaskStatus), default=VideoTaskStatus.PENDING)
# Task inputs (stored as JSON)
request_data = Column(JSON, nullable=True)
# Task results
result = Column(JSON, nullable=True)
error = Column(Text, nullable=True)
# Progress tracking
progress = Column(Float, default=0.0)
message = Column(String(255), nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)

View File

@@ -43,6 +43,12 @@ async def get_gsc_auth_url(user: dict = Depends(get_current_user)):
logger.info(f"OAuth URL: {auth_url[:100]}...")
return {"auth_url": auth_url}
except FileNotFoundError as e:
logger.error(f"GSC credentials not found: {e}")
raise HTTPException(
status_code=503,
detail="Google Search Console integration is not configured. Please add gsc_credentials.json to the backend directory or set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET environment variables."
)
except Exception as e:
logger.error(f"Error generating GSC OAuth URL: {e}")
logger.error(f"Error details: {str(e)}")
@@ -73,34 +79,29 @@ async def handle_gsc_callback(
from services.platform_insights_monitoring_service import create_platform_insights_task
# Get user_id from state (stored during OAuth flow)
# Note: handle_oauth_callback already deleted state, so we need to get user_id from recent credentials
db = SessionLocal()
try:
# Get user_id from most recent GSC credentials (since state was deleted)
import sqlite3
with sqlite3.connect(gsc_service.db_path) as conn:
cursor = conn.cursor()
cursor.execute('SELECT user_id FROM gsc_credentials ORDER BY updated_at DESC LIMIT 1')
result = cursor.fetchone()
if result:
user_id = result[0]
# Don't fetch site_url here - it requires API calls
# The executor will fetch it when the task runs (weekly)
# Create insights task without site_url to avoid API calls
task_result = create_platform_insights_task(
user_id=user_id,
platform='gsc',
site_url=None, # Will be fetched by executor when task runs
db=db
)
if task_result.get('success'):
logger.info(f"Created GSC insights task for user {user_id}")
else:
logger.warning(f"Failed to create GSC insights task: {task_result.get('error')}")
finally:
db.close()
# Format is "user_id:random_string"
user_id = state.split(':')[0] if ':' in state else None
if user_id:
db = SessionLocal()
try:
# Create insights task without site_url to avoid API calls
# The executor will fetch it when the task runs (weekly)
task_result = create_platform_insights_task(
user_id=user_id,
platform='gsc',
site_url=None, # Will be fetched by executor when task runs
db=db
)
if task_result.get('success'):
logger.info(f"Created GSC insights task for user {user_id}")
else:
logger.warning(f"Failed to create GSC insights task: {task_result.get('error')}")
finally:
db.close()
else:
logger.warning(f"Could not extract user_id from state: {state}")
except Exception as e:
# Non-critical: log but don't fail OAuth callback
logger.warning(f"Failed to create GSC insights task after OAuth: {e}", exc_info=True)

View File

@@ -3,8 +3,8 @@ WordPress OAuth2 Routes
Handles WordPress.com OAuth2 authentication flow.
"""
from fastapi import APIRouter, Depends, HTTPException, status, Query
from fastapi.responses import RedirectResponse, HTMLResponse
from fastapi import APIRouter, Depends, HTTPException, status, Query, Request
from fastapi.responses import RedirectResponse, HTMLResponse, JSONResponse
from typing import Dict, Any, Optional
from pydantic import BaseModel
from loguru import logger
@@ -61,14 +61,23 @@ async def get_wordpress_auth_url(
@router.get("/callback")
async def handle_wordpress_callback(
request: Request,
code: str = Query(..., description="Authorization code from WordPress"),
state: str = Query(..., description="State parameter for security"),
error: Optional[str] = Query(None, description="Error from WordPress OAuth")
):
"""Handle WordPress OAuth2 callback."""
try:
# Check if JSON response is requested
wants_json = request.headers.get("accept") == "application/json" or request.query_params.get("format") == "json"
if error:
logger.error(f"WordPress OAuth error: {error}")
if wants_json:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"success": False, "error": error}
)
html_content = f"""
<!DOCTYPE html>
<html>
@@ -77,7 +86,7 @@ async def handle_wordpress_callback(
<script>
// Send error message to parent window
window.onload = function() {{
window.parent.postMessage({{
(window.opener || window.parent).postMessage({{
type: 'WPCOM_OAUTH_ERROR',
success: false,
error: '{error}'
@@ -100,6 +109,11 @@ async def handle_wordpress_callback(
if not code or not state:
logger.error("Missing code or state parameter in WordPress OAuth callback")
if wants_json:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"success": False, "error": "Missing parameters"}
)
html_content = """
<!DOCTYPE html>
<html>
@@ -134,6 +148,11 @@ async def handle_wordpress_callback(
if not result or not result.get('success'):
logger.error("Failed to exchange WordPress OAuth code for token")
if wants_json:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"success": False, "error": "Token exchange failed"}
)
html_content = """
<!DOCTYPE html>
<html>
@@ -162,6 +181,18 @@ async def handle_wordpress_callback(
# Return success page with postMessage script
blog_url = result.get('blog_url', '')
blog_id = result.get('blog_id', '')
if wants_json:
return JSONResponse(
content={
"success": True,
"blog_url": blog_url,
"blog_id": blog_id,
"sites": [{"blog_url": blog_url, "blog_id": blog_id}] # Simplified for now
}
)
html_content = f"""
<!DOCTYPE html>
<html>
@@ -174,7 +205,7 @@ async def handle_wordpress_callback(
type: 'WPCOM_OAUTH_SUCCESS',
success: true,
blogUrl: '{blog_url}',
blogId: '{result.get('blog_id', '')}'
blogId: '{blog_id}'
}}, '*');
window.close();
}};

View File

@@ -0,0 +1,122 @@
import asyncio
import time
import os
import sys
from typing import Dict, Any, List
from tabulate import tabulate
from loguru import logger
# Add project root to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from services.llm_providers.main_image_generation import generate_image_with_provider
from services.llm_providers.image_generation.wavespeed_provider import WaveSpeedImageProvider
async def benchmark_provider(provider_name: str, model: str, prompt: str) -> Dict[str, Any]:
"""Benchmark a single provider/model combination."""
logger.info(f"Benchmarking {provider_name} ({model})...")
start_time = time.time()
try:
# We use a mocked user_id for validation bypass if needed,
# or rely on the system to handle "benchmark_user"
result = await generate_image_with_provider(
prompt=prompt,
provider=provider_name,
model=model,
width=1024,
height=1024,
user_id="benchmark_user"
)
duration = time.time() - start_time
success = result.get("success", False)
return {
"provider": provider_name,
"model": model,
"duration": duration,
"success": success,
"error": result.get("error")
}
except Exception as e:
return {
"provider": provider_name,
"model": model,
"duration": time.time() - start_time,
"success": False,
"error": str(e)
}
async def run_benchmarks():
"""Run benchmarks across configured providers."""
# Check configured providers
wavespeed_key = os.getenv("WAVESPEED_API_KEY")
stability_key = os.getenv("STABILITY_API_KEY")
hf_token = os.getenv("HF_TOKEN")
logger.info("Checking configured providers...")
logger.info(f"WaveSpeed: {'✅ Configured' if wavespeed_key else '❌ Missing API Key'}")
logger.info(f"Stability: {'✅ Configured' if stability_key else '❌ Missing API Key'}")
logger.info(f"HuggingFace: {'✅ Configured' if hf_token else '❌ Missing API Key'}")
prompt = "A professional brand avatar of a creative designer, minimalist style, clean background, high resolution"
tasks = []
# WaveSpeed Models
if wavespeed_key:
tasks.append(benchmark_provider("wavespeed", "ideogram-v3-turbo", prompt))
tasks.append(benchmark_provider("wavespeed", "qwen-image", prompt))
tasks.append(benchmark_provider("wavespeed", "flux-kontext-pro", prompt))
# Stability Models
if stability_key:
tasks.append(benchmark_provider("stability", "core", prompt))
# HuggingFace Models
if hf_token:
tasks.append(benchmark_provider("huggingface", "black-forest-labs/FLUX.1-dev", prompt))
if not tasks:
logger.warning("No providers configured for benchmarking.")
return
logger.info(f"Starting benchmark for {len(tasks)} configurations...")
results = await asyncio.gather(*tasks)
# Display results
table_data = []
for r in results:
status = "✅ Success" if r["success"] else f"❌ Failed: {r['error'][:30]}..."
table_data.append([
r["provider"],
r["model"],
f"{r['duration']:.2f}s",
status
])
print("\n" + "="*60)
print("AVATAR GENERATION PERFORMANCE BENCHMARK")
print("="*60)
print(tabulate(table_data, headers=["Provider", "Model", "Time", "Status"], tablefmt="grid"))
print("\nRecommendation:")
# Simple recommendation logic
successful = [r for r in results if r["success"]]
if successful:
fastest = min(successful, key=lambda x: x["duration"])
print(f"Fastest provider: {fastest['provider']} ({fastest['model']}) at {fastest['duration']:.2f}s")
# Check WaveSpeed specifically
wavespeed_results = [r for r in successful if r["provider"] == "wavespeed"]
if wavespeed_results:
avg_wavespeed = sum(r["duration"] for r in wavespeed_results) / len(wavespeed_results)
print(f"WaveSpeed Average: {avg_wavespeed:.2f}s")
else:
print("No successful generations to analyze.")
if __name__ == "__main__":
asyncio.run(run_benchmarks())

View File

@@ -0,0 +1,88 @@
import sys
import os
import logging
# Add backend to path
sys.path.append(os.path.join(os.getcwd(), 'backend'))
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
from models.subscription_models import APIUsageLog, UserSubscription, APIProvider
from services.subscription import UsageTrackingService, PricingService
USER_ID = "user_33Gz1FPI86VDXhRY8QN4ragRFGN"
def get_db_path(user_id):
# Logic from database.py to resolve path
base_path = os.getcwd()
# Sanitize user_id
safe_user_id = "".join(c for c in user_id if c.isalnum() or c in ('-', '_'))
user_workspace = os.path.join(base_path, "workspace", f"workspace_{safe_user_id}")
# Try both naming conventions
db_path_v1 = os.path.join(user_workspace, "db", "alwrity.db")
db_path_v2 = os.path.join(user_workspace, "db", f"alwrity_{safe_user_id}.db")
if os.path.exists(db_path_v2):
return db_path_v2
return db_path_v1
def check_user_data():
db_path = get_db_path(USER_ID)
logger.info(f"Checking DB at: {db_path}")
if not os.path.exists(db_path):
logger.error(f"DB file not found at {db_path}")
# Check default DB as fallback
default_db = os.path.join(os.getcwd(), 'backend', 'data', 'alwrity.db')
if os.path.exists(default_db):
logger.info(f"Falling back to default DB: {default_db}")
db_url = f"sqlite:///{default_db}"
else:
return
else:
db_url = f"sqlite:///{db_path}"
engine = create_engine(db_url)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = SessionLocal()
try:
# 1. Check API Usage Logs
logs = db.query(APIUsageLog).filter(APIUsageLog.user_id == USER_ID).all()
logger.info(f"Found {len(logs)} usage logs for user {USER_ID}")
if logs:
last_log = logs[-1]
logger.info(f"Last log: {last_log.timestamp} - {last_log.provider} - {last_log.cost_total}")
# Print provider breakdown
from collections import Counter
providers = Counter([l.provider for l in logs])
logger.info(f"Provider breakdown: {providers}")
# 2. Check Subscription
sub = db.query(UserSubscription).filter(UserSubscription.user_id == USER_ID).first()
if sub:
logger.info(f"Subscription found: {sub.plan_type} ({sub.status})")
else:
logger.warning("No subscription found")
# 3. Run Usage Service
logger.info("Running UsageTrackingService.get_user_usage_stats...")
service = UsageTrackingService(db)
stats = service.get_user_usage_stats(USER_ID)
logger.info(f"Service Stats: {stats}")
except Exception as e:
logger.error(f"Error: {e}")
finally:
db.close()
if __name__ == "__main__":
check_user_data()

View File

@@ -0,0 +1,48 @@
import os
import sys
from pathlib import Path
# Add the backend directory to the Python path
backend_dir = Path(__file__).parent.parent
sys.path.insert(0, str(backend_dir))
def diagnose():
print(f"Current working directory: {os.getcwd()}")
# Replicate database.py logic
file_path = os.path.abspath(__file__)
# backend/scripts/diagnose.py -> backend/scripts -> backend -> root
# Wait, in database.py it is services/database.py -> services -> backend -> root
# So 3 levels up.
# Here: scripts/diagnose.py -> scripts -> backend -> root.
# So also 3 levels up.
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(file_path)))
print(f"Calculated ROOT_DIR: {ROOT_DIR}")
workspace_dir = os.path.join(ROOT_DIR, 'workspace')
print(f"Calculated WORKSPACE_DIR: {workspace_dir}")
if os.path.exists(workspace_dir):
print(f"Workspace directory exists.")
print("Contents:")
try:
for item in os.listdir(workspace_dir):
print(f" - {item}")
except Exception as e:
print(f"Error listing workspace: {e}")
else:
print(f"Workspace directory DOES NOT exist.")
# Check for alwrity.db in backend
backend_db = os.path.join(backend_dir, 'alwrity.db')
if os.path.exists(backend_db):
print(f"Found legacy DB in backend: {backend_db}")
# Check for alwrity.db in root
root_db = os.path.join(ROOT_DIR, 'alwrity.db')
if os.path.exists(root_db):
print(f"Found legacy DB in root: {root_db}")
if __name__ == "__main__":
diagnose()

View File

@@ -0,0 +1,57 @@
import os
import sqlite3
import pandas as pd
from pathlib import Path
def inspect_dbs():
root = Path(os.getcwd())
workspace_dir = root / 'workspace'
if not workspace_dir.exists():
print("No workspace directory found.")
return
print(f"Scanning {workspace_dir} for databases...")
for user_dir in workspace_dir.iterdir():
if user_dir.is_dir() and user_dir.name.startswith('workspace_'):
db_dir = user_dir / 'db'
if db_dir.exists():
for db_file in db_dir.glob('*.db'):
print(f"\n--- Checking DB: {db_file} ---")
try:
conn = sqlite3.connect(db_file)
# Check tables
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = [r[0] for r in cursor.fetchall()]
print(f"Tables: {len(tables)}")
if 'api_usage_logs' in tables:
count = cursor.execute("SELECT count(*) FROM api_usage_logs").fetchone()[0]
print(f"api_usage_logs count: {count}")
if count > 0:
# Show last 5 logs
print("Last 5 logs:")
df = pd.read_sql_query("SELECT * FROM api_usage_logs ORDER BY created_at DESC LIMIT 5", conn)
print(df[['id', 'provider', 'model_used', 'cost_total', 'created_at']].to_string())
else:
print("Table 'api_usage_logs' NOT found.")
if 'usage_summaries' in tables:
print("Usage Summaries:")
df = pd.read_sql_query("SELECT * FROM usage_summaries", conn)
if not df.empty:
print(df.to_string())
else:
print("Table 'usage_summaries' is empty.")
else:
print("Table 'usage_summaries' NOT found.")
conn.close()
except Exception as e:
print(f"Error reading DB: {e}")
if __name__ == "__main__":
inspect_dbs()

View File

@@ -0,0 +1,42 @@
import os
import sys
import sqlite3
user_id = "user_33Gz1FPI86VDXhRY8QN4ragRFGN"
base_path = os.getcwd()
safe_user_id = "".join(c for c in user_id if c.isalnum() or c in ('-', '_'))
user_workspace = os.path.join(base_path, "workspace", f"workspace_{safe_user_id}")
db_path = os.path.join(user_workspace, "db", f"alwrity_{safe_user_id}.db")
print(f"Reading from: {db_path}")
conn = sqlite3.connect(db_path)
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
print("\n--- API Usage Logs ---")
cursor.execute("SELECT * FROM api_usage_logs")
rows = cursor.fetchall()
for row in rows:
print(dict(row))
print("\n--- Subscription Plans ---")
try:
cursor.execute("SELECT * FROM subscription_plans")
rows = cursor.fetchall()
for row in rows:
print(dict(row))
except Exception as e:
print(f"Error reading subscription_plans: {e}")
print("\n--- Usage Summaries ---")
try:
cursor.execute("SELECT * FROM usage_summaries")
rows = cursor.fetchall()
for row in rows:
print(dict(row))
except Exception as e:
print(f"Error reading usage_summaries: {e}")
conn.close()

View File

@@ -0,0 +1,89 @@
import asyncio
import logging
import sys
import os
# Add backend directory to path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from unittest.mock import MagicMock, AsyncMock
from services.intelligence.agents.specialized_agents import ContentGuardianAgent, StrategyArchitectAgent
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def test_content_guardian():
print("\n=== Testing ContentGuardianAgent ===")
# Mock Intelligence Service
mock_intelligence = MagicMock()
mock_intelligence.is_initialized.return_value = True
# Mock search for cannibalization check
# Scenario 1: No cannibalization
mock_intelligence.search = AsyncMock(return_value=[])
agent = ContentGuardianAgent(mock_intelligence, user_id="test_user")
content = "This is a unique piece of content about AI agents." + " word" * 50 # Make it long enough
print(f"Testing assess_content_quality with content length: {len(content)}")
result = await agent.assess_content_quality(content)
print("Result:", result)
if result.get("quality_score", 0) > 0:
print("✅ assess_content_quality returned a valid score.")
else:
print("❌ assess_content_quality failed to return a valid score.")
# Scenario 2: Cannibalization detected
mock_intelligence.search = AsyncMock(return_value=[{'id': 'existing_doc', 'score': 0.9}])
print("\nTesting assess_content_quality with cannibalization...")
result_cannibal = await agent.assess_content_quality(content)
print("Result (Cannibalization):", result_cannibal)
if result_cannibal.get("cannibalization_risk", {}).get("warning"):
print("✅ Cannibalization correctly detected.")
else:
print("❌ Cannibalization NOT detected when it should be.")
async def test_strategy_architect():
print("\n=== Testing StrategyArchitectAgent ===")
mock_intelligence = MagicMock()
mock_intelligence.is_initialized.return_value = True
# Scenario 1: No clusters
mock_intelligence.cluster = AsyncMock(return_value=[])
agent = StrategyArchitectAgent(mock_intelligence, user_id="test_user")
print("Testing discover_pillars (No clusters)...")
pillars = await agent.discover_pillars()
print(f"Pillars found: {len(pillars)}")
if len(pillars) == 0:
print("✅ Correctly handled no clusters.")
else:
print("❌ Should have returned 0 pillars.")
# Scenario 2: Clusters found
mock_intelligence.cluster = AsyncMock(return_value=[[0, 1, 2], [3, 4]])
print("\nTesting discover_pillars (With clusters)...")
pillars = await agent.discover_pillars()
print(f"Pillars found: {len(pillars)}")
if len(pillars) == 2:
print("✅ Correctly identified pillars.")
print("Pillar 1 size:", pillars[0]['size'])
else:
print("❌ Failed to identify pillars.")
if __name__ == "__main__":
asyncio.run(test_content_guardian())
asyncio.run(test_strategy_architect())

View File

@@ -0,0 +1,43 @@
import os
import sys
import sqlite3
# Add backend to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from services.database import get_user_db_path
user_id = "user_33Gz1FPI86VDXhRY8QN4ragRFGN"
base_path = os.getcwd()
safe_user_id = "".join(c for c in user_id if c.isalnum() or c in ('-', '_'))
user_workspace = os.path.join(base_path, "workspace", f"workspace_{safe_user_id}")
path1 = os.path.join(user_workspace, "db", "alwrity.db")
path2 = os.path.join(user_workspace, "db", f"alwrity_{safe_user_id}.db")
print(f"Checking paths for user {user_id}:")
print(f"Legacy: {path1}")
print(f"Specific: {path2}")
def check_db(path):
if not os.path.exists(path):
print(f" [MISSING] {path}")
return
try:
conn = sqlite3.connect(path)
cursor = conn.cursor()
cursor.execute("SELECT count(*) FROM api_usage_logs")
count = cursor.fetchone()[0]
print(f" [EXISTS] {path} - Rows in api_usage_logs: {count}")
conn.close()
except Exception as e:
print(f" [EXISTS] {path} - Error reading: {e}")
check_db(path1)
check_db(path2)
print("-" * 30)
resolved = get_user_db_path(user_id)
print(f"Application resolves to: {resolved}")

View File

@@ -20,12 +20,13 @@ class BaseAnalyticsHandler(ABC):
self.platform_name = platform_type.value
@abstractmethod
async def get_analytics(self, user_id: str) -> AnalyticsData:
async def get_analytics(self, user_id: str, **kwargs) -> AnalyticsData:
"""
Get analytics data for the platform
Args:
user_id: User ID to get analytics for
**kwargs: Additional arguments for specific handlers
Returns:
AnalyticsData object with platform metrics

View File

@@ -42,7 +42,7 @@ class BingAnalyticsHandler(BaseAnalyticsHandler):
db_url = f'sqlite:///{db_path}'
return BingInsightsService(db_url)
async def get_analytics(self, user_id: str) -> AnalyticsData:
async def get_analytics(self, user_id: str, target_url: str = None, **kwargs) -> AnalyticsData:
"""
Get Bing Webmaster analytics data using Bing Webmaster API
"""
@@ -83,9 +83,32 @@ class BingAnalyticsHandler(BaseAnalyticsHandler):
if not access_token:
return self.create_error_response('Bing Webmaster access token not available')
# Select site: Prefer target_url match, otherwise first site
selected_site = sites[0] if sites else None
if not selected_site:
return self.create_error_response('No Bing sites found')
if target_url and sites:
logger.info(f"Attempting to match target URL: {target_url}")
# Normalize target URL (remove protocol, trailing slash)
normalized_target = target_url.replace('https://', '').replace('http://', '').rstrip('/')
for site in sites:
# Bing uses 'Url' key
site_url = site.get('Url', '')
normalized_site = site_url.replace('https://', '').replace('http://', '').rstrip('/')
if normalized_target in normalized_site or normalized_site in normalized_target:
selected_site = site
logger.info(f"Found matching Bing site: {site_url}")
break
site_url_for_storage = selected_site.get('Url', '') if selected_site else ''
logger.info(f"Using Bing site URL: {site_url_for_storage}")
query_stats = {}
try:
site_url_for_storage = sites[0].get('Url', '') if (sites and isinstance(sites[0], dict)) else None
stored = storage_service.get_analytics_summary(user_id, site_url_for_storage, days=30)
if stored and isinstance(stored, dict):
query_stats = {
@@ -99,7 +122,7 @@ class BingAnalyticsHandler(BaseAnalyticsHandler):
logger.warning(f"Bing analytics: Failed to read stored analytics summary: {e}")
# Get enhanced insights
insights = self._get_enhanced_insights_with_service(insights_service, user_id, sites[0].get('Url', '') if sites else '')
insights = self._get_enhanced_insights_with_service(insights_service, user_id, site_url_for_storage)
metrics = {
'connection_status': 'connected',

View File

@@ -22,16 +22,22 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
super().__init__(PlatformType.GSC)
self.gsc_service = GSCService()
async def get_analytics(self, user_id: str) -> AnalyticsData:
async def get_analytics(self, user_id: str, target_url: str = None, **kwargs) -> AnalyticsData:
"""
Get Google Search Console analytics data with caching
Args:
user_id: User ID to get analytics for
target_url: Optional URL to prefer when selecting GSC site
Returns comprehensive SEO metrics including clicks, impressions, CTR, and position data.
"""
self.log_analytics_request(user_id, "get_analytics")
# Check cache first - GSC API calls can be expensive
cached_data = analytics_cache.get('gsc_analytics', user_id)
# Include target_url in cache key if provided
cache_key = f"{user_id}_{target_url}" if target_url else user_id
cached_data = analytics_cache.get('gsc_analytics', cache_key)
if cached_data:
logger.info("Using cached GSC analytics for user {user_id}", user_id=user_id)
return AnalyticsData(**cached_data)
@@ -45,8 +51,23 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
logger.warning(f"No GSC sites found for user {user_id}")
return self.create_error_response('No GSC sites found')
# Get analytics for the first site (or combine all sites)
site_url = sites[0]['siteUrl']
# Select site: Prefer target_url match, otherwise first site
selected_site = sites[0]
if target_url:
logger.info(f"Attempting to match target URL: {target_url}")
# Normalize target URL (remove protocol, trailing slash)
normalized_target = target_url.replace('https://', '').replace('http://', '').rstrip('/')
for site in sites:
site_url = site['siteUrl']
normalized_site = site_url.replace('https://', '').replace('http://', '').rstrip('/')
if normalized_target in normalized_site or normalized_site in normalized_target:
selected_site = site
logger.info(f"Found matching GSC site: {site_url}")
break
site_url = selected_site['siteUrl']
logger.info(f"Using GSC site URL: {site_url}")
# Get search analytics for last 30 days
@@ -71,7 +92,7 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
)
# Cache the result to avoid expensive API calls
analytics_cache.set('gsc_analytics', user_id, result.__dict__)
analytics_cache.set('gsc_analytics', cache_key, result.__dict__)
logger.info("Cached GSC analytics data for user {user_id}", user_id=user_id)
return result
@@ -81,7 +102,7 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
error_result = self.create_error_response(str(e))
# Cache error result for shorter time to retry sooner
analytics_cache.set('gsc_analytics', user_id, error_result.__dict__, ttl_override=300) # 5 minutes
analytics_cache.set('gsc_analytics', cache_key, error_result.__dict__, ttl_override=300) # 5 minutes
return error_result
def get_connection_status(self, user_id: str) -> Dict[str, Any]:
@@ -117,111 +138,93 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
# New structure from updated GSC service
overall_rows = search_analytics.get('overall_metrics', {}).get('rows', [])
query_rows = search_analytics.get('query_data', {}).get('rows', [])
verification_rows = search_analytics.get('verification_data', {}).get('rows', [])
logger.info(f"GSC Overall metrics rows: {len(overall_rows)}")
logger.info(f"GSC Query data rows: {len(query_rows)}")
logger.info(f"GSC Verification rows: {len(verification_rows)}")
# Calculate totals from overall_rows (most accurate as it includes anonymized queries)
total_clicks = 0
total_impressions = 0
total_position = 0
valid_position_rows = 0
if overall_rows:
logger.info(f"GSC Overall first row: {overall_rows[0]}")
if query_rows:
logger.info(f"GSC Query first row: {query_rows[0]}")
# Use overall_rows for totals if available, otherwise fallback to query_rows
calc_rows = overall_rows if overall_rows else query_rows
for row in calc_rows:
clicks = row.get('clicks', 0)
impressions = row.get('impressions', 0)
position = row.get('position', 0)
total_clicks += clicks
total_impressions += impressions
if position and position > 0:
total_position += position * impressions # Weighted average
# Calculate weighted average position
avg_position = total_position / total_impressions if total_impressions > 0 else 0
avg_ctr = (total_clicks / total_impressions * 100) if total_impressions > 0 else 0
# Use query_rows for top queries list
top_queries_source = query_rows
# Use query_rows for detailed insights, overall_rows for summary
rows = query_rows if query_rows else overall_rows
else:
# Legacy structure
rows = search_analytics.get('rows', [])
logger.info(f"GSC Legacy rows count: {len(rows)}")
if rows:
logger.info(f"GSC Legacy first row structure: {rows[0]}")
logger.info(f"GSC Legacy first row keys: {list(rows[0].keys()) if rows[0] else 'No rows'}")
# Calculate summary metrics - handle different response formats
total_clicks = 0
total_impressions = 0
total_position = 0
valid_rows = 0
for row in rows:
# Handle different possible response formats
clicks = row.get('clicks', 0)
impressions = row.get('impressions', 0)
position = row.get('position', 0)
# ... existing legacy logic ...
calc_rows = rows
top_queries_source = rows
# If position is 0 or None, skip it from average calculation
if position and position > 0:
total_position += position
valid_rows += 1
total_clicks = 0
total_impressions = 0
total_position = 0
valid_position_rows = 0
total_clicks += clicks
total_impressions += impressions
avg_ctr = (total_clicks / total_impressions * 100) if total_impressions > 0 else 0
avg_position = total_position / valid_rows if valid_rows > 0 else 0
logger.info(f"GSC Calculated metrics - clicks: {total_clicks}, impressions: {total_impressions}, ctr: {avg_ctr}, position: {avg_position}, valid_rows: {valid_rows}")
# Get top performing queries - handle different data structures
if rows and 'keys' in rows[0]:
# New GSC API format with keys array
top_queries = sorted(rows, key=lambda x: x.get('clicks', 0), reverse=True)[:10]
# Get top performing pages (if we have page data)
page_data = {}
for row in rows:
# Handle different key structures
keys = row.get('keys', [])
if len(keys) > 1 and keys[1]: # Page data available
page = keys[1].get('keys', ['Unknown'])[0] if isinstance(keys[1], dict) else str(keys[1])
else:
page = 'Unknown'
for row in calc_rows:
clicks = row.get('clicks', 0)
impressions = row.get('impressions', 0)
position = row.get('position', 0)
if page not in page_data:
page_data[page] = {'clicks': 0, 'impressions': 0, 'ctr': 0, 'position': 0}
page_data[page]['clicks'] += row.get('clicks', 0)
page_data[page]['impressions'] += row.get('impressions', 0)
else:
# Legacy format or no keys structure
top_queries = sorted(rows, key=lambda x: x.get('clicks', 0), reverse=True)[:10]
page_data = {}
total_clicks += clicks
total_impressions += impressions
if position and position > 0:
# Simple average for legacy/unknown structure if we can't do weighted
total_position += position
valid_position_rows += 1
avg_ctr = (total_clicks / total_impressions * 100) if total_impressions > 0 else 0
avg_position = total_position / valid_position_rows if valid_position_rows > 0 else 0
# Calculate page metrics
for page in page_data:
if page_data[page]['impressions'] > 0:
page_data[page]['ctr'] = page_data[page]['clicks'] / page_data[page]['impressions'] * 100
top_pages = sorted(page_data.items(), key=lambda x: x[1]['clicks'], reverse=True)[:10]
return {
'connection_status': 'connected',
'connected_sites': 1, # GSC typically has one site per user
'total_clicks': total_clicks,
'total_impressions': total_impressions,
'avg_ctr': round(avg_ctr, 2),
'avg_position': round(avg_position, 2),
'total_queries': len(rows),
'top_queries': [
{
# Get top performing queries
top_queries = []
if top_queries_source:
# Sort by clicks
sorted_queries = sorted(top_queries_source, key=lambda x: x.get('clicks', 0), reverse=True)[:10]
for row in sorted_queries:
top_queries.append({
'query': self._extract_query_from_row(row),
'clicks': row.get('clicks', 0),
'impressions': row.get('impressions', 0),
'ctr': round(row.get('ctr', 0) * 100, 2),
'position': round(row.get('position', 0), 2)
}
for row in top_queries
],
'top_pages': [
{
'page': page,
'clicks': data['clicks'],
'impressions': data['impressions'],
'ctr': round(data['ctr'], 2)
}
for page, data in top_pages
],
'note': 'Google Search Console provides search performance data, keyword rankings, and SEO insights'
})
# Prepare Top Pages (requires page dimension, but we only requested query dimension in gsc_service step 3)
# To get top pages, we would need another API call with dimension=['page']
# For now, we'll return empty top_pages or infer from what we have if possible (we can't from query data)
top_pages = []
return {
'connection_status': 'connected',
'connected_sites': 1,
'total_clicks': total_clicks,
'total_impressions': total_impressions,
'avg_ctr': round(avg_ctr, 2),
'avg_position': round(avg_position, 2),
'total_queries': len(top_queries_source) if top_queries_source else 0,
'top_queries': top_queries,
'top_pages': top_pages
}
except Exception as e:

View File

@@ -59,6 +59,32 @@ class PlatformAnalyticsService:
logger.info(f"Getting comprehensive analytics for user {user_id}, platforms: {platforms}")
analytics_data = {}
# Determine target URL from Wix/WP for GSC site selection
target_url = None
try:
status = await self.get_platform_connection_status(user_id)
# Check Wix
if status.get('wix', {}).get('connected'):
sites = status['wix'].get('sites', [])
if sites:
# Assuming site object has 'blog_url' or 'url'
site = sites[0]
target_url = site.get('blog_url') or site.get('url')
# Check WordPress if no Wix
if not target_url and status.get('wordpress', {}).get('connected'):
sites = status['wordpress'].get('sites', [])
if sites:
site = sites[0]
target_url = site.get('blog_url') or site.get('url')
if target_url:
logger.info(f"Identified primary site URL for GSC matching: {target_url}")
except Exception as e:
logger.warning(f"Failed to determine target URL for GSC: {e}")
for platform_name in platforms:
try:
# Convert string to PlatformType enum
@@ -66,7 +92,10 @@ class PlatformAnalyticsService:
handler = self.handlers.get(platform_type)
if handler:
analytics_data[platform_name] = await handler.get_analytics(user_id)
if platform_type == PlatformType.GSC or platform_type == PlatformType.BING:
analytics_data[platform_name] = await handler.get_analytics(user_id, target_url=target_url)
else:
analytics_data[platform_name] = await handler.get_analytics(user_id)
else:
logger.warning(f"Unknown platform: {platform_name}")
analytics_data[platform_name] = self._create_error_response(platform_name, f"Unknown platform: {platform_name}")

View File

@@ -30,6 +30,8 @@ from models.product_asset_models import ProductAsset, ProductStyleTemplate, Ecom
from models.podcast_models import PodcastProject
# Research models use SubscriptionBase
from models.research_models import ResearchProject
# Video Studio models
from models.video_models import VideoGenerationTask
# Bing Analytics models
from models.bing_analytics_models import Base as BingAnalyticsBase
@@ -54,7 +56,22 @@ def get_user_db_path(user_id: str) -> str:
# Sanitize user_id to be safe for filesystem
safe_user_id = "".join(c for c in user_id if c.isalnum() or c in ('-', '_'))
user_workspace = os.path.join(WORKSPACE_DIR, f"workspace_{safe_user_id}")
return os.path.join(user_workspace, 'db', f'alwrity_{safe_user_id}.db')
# Check for legacy naming convention first (to support existing data)
# Some older workspaces might have 'alwrity.db' instead of 'alwrity_{user_id}.db'
legacy_db_path = os.path.join(user_workspace, 'db', 'alwrity.db')
specific_db_path = os.path.join(user_workspace, 'db', f'alwrity_{safe_user_id}.db')
# If the specific one exists, use it (preferred)
if os.path.exists(specific_db_path):
return specific_db_path
# If legacy exists and specific doesn't, use legacy
if os.path.exists(legacy_db_path):
return legacy_db_path
# Default to specific for new databases
return specific_db_path
def get_all_user_ids() -> List[str]:
"""

View File

@@ -14,6 +14,8 @@ from loguru import logger
from services.database import get_user_db_path
from dotenv import load_dotenv
class GSCService:
"""Service for Google Search Console integration."""
@@ -31,10 +33,62 @@ class GSCService:
services_dir = os.path.dirname(__file__)
backend_dir = os.path.abspath(os.path.join(services_dir, os.pardir))
self.credentials_file = os.path.join(backend_dir, "gsc_credentials.json")
logger.info(f"GSC credentials file path set to: {self.credentials_file}")
# Load client config from file or environment variables
self.client_config = self._load_client_config()
if self.client_config:
logger.info("GSC client configuration loaded successfully")
else:
logger.warning(f"GSC credentials not found in {self.credentials_file} or environment variables")
self.scopes = ['https://www.googleapis.com/auth/webmasters.readonly']
# Note: Tables are initialized lazily per user
logger.info("GSC Service initialized successfully")
def _load_client_config(self) -> Optional[Dict[str, Any]]:
"""Load Google client configuration from environment variables or file."""
# Reload environment variables to catch any runtime changes (e.g. .env updates)
load_dotenv(override=True)
# 1. Check Environment Variables (Priority)
client_id = os.getenv("GOOGLE_CLIENT_ID")
client_secret = os.getenv("GOOGLE_CLIENT_SECRET")
if client_id and client_secret:
redirect_uri = os.getenv('GSC_REDIRECT_URI', 'http://localhost:8000/gsc/callback')
logger.info("Loading GSC credentials from environment variables")
# Construct the config dictionary expected by google_auth_oauthlib
return {
"web": {
"client_id": client_id,
"client_secret": client_secret,
"project_id": os.getenv("GOOGLE_PROJECT_ID", "alwrity"),
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"redirect_uris": [
"http://localhost:5173/onboarding",
redirect_uri
],
"javascript_origins": [
"http://localhost:5173",
"http://localhost:8000"
]
}
}
# 2. Fallback to File
if os.path.exists(self.credentials_file):
try:
with open(self.credentials_file, 'r') as f:
config = json.load(f)
logger.info(f"Loading GSC credentials from file: {self.credentials_file}")
return config
except Exception as e:
logger.warning(f"Failed to load GSC credentials from file: {e}")
return None
def _get_db_path(self, user_id: str) -> str:
return get_user_db_path(user_id)
@@ -94,11 +148,11 @@ class GSCService:
self._init_gsc_tables(user_id)
db_path = self._get_db_path(user_id)
# Read client credentials from file to ensure we have all required fields
with open(self.credentials_file, 'r') as f:
client_config = json.load(f)
if not self.client_config:
logger.error("Cannot save credentials: Client configuration not loaded")
return False
web_config = client_config.get('web', {})
web_config = self.client_config.get('web', {})
credentials_json = json.dumps({
'token': credentials.token,
@@ -184,12 +238,17 @@ class GSCService:
try:
logger.info(f"Generating OAuth URL for user: {user_id}")
if not os.path.exists(self.credentials_file):
raise FileNotFoundError(f"GSC credentials file not found: {self.credentials_file}")
# Retry loading config if missing (in case .env was added later)
if not self.client_config:
self.client_config = self._load_client_config()
if not self.client_config:
raise FileNotFoundError("GSC credentials not found in file or environment variables.")
redirect_uri = os.getenv('GSC_REDIRECT_URI', 'http://localhost:8000/gsc/callback')
flow = Flow.from_client_secrets_file(
self.credentials_file,
flow = Flow.from_client_config(
self.client_config,
scopes=self.scopes,
redirect_uri=redirect_uri
)
@@ -256,8 +315,12 @@ class GSCService:
conn.commit()
# Exchange code for credentials
flow = Flow.from_client_secrets_file(
self.credentials_file,
if not self.client_config:
logger.error("Cannot handle callback: Client configuration not loaded")
return False
flow = Flow.from_client_config(
self.client_config,
scopes=self.scopes,
redirect_uri=os.getenv('GSC_REDIRECT_URI', 'http://localhost:8000/gsc/callback')
)
@@ -283,7 +346,11 @@ class GSCService:
service = build('searchconsole', 'v1', credentials=credentials)
logger.info(f"Authenticated GSC service created for user: {user_id}")
return service
except ValueError as e:
# Log as warning only, as this is expected for unconnected users
# logger.warning(f"Cannot create GSC service for user {user_id}: {e}")
raise e
except Exception as e:
logger.error(f"Error creating authenticated GSC service for user {user_id}: {e}")
raise
@@ -291,7 +358,13 @@ class GSCService:
def get_site_list(self, user_id: str) -> List[Dict[str, Any]]:
"""Get list of sites from GSC."""
try:
service = self.get_authenticated_service(user_id)
try:
service = self.get_authenticated_service(user_id)
except ValueError:
# User not connected or credentials invalid
logger.warning(f"User {user_id} not connected to GSC. Returning empty site list.")
return []
sites = service.sites().list().execute()
site_list = []
@@ -306,7 +379,8 @@ class GSCService:
except Exception as e:
logger.error(f"Error getting site list for user {user_id}: {e}")
raise
# Return empty list instead of raising to prevent frontend 500s
return []
def get_search_analytics(self, user_id: str, site_url: str,
start_date: str = None, end_date: str = None) -> Dict[str, Any]:
@@ -325,7 +399,12 @@ class GSCService:
logger.info(f"Returning cached analytics data for user: {user_id}")
return cached_data
service = self.get_authenticated_service(user_id)
try:
service = self.get_authenticated_service(user_id)
except ValueError:
logger.warning(f"User {user_id} not connected to GSC. Returning empty analytics.")
return {'error': 'User not connected to GSC', 'rows': [], 'rowCount': 0}
if not service:
logger.error(f"Failed to get authenticated GSC service for user: {user_id}")
return {'error': 'Authentication failed', 'rows': [], 'rowCount': 0}
@@ -359,11 +438,11 @@ class GSCService:
logger.error(f"GSC Data verification failed for user {user_id}: {verification_error}")
return {'error': f'Data verification failed: {str(verification_error)}', 'rows': [], 'rowCount': 0}
# Step 2: Get overall metrics (no dimensions)
# Step 2: Get daily metrics for charting (ensure we have rows)
request = {
'startDate': start_date,
'endDate': end_date,
'dimensions': [], # No dimensions for overall metrics
'dimensions': ['date'], # Use date dimension to get time-series data
'rowLimit': 1000
}
@@ -472,7 +551,11 @@ class GSCService:
def revoke_user_access(self, user_id: str) -> bool:
"""Revoke user's GSC access."""
try:
with sqlite3.connect(self.db_path) as conn:
db_path = self._get_db_path(user_id)
if not os.path.exists(db_path):
return True
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
# Delete credentials
@@ -496,7 +579,11 @@ class GSCService:
def clear_incomplete_credentials(self, user_id: str) -> bool:
"""Clear incomplete GSC credentials that are missing required fields."""
try:
with sqlite3.connect(self.db_path) as conn:
db_path = self._get_db_path(user_id)
if not os.path.exists(db_path):
return True
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
cursor.execute('DELETE FROM gsc_credentials WHERE user_id = ?', (user_id,))
conn.commit()
@@ -511,7 +598,11 @@ class GSCService:
def _get_cached_data(self, user_id: str, site_url: str, data_type: str, cache_key: str) -> Optional[Dict]:
"""Get cached data if not expired."""
try:
with sqlite3.connect(self.db_path) as conn:
db_path = self._get_db_path(user_id)
if not os.path.exists(db_path):
return None
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
cursor.execute('''
SELECT data_json FROM gsc_data_cache
@@ -531,9 +622,12 @@ class GSCService:
def _cache_data(self, user_id: str, site_url: str, data_type: str, data: Dict, cache_key: str):
"""Cache data with expiration."""
try:
self._init_gsc_tables(user_id)
db_path = self._get_db_path(user_id)
expires_at = datetime.now() + timedelta(hours=1) # Cache for 1 hour
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
cursor.execute('''
INSERT OR REPLACE INTO gsc_data_cache

View File

@@ -24,7 +24,16 @@ class WordPressOAuthService:
# WordPress.com OAuth2 credentials
self.client_id = os.getenv('WORDPRESS_CLIENT_ID', '')
self.client_secret = os.getenv('WORDPRESS_CLIENT_SECRET', '')
self.redirect_uri = os.getenv('WORDPRESS_REDIRECT_URI', 'https://alwrity-ai.vercel.app/wp/callback')
# Determine redirect URI dynamically
default_redirect = 'https://alwrity-ai.vercel.app/wp/callback'
frontend_url = os.getenv('FRONTEND_URL')
if frontend_url:
self.redirect_uri = f"{frontend_url.rstrip('/')}/wp/callback"
else:
self.redirect_uri = os.getenv('WORDPRESS_REDIRECT_URI', default_redirect)
self.base_url = "https://public-api.wordpress.com"
# Validate configuration

View File

@@ -17,8 +17,7 @@ from .core_agent_framework import (
# Market signal detection
from .market_signal_detector import (
MarketSignal,
MarketSignalDetector,
MarketTrendAnalyzer
MarketSignalDetector
)
# Performance monitoring

View File

@@ -105,6 +105,18 @@ class ALwrityAgentOrchestrator:
def _create_specialized_agents(self):
"""Create specialized marketing agents"""
try:
# Check if onboarding is complete before initializing heavy agents
try:
from services.onboarding.progress_service import OnboardingProgressService
onboarding_service = OnboardingProgressService()
status = onboarding_service.get_onboarding_status(self.user_id)
if not status.get("is_completed", False):
logger.info(f"Skipping agent initialization for user {self.user_id} - Onboarding incomplete")
return
except Exception as e:
logger.warning(f"Could not check onboarding status for {self.user_id}: {e}")
# Fallthrough to attempt initialization if check fails
enabled_by_key = {}
db = None
try:
@@ -159,6 +171,26 @@ class ALwrityAgentOrchestrator:
self.trend_surfer_agent = TrendSurferAgent(intel_service, self.user_id)
self.agents['trend'] = self.trend_surfer_agent
# Content Guardian Agent
if enabled_by_key.get("content_guardian", True):
try:
from services.intelligence.sif_agents import ContentGuardianAgent
from services.intelligence.txtai_service import TxtaiIntelligenceService
# Initialize intelligence service if not already available
intel_service = TxtaiIntelligenceService(self.user_id)
# Initialize Content Guardian Agent
self.content_guardian_agent = ContentGuardianAgent(
intelligence_service=intel_service,
user_id=self.user_id,
sif_service=None # SIF service is optional/circular dependency handling
)
self.agents['guardian'] = self.content_guardian_agent
logger.info(f"Initialized ContentGuardianAgent for user {self.user_id}")
except Exception as e:
logger.error(f"Failed to initialize ContentGuardianAgent: {e}")
logger.info(f"Created {len(self.agents)} specialized agents for user {self.user_id}")
except Exception as e:

View File

@@ -0,0 +1,213 @@
import logging
import time
from datetime import datetime
from sqlalchemy import text
from services.database import get_session_for_user
from models.subscription_models import APIProvider, UsageSummary
from services.subscription import PricingService
logger = logging.getLogger(__name__)
def track_agent_usage_sync(user_id: str, model_name: str, prompt: str, response_text: str, duration: float):
"""
Synchronously track agent LLM usage.
This mimics the logic in llm_text_gen to ensure consistency and robustness.
"""
try:
# Detect provider
provider_enum = APIProvider.GEMINI # Default
actual_provider_name = "gemini"
model_lower = model_name.lower()
if "gemini" in model_lower:
provider_enum = APIProvider.GEMINI
actual_provider_name = "gemini"
elif "gpt" in model_lower or "openai" in model_lower or "mistral" in model_lower:
# HuggingFace/Mistral often mapped to gpt-oss or mistral
provider_enum = APIProvider.MISTRAL
actual_provider_name = "huggingface"
elif "claude" in model_lower or "anthropic" in model_lower:
provider_enum = APIProvider.ANTHROPIC
actual_provider_name = "anthropic"
logger.info(f"[AgentTracking] Tracking usage for user {user_id}, provider {actual_provider_name}, model {model_name}")
db = get_session_for_user(user_id)
if not db:
logger.error(f"[AgentTracking] Could not get database session for user {user_id}")
return
try:
# Estimate tokens
tokens_input = int(len(prompt.split()) * 1.3)
tokens_output = int(len(str(response_text).split()) * 1.3)
tokens_total = tokens_input + tokens_output
pricing = PricingService(db)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Get limits
limits = pricing.get_user_limits(user_id)
token_limit = 0
provider_key = provider_enum.value
if limits and limits.get('limits'):
token_limit = limits['limits'].get(f"{provider_key}_tokens", 0) or 0
# Check for existing record
check_query = text("SELECT COUNT(*) FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period")
record_count = db.execute(check_query, {'user_id': user_id, 'period': current_period}).scalar()
current_calls_before = 0
current_tokens_before = 0
if record_count and record_count > 0:
# Read current values
sql_query = text(f"""
SELECT {provider_key}_calls, {provider_key}_tokens
FROM usage_summaries
WHERE user_id = :user_id AND billing_period = :period
LIMIT 1
""")
result = db.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
if result:
current_calls_before = result[0] if result[0] is not None else 0
current_tokens_before = result[1] if result[1] is not None else 0
else:
# Create new summary
summary = UsageSummary(user_id=user_id, billing_period=current_period)
db.add(summary)
db.flush()
# Update calls
new_calls = current_calls_before + 1
update_calls_query = text(f"""
UPDATE usage_summaries
SET {provider_key}_calls = :new_calls
WHERE user_id = :user_id AND billing_period = :period
""")
db.execute(update_calls_query, {
'new_calls': new_calls,
'user_id': user_id,
'period': current_period
})
# Update tokens with limit check
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
projected_new_tokens = current_tokens_before + tokens_total
if token_limit > 0 and projected_new_tokens > token_limit:
new_tokens = token_limit
tokens_total = max(0, token_limit - current_tokens_before)
else:
new_tokens = projected_new_tokens
update_tokens_query = text(f"""
UPDATE usage_summaries
SET {provider_key}_tokens = :new_tokens
WHERE user_id = :user_id AND billing_period = :period
""")
db.execute(update_tokens_query, {
'new_tokens': new_tokens,
'user_id': user_id,
'period': current_period
})
else:
tokens_total = 0
# Calculate cost
try:
tracked_tokens_input = min(tokens_input, tokens_total)
tracked_tokens_output = max(0, tokens_total - tracked_tokens_input)
cost_info = pricing.calculate_api_cost(
provider=provider_enum,
model_name=model_name,
tokens_input=tracked_tokens_input,
tokens_output=tracked_tokens_output,
request_count=1
)
cost_total = cost_info.get('cost_total', 0.0) or 0.0
cost_input = cost_info.get('cost_input', 0.0) or 0.0
cost_output = cost_info.get('cost_output', 0.0) or 0.0
except Exception as e:
logger.error(f"[AgentTracking] Cost calculation failed: {e}")
cost_total = 0.0
cost_input = 0.0
cost_output = 0.0
# Insert into APIUsageLog
try:
log_query = text("""
INSERT INTO api_usage_logs (
user_id, provider, endpoint, method, model_used,
tokens_input, tokens_output, tokens_total,
cost_input, cost_output, cost_total,
response_time, status_code, billing_period,
timestamp, actual_provider_name
) VALUES (
:user_id, :provider, :endpoint, :method, :model_used,
:tokens_input, :tokens_output, :tokens_total,
:cost_input, :cost_output, :cost_total,
:response_time, :status_code, :billing_period,
:created_at, :actual_provider_name
)
""")
db.execute(log_query, {
'user_id': user_id,
'provider': provider_enum.name, # Use name (GEMINI) not value (gemini) for SQLAlchemy Enum
'endpoint': 'agent_action',
'method': 'GENERATE',
'model_used': model_name,
'tokens_input': tracked_tokens_input,
'tokens_output': tracked_tokens_output,
'tokens_total': tracked_tokens_input + tracked_tokens_output,
'cost_input': cost_input,
'cost_output': cost_output,
'cost_total': cost_total,
'response_time': duration,
'status_code': 200,
'billing_period': current_period,
'created_at': datetime.utcnow(),
'actual_provider_name': actual_provider_name
})
except Exception as log_e:
logger.error(f"[AgentTracking] Failed to insert usage log: {log_e}")
if cost_total > 0:
update_costs_query = text(f"""
UPDATE usage_summaries
SET {provider_key}_cost = COALESCE({provider_key}_cost, 0) + :cost,
total_cost = COALESCE(total_cost, 0) + :cost
WHERE user_id = :user_id AND billing_period = :period
""")
db.execute(update_costs_query, {
'cost': cost_total,
'user_id': user_id,
'period': current_period
})
# Update totals
update_totals_query = text("""
UPDATE usage_summaries
SET total_calls = COALESCE(total_calls, 0) + 1,
total_tokens = COALESCE(total_tokens, 0) + :tokens_total
WHERE user_id = :user_id AND billing_period = :period
""")
db.execute(update_totals_query, {
'tokens_total': tokens_total,
'user_id': user_id,
'period': current_period
})
db.commit()
logger.info(f"[AgentTracking] ✅ Usage tracked: {new_calls} calls, {cost_total} cost")
except Exception as e:
logger.error(f"[AgentTracking] Error tracking usage: {e}", exc_info=True)
db.rollback()
finally:
db.close()
except Exception as e:
logger.error(f"[AgentTracking] Top level error: {e}", exc_info=True)

View File

@@ -32,9 +32,64 @@ from services.database import get_session_for_user
from services.intelligence.monitoring.semantic_dashboard import RealTimeSemanticMonitor
from services.intelligence.agents.safety_framework import get_safety_framework
from services.agent_activity_service import AgentActivityService
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
import time
logger = get_service_logger(__name__)
class TrackingLLMWrapper:
"""
Wrapper for LLM instances to transparently track usage.
Intercepts calls to __call__ and generate() to log metrics.
"""
def __init__(self, llm: Any, user_id: str, model_name: str):
self.llm = llm
self.user_id = user_id
self.model_name = model_name
def __call__(self, prompt: str, *args, **kwargs) -> Any:
return self.generate(prompt, *args, **kwargs)
def generate(self, prompt: str, *args, **kwargs) -> str:
start_time = time.time()
try:
# Delegate to the underlying LLM
if hasattr(self.llm, "generate"):
response = self.llm.generate(prompt, *args, **kwargs)
else:
response = self.llm(prompt, *args, **kwargs)
# Handle response format (some might return list of dicts)
response_text = str(response)
if isinstance(response, list):
if response and isinstance(response[0], dict) and 'generated_text' in response[0]:
response_text = response[0]['generated_text']
else:
response_text = str(response[0])
# Track usage
duration = time.time() - start_time
try:
track_agent_usage_sync(
user_id=self.user_id,
model_name=self.model_name,
prompt=prompt,
response_text=response_text,
duration=duration
)
except Exception as e:
logger.warning(f"Failed to track agent usage in wrapper: {e}")
return response
except Exception as e:
logger.error(f"LLM generation failed in tracking wrapper: {e}")
raise e
def __getattr__(self, name):
# Delegate other attribute access to the underlying LLM
return getattr(self.llm, name)
@dataclass
class AgentAction:
"""Represents an action taken by an agent"""
@@ -114,6 +169,10 @@ class BaseALwrityAgent(ABC):
self.txtai_agent = None
self.llm = llm # Ensure llm is set if provided, regardless of txtai availability
# Wrap LLM with tracking if it exists
if self.llm:
self.llm = TrackingLLMWrapper(self.llm, self.user_id, self.model_name)
self.agent_key = self._resolve_agent_key(agent_type)
self._agent_profile = self._load_agent_profile_overrides()
self._prompt_context = self._load_prompt_context()
@@ -121,10 +180,17 @@ class BaseALwrityAgent(ABC):
if TXTAI_AVAILABLE:
try:
if not self.llm:
self.llm = LLM(model_name)
self.txtai_agent = self._create_txtai_agent()
logger.info(f"Initialized txtai agent for {agent_type} - {self.agent_id}")
# Create new LLM if not provided
raw_llm = LLM(model_name)
# Wrap it
self.llm = TrackingLLMWrapper(raw_llm, self.user_id, self.model_name)
try:
self.txtai_agent = self._create_txtai_agent()
logger.info(f"Initialized txtai agent for {agent_type} - {self.agent_id}")
except Exception as inner_e:
logger.warning(f"Could not initialize specific txtai agent for {agent_type}: {inner_e}")
self.txtai_agent = self._create_fallback_agent()
except Exception as e:
logger.error(f"Failed to initialize txtai agent for {agent_type}: {e}")
self.txtai_agent = self._create_fallback_agent()
@@ -134,6 +200,38 @@ class BaseALwrityAgent(ABC):
# Initialize safety framework
self.safety_framework = get_safety_framework(user_id)
async def _generate_llm_response(self, prompt: str) -> str:
"""
Helper to generate text using the agent's LLM with usage tracking.
Centralized method for all agents inheriting from BaseALwrityAgent.
"""
if not self.llm:
return "[LLM Unavailable]"
try:
# Run in executor to avoid blocking if LLM is synchronous
loop = asyncio.get_event_loop()
# Use the wrapped LLM's generate method (which handles tracking)
if hasattr(self.llm, "generate"):
response = await loop.run_in_executor(None, lambda: self.llm.generate(prompt))
else:
response = await loop.run_in_executor(None, lambda: self.llm(prompt))
# Handle list output (some models return list of dicts)
response_text = str(response)
if isinstance(response, list):
if response and isinstance(response[0], dict) and 'generated_text' in response[0]:
response_text = response[0]['generated_text']
else:
response_text = str(response[0])
return response_text
except Exception as e:
logger.error(f"LLM generation failed in agent {self.agent_type}: {e}")
return "[Generation Failed]"
def _resolve_agent_key(self, agent_type: str) -> str:
value = str(agent_type or "").strip()
if value.lower() == "strategyorchestrator".lower():

View File

@@ -758,6 +758,11 @@ async def get_agent_performance_summary(user_id: str, agent_id: str) -> Dict[str
"""Get comprehensive performance summary for an agent"""
return await performance_service.get_agent_performance_summary(user_id, agent_id)
async def get_all_agents_performance_summary(user_id: str) -> List[Dict[str, Any]]:
async def get_all_agents_performance_summary(user_id: str) -> List[Dict[str, Any]]:
"""Get performance summary for all agents for a user"""
return await performance_service.get_all_agents_performance_summary(user_id)
return await performance_service.get_all_agents_performance_summary(user_id)
# Alias for backward compatibility
PerformanceMonitor = AgentPerformanceMonitor
performance_monitor = performance_service
AgentPerformanceMetrics = AgentPerformanceSnapshot

View File

@@ -13,6 +13,7 @@ from loguru import logger
from ..txtai_service import TxtaiIntelligenceService
from services.intelligence.agents.core_agent_framework import BaseALwrityAgent, AgentAction
from services.seo_tools.content_strategy_service import ContentStrategyService
from services.intelligence.sif_agents import SharedLLMWrapper, LocalLLMWrapper
try:
from services.intelligence.sif_integration import SIFIntegrationService
SIF_AVAILABLE = True
@@ -20,14 +21,36 @@ except ImportError:
SIF_AVAILABLE = False
try:
from txtai import Agent, LLM
# Try importing from pipeline first (standard location)
from txtai.pipeline import Agent, LLM
TXTAI_AVAILABLE = True
except ImportError:
TXTAI_AVAILABLE = False
logger.warning("txtai not available, using fallback implementation")
try:
# Fallback to top-level import
from txtai import Agent, LLM
TXTAI_AVAILABLE = True
except ImportError:
TXTAI_AVAILABLE = False
Agent = None
LLM = None
logger.warning("txtai not available, using fallback implementation")
class SIFBaseAgent:
def __init__(self, intelligence_service: TxtaiIntelligenceService):
class SIFBaseAgent(BaseALwrityAgent):
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str, agent_type: str = "sif_agent", model_name: str = "Qwen/Qwen2.5-3B-Instruct", llm: Any = None):
# Hybrid LLM Strategy:
# 1. Shared LLM for external/high-quality generation
self.shared_llm = SharedLLMWrapper(user_id)
# 2. Local LLM for internal agent work (default for SIF agents)
if llm is None:
if TXTAI_AVAILABLE:
# Use Lazy Local LLM
llm = LocalLLMWrapper(model_name)
else:
# Fallback to Shared if txtai not available
llm = self.shared_llm
super().__init__(user_id, agent_type, model_name, llm)
self.intelligence = intelligence_service
def _log_agent_operation(self, operation: str, **kwargs):
@@ -36,9 +59,27 @@ class SIFBaseAgent:
if kwargs:
logger.debug(f"[{self.__class__.__name__}] Parameters: {kwargs}")
def _create_txtai_agent(self):
"""
SIF agents use the intelligence service directly, but we can expose
capabilities via a standard agent interface if needed.
"""
if not TXTAI_AVAILABLE or Agent is None:
return None
# Return a simple agent that can use the LLM
try:
return Agent(llm=self.llm, tools=[])
except Exception as e:
logger.warning(f"Failed to create txtai Agent: {e}")
return None
class StrategyArchitectAgent(SIFBaseAgent):
"""Agent for discovering content pillars and identifying strategic gaps."""
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str):
super().__init__(intelligence_service, user_id, agent_type="strategy_architect")
async def discover_pillars(self) -> List[Dict[str, Any]]:
"""Identify content pillars through semantic clustering."""
self._log_agent_operation("Discovering content pillars")
@@ -108,9 +149,61 @@ class ContentGuardianAgent(SIFBaseAgent):
CANNIBALIZATION_THRESHOLD = 0.85 # Similarity threshold for cannibalization warning
ORIGINALITY_THRESHOLD = 0.75 # Minimum originality score
def __init__(self, intelligence_service: TxtaiIntelligenceService, sif_service: Any = None):
super().__init__(intelligence_service)
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str, sif_service: Any = None):
super().__init__(intelligence_service, user_id, agent_type="content_guardian")
self.sif_service = sif_service
# Lazy initialization of SIF service if not provided
if self.sif_service is None and SIF_AVAILABLE:
try:
self.sif_service = SIFIntegrationService(user_id)
logger.info(f"[{self.__class__.__name__}] Lazily initialized SIFIntegrationService")
except Exception as e:
logger.warning(f"[{self.__class__.__name__}] Failed to lazily initialize SIF service: {e}")
async def assess_content_quality(self, content: str) -> Dict[str, Any]:
"""
Assess content quality based on originality, readability, and cannibalization risks.
"""
self._log_agent_operation("Assessing content quality", content_length=len(content))
try:
# 1. Check for cannibalization
cannibalization_result = await self.check_cannibalization(content)
# 2. Check originality (if not cannibalized)
originality_score = 1.0
if not cannibalization_result.get("warning"):
originality_result = await self.verify_originality(content, None)
originality_score = originality_result.get("originality_score", 1.0)
# 3. Check Style Compliance
style_result = await self.style_enforcer(content)
style_score = style_result.get("compliance_score", 1.0)
# 4. Basic Readability (Flesch-Kincaid proxy via sentence length/word complexity)
# Simple heuristic for now
words = content.split()
sentences = content.split('.')
avg_sentence_length = len(words) / max(1, len(sentences))
readability_score = 1.0 if avg_sentence_length < 20 else max(0.5, 1.0 - (avg_sentence_length - 20) * 0.05)
# Weighted Score: Originality (40%) + Style (30%) + Readability (30%)
quality_score = (originality_score * 0.4) + (style_score * 0.3) + (readability_score * 0.3)
return {
"quality_score": quality_score,
"originality_score": originality_score,
"readability_score": readability_score,
"style_score": style_score,
"cannibalization_risk": cannibalization_result,
"style_compliance": style_result,
"is_acceptable": quality_score > 0.7 and not cannibalization_result.get("warning", False)
}
except Exception as e:
logger.error(f"[{self.__class__.__name__}] Failed to assess content quality: {e}")
return {"error": str(e), "quality_score": 0.0}
async def check_cannibalization(self, new_draft: str) -> Dict[str, Any]:
"""Check if a new draft competes semantically with existing pages."""
@@ -193,25 +286,74 @@ class ContentGuardianAgent(SIFBaseAgent):
# 1. Fetch Style Guidelines from SIF if not provided
if not style_guidelines and self.sif_service:
try:
# Search for website analysis to get brand voice/style
# We assume the most relevant 'website_analysis' doc contains the guidelines
results = await self.intelligence.search("website analysis brand voice style", limit=1)
if results:
import json
res = results[0]
metadata_str = res.get('object')
metadata = json.loads(metadata_str) if isinstance(metadata_str, str) else (metadata_str or res)
# Use central SIF service to get robust context
seo_context = await self.sif_service.get_seo_context()
if seo_context and "error" not in seo_context:
# Extract brand voice/style from the context
# The context structure is normalized in get_seo_context
if metadata.get('type') == 'website_analysis':
report = metadata.get('full_report', {})
style_guidelines = {
"tone": report.get('brand_analysis', {}).get('brand_voice', 'neutral'),
"style_patterns": report.get('style_patterns', {}),
"writing_style": report.get('writing_style', {})
}
logger.info(f"[{self.__class__.__name__}] Retrieved style guidelines from SIF: {style_guidelines.get('tone')}")
# Note: get_seo_context returns a flattened dict.
# We need to dig into the original structure if available, or rely on what's mapped.
# However, get_seo_context maps 'seo_audit', 'sitemap_analysis', etc.
# Brand info is usually in 'brand_analysis' col of WebsiteAnalysis, which might not be fully exposed
# in the simplified get_seo_context return.
# Let's check if we can get the full object or if we need to expand get_seo_context.
# For now, we'll try to use what's there or fall back to a specific search if needed.
# Actually, looking at get_seo_context implementation:
# It returns 'seo_audit', 'crawl_result'.
# Brand analysis is often stored in WebsiteAnalysis.brand_analysis.
# We might need to extend get_seo_context or do a specific retrieval here.
# But wait! I saw get_seo_context implementation earlier:
# It retrieves the "full_report" from the SIF metadata.
# If the SIF index contains the full WebsiteAnalysis object, we are good.
# Let's try to get it from the full report if we can access it,
# but get_seo_context returns a filtered dict.
# Alternative: Use the robust retrieval logic but specifically for brand info if get_seo_context is too narrow.
# But get_seo_context logic includes "website analysis seo audit" query.
# Let's assume for now we use the same retrieval logic but locally adapted,
# OR better, trust get_seo_context to be the single point of truth.
# If get_seo_context doesn't return brand info, we should update IT, not hack here.
# But I can't update SIFIntegrationService right now without context switch.
# Let's stick to the previous manual search pattern BUT use the SIF service helper if possible.
# Actually, the previous code was:
# results = await self.intelligence.search("website analysis brand voice style", limit=1)
# Let's keep it simple and robust:
# Try to get it from SIF service if possible.
# Since get_seo_context might not return brand_voice directly, let's try to see if we can use it.
# Actually, let's use the manual search but with better error handling,
# mirroring get_seo_context's robustness (e.g. parsing).
results = await self.intelligence.search("website analysis brand voice style", limit=1)
if results:
res = results[0]
metadata_str = res.get('object')
metadata = json.loads(metadata_str) if isinstance(metadata_str, str) else (metadata_str or res)
if metadata.get('type') == 'website_analysis':
report = metadata.get('full_report', {})
# Support both flat and nested structures
brand_analysis = report.get('brand_analysis') or report.get('brand_voice', {})
if isinstance(brand_analysis, str):
# Handle case where it might be a JSON string
try: brand_analysis = json.loads(brand_analysis)
except: brand_analysis = {"brand_voice": brand_analysis}
style_guidelines = {
"tone": brand_analysis.get('brand_voice', 'neutral') if isinstance(brand_analysis, dict) else 'neutral',
"style_patterns": report.get('style_patterns', {}),
"writing_style": report.get('writing_style', {})
}
logger.info(f"[{self.__class__.__name__}] Retrieved style guidelines from SIF index")
except Exception as e:
logger.warning(f"[{self.__class__.__name__}] Failed to retrieve style guidelines from SIF: {e}")
logger.warning(f"[{self.__class__.__name__}] Failed to retrieve style guidelines: {e}")
issues = []
score = 1.0
@@ -246,6 +388,55 @@ class ContentGuardianAgent(SIFBaseAgent):
logger.error(f"[{self.__class__.__name__}] Style enforcement failed: {e}")
return {"error": str(e)}
async def perform_site_audit(self, website_url: str, limit: int = 10) -> Dict[str, Any]:
"""
Perform a quality audit on the user's website content.
"""
self._log_agent_operation("Performing site audit", website_url=website_url)
try:
# 1. Retrieve recent content for the site from SIF
# We search for everything with the website_url in metadata
# Note: This depends on how data is indexed.
results = await self.intelligence.search(f"site:{website_url}", limit=limit)
if not results:
logger.info(f"[{self.__class__.__name__}] No content found for site audit")
return {"error": "No content found"}
audit_results = []
total_quality = 0.0
for res in results:
text = res.get('text', '')
if not text or len(text) < 100:
continue
quality = await self.assess_content_quality(text)
audit_results.append({
"id": res.get('id'),
"title": res.get('title', 'Unknown'),
"quality": quality
})
total_quality += quality.get('quality_score', 0.0)
avg_quality = total_quality / len(audit_results) if audit_results else 0.0
report = {
"website_url": website_url,
"pages_audited": len(audit_results),
"average_quality_score": avg_quality,
"details": audit_results,
"timestamp": datetime.utcnow().isoformat()
}
logger.info(f"[{self.__class__.__name__}] Site audit completed. Avg Quality: {avg_quality:.2f}")
return report
except Exception as e:
logger.error(f"[{self.__class__.__name__}] Site audit failed: {e}")
return {"error": str(e)}
async def safety_filter(self, text: str) -> Dict[str, Any]:
"""
Tool: Flags potentially harmful, offensive, or sensitive content.
@@ -290,8 +481,8 @@ class LinkGraphAgent(SIFBaseAgent):
RELEVANCE_THRESHOLD = 0.6 # Minimum relevance score for link suggestions
MAX_SUGGESTIONS = 10 # Maximum number of link suggestions
def __init__(self, intelligence_service: TxtaiIntelligenceService, sif_service: Any = None):
super().__init__(intelligence_service)
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str, sif_service: Any = None):
super().__init__(intelligence_service, user_id, agent_type="link_graph")
self.sif_service = sif_service
async def suggest_internal_links(self, draft: str) -> List[Dict[str, Any]]:
@@ -823,9 +1014,10 @@ class ContentStrategyAgent(BaseALwrityAgent):
Maintain the original meaning and tone.
"""
if hasattr(self.llm, "generate"):
if self.llm:
# We assume the LLM returns JSON-like text or we parse it
response = self.llm.generate(f"{system_prompt}\n\nText to rewrite:\n{content}")
response = await self._generate_llm_response(f"{system_prompt}\n\nText to rewrite:\n{content}")
# Simple parsing fallback if LLM returns raw text
if isinstance(response, str) and not response.strip().startswith("{"):
optimized_content = response
@@ -1456,34 +1648,7 @@ class SEOOptimizationAgent(BaseALwrityAgent):
"timestamp": datetime.utcnow().isoformat()
}
async def _generate_llm_response(self, prompt: str) -> str:
"""Helper to generate text using the agent's LLM"""
if not self.llm:
return "[LLM Unavailable]"
try:
# Run in executor to avoid blocking if LLM is synchronous
loop = asyncio.get_event_loop()
# Check if LLM is a txtai pipeline (callable) or has generate method
if hasattr(self.llm, "generate"):
# Some txtai pipelines use generate, some are just called
response = await loop.run_in_executor(None, lambda: self.llm.generate(prompt))
else:
# Assume callable (standard txtai pipeline)
response = await loop.run_in_executor(None, lambda: self.llm(prompt))
# Handle list output (some models return list of dicts)
if isinstance(response, list):
if response and isinstance(response[0], dict) and 'generated_text' in response[0]:
return response[0]['generated_text']
return str(response[0])
return str(response)
except Exception as e:
logger.error(f"LLM generation failed: {e}")
return "[Generation Failed]"
async def _strategy_generator_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
"""SEO strategy generation tool"""
audit_results = context.get("audit_results", {})
@@ -1629,8 +1794,8 @@ class SocialAmplificationAgent(BaseALwrityAgent):
Return ONLY the adapted content.
"""
if hasattr(self.llm, "generate"):
adapted_content = self.llm.generate(prompt)
if self.llm:
adapted_content = await self._generate_llm_response(prompt)
else:
adapted_content = f"[Mock {platform}]: {content[:50]}... #adapted"

View File

@@ -19,7 +19,7 @@ class TrendSurferAgent(SIFBaseAgent):
"""
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str):
super().__init__(intelligence_service)
super().__init__(intelligence_service, user_id, agent_type="trend_surfer")
self.user_id = user_id
self.signal_detector = MarketSignalDetector(user_id)
self.trends_service = GoogleTrendsService()
@@ -148,15 +148,41 @@ class TrendSurferAgent(SIFBaseAgent):
else:
recommendation = "Create new content"
# Use LLM to generate creative angle
headline = f"Trend: {trend.description}"
angle = f"Leverage {trend.source} trend on {trend.related_topics[0] if trend.related_topics else 'topic'}"
try:
prompt = f"""
Analyze this market trend signal and propose a content angle:
Trend: {trend.description}
Related Topics: {', '.join(trend.related_topics)}
Impact Score: {trend.impact_score}
Recommendation: {recommendation}
Provide a catchy headline and a 1-sentence strategic angle.
Format: Headline | Angle
"""
response = await self._generate_llm_response(prompt)
if response and "|" in response:
parts = response.split('|')
headline = parts[0].strip()
angle = parts[1].strip()
elif response:
angle = response.strip()
except Exception as e:
logger.warning(f"[{self.__class__.__name__}] LLM generation failed for opportunity: {e}")
return {
"trend_id": trend.signal_id,
"topic": trend.description,
"headline": headline,
"source": trend.source,
"urgency": trend.urgency_level.value,
"impact_score": trend.impact_score,
"current_coverage": coverage_score,
"recommendation": recommendation,
"suggested_angle": f"Leverage {trend.source} trend on {trend.related_topics[0] if trend.related_topics else 'topic'}",
"suggested_angle": angle,
"detected_at": trend.detected_at
}

View File

@@ -5,13 +5,76 @@ Each agent leverages TxtaiIntelligenceService for semantic operations.
"""
import traceback
import json
import asyncio
from typing import List, Dict, Any, Optional
from datetime import datetime
from loguru import logger
from .txtai_service import TxtaiIntelligenceService
from .txtai_service import TxtaiIntelligenceService, TXTAI_AVAILABLE
from services.intelligence.agents.core_agent_framework import BaseALwrityAgent
from services.llm_providers.main_text_generation import llm_text_gen
class SIFBaseAgent:
def __init__(self, intelligence_service: TxtaiIntelligenceService):
# Optional txtai imports
try:
from txtai.pipeline import Agent, LLM
except ImportError:
Agent = None
LLM = None
class SharedLLMWrapper:
"""Wraps the shared ALwrity LLM service to look like a txtai LLM."""
def __init__(self, user_id: str):
self.user_id = user_id
def generate(self, prompt: str, **kwargs) -> str:
"""Generate text using the shared LLM provider."""
# We ignore kwargs like 'max_tokens' as llm_text_gen handles defaults,
# but we could map them if needed.
return llm_text_gen(prompt, user_id=self.user_id)
def __call__(self, prompt: str, **kwargs) -> str:
return self.generate(prompt, **kwargs)
class LocalLLMWrapper:
"""
Lazily loads a local LLM via txtai.
This prevents blocking server startup with heavy model loads.
"""
def __init__(self, model_path: str):
self.model_path = model_path
self._llm = None
@property
def llm(self):
if self._llm is None:
if LLM is None:
raise ImportError("txtai.pipeline.LLM is not available")
logger.info(f"Loading local LLM: {self.model_path}")
self._llm = LLM(path=self.model_path)
return self._llm
def __call__(self, prompt: str, **kwargs) -> str:
return self.llm(prompt, **kwargs)
def generate(self, prompt: str, **kwargs) -> str:
return self.llm(prompt, **kwargs)
class SIFBaseAgent(BaseALwrityAgent):
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str, agent_type: str = "sif_agent", model_name: str = "Qwen/Qwen2.5-3B-Instruct", llm: Any = None):
# Hybrid LLM Strategy:
# 1. Shared LLM for external/high-quality generation (available to all agents)
self.shared_llm = SharedLLMWrapper(user_id)
# 2. Local LLM for internal agent work (default for SIF agents)
if llm is None:
if TXTAI_AVAILABLE:
# Use Lazy Local LLM
llm = LocalLLMWrapper(model_name)
else:
# Fallback to Shared if txtai not available
llm = self.shared_llm
super().__init__(user_id, agent_type, model_name, llm)
self.intelligence = intelligence_service
def _log_agent_operation(self, operation: str, **kwargs):
@@ -20,9 +83,23 @@ class SIFBaseAgent:
if kwargs:
logger.debug(f"[{self.__class__.__name__}] Parameters: {kwargs}")
def _create_txtai_agent(self):
"""
SIF agents use the intelligence service directly, but we can expose
capabilities via a standard agent interface if needed.
"""
if not TXTAI_AVAILABLE:
return None
# Return a simple agent that can use the LLM
return Agent(llm=self.llm, tools=[])
class StrategyArchitectAgent(SIFBaseAgent):
"""Agent for discovering content pillars and identifying strategic gaps."""
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str):
super().__init__(intelligence_service, user_id, agent_type="strategy_architect")
async def discover_pillars(self) -> List[Dict[str, Any]]:
"""Identify content pillars through semantic clustering."""
self._log_agent_operation("Discovering content pillars")
@@ -58,6 +135,61 @@ class StrategyArchitectAgent(SIFBaseAgent):
logger.error(f"[{self.__class__.__name__}] Failed to discover pillars: {e}")
logger.error(f"[{self.__class__.__name__}] Full traceback: {traceback.format_exc()}")
return []
async def analyze_content_strategy(self, website_data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Analyze content strategy based on website data and semantic insights.
Args:
website_data: Dictionary containing website analysis data
Returns:
List of strategic recommendations
"""
self._log_agent_operation("Analyzing content strategy")
try:
recommendations = []
# 1. Discover existing pillars
pillars = await self.discover_pillars()
# 2. Analyze gaps based on pillars (simplified logic for now)
if not pillars:
recommendations.append({
"type": "strategy_gap",
"priority": "high",
"title": "Establish Core Content Pillars",
"description": "No clear content clusters found. Focus on defining 3-5 core topics to build authority."
})
else:
# Suggest strengthening weak pillars
for pillar in pillars:
if pillar['size'] < 3:
recommendations.append({
"type": "content_depth",
"priority": "medium",
"title": f"Strengthen Pillar {pillar['pillar_id']}",
"description": "This topic cluster has few articles. Create more content to establish authority.",
"pillar_id": pillar['pillar_id']
})
# 3. Add generic recommendations based on website data if available
if website_data:
if not website_data.get('description'):
recommendations.append({
"type": "metadata",
"priority": "high",
"title": "Missing Meta Description",
"description": "Website is missing a meta description. Add one to improve SEO CTR."
})
logger.info(f"[{self.__class__.__name__}] Generated {len(recommendations)} strategic recommendations")
return recommendations
except Exception as e:
logger.error(f"[{self.__class__.__name__}] Failed to analyze content strategy: {e}")
return []
def _calculate_cluster_confidence(self, cluster_indices: List[int]) -> float:
"""Calculate confidence score for a cluster based on its size and coherence."""
@@ -92,10 +224,40 @@ class ContentGuardianAgent(SIFBaseAgent):
CANNIBALIZATION_THRESHOLD = 0.85 # Similarity threshold for cannibalization warning
ORIGINALITY_THRESHOLD = 0.75 # Minimum originality score
def __init__(self, intelligence_service: TxtaiIntelligenceService, sif_service: Any = None):
super().__init__(intelligence_service)
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str, sif_service: Any = None):
super().__init__(intelligence_service, user_id, agent_type="content_guardian")
self.sif_service = sif_service
async def assess_content_quality(self, website_data: Dict[str, Any]) -> Dict[str, Any]:
"""Assess overall content quality based on website data."""
self._log_agent_operation("Assessing content quality")
try:
# Extract sample text or description from website_data
text_to_analyze = website_data.get('description', '') or website_data.get('title', '')
if not text_to_analyze:
return {"score": 0.5, "reason": "No content to analyze"}
# Run style check
style_result = await self.style_enforcer(text_to_analyze)
# Run safety check
safety_result = await self.safety_filter(text_to_analyze)
# Calculate aggregate score
base_score = style_result.get('compliance_score', 0.8)
if safety_result.get('action') == 'flag_for_review':
base_score *= 0.5
return {
"score": base_score,
"style_analysis": style_result,
"safety_analysis": safety_result,
"analyzed_text_length": len(text_to_analyze)
}
except Exception as e:
logger.error(f"[{self.__class__.__name__}] Quality assessment failed: {e}")
return {"score": 0.0, "error": str(e)}
async def check_cannibalization(self, new_draft: str) -> Dict[str, Any]:
"""Check if a new draft competes semantically with existing pages."""
self._log_agent_operation("Checking for semantic cannibalization", draft_length=len(new_draft))
@@ -274,8 +436,8 @@ class LinkGraphAgent(SIFBaseAgent):
RELEVANCE_THRESHOLD = 0.6 # Minimum relevance score for link suggestions
MAX_SUGGESTIONS = 10 # Maximum number of link suggestions
def __init__(self, intelligence_service: TxtaiIntelligenceService, sif_service: Any = None):
super().__init__(intelligence_service)
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str, sif_service: Any = None):
super().__init__(intelligence_service, user_id, agent_type="link_graph")
self.sif_service = sif_service
async def suggest_internal_links(self, draft: str) -> List[Dict[str, Any]]:
@@ -479,6 +641,9 @@ class CitationExpert(SIFBaseAgent):
EVIDENCE_THRESHOLD = 0.7 # Minimum relevance score for evidence
MAX_EVIDENCE = 5 # Maximum number of evidence pieces to return
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str):
super().__init__(intelligence_service, user_id, agent_type="citation_expert")
async def fact_checker(self, claim: str) -> List[Dict[str, Any]]:
"""
Tool: Verifies facts against trusted research data.
@@ -542,60 +707,25 @@ class CitationExpert(SIFBaseAgent):
"claim": claim,
"status": status,
"evidence_count": len(evidence),
"top_evidence": evidence[0]['source'] if evidence else None
"top_evidence": evidence[0] if evidence else None
})
return {
"status": "verification_complete",
"total_claims": len(claims),
"status": "completed",
"verified_claims": verified_results,
"unsupported_count": len([c for c in verified_results if c['status'] == 'unsupported']),
"timestamp": datetime.utcnow().isoformat()
"verification_score": len([c for c in verified_results if c['status'] == 'supported']) / len(verified_results)
}
async def verify_facts(self, claim: str) -> List[Dict[str, Any]]:
"""Find supporting or contradicting evidence in the indexed research."""
self._log_agent_operation("Verifying facts", claim_length=len(claim))
"""Verify a single claim against intelligence data."""
results = await self.intelligence.search(claim, limit=3)
try:
if not self.intelligence.is_initialized():
logger.error(f"[{self.__class__.__name__}] Intelligence service not initialized")
return []
if not claim or len(claim.strip()) < 20:
logger.warning(f"[{self.__class__.__name__}] Claim too short for meaningful verification")
return []
results = await self.intelligence.search(claim, limit=self.MAX_EVIDENCE)
if not results:
logger.info(f"[{self.__class__.__name__}] No evidence found for claim")
return []
evidence = []
for result in results:
relevance_score = result.get('score', 0.0)
if relevance_score >= self.EVIDENCE_THRESHOLD:
evidence_piece = {
"source": result.get('id', 'unknown'),
"relevance": relevance_score,
"confidence": self._calculate_evidence_confidence(relevance_score),
"type": "supporting" if relevance_score > 0.8 else "related",
"excerpt": result.get('text', '')[:200] + "..." if len(result.get('text', '')) > 200 else result.get('text', '')
}
evidence.append(evidence_piece)
logger.debug(f"[{self.__class__.__name__}] Found evidence: {evidence_piece['source']} (score: {relevance_score:.3f})")
logger.info(f"[{self.__class__.__name__}] Found {len(evidence)} pieces of evidence for claim")
return evidence
except Exception as e:
logger.error(f"[{self.__class__.__name__}] Failed to verify facts: {e}")
logger.error(f"[{self.__class__.__name__}] Full traceback: {traceback.format_exc()}")
return []
def _calculate_evidence_confidence(self, relevance_score: float) -> float:
"""Calculate confidence score for evidence."""
# Simple confidence based on relevance score
return min(1.0, relevance_score * 1.2)
evidence = []
for result in results:
if result.get('score', 0) > self.EVIDENCE_THRESHOLD:
evidence.append({
"text": result.get('text'),
"source": result.get('id'),
"confidence": result.get('score')
})
return evidence

View File

@@ -938,14 +938,14 @@ class SIFIntegrationService:
# Strategic recommendations (lazy initialization to avoid circular imports)
if not self.strategy_agent:
from .sif_agents import StrategyArchitectAgent
self.strategy_agent = StrategyArchitectAgent(self.intelligence_service)
self.strategy_agent = StrategyArchitectAgent(self.intelligence_service, user_id=self.user_id)
recommendations = await self.strategy_agent.analyze_content_strategy(website_data)
insights["strategic_recommendations"] = recommendations
# Content quality assessment (lazy initialization to avoid circular imports)
if not self.guardian_agent:
from .sif_agents import ContentGuardianAgent
self.guardian_agent = ContentGuardianAgent(self.intelligence_service, sif_service=self)
self.guardian_agent = ContentGuardianAgent(self.intelligence_service, user_id=self.user_id, sif_service=self)
quality_score = await self.guardian_agent.assess_content_quality(website_data)
insights["content_quality"] = quality_score

View File

@@ -33,7 +33,13 @@ class TxtaiIntelligenceService:
self._initialized = False
self.enable_caching = enable_caching
self.cache_manager = semantic_cache_manager if enable_caching else None
self._initialize_embeddings()
# Lazy initialization - do not initialize embeddings on startup
# self._initialize_embeddings()
def _ensure_initialized(self):
"""Lazy initialization helper."""
if not self._initialized:
self._initialize_embeddings()
def _initialize_embeddings(self):
"""Initialize txtai embeddings with local storage support and comprehensive error handling."""
@@ -106,6 +112,7 @@ class TxtaiIntelligenceService:
Args:
items: List of (id, text, metadata) tuples.
"""
self._ensure_initialized()
if not self._initialized or not self.embeddings:
logger.error(f"Cannot index content - service not initialized for user {self.user_id}")
return
@@ -145,6 +152,7 @@ class TxtaiIntelligenceService:
async def search(self, query: str, limit: int = 5) -> List[Dict[str, Any]]:
"""Perform semantic search with intelligent caching."""
self._ensure_initialized()
if not self._initialized or not self.embeddings:
logger.error(f"Cannot perform search - service not initialized for user {self.user_id}")
return []
@@ -186,6 +194,7 @@ class TxtaiIntelligenceService:
async def get_similarity(self, text1: str, text2: str) -> float:
"""Get semantic similarity between two texts with caching."""
self._ensure_initialized()
if not self._initialized or not self.embeddings:
logger.error(f"Cannot calculate similarity - service not initialized for user {self.user_id}")
return 0.0
@@ -234,6 +243,7 @@ class TxtaiIntelligenceService:
async def cluster(self, min_score: float = 0.5) -> List[List[int]]:
"""Cluster indexed content to find semantic pillars using graph-based clustering with caching."""
self._ensure_initialized()
if not self._initialized or not self.embeddings:
logger.error(f"Cannot cluster content - service not initialized for user {self.user_id}")
return []
@@ -358,6 +368,7 @@ class TxtaiIntelligenceService:
async def classify(self, text: str, labels: List[str]) -> List[Tuple[str, float]]:
"""Classify text using zero-shot classification."""
self._ensure_initialized()
if not self._initialized or not Labels:
logger.error(f"Cannot classify text - service not initialized or Labels not available for user {self.user_id}")
return []

View File

@@ -297,7 +297,7 @@ def _dict_to_types_schema(schema: Dict[str, Any]) -> types.Schema:
return _convert(schema)
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9, top_k=40, max_tokens=8192, system_prompt=None):
def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9, top_k=40, max_tokens=8192, system_prompt=None, user_id: str = None):
"""
Generate structured JSON response using Google's Gemini Pro model.
@@ -312,6 +312,7 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
top_k (int): Top-k sampling parameter
max_tokens (int): Maximum tokens in response. Use 8192 for complex outputs
system_prompt (str, optional): System instruction for the model
user_id (str, optional): User ID for usage tracking.
Returns:
dict: Parsed JSON response matching the provided schema
@@ -468,6 +469,25 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
logger.info(f"Response has parsed attribute: {response.parsed is not None}")
if response.parsed is not None:
logger.info("Using response.parsed for structured output")
# Track usage if user_id is provided
if user_id:
try:
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
import json
response_str = json.dumps(response.parsed)
track_agent_usage_sync(
user_id=user_id,
model_name="gemini-2.5-flash",
prompt=prompt,
response_text=response_str,
duration=0.5
)
except Exception as e:
logger.error(f"Failed to track usage: {e}")
return response.parsed
else:
logger.warning("Response.parsed is None, falling back to text parsing")
@@ -500,6 +520,22 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
parsed_text = json.loads(cleaned_text)
logger.info("Successfully parsed text as JSON")
# Track usage if user_id is provided
if user_id:
try:
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
track_agent_usage_sync(
user_id=user_id,
model_name="gemini-2.5-flash",
prompt=prompt,
response_text=cleaned_text,
duration=0.5
)
except Exception as e:
logger.error(f"Failed to track usage: {e}")
return parsed_text
except json.JSONDecodeError as e:
logger.error(f"Failed to parse text as JSON: {e}")
@@ -521,6 +557,26 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
fixed_json = re.sub(r',\s*]', ']', fixed_json)
parsed_text = json.loads(fixed_json)
# Track usage if user_id is provided
if user_id:
try:
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
import json
response_str = json.dumps(parsed_text) if parsed_text else ""
track_agent_usage_sync(
user_id=user_id,
model_name="gemini-2.5-flash",
prompt=prompt,
response_text=response_str,
duration=0.5 # Approximation
)
logger.info(f"✅ Tracked structured JSON usage for user {user_id}")
except Exception as e:
logger.error(f"Failed to track usage: {e}")
logger.info("Successfully parsed cleaned JSON")
return parsed_text
except Exception as fix_error:
@@ -537,6 +593,22 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
import json
parsed_text = json.loads(part.text)
logger.info("Successfully parsed candidate text as JSON")
# Track usage if user_id is provided
if user_id:
try:
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
track_agent_usage_sync(
user_id=user_id,
model_name="gemini-2.5-flash",
prompt=prompt,
response_text=part.text,
duration=0.5
)
except Exception as e:
logger.error(f"Failed to track usage: {e}")
return parsed_text
except json.JSONDecodeError as e:
logger.error(f"Failed to parse candidate text as JSON: {e}")

View File

@@ -4,6 +4,7 @@ import io
import os
from typing import Optional
from PIL import Image
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from .base import ImageGenerationProvider, ImageGenerationOptions, ImageGenerationResult
from services.wavespeed.client import WaveSpeedClient
@@ -14,7 +15,10 @@ logger = get_service_logger("wavespeed.image_provider")
class WaveSpeedImageProvider(ImageGenerationProvider):
"""WaveSpeed AI image generation provider supporting Ideogram V3 and Qwen."""
"""WaveSpeed AI image generation provider supporting Ideogram V3 and Qwen.
Implements robust error handling and retries for production stability.
"""
SUPPORTED_MODELS = {
"ideogram-v3-turbo": {
@@ -54,6 +58,28 @@ class WaveSpeedImageProvider(ImageGenerationProvider):
logger.info("[WaveSpeed Image Provider] Initialized with available models: %s",
list(self.SUPPORTED_MODELS.keys()))
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=2, max=10),
retry=retry_if_exception_type((RuntimeError, IOError)),
reraise=True
)
def _call_api_with_retry(self, method, **kwargs):
"""Execute API call with retry logic.
Args:
method: Callable API method
**kwargs: Arguments for the method
Returns:
API response
"""
try:
return method(**kwargs)
except Exception as e:
logger.warning(f"WaveSpeed API call failed (retrying): {str(e)}")
raise
def _validate_options(self, options: ImageGenerationOptions) -> None:
"""Validate generation options.
@@ -117,7 +143,7 @@ class WaveSpeedImageProvider(ImageGenerationProvider):
# Call WaveSpeed API (using generic image generation method)
# This will need to be adjusted based on actual WaveSpeed client implementation
result = self.client.generate_image(**params)
result = self._call_api_with_retry(self.client.generate_image, **params)
# Extract image bytes from result
# Adjust based on actual WaveSpeed API response format
@@ -167,7 +193,7 @@ class WaveSpeedImageProvider(ImageGenerationProvider):
params["seed"] = options.seed
# Call WaveSpeed API
result = self.client.generate_image(**params)
result = self._call_api_with_retry(self.client.generate_image, **params)
# Extract image bytes from result
if isinstance(result, bytes):
@@ -216,7 +242,7 @@ class WaveSpeedImageProvider(ImageGenerationProvider):
params["seed"] = options.seed
# Call WaveSpeed API
result = self.client.generate_image(**params)
result = self._call_api_with_retry(self.client.generate_image, **params)
# Extract image bytes from result
if isinstance(result, bytes):

View File

@@ -107,11 +107,13 @@ def generate_audio(
estimated_cost = (character_count / 1000.0) * cost_per_1000_chars
try:
from services.database import get_db
from services.database import get_session_for_user
from services.subscription import PricingService
from models.subscription_models import UsageSummary, APIProvider
db = next(get_db())
db = get_session_for_user(user_id)
if not db:
raise RuntimeError("Failed to get database session")
try:
pricing_service = PricingService(db)
@@ -194,7 +196,11 @@ def generate_audio(
if audio_bytes:
logger.info(f"[audio_gen] ✅ API call successful, tracking usage for user {user_id}")
try:
db_track = next(get_db())
db_track = get_session_for_user(user_id)
if not db_track:
logger.error(f"[audio_gen] ❌ Failed to get database session for tracking")
raise RuntimeError("Failed to get database session")
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
@@ -383,12 +389,14 @@ def clone_voice(
voice_clone_cost = 0.5
from services.database import get_db
from services.database import get_session_for_user
from services.subscription import PricingService
from models.subscription_models import APIProvider
try:
db = next(get_db())
db = get_session_for_user(user_id)
if not db:
raise RuntimeError("Failed to get database session")
try:
pricing_service = PricingService(db)
can_proceed, message, usage_info = pricing_service.check_usage_limits(
@@ -432,7 +440,11 @@ def clone_voice(
if preview_audio_bytes:
try:
db_track = next(get_db())
db_track = get_session_for_user(user_id)
if not db_track:
logger.error(f"[clone_voice] ❌ Failed to get database session for tracking")
raise RuntimeError("Failed to get database session")
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
@@ -570,12 +582,14 @@ def qwen3_voice_clone(
char_count = len(text)
estimated_cost = max(0.005, 0.005 * (char_count / 100.0))
from services.database import get_db
from services.database import get_session_for_user
from services.subscription import PricingService
from models.subscription_models import APIProvider
try:
db = next(get_db())
db = get_session_for_user(user_id)
if not db:
raise RuntimeError("Failed to get database session")
try:
pricing_service = PricingService(db)
can_proceed, message, usage_info = pricing_service.check_usage_limits(
@@ -615,7 +629,11 @@ def qwen3_voice_clone(
if preview_audio_bytes:
try:
db_track = next(get_db())
db_track = get_session_for_user(user_id)
if not db_track:
logger.error(f"[qwen3_voice_clone] ❌ Failed to get database session for tracking")
raise RuntimeError("Failed to get database session")
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
@@ -691,6 +709,7 @@ def qwen3_voice_clone(
├─ Provider: wavespeed
├─ Model: wavespeed-ai/qwen3-tts/voice-clone
├─ Calls: {current_calls_before}{new_calls}
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
├─ Text chars: {char_count}
└─ Status: ✅ Allowed & Tracked
""", flush=True)
@@ -724,3 +743,373 @@ def qwen3_voice_clone(
},
)
def qwen3_voice_design(
text: str,
voice_description: str,
*,
language: str = "auto",
user_id: Optional[str] = None,
) -> VoiceCloneResult:
try:
if not user_id:
raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.")
if not text or not isinstance(text, str) or len(text.strip()) == 0:
raise ValueError("Text is required and cannot be empty")
text = text.strip()
if not voice_description or not isinstance(voice_description, str) or len(voice_description.strip()) == 0:
raise ValueError("Voice description is required")
voice_description = voice_description.strip()
char_count = len(text)
# Pricing logic similar to TTS/Clone
estimated_cost = max(0.005, 0.005 * (char_count / 100.0))
from services.database import get_session_for_user
from services.subscription import PricingService
from models.subscription_models import APIProvider
try:
db = get_session_for_user(user_id)
if not db:
raise RuntimeError("Failed to get database session")
try:
pricing_service = PricingService(db)
can_proceed, message, usage_info = pricing_service.check_usage_limits(
user_id=user_id,
provider=APIProvider.AUDIO,
tokens_requested=char_count,
actual_provider_name="wavespeed",
)
if not can_proceed:
raise HTTPException(
status_code=429,
detail={
"error": message,
"message": message,
"provider": "wavespeed",
"usage_info": usage_info if usage_info else {},
},
)
finally:
db.close()
except HTTPException:
raise
except Exception as sub_error:
raise RuntimeError(f"Subscription check failed: {str(sub_error)}")
import time
start_time = time.time()
client = WaveSpeedClient()
preview_audio_bytes = client.voice_design(
text=text,
voice_description=voice_description,
language=language
)
response_time = time.time() - start_time
# Track usage
try:
db_track = get_session_for_user(user_id)
if not db_track:
logger.error(f"[qwen3_voice_design] ❌ Failed to get database session for tracking")
raise RuntimeError("Failed to get database session")
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
from sqlalchemy import text as sql_text
from services.subscription.provider_detection import detect_actual_provider
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
summary = UsageSummary(user_id=user_id, billing_period=current_period)
db_track.add(summary)
db_track.flush()
current_calls_before = getattr(summary, "audio_calls", 0) or 0
current_cost_before = getattr(summary, "audio_cost", 0.0) or 0.0
new_calls = current_calls_before + 1
new_cost = current_cost_before + float(estimated_cost)
update_query = sql_text("""
UPDATE usage_summaries
SET audio_calls = :new_calls,
audio_cost = :new_cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_query, {
"new_calls": new_calls,
"new_cost": new_cost,
"user_id": user_id,
"period": current_period
})
summary.total_cost = (summary.total_cost or 0.0) + float(estimated_cost)
summary.total_calls = (summary.total_calls or 0) + 1
summary.updated_at = datetime.utcnow()
actual_provider = detect_actual_provider(
provider_enum=APIProvider.AUDIO,
model_name="wavespeed-ai/qwen3-tts/voice-design",
endpoint="/audio-generation/wavespeed/qwen3-tts/voice-design",
)
usage_log = APIUsageLog(
user_id=user_id,
provider=APIProvider.AUDIO,
endpoint="/audio-generation/wavespeed/qwen3-tts/voice-design",
method="POST",
model_used="wavespeed-ai/qwen3-tts/voice-design",
actual_provider_name=actual_provider,
tokens_input=char_count,
tokens_output=0,
tokens_total=char_count,
cost_input=0.0,
cost_output=0.0,
cost_total=float(estimated_cost),
response_time=response_time,
status_code=200,
request_size=len(text) + len(voice_description),
response_size=len(preview_audio_bytes),
billing_period=current_period,
)
db_track.add(usage_log)
db_track.commit()
print(f"""
[SUBSCRIPTION] Qwen3 Voice Design
├─ User: {user_id}
├─ Provider: wavespeed
├─ Model: wavespeed-ai/qwen3-tts/voice-design
├─ Calls: {current_calls_before}{new_calls}
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
├─ Text chars: {char_count}
└─ Status: ✅ Allowed & Tracked
""", flush=True)
sys.stdout.flush()
except Exception as track_error:
logger.error(f"[qwen3_voice_design] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
logger.error(f"[qwen3_voice_design] ❌ Failed to track usage: {usage_error}", exc_info=True)
return VoiceCloneResult(
preview_audio_bytes=preview_audio_bytes,
provider="wavespeed",
model="wavespeed-ai/qwen3-tts/voice-design",
custom_voice_id="", # No persistent ID for design usually, unless we save it
file_size=len(preview_audio_bytes),
)
except HTTPException:
raise
except RuntimeError:
raise
except Exception as e:
logger.error(f"[qwen3_voice_design] Error designing voice: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
"error": "Qwen3 voice design failed",
"message": str(e),
},
)
def cosyvoice_voice_clone(
audio_bytes: bytes,
text: str,
*,
reference_text: Optional[str] = None,
audio_mime_type: Optional[str] = None,
user_id: Optional[str] = None,
) -> VoiceCloneResult:
try:
if not user_id:
raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.")
if not audio_bytes or not isinstance(audio_bytes, (bytes, bytearray)) or len(audio_bytes) == 0:
raise ValueError("Audio is required and cannot be empty")
if len(audio_bytes) > 15 * 1024 * 1024:
raise ValueError("Audio file too large. Maximum is 15MB.")
if not text or not isinstance(text, str) or len(text.strip()) == 0:
raise ValueError("Text is required and cannot be empty")
text = text.strip()
if len(text) > 4000:
raise ValueError("Text too long. Please keep it under 4000 characters.")
char_count = len(text)
estimated_cost = max(0.005, 0.005 * (char_count / 100.0))
from services.database import get_session_for_user
from services.subscription import PricingService
from models.subscription_models import APIProvider
try:
db = get_session_for_user(user_id)
if not db:
raise RuntimeError("Failed to get database session")
try:
pricing_service = PricingService(db)
can_proceed, message, usage_info = pricing_service.check_usage_limits(
user_id=user_id,
provider=APIProvider.AUDIO,
tokens_requested=char_count,
actual_provider_name="wavespeed",
)
if not can_proceed:
raise HTTPException(
status_code=429,
detail={
"error": message,
"message": message,
"provider": "wavespeed",
"usage_info": usage_info if usage_info else {},
},
)
finally:
db.close()
except HTTPException:
raise
except Exception as sub_error:
raise RuntimeError(f"Subscription check failed: {str(sub_error)}")
import time
start_time = time.time()
client = WaveSpeedClient()
preview_audio_bytes = client.cosyvoice_voice_clone(
audio_bytes=bytes(audio_bytes),
text=text,
audio_mime_type=audio_mime_type or "audio/wav",
reference_text=reference_text,
)
response_time = time.time() - start_time
if preview_audio_bytes:
try:
db_track = get_session_for_user(user_id)
if not db_track:
logger.error(f"[cosyvoice_voice_clone] ❌ Failed to get database session for tracking")
raise RuntimeError("Failed to get database session")
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
from sqlalchemy import text as sql_text
from services.subscription.provider_detection import detect_actual_provider
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
summary = UsageSummary(user_id=user_id, billing_period=current_period)
db_track.add(summary)
db_track.flush()
current_calls_before = getattr(summary, "audio_calls", 0) or 0
current_cost_before = getattr(summary, "audio_cost", 0.0) or 0.0
new_calls = current_calls_before + 1
new_cost = current_cost_before + float(estimated_cost)
update_query = sql_text("""
UPDATE usage_summaries
SET audio_calls = :new_calls,
audio_cost = :new_cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_query, {
"new_calls": new_calls,
"new_cost": new_cost,
"user_id": user_id,
"period": current_period
})
summary.total_cost = (summary.total_cost or 0.0) + float(estimated_cost)
summary.total_calls = (summary.total_calls or 0) + 1
summary.updated_at = datetime.utcnow()
actual_provider = detect_actual_provider(
provider_enum=APIProvider.AUDIO,
model_name="wavespeed-ai/cosyvoice-tts/voice-clone",
endpoint="/audio-generation/wavespeed/cosyvoice-tts/voice-clone",
)
usage_log = APIUsageLog(
user_id=user_id,
provider=APIProvider.AUDIO,
endpoint="/audio-generation/wavespeed/cosyvoice-tts/voice-clone",
method="POST",
model_used="wavespeed-ai/cosyvoice-tts/voice-clone",
actual_provider_name=actual_provider,
tokens_input=char_count,
tokens_output=0,
tokens_total=char_count,
cost_input=0.0,
cost_output=0.0,
cost_total=float(estimated_cost),
response_time=response_time,
status_code=200,
request_size=len(audio_bytes) + len(text.encode("utf-8")),
response_size=len(preview_audio_bytes),
billing_period=current_period,
)
db_track.add(usage_log)
db_track.commit()
print(f"""
[SUBSCRIPTION] CosyVoice Voice Clone
├─ User: {user_id}
├─ Provider: wavespeed
├─ Model: wavespeed-ai/cosyvoice-tts/voice-clone
├─ Calls: {current_calls_before}{new_calls}
├─ Text chars: {char_count}
└─ Status: ✅ Allowed & Tracked
""", flush=True)
sys.stdout.flush()
except Exception as track_error:
logger.error(f"[cosyvoice_voice_clone] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
logger.error(f"[cosyvoice_voice_clone] ❌ Failed to track usage: {usage_error}", exc_info=True)
return VoiceCloneResult(
preview_audio_bytes=preview_audio_bytes,
provider="wavespeed",
model="wavespeed-ai/cosyvoice-tts/voice-clone",
custom_voice_id="",
file_size=len(preview_audio_bytes),
)
except HTTPException:
raise
except RuntimeError:
raise
except Exception as e:
logger.error(f"[cosyvoice_voice_clone] Error cloning voice: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
"error": "CosyVoice voice cloning failed",
"message": str(e),
},
)

View File

@@ -2,6 +2,8 @@ from __future__ import annotations
import os
import io
import base64
import logging
from typing import Optional, Dict, Any
from PIL import Image
@@ -9,6 +11,9 @@ from .image_generation import (
ImageGenerationOptions,
ImageGenerationResult,
)
from .image_generation.base import ImageEditOptions
from .image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
from utils.logger_utils import get_service_logger
try:
@@ -22,21 +27,36 @@ logger = get_service_logger("image_editing.facade")
DEFAULT_IMAGE_EDIT_MODEL = os.getenv(
"HF_IMAGE_EDIT_MODEL",
"Qwen/Qwen-Image-Edit",
"WAVESPEED_IMAGE_EDIT_MODEL",
"qwen-edit-plus",
)
def _select_provider(explicit: Optional[str]) -> str:
"""Select provider for image editing. Defaults to huggingface with fal-ai."""
"""
Select the appropriate image editing provider.
Priority:
1. Explicitly requested provider
2. WaveSpeed (if API key available) - Preferred for quality/speed
3. Hugging Face (fallback)
"""
if explicit:
return explicit
# Default to huggingface for image editing (best support for image-to-image)
return explicit.lower()
# Check for WaveSpeed API key first (Preferred provider)
if os.getenv("WAVESPEED_API_KEY"):
return "wavespeed"
# Default to huggingface if WaveSpeed not available
return "huggingface"
def _get_provider_client(provider_name: str, api_key: Optional[str] = None):
"""Get InferenceClient for the specified provider."""
"""Get the client for the specified provider."""
if provider_name == "wavespeed":
return WaveSpeedEditProvider(api_key=api_key)
if not HF_HUB_AVAILABLE:
raise RuntimeError("huggingface_hub is not installed. Install with: pip install huggingface_hub")
@@ -44,7 +64,7 @@ def _get_provider_client(provider_name: str, api_key: Optional[str] = None):
api_key = api_key or os.getenv("HF_TOKEN")
if not api_key:
raise RuntimeError("HF_TOKEN is required for Hugging Face image editing")
# Use fal-ai provider for fast inference
# Use fal-ai provider for fast inference via HF Inference API
return InferenceClient(provider="fal-ai", api_key=api_key)
raise ValueError(f"Unknown image editing provider: {provider_name}")
@@ -86,6 +106,8 @@ def edit_image(
from fastapi import HTTPException
logger.info(f"[Image Editing] 🔍 Starting pre-flight validation for user_id={user_id}")
# Note: get_db() is a generator, so we need to use next() to get the session
# and ensure we close it in the finally block
db = next(get_db())
try:
pricing_service = PricingService(db)
@@ -99,6 +121,9 @@ def edit_image(
# Re-raise immediately - don't proceed with API call
logger.error(f"[Image Editing] ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
raise
except Exception as e:
logger.error(f"[Image Editing] ❌ Unexpected error during pre-flight validation: {e}")
raise HTTPException(status_code=500, detail=f"Image editing validation failed: {str(e)}")
finally:
db.close()
else:
@@ -119,6 +144,69 @@ def edit_image(
# Get provider client
client = _get_provider_client(provider_name, opts.get("api_key"))
if provider_name == "wavespeed":
# Handle WaveSpeed provider
try:
# Convert inputs to base64 for WaveSpeed
image_b64 = base64.b64encode(input_image_bytes).decode('utf-8')
mask_b64 = None
if mask_bytes:
mask_b64 = base64.b64encode(mask_bytes).decode('utf-8')
# Determine operation type based on prompt/mask
operation = "general_edit" # Default
if not prompt and mask_b64:
operation = "remove_bg" # Heuristic: mask but no prompt implies removal/in-painting
elif prompt and not mask_b64:
operation = "style_transfer" # Heuristic: prompt but no mask implies style transfer
elif opts.get("operation"):
operation = opts.get("operation")
edit_options = ImageEditOptions(
image_base64=image_b64,
prompt=prompt.strip(),
operation=operation,
mask_base64=mask_b64,
model=model,
guidance_scale=opts.get("guidance_scale"),
steps=opts.get("steps"),
seed=opts.get("seed"),
extra=opts
)
logger.info(f"[Image Editing] Calling WaveSpeed edit with model={model}")
result = client.edit(edit_options)
# TRACK USAGE after successful WaveSpeed call
if user_id:
try:
from services.llm_providers.main_image_generation import _track_image_operation_usage
# Estimate cost (WaveSpeed default: $0.02)
estimated_cost = result.metadata.get("estimated_cost", 0.02) if result.metadata else 0.02
_track_image_operation_usage(
user_id=user_id,
provider="wavespeed",
model=result.model or model,
operation_type="image-editing",
result_bytes=result.image_bytes,
cost=estimated_cost,
prompt=prompt,
endpoint="/image-editing",
metadata=result.metadata,
log_prefix="[Image Editing]"
)
except Exception as track_error:
logger.warning(f"[Image Editing] ⚠️ Failed to track usage: {track_error}")
return result
except Exception as e:
logger.error(f"[Image Editing] ❌ WaveSpeed editing failed: {e}", exc_info=True)
raise RuntimeError(f"WaveSpeed editing failed: {str(e)}")
# Hugging Face (Fallback)
# Prepare parameters for image-to-image
params: Dict[str, Any] = {}
if opts.get("guidance_scale") is not None:
@@ -170,6 +258,29 @@ def edit_image(
logger.info(f"[Image Editing] ✅ Successfully edited image: {len(edited_image_bytes)} bytes")
# TRACK USAGE after successful HF call
if user_id:
try:
from services.llm_providers.main_image_generation import _track_image_operation_usage
# Estimate cost (HF/Fal-ai default: $0.05)
estimated_cost = 0.05
_track_image_operation_usage(
user_id=user_id,
provider="huggingface",
model=model,
operation_type="image-editing",
result_bytes=edited_image_bytes,
cost=estimated_cost,
prompt=prompt,
endpoint="/image-editing",
metadata={"provider": "fal-ai"},
log_prefix="[Image Editing]"
)
except Exception as track_error:
logger.warning(f"[Image Editing] ⚠️ Failed to track usage: {track_error}")
return ImageGenerationResult(
image_bytes=edited_image_bytes,
width=edited_image.width,

View File

@@ -5,6 +5,7 @@ import sys
import base64
from datetime import datetime
from typing import Optional, Dict, Any
from fastapi import HTTPException
from fastapi.concurrency import run_in_threadpool
from .image_generation import (
@@ -29,6 +30,11 @@ logger = get_service_logger("image_generation.facade")
def _select_provider(explicit: Optional[str]) -> str:
if explicit:
return explicit
# User requested WaveSpeed as default provider
if os.getenv("WAVESPEED_API_KEY"):
return "wavespeed"
gpt_provider = (os.getenv("GPT_PROVIDER") or "").lower()
if gpt_provider.startswith("gemini"):
return "gemini"
@@ -36,8 +42,7 @@ def _select_provider(explicit: Optional[str]) -> str:
return "huggingface"
if os.getenv("STABILITY_API_KEY"):
return "stability"
if os.getenv("WAVESPEED_API_KEY"):
return "wavespeed"
# Fallback to huggingface to enable a path if configured
return "huggingface"
@@ -739,18 +744,139 @@ async def generate_image_with_provider(
}
except Exception as e:
logger.error(f"Error in generate_image_with_provider: {e}")
# Propagate specific error message if available
error_detail = str(e)
if "402" in error_detail or "Payment Required" in error_detail:
raise HTTPException(status_code=402, detail=f"Payment Required: {error_detail}")
return {
"success": False,
"error": str(e)
"error": error_detail
}
import time
from services.database import get_session_for_user
from models.onboarding import WebsiteAnalysis, OnboardingSession, CompetitorAnalysis
async def enhance_image_prompt(prompt: str, user_id: Optional[str] = None) -> str:
"""
Enhance image prompt using LLM.
Placeholder implementation.
Enhance image prompt using WaveSpeed's specialized prompt optimizer.
Restructures and enriches prompts for visual clarity and cinematic detail.
Uses Step 2 (Website Analysis) and Step 3 (Competitor Analysis) context if available.
"""
return prompt
start_time = time.time()
try:
from services.wavespeed.client import WaveSpeedClient
# 1. Pre-flight Validation
if user_id:
_validate_image_operation(
user_id=user_id,
operation_type="prompt-enhancement",
num_operations=1,
log_prefix="[Prompt Enhancement]"
)
# 2. Fetch Context from Step 2 & 3
context_instruction = ""
if user_id:
try:
db_session = get_session_for_user(user_id)
try:
# Get Onboarding Session
session = db_session.query(OnboardingSession).filter(
OnboardingSession.user_id == user_id
).first()
if session:
# Step 2: Website Analysis
website_analysis = db_session.query(WebsiteAnalysis).filter(
WebsiteAnalysis.session_id == session.id
).first()
if website_analysis:
# Handle potential JSON or dict types
brand_voice = website_analysis.brand_analysis
style = website_analysis.style_guidelines
target_audience = website_analysis.target_audience
context_instruction += "\n\nCONTEXT FROM WEBSITE ANALYSIS:\n"
if target_audience:
context_instruction += f"Target Audience: {target_audience}\n"
if brand_voice and isinstance(brand_voice, dict):
context_instruction += f"Brand Voice: {brand_voice.get('voice_characteristics', '')} - {brand_voice.get('tone', '')}\n"
if style and isinstance(style, dict):
context_instruction += f"Visual Style: {style.get('visual_style', '')} - {style.get('color_palette', '')}\n"
# Step 3: Competitor Analysis (Limit to top 3)
competitors = db_session.query(CompetitorAnalysis).filter(
CompetitorAnalysis.session_id == session.id
).limit(3).all()
if competitors:
context_instruction += "\nCOMPETITOR VISUAL INSIGHTS:\n"
for comp in competitors:
if comp.analysis_data and isinstance(comp.analysis_data, dict):
comp_title = comp.analysis_data.get('title', 'Competitor')
# Try to extract visual/content insights if available
highlights = comp.analysis_data.get('highlights', [])
if highlights:
context_instruction += f"- {comp_title}: {', '.join(highlights[:2])}\n"
finally:
db_session.close()
except Exception as db_ex:
logger.warning(f"Failed to fetch context for prompt enhancement: {db_ex}")
# Combine prompt with context
full_input_text = prompt
if context_instruction:
logger.info(f"Enhancing prompt for user {user_id} with Step 2/3 context")
# We append context as instruction for the optimizer
full_input_text = f"Original Request: {prompt}\n\n{context_instruction}\n\nTask: Generate a hyper-personalized, detailed image generation prompt based on the Original Request and the provided Context. Ensure the visual style aligns with the Brand Voice and Visual Style."
else:
logger.info(f"Enhancing prompt for user {user_id} (no context found)")
# 3. Call WaveSpeed
client = WaveSpeedClient()
# Use 'image' mode for avatar/image generation workflows
# Use 'photographic' style as requested for avatars
optimized_prompt = client.optimize_prompt(
text=full_input_text,
mode="image",
style="photographic",
enable_sync_mode=True,
timeout=30
)
# 4. Track Usage
if user_id:
duration = time.time() - start_time
# Track as 0 cost for now unless we have specific pricing for prompt opt
# But we track it as an operation
_track_image_operation_usage(
user_id=user_id,
provider="wavespeed",
model="wavespeed-prompt-opt",
operation_type="prompt-enhancement",
result_bytes=b"", # No image
cost=0.0,
prompt=prompt,
endpoint="/enhance-prompt",
metadata={"duration": duration, "context_added": bool(context_instruction)},
log_prefix="[Prompt Enhancement]",
response_time=duration
)
return optimized_prompt
except Exception as e:
logger.error(f"Failed to enhance prompt via WaveSpeed: {e}")
# Fallback to original prompt on failure
return prompt
async def generate_image_variation(
@@ -760,13 +886,123 @@ async def generate_image_variation(
**kwargs
) -> Dict[str, Any]:
"""
Generate variation of an existing image.
Placeholder implementation.
Generate variation of an existing image using image-to-image editing.
Wrapper for step4_asset_routes.
"""
return {
"success": False,
"error": "Not implemented yet"
}
try:
# Handle image input (bytes, file, or base64)
image_bytes = None
if isinstance(image, bytes):
image_bytes = image
elif hasattr(image, "read"):
image_bytes = await image.read()
elif isinstance(image, str):
# Assume base64 or path
if os.path.exists(image):
with open(image, "rb") as f:
image_bytes = f.read()
else:
# Try base64 decode
try:
if "base64," in image:
image = image.split("base64,")[1]
image_bytes = base64.b64decode(image)
except:
pass
if not image_bytes:
return {"success": False, "error": "Invalid image input"}
# Convert to base64 for internal function
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
# Use generate_image_edit with "variation" intent
# For variation, we typically use general_edit with specific prompt
result = await run_in_threadpool(
generate_image_edit,
image_base64=image_base64,
prompt=prompt,
operation="general_edit",
model=kwargs.get("model", "qwen-edit-plus"), # Default to capable model
options=kwargs,
user_id=user_id
)
result_base64 = base64.b64encode(result.image_bytes).decode('utf-8')
return {
"success": True,
"image_base64": result_base64,
"metadata": result.metadata
}
except Exception as e:
logger.error(f"Error in generate_image_variation: {e}")
return {
"success": False,
"error": str(e)
}
async def generate_image_enhance(
image: Any,
user_id: Optional[str] = None,
**kwargs
) -> Dict[str, Any]:
"""
Enhance/Upscale an existing image.
Wrapper for step4_asset_routes.
"""
try:
# Handle image input
image_bytes = None
if isinstance(image, bytes):
image_bytes = image
elif hasattr(image, "read"):
image_bytes = await image.read()
elif isinstance(image, str):
if os.path.exists(image):
with open(image, "rb") as f:
image_bytes = f.read()
else:
try:
if "base64," in image:
image = image.split("base64,")[1]
image_bytes = base64.b64decode(image)
except:
pass
if not image_bytes:
return {"success": False, "error": "Invalid image input"}
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
# Use generate_image_edit with "enhance" intent
# Use high-res model like nano-banana-pro-edit-ultra
result = await run_in_threadpool(
generate_image_edit,
image_base64=image_base64,
prompt="enhance details, high resolution, professional quality, 4k, sharp focus",
operation="general_edit",
model="nano-banana-pro-edit-ultra",
options={**kwargs, "resolution": "4k"},
user_id=user_id
)
result_base64 = base64.b64encode(result.image_bytes).decode('utf-8')
return {
"success": True,
"image_base64": result_base64,
"metadata": result.metadata
}
except Exception as e:
logger.error(f"Error in generate_image_enhance: {e}")
return {
"success": False,
"error": str(e)
}

View File

@@ -260,335 +260,23 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
if response_text:
logger.info(f"[llm_text_gen] ✅ API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
try:
db_track = get_session_for_user(user_id)
try:
# Estimate tokens from prompt and response
# Recalculate input tokens from prompt (consistent with pre-flight estimation)
tokens_input = int(len(prompt.split()) * 1.3)
tokens_output = int(len(str(response_text).split()) * 1.3) # Estimate output tokens
tokens_total = tokens_input + tokens_output
logger.debug(f"[llm_text_gen] Token estimates: input={tokens_input}, output={tokens_output}, total={tokens_total}")
# Get or create usage summary
from models.subscription_models import UsageSummary
from services.subscription import PricingService
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
logger.debug(f"[llm_text_gen] Looking for usage summary: user_id={user_id}, period={current_period}")
# Get limits once for safety check (to prevent exceeding limits even if actual usage > estimate)
provider_name = provider_enum.value
limits = pricing.get_user_limits(user_id)
token_limit = 0
if limits and limits.get('limits'):
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
# CRITICAL: Use raw SQL to read current values directly from DB, bypassing SQLAlchemy cache
# This ensures we always get the absolute latest committed values, even across different sessions
from sqlalchemy import text
current_calls_before = 0
current_tokens_before = 0
record_count = 0 # Initialize to ensure it's always defined
# CRITICAL: First check if record exists using COUNT query
try:
check_query = text("SELECT COUNT(*) FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period")
record_count = db_track.execute(check_query, {'user_id': user_id, 'period': current_period}).scalar()
logger.debug(f"[llm_text_gen] 🔍 DEBUG: Record count check - found {record_count} record(s) for user={user_id}, period={current_period}")
except Exception as count_error:
logger.error(f"[llm_text_gen] ❌ COUNT query failed: {count_error}", exc_info=True)
record_count = 0
if record_count and record_count > 0:
# Record exists - read current values with raw SQL
try:
# Validate provider_name to prevent SQL injection (whitelist approach)
valid_providers = ['gemini', 'openai', 'anthropic', 'mistral']
if provider_name not in valid_providers:
raise ValueError(f"Invalid provider_name for SQL query: {provider_name}")
# Read current values directly from database using raw SQL
# CRITICAL: This bypasses SQLAlchemy's session cache and gets absolute latest values
sql_query = text(f"""
SELECT {provider_name}_calls, {provider_name}_tokens
FROM usage_summaries
WHERE user_id = :user_id AND billing_period = :period
LIMIT 1
""")
logger.debug(f"[llm_text_gen] 🔍 Executing raw SQL for EXISTING record: SELECT {provider_name}_calls, {provider_name}_tokens WHERE user_id={user_id}, period={current_period}")
result = db_track.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
if result:
raw_calls = result[0] if result[0] is not None else 0
raw_tokens = result[1] if result[1] is not None else 0
current_calls_before = raw_calls
current_tokens_before = raw_tokens
logger.debug(f"[llm_text_gen] ✅ Raw SQL SUCCESS: Found EXISTING record - calls={current_calls_before}, tokens={current_tokens_before} (provider={provider_name}, column={provider_name}_calls/{provider_name}_tokens)")
logger.debug(f"[llm_text_gen] 🔍 Raw SQL returned row: {result}, extracted calls={raw_calls}, tokens={raw_tokens}")
else:
logger.error(f"[llm_text_gen] ❌ CRITICAL BUG: Record EXISTS (count={record_count}) but SELECT query returned None! Query: {sql_query}")
# Fallback: Use ORM to get values
summary_fallback = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if summary_fallback:
db_track.refresh(summary_fallback)
current_calls_before = getattr(summary_fallback, f"{provider_name}_calls", 0) or 0
current_tokens_before = getattr(summary_fallback, f"{provider_name}_tokens", 0) or 0
logger.warning(f"[llm_text_gen] ⚠️ Using ORM fallback: calls={current_calls_before}, tokens={current_tokens_before}")
except Exception as sql_error:
logger.error(f"[llm_text_gen] ❌ Raw SQL query failed: {sql_error}", exc_info=True)
# Fallback: Use ORM to get values
summary_fallback = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if summary_fallback:
db_track.refresh(summary_fallback)
current_calls_before = getattr(summary_fallback, f"{provider_name}_calls", 0) or 0
current_tokens_before = getattr(summary_fallback, f"{provider_name}_tokens", 0) or 0
else:
logger.debug(f"[llm_text_gen] No record exists yet (will create new) - user={user_id}, period={current_period}")
# Get or create usage summary object (needed for ORM update)
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
logger.debug(f"[llm_text_gen] Creating NEW usage summary for user {user_id}, period {current_period}")
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush() # Ensure summary is persisted before updating
# New record - values are already 0, no need to set
logger.debug(f"[llm_text_gen] ✅ New summary created - starting from 0")
else:
# CRITICAL: Update the ORM object with values from raw SQL query
# This ensures the ORM object reflects the actual database state before we increment
logger.debug(f"[llm_text_gen] 🔄 Existing summary found - syncing with raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}")
setattr(summary, f"{provider_name}_calls", current_calls_before)
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
setattr(summary, f"{provider_name}_tokens", current_tokens_before)
logger.debug(f"[llm_text_gen] ✅ Synchronized ORM object: {provider_name}_calls={current_calls_before}, {provider_name}_tokens={current_tokens_before}")
logger.debug(f"[llm_text_gen] Current {provider_name}_calls from DB (raw SQL): {current_calls_before}")
# Update provider-specific counters (sync operation)
new_calls = current_calls_before + 1
# CRITICAL: Use direct SQL UPDATE instead of ORM setattr for dynamic attributes
# SQLAlchemy doesn't detect changes when using setattr() on dynamic attributes
# Using raw SQL UPDATE ensures the change is persisted
from sqlalchemy import text
update_calls_query = text(f"""
UPDATE usage_summaries
SET {provider_name}_calls = :new_calls
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_calls_query, {
'new_calls': new_calls,
'user_id': user_id,
'period': current_period
})
logger.debug(f"[llm_text_gen] Updated {provider_name}_calls via SQL: {current_calls_before} -> {new_calls}")
# Update token usage for LLM providers with safety check
# CRITICAL: Use current_tokens_before from raw SQL query (NOT from ORM object)
# The ORM object may have stale values, but raw SQL always has the latest committed values
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
logger.debug(f"[llm_text_gen] Current {provider_name}_tokens from DB (raw SQL): {current_tokens_before}")
# SAFETY CHECK: Prevent exceeding token limit even if actual usage exceeds estimate
# This prevents abuse where actual response tokens exceed pre-flight validation estimate
projected_new_tokens = current_tokens_before + tokens_total
# If limit is set (> 0) and would be exceeded, cap at limit
if token_limit > 0 and projected_new_tokens > token_limit:
logger.warning(
f"[llm_text_gen] ⚠️ ACTUAL token usage ({tokens_total}) exceeded estimate. "
f"Would exceed limit: {projected_new_tokens} > {token_limit}. "
f"Capping tracked tokens at limit to prevent abuse."
)
# Cap at limit to prevent abuse
new_tokens = token_limit
# Adjust tokens_total for accurate total tracking
tokens_total = token_limit - current_tokens_before
if tokens_total < 0:
tokens_total = 0
else:
new_tokens = projected_new_tokens
# CRITICAL: Use direct SQL UPDATE instead of ORM setattr for dynamic attributes
update_tokens_query = text(f"""
UPDATE usage_summaries
SET {provider_name}_tokens = :new_tokens
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_tokens_query, {
'new_tokens': new_tokens,
'user_id': user_id,
'period': current_period
})
logger.debug(f"[llm_text_gen] Updated {provider_name}_tokens via SQL: {current_tokens_before} -> {new_tokens}")
else:
current_tokens_before = 0
new_tokens = 0
# Determine tracked tokens (after any safety capping)
tracked_tokens_input = min(tokens_input, tokens_total)
tracked_tokens_output = max(tokens_total - tracked_tokens_input, 0)
# Calculate and persist cost for this call
try:
cost_info = pricing.calculate_api_cost(
provider=provider_enum,
model_name=model,
tokens_input=tracked_tokens_input,
tokens_output=tracked_tokens_output,
request_count=1
)
cost_total = cost_info.get('cost_total', 0.0) or 0.0
except Exception as cost_error:
cost_total = 0.0
logger.error(f"[llm_text_gen] ❌ Failed to calculate API cost: {cost_error}", exc_info=True)
if cost_total > 0:
logger.debug(f"[llm_text_gen] 💰 Calculated cost for {provider_name}: ${cost_total:.6f}")
update_costs_query = text(f"""
UPDATE usage_summaries
SET {provider_name}_cost = COALESCE({provider_name}_cost, 0) + :cost,
total_cost = COALESCE(total_cost, 0) + :cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_costs_query, {
'cost': cost_total,
'user_id': user_id,
'period': current_period
})
# Keep ORM object in sync for logging/debugging
current_provider_cost = getattr(summary, f"{provider_name}_cost", 0.0) or 0.0
setattr(summary, f"{provider_name}_cost", current_provider_cost + cost_total)
summary.total_cost = (summary.total_cost or 0.0) + cost_total
else:
logger.debug(f"[llm_text_gen] 💰 Cost calculation returned $0 for {provider_name} (tokens_input={tracked_tokens_input}, tokens_output={tracked_tokens_output})")
# Update totals using SQL UPDATE
old_total_calls = summary.total_calls or 0
old_total_tokens = summary.total_tokens or 0
new_total_calls = old_total_calls + 1
new_total_tokens = old_total_tokens + tokens_total
update_totals_query = text("""
UPDATE usage_summaries
SET total_calls = :total_calls, total_tokens = :total_tokens
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_totals_query, {
'total_calls': new_total_calls,
'total_tokens': new_total_tokens,
'user_id': user_id,
'period': current_period
})
logger.debug(f"[llm_text_gen] Updated totals via SQL: calls {old_total_calls} -> {new_total_calls}, tokens {old_total_tokens} -> {new_total_tokens}")
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) if limits else 0
# Get image stats for unified log
current_images_before = getattr(summary, "stability_calls", 0) or 0
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
# Get image editing stats for unified log
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
# Get video stats for unified log
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
# Get audio stats for unified log
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else ''
# CRITICAL DEBUG: Print diagnostic info BEFORE commit (always visible, flushed immediately)
import sys
debug_msg = f"[DEBUG] BEFORE COMMIT - Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}, Period: {current_period}, New calls will be: {new_calls}, New tokens will be: {new_tokens}"
print(debug_msg, flush=True)
sys.stdout.flush()
logger.debug(f"[llm_text_gen] {debug_msg}")
# CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions
db_track.flush() # Flush to ensure changes are in DB (not just in transaction)
db_track.commit() # Commit transaction to make changes visible to other sessions
logger.debug(f"[llm_text_gen] ✅ Successfully tracked usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens (COMMITTED to DB)")
logger.debug(f"[llm_text_gen] Database state after commit: {provider_name}_calls={new_calls}, {provider_name}_tokens={new_tokens} (should be visible to next session)")
# CRITICAL: Verify commit worked by reading back from DB immediately after commit
try:
verify_query = text(f"SELECT {provider_name}_calls, {provider_name}_tokens FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1")
verify_result = db_track.execute(verify_query, {'user_id': user_id, 'period': current_period}).first()
if verify_result:
verified_calls = verify_result[0] if verify_result[0] is not None else 0
verified_tokens = verify_result[1] if verify_result[1] is not None else 0
logger.debug(f"[llm_text_gen] ✅ VERIFICATION AFTER COMMIT: Read back calls={verified_calls}, tokens={verified_tokens} (expected: calls={new_calls}, tokens={new_tokens})")
if verified_calls != new_calls or verified_tokens != new_tokens:
logger.error(f"[llm_text_gen] ❌ CRITICAL: COMMIT VERIFICATION FAILED! Expected calls={new_calls}, tokens={new_tokens}, but DB has calls={verified_calls}, tokens={verified_tokens}")
# Force another commit attempt
db_track.commit()
verify_result2 = db_track.execute(verify_query, {'user_id': user_id, 'period': current_period}).first()
if verify_result2:
verified_calls2 = verify_result2[0] if verify_result2[0] is not None else 0
verified_tokens2 = verify_result2[1] if verify_result2[1] is not None else 0
logger.debug(f"[llm_text_gen] 🔄 After second commit attempt: calls={verified_calls2}, tokens={verified_tokens2}")
else:
logger.debug(f"[llm_text_gen] ✅ COMMIT VERIFICATION PASSED: Values match expected values")
else:
logger.error(f"[llm_text_gen] ❌ CRITICAL: COMMIT VERIFICATION FAILED! Record not found after commit!")
except Exception as verify_error:
logger.error(f"[llm_text_gen] ❌ Error verifying commit: {verify_error}", exc_info=True)
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")
# Include image stats in the log
# DEBUG: Log the actual values being used
logger.debug(f"[llm_text_gen] 📊 FINAL VALUES FOR LOG: calls_before={current_calls_before}, calls_after={new_calls}, tokens_before={current_tokens_before}, tokens_after={new_tokens}, provider={provider_name}, enum={provider_enum}")
# CRITICAL DEBUG: Print diagnostic info to stdout (always visible)
print(f"[DEBUG] Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}")
print(f"""
[SUBSCRIPTION] LLM Text Generation
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: {actual_provider_name}
├─ Model: {model}
├─ Calls: {current_calls_before}{new_calls} / {call_limit if call_limit > 0 else ''}
├─ Tokens: {current_tokens_before}{new_tokens} / {token_limit if token_limit > 0 else ''}
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
├─ Audio: {current_audio_calls} / {audio_limit_display}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:
logger.error(f"[llm_text_gen] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
# Estimate tokens
tokens_input = int(len(prompt.split()) * 1.3)
# Calculate duration (mocking it since we didn't track start time explicitly in this function)
# Ideally we should track start_time at beginning of function
duration = 0.5
track_agent_usage_sync(
user_id=user_id,
model_name=model,
prompt=prompt,
response_text=response_text,
duration=duration
)
except Exception as usage_error:
# Non-blocking: log error but don't fail the request
logger.error(f"[llm_text_gen] ❌ Failed to track usage: {usage_error}", exc_info=True)
@@ -661,208 +349,18 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
if response_text:
logger.info(f"[llm_text_gen] ✅ Fallback API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
try:
db_track = get_session_for_user(user_id)
try:
# Estimate tokens from prompt and response
# Recalculate input tokens from prompt (consistent with pre-flight estimation)
tokens_input = int(len(prompt.split()) * 1.3)
tokens_output = int(len(str(response_text).split()) * 1.3)
tokens_total = tokens_input + tokens_output
# Get or create usage summary
from models.subscription_models import UsageSummary
from services.subscription import PricingService
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Get limits once for safety check (to prevent exceeding limits even if actual usage > estimate)
provider_name = provider_enum.value
limits = pricing.get_user_limits(user_id)
token_limit = 0
if limits and limits.get('limits'):
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
# CRITICAL: Use raw SQL to read current values directly from DB, bypassing SQLAlchemy cache
from sqlalchemy import text
current_calls_before = 0
current_tokens_before = 0
try:
# Validate provider_name to prevent SQL injection
valid_providers = ['gemini', 'openai', 'anthropic', 'mistral']
if provider_name not in valid_providers:
raise ValueError(f"Invalid provider_name for SQL query: {provider_name}")
# Read current values directly from database using raw SQL
sql_query = text(f"""
SELECT {provider_name}_calls, {provider_name}_tokens
FROM usage_summaries
WHERE user_id = :user_id AND billing_period = :period
LIMIT 1
""")
result = db_track.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
if result:
current_calls_before = result[0] if result[0] is not None else 0
current_tokens_before = result[1] if result[1] is not None else 0
logger.debug(f"[llm_text_gen] Raw SQL read current values (fallback): calls={current_calls_before}, tokens={current_tokens_before}")
except Exception as sql_error:
logger.warning(f"[llm_text_gen] Raw SQL query failed (fallback), falling back to ORM: {sql_error}")
# Fallback to ORM query if raw SQL fails
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if summary:
db_track.refresh(summary)
current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0
current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0
# Get or create usage summary object (needed for ORM update)
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush() # Ensure summary is persisted before updating
else:
# CRITICAL: Update the ORM object with values from raw SQL query
# This ensures the ORM object reflects the actual database state before we increment
setattr(summary, f"{provider_name}_calls", current_calls_before)
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
setattr(summary, f"{provider_name}_tokens", current_tokens_before)
logger.debug(f"[llm_text_gen] Synchronized summary object with raw SQL values (fallback): calls={current_calls_before}, tokens={current_tokens_before}")
# Get "before" state for unified log (from raw SQL query)
logger.debug(f"[llm_text_gen] Current {provider_name}_calls from DB (fallback, raw SQL): {current_calls_before}")
# Update provider-specific counters (sync operation)
new_calls = current_calls_before + 1
setattr(summary, f"{provider_name}_calls", new_calls)
# Update token usage for LLM providers with safety check
# Use current_tokens_before from raw SQL query (most reliable)
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
logger.debug(f"[llm_text_gen] Current {provider_name}_tokens from DB (fallback, raw SQL): {current_tokens_before}")
# SAFETY CHECK: Prevent exceeding token limit even if actual usage exceeds estimate
# This prevents abuse where actual response tokens exceed pre-flight validation estimate
projected_new_tokens = current_tokens_before + tokens_total
# If limit is set (> 0) and would be exceeded, cap at limit
if token_limit > 0 and projected_new_tokens > token_limit:
logger.warning(
f"[llm_text_gen] ⚠️ ACTUAL token usage ({tokens_total}) exceeded estimate in fallback provider. "
f"Would exceed limit: {projected_new_tokens} > {token_limit}. "
f"Capping tracked tokens at limit to prevent abuse."
)
# Cap at limit to prevent abuse
new_tokens = token_limit
# Adjust tokens_total for accurate total tracking
tokens_total = token_limit - current_tokens_before
if tokens_total < 0:
tokens_total = 0
else:
new_tokens = projected_new_tokens
setattr(summary, f"{provider_name}_tokens", new_tokens)
else:
current_tokens_before = 0
new_tokens = 0
# Determine tracked tokens after any safety capping
tracked_tokens_input = min(tokens_input, tokens_total)
tracked_tokens_output = max(tokens_total - tracked_tokens_input, 0)
# Calculate and persist cost for this fallback call
cost_total = 0.0
try:
cost_info = pricing.calculate_api_cost(
provider=provider_enum,
model_name=fallback_model,
tokens_input=tracked_tokens_input,
tokens_output=tracked_tokens_output,
request_count=1
)
cost_total = cost_info.get('cost_total', 0.0) or 0.0
except Exception as cost_error:
logger.error(f"[llm_text_gen] ❌ Failed to calculate fallback cost: {cost_error}", exc_info=True)
if cost_total > 0:
update_costs_query = text(f"""
UPDATE usage_summaries
SET {provider_name}_cost = COALESCE({provider_name}_cost, 0) + :cost,
total_cost = COALESCE(total_cost, 0) + :cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_costs_query, {
'cost': cost_total,
'user_id': user_id,
'period': current_period
})
setattr(summary, f"{provider_name}_cost", (getattr(summary, f"{provider_name}_cost", 0.0) or 0.0) + cost_total)
summary.total_cost = (summary.total_cost or 0.0) + cost_total
# Update totals (using potentially capped tokens_total from safety check)
summary.total_calls = (summary.total_calls or 0) + 1
summary.total_tokens = (summary.total_tokens or 0) + tokens_total
# Get plan details for unified log (limits already retrieved above)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0
# Get image stats for unified log
current_images_before = getattr(summary, "stability_calls", 0) or 0
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
# Get image editing stats for unified log
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
# Get video stats for unified log
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
# Get audio stats for unified log
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else ''
# CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions
db_track.flush() # Flush to ensure changes are in DB (not just in transaction)
db_track.commit() # Commit transaction to make changes visible to other sessions
logger.info(f"[llm_text_gen] ✅ Successfully tracked fallback usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens (committed)")
# UNIFIED SUBSCRIPTION LOG for fallback
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")
# Include image stats in the log
print(f"""
[SUBSCRIPTION] LLM Text Generation (Fallback)
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: {actual_provider_name}
├─ Model: {fallback_model}
├─ Calls: {current_calls_before}{new_calls} / {call_limit if call_limit > 0 else ''}
├─ Tokens: {current_tokens_before}{new_tokens} / {token_limit if token_limit > 0 else ''}
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
├─ Audio: {current_audio_calls} / {audio_limit_display}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:
logger.error(f"[llm_text_gen] ❌ Error tracking fallback usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
# Estimate tokens
tokens_input = int(len(prompt.split()) * 1.3)
track_agent_usage_sync(
user_id=user_id,
model_name=fallback_model,
prompt=prompt,
response_text=response_text,
duration=0.5 # Approximate duration
)
except Exception as usage_error:
logger.error(f"[llm_text_gen] ❌ Failed to track fallback usage: {usage_error}", exc_info=True)

View File

@@ -36,6 +36,172 @@ class VideoProviderNotImplemented(Exception):
pass
def _track_video_operation_usage(
user_id: str,
provider: str,
model: str,
operation_type: str,
result_bytes: bytes,
cost: float,
prompt: Optional[str] = None,
endpoint: str = "/video-generation",
metadata: Optional[Dict[str, Any]] = None,
log_prefix: str = "[Video Generation]",
response_time: float = 0.0
) -> Dict[str, Any]:
"""
Reusable usage tracking helper for all video operations.
Args:
user_id: User ID for tracking
provider: Provider name
model: Model name used
operation_type: Type of operation (for logging)
result_bytes: Generated video bytes
cost: Cost of the operation
prompt: Optional prompt text
endpoint: API endpoint path
metadata: Optional additional metadata
log_prefix: Logging prefix
response_time: API response time
Returns:
Dictionary with tracking information
"""
try:
from services.database import get_session_for_user
db_track = get_session_for_user(user_id)
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Get or create usage summary
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush()
# Get current values before update
current_calls_before = getattr(summary, "video_calls", 0) or 0
current_cost_before = getattr(summary, "video_cost", 0.0) or 0.0
# Update video calls and cost
new_calls = current_calls_before + 1
new_cost = current_cost_before + cost
# Use direct SQL UPDATE for dynamic attributes
from sqlalchemy import text as sql_text
update_query = sql_text("""
UPDATE usage_summaries
SET video_calls = :new_calls,
video_cost = :new_cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_query, {
'new_calls': new_calls,
'new_cost': new_cost,
'user_id': user_id,
'period': current_period
})
# Update total cost
summary.total_cost = (summary.total_cost or 0.0) + cost
summary.total_calls = (summary.total_calls or 0) + 1
summary.updated_at = datetime.utcnow()
# Create usage log
request_size = len(prompt.encode("utf-8")) if prompt else 0
usage_log = APIUsageLog(
user_id=user_id,
provider=APIProvider.WAVESPEED, # Default for video
endpoint=endpoint,
method="POST",
model_used=model or "unknown",
actual_provider_name=provider,
tokens_input=0,
tokens_output=0,
tokens_total=0,
cost_input=0.0,
cost_output=0.0,
cost_total=cost,
response_time=response_time,
status_code=200,
request_size=request_size,
response_size=len(result_bytes) if result_bytes else 0,
billing_period=current_period,
)
db_track.add(usage_log)
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
# Get limits for display
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
video_limit_display = video_limit if (video_limit > 0 or tier != 'enterprise') else ''
# Get related stats for unified log
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else ''
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
image_edit_limit_display = image_edit_limit if (image_edit_limit > 0 or tier != 'enterprise') else ''
db_track.commit()
logger.info(f"{log_prefix} ✅ Successfully tracked usage: user {user_id} -> {operation_type} -> {new_calls} calls, ${cost:.4f}")
# UNIFIED SUBSCRIPTION LOG
operation_name = operation_type.replace("-", " ").title()
print(f"""
[SUBSCRIPTION] {operation_name}
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: {provider}
├─ Actual Provider: {provider}
├─ Model: {model or 'unknown'}
├─ Calls: {current_calls_before}{new_calls} / {video_limit_display}
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
├─ Audio: {current_audio_calls} / {audio_limit_display}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit_display}
└─ Status: ✅ Allowed & Tracked
""", flush=True)
sys.stdout.flush()
return {
"current_calls": new_calls,
"cost": cost,
"total_cost": new_cost,
}
except Exception as track_error:
logger.error(f"{log_prefix} ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
import traceback
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
db_track.rollback()
return {}
finally:
db_track.close()
except Exception as usage_error:
logger.error(f"{log_prefix} ❌ Failed to track usage: {usage_error}", exc_info=True)
import traceback
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
return {}
def _get_api_key(provider: str) -> Optional[str]:
try:
manager = APIKeyManager()
@@ -500,156 +666,74 @@ async def ai_video_generate(
raise
finally:
db.close()
logger.info(f"[Video Generation] ✅ Pre-flight validation passed - proceeding with {operation_type}")
# Progress callback: Initial submission
if progress_callback:
progress_callback(10.0, f"Submitting {operation_type} request to {provider}...")
# Generate video based on operation type
model_name = kwargs.get("model", _get_default_model(operation_type, provider))
# Track response time for video generation
# Track response time
import time
from datetime import datetime
start_time = time.time()
# Execute operation based on type
result = {}
try:
if operation_type == "text-to-video":
if provider == "huggingface":
video_bytes = _generate_with_huggingface(
prompt=prompt,
**kwargs,
)
# For text-to-video, create metadata dict (HuggingFace doesn't return metadata)
result_dict = {
video_bytes = _generate_with_huggingface(prompt=prompt, **kwargs)
result = {
"video_bytes": video_bytes,
"prompt": prompt,
"duration": kwargs.get("duration", 5.0),
"model_name": model_name,
"cost": 0.10, # Default cost, will be calculated in track_video_usage
"provider": provider,
"resolution": kwargs.get("resolution", "720p"),
"width": 1280, # Default, actual may vary
"height": 720, # Default, actual may vary
"metadata": {},
"model_name": kwargs.get("model", "tencent/HunyuanVideo"),
"provider": "huggingface",
"cost": 0.0, # HuggingFace inference is free/low cost
}
elif provider == "wavespeed":
# WaveSpeed text-to-video - use unified service
result_dict = await _generate_text_to_video_wavespeed(
result = await _generate_text_to_video_wavespeed(
prompt=prompt,
progress_callback=progress_callback,
**kwargs,
**kwargs
)
elif provider == "gemini":
video_bytes = _generate_with_gemini(prompt=prompt, **kwargs)
result_dict = {
"video_bytes": video_bytes,
"prompt": prompt,
"duration": kwargs.get("duration", 5.0),
"model_name": model_name,
"cost": 0.10,
"provider": provider,
"resolution": kwargs.get("resolution", "720p"),
"width": 1280,
"height": 720,
"metadata": {},
}
result = {"video_bytes": _generate_with_gemini(prompt, **kwargs)}
elif provider == "openai":
video_bytes = _generate_with_openai(prompt=prompt, **kwargs)
result_dict = {
"video_bytes": video_bytes,
"prompt": prompt,
"duration": kwargs.get("duration", 5.0),
"model_name": model_name,
"cost": 0.10,
"provider": provider,
"resolution": kwargs.get("resolution", "720p"),
"width": 1280,
"height": 720,
"metadata": {},
}
result = {"video_bytes": _generate_with_openai(prompt, **kwargs)}
else:
raise RuntimeError(f"Unknown provider for text-to-video: {provider}")
raise ValueError(f"Unknown provider for text-to-video: {provider}")
elif operation_type == "image-to-video":
if provider == "wavespeed":
# Progress callback: Starting generation
if progress_callback:
progress_callback(20.0, "Video generation in progress...")
# Handle async call from sync context
# Since ai_video_generate is sync, we need to run async function
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# We're in an async context - use ThreadPoolExecutor to run in new event loop
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(
asyncio.run,
_generate_image_to_video_wavespeed(
image_data=image_data,
image_base64=image_base64,
prompt=prompt or kwargs.get("prompt", ""),
progress_callback=progress_callback,
**kwargs
)
)
result_dict = future.result()
else:
# Event loop exists but not running - use it
result_dict = loop.run_until_complete(_generate_image_to_video_wavespeed(
image_data=image_data,
image_base64=image_base64,
prompt=prompt or kwargs.get("prompt", ""),
progress_callback=progress_callback,
**kwargs
))
except RuntimeError:
# No event loop exists, create a new one
result_dict = asyncio.run(_generate_image_to_video_wavespeed(
image_data=image_data,
image_base64=image_base64,
prompt=prompt or kwargs.get("prompt", ""),
progress_callback=progress_callback,
**kwargs
))
video_bytes = result_dict["video_bytes"]
model_name = result_dict.get("model_name", model_name)
# Progress callback: Processing result
if progress_callback:
progress_callback(90.0, "Processing video result...")
result = await _generate_image_to_video_wavespeed(
image_data=image_data,
image_base64=image_base64,
prompt=prompt or "",
progress_callback=progress_callback,
**kwargs
)
else:
raise RuntimeError(f"Unknown provider for image-to-video: {provider}. Only 'wavespeed' is supported.")
raise ValueError(f"Unknown provider for image-to-video: {provider}")
# Track usage (same pattern as text generation)
# Use cost from result_dict if available, otherwise calculate
response_time = time.time() - start_time
cost_override = result_dict.get("cost") if operation_type == "image-to-video" else kwargs.get("cost_override")
track_video_usage(
user_id=user_id,
provider=provider,
model_name=model_name,
prompt=result_dict.get("prompt", prompt or ""),
video_bytes=video_bytes,
cost_override=cost_override,
response_time=response_time,
)
# Progress callback: Complete
if progress_callback:
progress_callback(100.0, "Video generation complete!")
return result_dict
except HTTPException:
# Re-raise HTTPExceptions (e.g., from validation or API errors)
raise
# TRACK USAGE after successful API call
video_bytes = result.get("video_bytes")
if user_id and video_bytes:
_track_video_operation_usage(
user_id=user_id,
provider=result.get("provider", provider),
model=result.get("model_name", kwargs.get("model", "unknown")),
operation_type=operation_type,
result_bytes=video_bytes,
cost=result.get("cost", 0.0),
prompt=prompt,
endpoint="/video-generation",
metadata=result.get("metadata"),
log_prefix=f"[{operation_type.replace('-', ' ').title()}]",
response_time=response_time
)
return result
except Exception as e:
logger.error(f"[video_gen] Error during video generation: {e}", exc_info=True)
raise HTTPException(status_code=500, detail={"error": str(e)})
# Log failure but don't track usage (no cost incurred)
logger.error(f"[video_gen] Generation failed: {str(e)}")
raise
def _get_default_model(operation_type: str, provider: str) -> str:

View File

@@ -46,6 +46,9 @@ class CorePersonaService:
# Get schema for structured response
persona_schema = self.prompt_builder.get_persona_schema()
# Extract user_id for tracking
user_id = onboarding_data.get("session_info", {}).get("user_id")
try:
# Generate structured response using Gemini
response = gemini_structured_json_response(
@@ -53,7 +56,8 @@ class CorePersonaService:
schema=persona_schema,
temperature=0.2, # Low temperature for consistent analysis
max_tokens=8192,
system_prompt="You are an expert writing style analyst and persona developer. Analyze the provided data to create a precise, actionable writing persona."
system_prompt="You are an expert writing style analyst and persona developer. Analyze the provided data to create a precise, actionable writing persona.",
user_id=user_id
)
if "error" in response:
@@ -103,13 +107,17 @@ class CorePersonaService:
# Get platform-specific schema
platform_schema = self.prompt_builder.get_platform_schema()
# Extract user_id for tracking
user_id = onboarding_data.get("session_info", {}).get("user_id")
try:
response = gemini_structured_json_response(
prompt=prompt,
schema=platform_schema,
temperature=0.2,
max_tokens=4096,
system_prompt=f"You are an expert in {platform} content strategy and platform-specific writing optimization."
system_prompt=f"You are an expert in {platform} content strategy and platform-specific writing optimization.",
user_id=user_id
)
return response

View File

@@ -62,6 +62,9 @@ class FacebookPersonaService:
# Get Facebook-specific schema
schema = self._get_enhanced_facebook_schema()
# Extract user_id for tracking
user_id = onboarding_data.get("session_info", {}).get("user_id")
# Generate structured response using Gemini with optimized prompts
response = gemini_structured_json_response(
@@ -69,7 +72,8 @@ class FacebookPersonaService:
schema=schema,
temperature=0.2,
max_tokens=4096,
system_prompt=system_prompt
system_prompt=system_prompt,
user_id=user_id
)
if not response or "error" in response:

View File

@@ -54,13 +54,17 @@ class LinkedInPersonaService:
# Get LinkedIn-specific schema
schema = self.schemas.get_enhanced_linkedin_schema()
# Extract user_id for tracking
user_id = onboarding_data.get("session_info", {}).get("user_id")
# Generate structured response using Gemini with optimized prompts
response = gemini_structured_json_response(
prompt=prompt,
schema=schema,
temperature=0.2,
max_tokens=4096,
system_prompt=system_prompt
system_prompt=system_prompt,
user_id=user_id
)
if "error" in response:

View File

@@ -56,6 +56,17 @@ async def check_and_execute_due_tasks(scheduler: 'TaskScheduler'):
continue
try:
# Check onboarding status first
# Skip users who haven't completed onboarding to prevent premature agent initialization
from services.onboarding.progress_service import OnboardingProgressService
onboarding_service = OnboardingProgressService()
status = onboarding_service.get_onboarding_status(user_id)
if not status.get("is_completed", False):
# Skip logging for inactive users to reduce noise, unless debugging
# logger.debug(f"[Scheduler Check] Skipping user {user_id} - Onboarding incomplete")
continue
# Check active strategies for this user (for interval adjustment)
try:
from services.active_strategy_service import ActiveStrategyService

View File

@@ -67,6 +67,27 @@ class SIFIndexingExecutor(TaskExecutor):
# 2. Sync User Website Content (Deep Crawl / Snapshot)
content_synced = await sif_service.sync_user_website_content(website_url)
# 3. Trigger Content Guardian Audit (Background Analysis)
# This ensures the agent runs immediately after new data is indexed
guardian_report = None
if content_synced:
try:
from services.intelligence.agents.specialized_agents import ContentGuardianAgent
# Re-use the intelligence service from sif_service
guardian_agent = ContentGuardianAgent(
intelligence_service=sif_service.intelligence_service,
user_id=user_id,
sif_service=sif_service
)
logger.info("Triggering Content Guardian Site Audit...")
guardian_report = await guardian_agent.perform_site_audit(website_url)
# Persist the audit report (optional, or rely on logs/alerts)
# For now, we just include it in the task result
except Exception as e:
logger.error(f"Failed to run Content Guardian audit: {e}")
# Determine overall success
# We consider it a success if at least one operation worked, or if both were attempted without error
# But ideally, content sync is the heavy lifter.
@@ -91,6 +112,7 @@ class SIFIndexingExecutor(TaskExecutor):
task_log.result_data = {
"metadata_synced": metadata_synced,
"content_synced": content_synced,
"guardian_report": guardian_report,
"website_url": website_url
}
task_log.execution_time_ms = int((time.time() - start_time) * 1000)

View File

@@ -29,9 +29,10 @@ def load_due_sif_indexing_tasks(db: Session, user_id: str = None) -> List[SIFInd
query = db.query(SIFIndexingTask).filter(
or_(
SIFIndexingTask.status == "pending",
SIFIndexingTask.status == "active",
SIFIndexingTask.status == "failed" # Retry failed tasks
),
SIFIndexingTask.next_run_at <= datetime.utcnow()
SIFIndexingTask.next_execution <= datetime.utcnow()
)
if user_id:

View File

@@ -199,6 +199,24 @@ class PricingService:
"cost_per_input_token": 0.0, # No additional token cost for grounding
"cost_per_output_token": 0.0, # No additional token cost for grounding
"description": "Grounding with Google Search - 1,500 RPD free, then $35/1K requests"
},
# Alwrity Voice Cloning - Qwen3
{
"provider": APIProvider.AUDIO,
"model_name": "alwrity-ai/qwen3-tts/voice-clone",
"cost_per_request": 0.10,
"cost_per_input_token": 0.00001,
"cost_per_output_token": 0.0,
"description": "Alwrity Qwen3 Voice Clone (Efficient)"
},
# Alwrity Voice Cloning - CosyVoice
{
"provider": APIProvider.AUDIO,
"model_name": "alwrity-ai/cosyvoice/voice-clone",
"cost_per_request": 0.15,
"cost_per_input_token": 0.00001,
"cost_per_output_token": 0.0,
"description": "Alwrity CosyVoice Clone (High Fidelity)"
}
]
@@ -402,11 +420,19 @@ class PricingService:
{
"provider": APIProvider.AUDIO,
"model_name": "wavespeed-ai/qwen3-tts/voice-clone",
"cost_per_request": 0.0,
"cost_per_input_token": 0.0,
"cost_per_request": 0.005,
"cost_per_input_token": 0.00005,
"cost_per_output_token": 0.0,
"description": "Qwen3-TTS Voice Clone via WaveSpeed (cost depends on text length)"
},
{
"provider": APIProvider.AUDIO,
"model_name": "wavespeed-ai/cosyvoice-tts/voice-clone",
"cost_per_request": 0.005,
"cost_per_input_token": 0.00005,
"cost_per_output_token": 0.0,
"description": "CosyVoice-TTS Voice Clone via WaveSpeed (cost depends on text length)"
},
{
"provider": APIProvider.AUDIO,
"model_name": "default",
@@ -429,8 +455,9 @@ class PricingService:
if existing:
# Update existing pricing (especially for HuggingFace if env vars changed)
if pricing_data["provider"] == APIProvider.MISTRAL:
# Update HuggingFace pricing from env vars
if pricing_data["provider"] in [APIProvider.MISTRAL, APIProvider.AUDIO]:
# Update pricing
existing.cost_per_request = pricing_data.get("cost_per_request", 0.0)
existing.cost_per_input_token = pricing_data["cost_per_input_token"]
existing.cost_per_output_token = pricing_data["cost_per_output_token"]
existing.description = pricing_data["description"]

View File

@@ -490,6 +490,32 @@ class UsageTrackingService:
'cost': image_edit_cost
}
# WaveSpeed (aggregated across Video, Audio, Image, Image Edit)
# Query APIUsageLog directly to get accurate WaveSpeed-specific usage
wavespeed_logs = self.db.query(APIUsageLog).filter(
APIUsageLog.user_id == user_id,
APIUsageLog.billing_period == billing_period,
APIUsageLog.actual_provider_name == "wavespeed"
).all()
if wavespeed_logs:
wavespeed_calls = len(wavespeed_logs)
wavespeed_tokens = sum((log.tokens_total or 0) for log in wavespeed_logs)
wavespeed_cost = sum(float(log.cost_total or 0.0) for log in wavespeed_logs)
provider_breakdown['wavespeed'] = {
'calls': wavespeed_calls,
'tokens': wavespeed_tokens,
'cost': wavespeed_cost
}
logger.info(f"[UsageStats] Calculated WaveSpeed usage: {wavespeed_calls} calls, ${wavespeed_cost:.6f}")
else:
provider_breakdown['wavespeed'] = {
'calls': 0,
'tokens': 0,
'cost': 0.0
}
# Search APIs
tavily_calls = getattr(summary, "tavily_calls", 0) or 0
tavily_cost = getattr(summary, "tavily_cost", 0.0) or 0.0

View File

@@ -12,6 +12,7 @@ from loguru import logger
from services.image_studio.infinitetalk_adapter import InfiniteTalkService
from services.video_studio.hunyuan_avatar_adapter import HunyuanAvatarService
from utils.logger_utils import get_service_logger
from services.llm_providers.main_video_generation import _track_video_operation_usage
logger = get_service_logger("video_studio.avatar")
@@ -58,6 +59,30 @@ class AvatarStudioService:
f"[AvatarStudio] Creating talking avatar: user={user_id}, resolution={resolution}, model={model}"
)
# PRE-FLIGHT VALIDATION: Validate video generation before API call
# MUST happen BEFORE any API calls - return immediately if validation fails
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_video_generation_operations
db = next(get_db())
try:
pricing_service = PricingService(db)
# Raises HTTPException immediately if validation fails - frontend gets immediate response
validate_video_generation_operations(
pricing_service=pricing_service,
user_id=user_id
)
except HTTPException:
# Re-raise immediately - don't proceed with API call
logger.error(f"[AvatarStudio] ❌ Pre-flight validation failed - blocking API call")
raise
finally:
db.close()
import time
start_time = time.time()
try:
if model == "hunyuan-avatar":
# Use Hunyuan Avatar (doesn't support mask_image)
@@ -82,12 +107,32 @@ class AvatarStudioService:
user_id=user_id,
)
response_time = time.time() - start_time
logger.info(
f"[AvatarStudio] ✅ Talking avatar created: "
f"model={model}, resolution={resolution}, duration={result.get('duration', 0)}s, "
f"cost=${result.get('cost', 0):.2f}"
)
# TRACK USAGE after successful API call
# Use video_bytes if available, otherwise check if result itself is bytes (unlikely, dict expected)
video_bytes = result.get("video_bytes")
if user_id and video_bytes:
_track_video_operation_usage(
user_id=user_id,
provider=model, # Use model name as provider/actual_provider for now
model=model,
operation_type="talking-avatar",
result_bytes=video_bytes,
cost=result.get("cost", 0.0),
prompt=prompt,
endpoint="/avatar-generation",
metadata=result,
log_prefix="[Avatar Generation]",
response_time=response_time
)
return result
except HTTPException:

View File

@@ -324,6 +324,39 @@ class WaveSpeedClient:
timeout=timeout,
)
def voice_design(
self,
text: str,
voice_description: str,
language: str = "auto",
timeout: int = 180,
) -> bytes:
return self.speech.voice_design(
text=text,
voice_description=voice_description,
language=language,
timeout=timeout,
)
def cosyvoice_voice_clone(
self,
audio_bytes: bytes,
text: str,
*,
model: str = "wavespeed-ai/cosyvoice-tts/voice-clone",
audio_mime_type: str = "audio/wav",
reference_text: Optional[str] = None,
timeout: int = 180,
) -> bytes:
return self.speech.cosyvoice_voice_clone(
audio_bytes=audio_bytes,
text=text,
model=model,
audio_mime_type=audio_mime_type,
reference_text=reference_text,
timeout=timeout,
)
def generate_text_video(
self,
prompt: str,

View File

@@ -146,14 +146,44 @@ class PromptGenerator:
if isinstance(first_output, str):
if first_output.startswith("http://") or first_output.startswith("https://"):
logger.info(f"[WaveSpeed] Fetching optimized prompt from URL: {first_output}")
url_response = requests.get(first_output, timeout=timeout)
if url_response.status_code == 200:
return url_response.text.strip()
else:
logger.error(f"[WaveSpeed] Failed to fetch prompt from URL: {url_response.status_code}")
# Use stream=True to avoid downloading large files into memory
try:
with requests.get(first_output, timeout=timeout, stream=True) as url_response:
if url_response.status_code == 200:
# Check Content-Length if available
content_length = url_response.headers.get("Content-Length")
if content_length and int(content_length) > 1024 * 1024: # 1MB limit for prompts
logger.error(f"[WaveSpeed] Optimized prompt URL content too large: {content_length} bytes")
raise HTTPException(
status_code=502,
detail="WaveSpeed prompt optimizer returned a file that is too large",
)
# Read content with limit
content = ""
for chunk in url_response.iter_content(chunk_size=8192, decode_unicode=True):
if chunk:
content += chunk
if len(content) > 1024 * 1024: # Hard limit 1MB
logger.error("[WaveSpeed] Optimized prompt URL content exceeded 1MB limit during download")
raise HTTPException(
status_code=502,
detail="WaveSpeed prompt optimizer returned a file that is too large",
)
return content.strip()
else:
logger.error(f"[WaveSpeed] Failed to fetch prompt from URL: {url_response.status_code}")
raise HTTPException(
status_code=502,
detail="Failed to fetch optimized prompt from WaveSpeed URL",
)
except requests.RequestException as e:
logger.error(f"[WaveSpeed] Error fetching prompt from URL: {str(e)}")
raise HTTPException(
status_code=502,
detail="Failed to fetch optimized prompt from WaveSpeed URL",
detail=f"Error fetching optimized prompt: {str(e)}",
)
else:
# It's already the text

View File

@@ -181,6 +181,102 @@ class SpeechGenerator:
audio_url = self._extract_audio_url(outputs)
return self._download_audio(audio_url, timeout)
def voice_design(
self,
text: str,
voice_description: str,
language: str = "auto",
timeout: int = 180,
) -> bytes:
"""
Generate speech using Qwen3 Voice Design (text + voice description).
"""
url = f"{self.base_url}/wavespeed-ai/qwen3-tts/voice-design"
payload = {
"text": text,
"voice_description": voice_description,
"language": language
}
logger.info(f"[WaveSpeed] Voice design via {url}")
try:
response = requests.post(
url,
headers=self._get_headers(),
json=payload,
timeout=(30, 90),
)
except requests_exceptions.Timeout as e:
raise HTTPException(status_code=504, detail={"error": "WaveSpeed Voice Design timed out", "message": str(e)})
except (requests_exceptions.ConnectionError, requests_exceptions.ConnectTimeout) as e:
raise HTTPException(status_code=504, detail={"error": "WaveSpeed Voice Design connection failed", "message": str(e)})
if response.status_code != 200:
raise HTTPException(
status_code=response.status_code,
detail={"error": "WaveSpeed Voice Design failed", "message": response.text}
)
try:
data = response.json()
# The API is async and returns a task ID or direct output depending on implementation.
# Based on user input, it returns a "data" object with "id" and we poll.
# BUT wait, the Python example provided by user shows:
# response = requests.post(url, ...)
# if response.status_code == 200: result = response.json()["data"] ...
# Then it polls /api/v3/predictions/{request_id}/result
# Let's handle the async polling logic here or in the caller.
# The user's Python example is very clear. It's an async task.
if "data" in data and "id" in data["data"]:
request_id = data["data"]["id"]
return self._poll_prediction_result(request_id, timeout=timeout)
# Fallback if it returns direct output (unlikely based on docs)
if "data" in data and "outputs" in data["data"] and data["data"]["outputs"]:
return self._download_audio(data["data"]["outputs"][0]["url"], timeout) # Assuming structure
raise ValueError(f"Unexpected response format: {data}")
except Exception as e:
logger.error(f"[WaveSpeed] Error parsing Voice Design response: {e}")
raise HTTPException(status_code=500, detail={"error": "Failed to parse Voice Design response", "message": str(e)})
def _poll_prediction_result(self, request_id: str, timeout: int = 180) -> bytes:
import time
url = f"https://api.wavespeed.ai/api/v3/predictions/{request_id}/result"
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = requests.get(url, headers=self._get_headers(), timeout=10)
if response.status_code == 200:
result = response.json().get("data", {})
status = result.get("status")
if status == "completed":
if result.get("outputs") and len(result["outputs"]) > 0:
audio_url = result["outputs"][0] # It's a URL string in the array
return self._download_audio(audio_url, timeout)
else:
raise ValueError("Completed task has no output URLs")
elif status == "failed":
raise ValueError(f"Task failed: {result.get('error')}")
# If processing/created, continue polling
time.sleep(1)
else:
logger.warning(f"Polling error {response.status_code}: {response.text}")
time.sleep(1)
except Exception as e:
logger.error(f"Polling exception: {e}")
time.sleep(1)
raise HTTPException(status_code=504, detail="Voice Design generation timed out")
def voice_clone(
self,
audio_bytes: bytes,
@@ -320,6 +416,70 @@ class SpeechGenerator:
audio_url = self._extract_audio_url(outputs)
return self._download_audio(audio_url, timeout)
def cosyvoice_voice_clone(
self,
audio_bytes: bytes,
text: str,
*,
model: str = "wavespeed-ai/cosyvoice-tts/voice-clone",
audio_mime_type: str = "audio/wav",
reference_text: Optional[str] = None,
timeout: int = 180,
) -> bytes:
url = f"{self.base_url}/{model}"
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
mime = audio_mime_type or "audio/wav"
audio_data_url = f"data:{mime};base64,{audio_b64}"
payload = {
"audio": audio_data_url,
"text": text,
}
if reference_text:
payload["reference_text"] = reference_text
logger.info(f"[WaveSpeed] CosyVoice voice clone via {url}")
try:
response = requests.post(
url,
headers=self._get_headers(),
json=payload,
timeout=(30, 90),
)
except requests_exceptions.Timeout as e:
raise HTTPException(status_code=504, detail={"error": "WaveSpeed CosyVoice voice clone timed out", "message": str(e)})
except (requests_exceptions.ConnectionError, requests_exceptions.ConnectTimeout) as e:
raise HTTPException(status_code=504, detail={"error": "WaveSpeed CosyVoice voice clone connection failed", "message": str(e)})
if response.status_code != 200:
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed CosyVoice voice clone failed",
"status_code": response.status_code,
"response": response.text,
},
)
response_json = response.json()
data = response_json.get("data") or response_json
outputs = data.get("outputs") or []
status = data.get("status")
prediction_id = data.get("id")
if not outputs and prediction_id and status in {"created", "processing"}:
result = self.polling.poll_until_complete(prediction_id, timeout_seconds=timeout, interval_seconds=0.8)
outputs = result.get("outputs") or []
if not outputs:
raise HTTPException(status_code=502, detail="WaveSpeed CosyVoice voice clone returned no outputs")
audio_url = self._extract_audio_url(outputs)
return self._download_audio(audio_url, timeout)
def _extract_audio_url(self, outputs: list) -> str:
"""Extract audio URL from outputs."""
if not isinstance(outputs, list) or len(outputs) == 0:

View File

@@ -90,9 +90,56 @@ def bootstrap_linguistic_models():
return True
def bootstrap_local_llm_models():
"""
Bootstrap Local LLM models (Qwen) for SIF Agents.
This ensures the model is cached locally before the server starts,
preventing large downloads during runtime.
"""
import os
verbose = os.getenv("ALWRITY_VERBOSE", "false").lower() == "true"
# Model to pre-download
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
# Using Qwen2.5-1.5B as it's more efficient for laptop CPU than 4B,
# but still capable for agent routing/clustering.
# If user specifically asked for Qwen3-4B, we can use that, but 1.5B is much faster.
# User said "local qwen model", 4B might be heavy. Let's stick to what was in code: "Qwen/Qwen3-4B-Instruct-2507"
# Actually, the code had "Qwen/Qwen3-4B-Instruct-2507" which seems like a specific fine-tune or typo.
# Let's use a standard efficient one: "Qwen/Qwen2.5-3B-Instruct" or "Qwen/Qwen2.5-1.5B-Instruct".
# Given "optimized for cpu-laptop", 1.5B or 3B is best.
# Let's use the one referenced in the code if valid, otherwise Qwen2.5-3B.
# The code had: "Qwen/Qwen3-4B-Instruct-2507". I suspect this is a placeholder or internal model.
# I will use "Qwen/Qwen2.5-3B-Instruct" as a safe, modern, powerful laptop-friendly default.
target_model = "Qwen/Qwen2.5-3B-Instruct"
if verbose:
print(f"🔍 Checking local LLM model '{target_model}'...")
try:
from huggingface_hub import snapshot_download
try:
# This checks cache and downloads if missing
snapshot_download(repo_id=target_model, repo_type="model")
if verbose:
print(f" ✅ Local LLM '{target_model}' available")
except Exception as e:
if verbose:
print(f" ⚠️ Failed to download/check local LLM: {e}")
print(" SIF agents may try to download it at runtime.")
return False
except ImportError:
if verbose:
print(" ⚠️ huggingface_hub not installed - skipping LLM bootstrap")
return True
# Bootstrap linguistic models BEFORE any imports that might need them
if __name__ == "__main__":
bootstrap_linguistic_models()
bootstrap_local_llm_models()
# NOW import modular utilities (after bootstrap)
from alwrity_utils import (

View File

@@ -114,8 +114,19 @@ def save_asset_to_library(
try:
source_module_enum = AssetSource(source_module.lower())
except ValueError:
logger.warning(f"Invalid source module: {source_module}, defaulting to 'story_writer'")
source_module_enum = AssetSource.STORY_WRITER
logger.warning(f"Invalid source module: {source_module}, attempting fallback based on asset type")
# Smart fallback based on asset type
if asset_type_enum == AssetType.IMAGE:
source_module_enum = AssetSource.MAIN_IMAGE_GENERATION
elif asset_type_enum == AssetType.AUDIO:
source_module_enum = AssetSource.MAIN_AUDIO_GENERATION
elif asset_type_enum == AssetType.VIDEO:
source_module_enum = AssetSource.MAIN_VIDEO_GENERATION
else:
source_module_enum = AssetSource.MAIN_TEXT_GENERATION
logger.info(f"Fallback source module: {source_module_enum.value}")
# Sanitize filename (remove path traversal attempts)
filename = re.sub(r'[^\w\s\-_\.]', '', filename.split('/')[-1])
@@ -151,6 +162,25 @@ def save_asset_to_library(
)
logger.info(f"✅ Asset saved to library: {asset.id} ({asset_type} from {source_module})")
# Trigger SIF Indexing for all new assets (Text, Image, etc.)
try:
from models.website_analysis_monitoring_models import SIFIndexingTask
from datetime import datetime
# Check if a SIF Indexing task exists for this user
existing_task = db.query(SIFIndexingTask).filter(SIFIndexingTask.user_id == user_id).first()
if existing_task:
logger.info(f"Triggering SIF Indexing task for user {user_id} due to new {asset_type} asset")
existing_task.next_execution = datetime.utcnow() # Run immediately
existing_task.status = "pending" # Ensure it gets picked up
db.add(existing_task)
# Note: Commit depends on the caller's transaction management
else:
logger.debug(f"No SIF Indexing task found for user {user_id} - skipping trigger")
except Exception as e:
logger.warning(f"Failed to trigger SIF Indexing task in asset_tracker: {e}")
return asset.id
except Exception as e:

View File

@@ -88,6 +88,14 @@ def save_file_safely(
Tuple of (file_path, error_message). file_path is None on error.
"""
try:
# Handle max_file_size if it comes as string (e.g. from env vars)
if isinstance(max_file_size, str):
try:
max_file_size = int(max_file_size)
except ValueError:
# Fallback to default if conversion fails
max_file_size = 100 * 1024 * 1024
# Validate file size
if len(content) > max_file_size:
return None, f"File size {len(content)} exceeds maximum {max_file_size}"

View File

@@ -0,0 +1,116 @@
"""
Media Utility Functions
Centralized helper functions for loading and managing media assets across modules.
Promotes reuse between Podcast, YouTube, and other media-heavy modules.
"""
import logging
from pathlib import Path
from typing import Optional, List
from urllib.parse import urlparse
# Configure logging
logger = logging.getLogger(__name__)
# Base Directories
# backend/utils/media_utils.py -> parents[2] = backend/.. = root
ROOT_DIR = Path(__file__).resolve().parents[2]
DATA_MEDIA_DIR = ROOT_DIR / "data" / "media"
# Module-specific directories
YOUTUBE_AVATARS_DIR = DATA_MEDIA_DIR / "youtube_avatars"
YOUTUBE_IMAGES_DIR = DATA_MEDIA_DIR / "youtube_images"
PODCAST_IMAGES_DIR = DATA_MEDIA_DIR / "podcast_images"
PODCAST_AVATARS_DIR = PODCAST_IMAGES_DIR / "avatars"
# Ensure directories exist
for directory in [YOUTUBE_AVATARS_DIR, YOUTUBE_IMAGES_DIR, PODCAST_IMAGES_DIR, PODCAST_AVATARS_DIR]:
directory.mkdir(parents=True, exist_ok=True)
def resolve_media_path(media_url_or_path: str) -> Optional[Path]:
"""
Resolve a media URL or filename to a concrete file path on disk.
Handles cross-module lookups (e.g. checking podcast avatars if not found in youtube).
Args:
media_url_or_path: URL path (e.g. /api/youtube/avatars/foo.png) or filename
Returns:
Path object if found, None otherwise
"""
if not media_url_or_path:
return None
try:
# Extract filename from URL/path
if "/" in media_url_or_path or "\\" in media_url_or_path:
# It's a URL or path
parsed = urlparse(media_url_or_path)
path = parsed.path if parsed.scheme else media_url_or_path
filename = path.split("/")[-1].split("?")[0]
else:
# It's just a filename
filename = media_url_or_path.split("?")[0]
if not filename:
return None
# Define search paths in order of likelihood
# We search all avatar/image directories
search_paths: List[Path] = [
YOUTUBE_AVATARS_DIR / filename,
PODCAST_AVATARS_DIR / filename,
YOUTUBE_IMAGES_DIR / filename,
PODCAST_IMAGES_DIR / filename,
# Fallback for nested podcast images (if they exist directly in podcast_images)
PODCAST_IMAGES_DIR / "avatars" / filename
]
# Check specific module based on URL prefix if present
if "/api/youtube/" in media_url_or_path:
# Prioritize YouTube paths
pass # Already first in list
elif "/api/podcast/" in media_url_or_path:
# Prioritize Podcast paths
search_paths = [
PODCAST_AVATARS_DIR / filename,
PODCAST_IMAGES_DIR / filename,
YOUTUBE_AVATARS_DIR / filename,
YOUTUBE_IMAGES_DIR / filename
]
# Iterate and find first existing file
for path in search_paths:
if path.exists() and path.is_file():
logger.debug(f"[MediaUtils] Resolved {media_url_or_path} to {path}")
return path
logger.warning(f"[MediaUtils] Could not resolve media path for: {media_url_or_path}")
return None
except Exception as e:
logger.error(f"[MediaUtils] Error resolving media path: {e}")
return None
def load_media_bytes(media_url_or_path: str) -> Optional[bytes]:
"""
Load media bytes from a URL or path with cross-module fallback.
Args:
media_url_or_path: URL path or filename
Returns:
File bytes if found, None otherwise
"""
path = resolve_media_path(media_url_or_path)
if path:
try:
return path.read_bytes()
except Exception as e:
logger.error(f"[MediaUtils] Error reading file {path}: {e}")
return None
return None

View File

@@ -150,6 +150,29 @@ def save_and_track_text_content(
if asset_id:
logger.info(f"✅ Text asset saved to library: ID={asset_id}, filename={filename}")
# Trigger SIF Content Guardian Indexing
try:
from models.website_analysis_monitoring_models import SIFIndexingTask
from datetime import datetime
# Use the existing DB session
# Check if a SIF Indexing task exists for this user
existing_task = db.query(SIFIndexingTask).filter(SIFIndexingTask.user_id == user_id).first()
if existing_task:
logger.info(f"Triggering SIF Indexing task for user {user_id} due to new content")
existing_task.next_execution = datetime.utcnow() # Run immediately
existing_task.status = "pending" # Ensure it gets picked up
db.add(existing_task)
# We don't force commit here as the session might be managed by the caller
# But if the caller commits, this change will be included.
# If the caller uses autocommit=False and commits later, this is fine.
# Most API endpoints commit at the end.
else:
logger.debug(f"No SIF Indexing task found for user {user_id} - skipping trigger")
except Exception as e:
logger.warning(f"Failed to trigger SIF Indexing task: {e}")
else:
logger.warning(f"Asset tracking returned None for {filename}")

64
debug_usage.py Normal file
View File

@@ -0,0 +1,64 @@
import sys
import os
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
# Add backend to path
sys.path.append(os.path.join(os.getcwd(), 'backend'))
from services.database import Base
from models.subscription_models import APIUsageLog, UserSubscription
from services.subscription import UsageTrackingService, PricingService
# Setup DB connection
# dynamic path resolution as per codebase
DB_PATH = os.path.join(os.getcwd(), 'backend', 'data', 'alwrity.db')
# Note: The codebase might use user-specific DBs now.
# Let's check how get_db works or if we need to look at a specific user db.
# user_memories says: Database path updated to `workspace/workspace_{user_id}/db/alwrity.db` to support user isolation.
USER_ID = "user_33Gz1FPI86VDXhRY8QN4ragRFGN"
WORKSPACE_DB_PATH = os.path.join(os.getcwd(), 'workspace', f'workspace_{USER_ID}', 'db', 'alwrity.db')
print(f"Checking specific user DB at: {WORKSPACE_DB_PATH}")
if os.path.exists(WORKSPACE_DB_PATH):
db_url = f"sqlite:///{WORKSPACE_DB_PATH}"
else:
print(f"User DB not found at {WORKSPACE_DB_PATH}, falling back to main DB for check (legacy/shared mode)")
db_url = f"sqlite:///backend/data/alwrity.db"
engine = create_engine(db_url)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = SessionLocal()
try:
print(f"\n--- Checking Usage for User: {USER_ID} ---")
# Check API Usage Logs
logs_count = db.query(APIUsageLog).filter(APIUsageLog.user_id == USER_ID).count()
print(f"Total API Usage Logs: {logs_count}")
if logs_count > 0:
last_log = db.query(APIUsageLog).filter(APIUsageLog.user_id == USER_ID).order_by(APIUsageLog.created_at.desc()).first()
print(f"Last Activity: {last_log.created_at} - {last_log.endpoint} ({last_log.provider})")
# Check Subscription
sub = db.query(UserSubscription).filter(UserSubscription.user_id == USER_ID).first()
if sub:
print(f"Subscription: {sub.plan_type} (Status: {sub.status})")
else:
print("No subscription record found.")
# Run Service Logic
print("\n--- Running UsageTrackingService.get_user_usage_stats ---")
usage_service = UsageTrackingService(db)
stats = usage_service.get_user_usage_stats(USER_ID)
print("Stats returned:")
print(stats)
except Exception as e:
print(f"Error: {e}")
finally:
db.close()

Binary file not shown.

After

Width:  |  Height:  |  Size: 52 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 52 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 117 KiB

View File

@@ -1,5 +1,5 @@
import React from 'react';
import { BrowserRouter as Router, Routes, Route, Navigate } from 'react-router-dom';
import { BrowserRouter as Router, Routes, Route, Navigate, useLocation } from 'react-router-dom';
import { Box, CircularProgress, Typography } from '@mui/material';
import { CopilotKit } from "@copilotkit/react-core";
import { ClerkProvider, useAuth } from '@clerk/clerk-react';
@@ -80,6 +80,92 @@ const ConditionalCopilotKit: React.FC<{ children: React.ReactNode }> = ({ childr
return <>{children}</>;
};
// Wrapper to only enable CopilotKit checks/provider when user is authenticated
// This prevents CopilotKit from running on the Landing page
const AuthenticatedCopilotWrapper: React.FC<{
children: React.ReactNode;
apiKey: string;
}> = ({ children, apiKey }) => {
const { isSignedIn } = useAuth();
const location = useLocation();
// Exclude CopilotKit from running on:
// 1. Landing page (handled by !isSignedIn)
// 2. Onboarding pages (to prevent health check timeouts)
const shouldExcludeCopilot = !isSignedIn || location.pathname.startsWith('/onboarding');
if (shouldExcludeCopilot) {
return <>{children}</>;
}
const hasKey = apiKey && apiKey.trim();
if (hasKey) {
// Enhanced error handler that updates health context
const handleCopilotKitError = (e: any) => {
console.error("CopilotKit Error:", e);
// Try to get health context if available
// We'll use a custom event to notify health context since we can't access it directly here
const errorMessage = e?.error?.message || e?.message || 'CopilotKit error occurred';
const errorType = errorMessage.toLowerCase();
// Differentiate between fatal and transient errors
const isFatalError =
errorType.includes('cors') ||
errorType.includes('ssl') ||
errorType.includes('certificate') ||
errorType.includes('403') ||
errorType.includes('forbidden') ||
errorType.includes('ERR_CERT_COMMON_NAME_INVALID');
// Dispatch event for health context to listen to
window.dispatchEvent(new CustomEvent('copilotkit-error', {
detail: {
error: e,
errorMessage,
isFatal: isFatalError,
}
}));
};
return (
<CopilotKitHealthProvider initialHealthStatus={true}>
<CopilotKitDegradedBanner />
<ErrorBoundary
context="CopilotKit"
showDetails={process.env.NODE_ENV === 'development'}
fallback={
<Box sx={{ p: 3, textAlign: 'center' }}>
<Typography variant="h6" color="warning" gutterBottom>
Chat Unavailable
</Typography>
<Typography variant="body2" color="textSecondary">
CopilotKit encountered an error. The app continues to work with manual controls.
</Typography>
</Box>
}
>
<CopilotKit
publicApiKey={apiKey}
showDevConsole={false}
onError={handleCopilotKitError}
>
{children}
</CopilotKit>
</ErrorBoundary>
</CopilotKitHealthProvider>
);
}
return (
<CopilotKitHealthProvider initialHealthStatus={false}>
<CopilotKitDegradedBanner />
{children}
</CopilotKitHealthProvider>
);
};
// Component to handle initial routing based on subscription and onboarding status
// Flow: Subscription → Onboarding → Dashboard
const InitialRouteHandler: React.FC = () => {
@@ -473,146 +559,86 @@ const App: React.FC = () => {
// Render app with or without CopilotKit based on whether we have a key
const renderApp = () => {
const appContent = (
return (
<Router>
<ConditionalCopilotKit>
<TokenInstaller />
<Routes>
<Route path="/" element={<RootRoute />} />
<Route
path="/onboarding"
element={
<ErrorBoundary context="Onboarding Wizard" showDetails>
<Wizard />
</ErrorBoundary>
}
/>
{/* Error Boundary Testing - Development Only */}
{process.env.NODE_ENV === 'development' && (
<Route path="/error-test" element={<ErrorBoundaryTest />} />
)}
<Route path="/dashboard" element={<ProtectedRoute><MainDashboard /></ProtectedRoute>} />
<Route path="/seo" element={<ProtectedRoute><SEODashboard /></ProtectedRoute>} />
<Route path="/seo-dashboard" element={<ProtectedRoute><SEODashboard /></ProtectedRoute>} />
<Route path="/content-planning" element={<ProtectedRoute><ContentPlanningDashboard /></ProtectedRoute>} />
<Route path="/facebook-writer" element={<ProtectedRoute><FacebookWriter /></ProtectedRoute>} />
<Route path="/linkedin-writer" element={<ProtectedRoute><LinkedInWriter /></ProtectedRoute>} />
<Route path="/blog-writer" element={<ProtectedRoute><BlogWriter /></ProtectedRoute>} />
<Route path="/story-writer" element={<ProtectedRoute><StoryWriter /></ProtectedRoute>} />
<Route path="/youtube-creator" element={<ProtectedRoute><YouTubeCreator /></ProtectedRoute>} />
<Route path="/podcast-maker" element={<ProtectedRoute><PodcastDashboard /></ProtectedRoute>} />
<Route path="/image-studio" element={<ProtectedRoute><ImageStudioDashboard /></ProtectedRoute>} />
<Route path="/video-studio" element={<ProtectedRoute><VideoStudioDashboard /></ProtectedRoute>} />
<Route path="/video-studio/create" element={<ProtectedRoute><CreateVideo /></ProtectedRoute>} />
<Route path="/video-studio/avatar" element={<ProtectedRoute><AvatarVideo /></ProtectedRoute>} />
<Route path="/video-studio/enhance" element={<ProtectedRoute><EnhanceVideo /></ProtectedRoute>} />
<Route path="/video-studio/extend" element={<ProtectedRoute><ExtendVideo /></ProtectedRoute>} />
<Route path="/video-studio/edit" element={<ProtectedRoute><EditVideo /></ProtectedRoute>} />
<Route path="/video-studio/transform" element={<ProtectedRoute><TransformVideo /></ProtectedRoute>} />
<Route path="/video-studio/social" element={<ProtectedRoute><SocialVideo /></ProtectedRoute>} />
<Route path="/video-studio/face-swap" element={<ProtectedRoute><FaceSwap /></ProtectedRoute>} />
<Route path="/video-studio/video-translate" element={<ProtectedRoute><VideoTranslate /></ProtectedRoute>} />
<Route path="/video-studio/video-background-remover" element={<ProtectedRoute><VideoBackgroundRemover /></ProtectedRoute>} />
<Route path="/video-studio/add-audio-to-video" element={<ProtectedRoute><AddAudioToVideo /></ProtectedRoute>} />
<Route path="/video-studio/library" element={<ProtectedRoute><LibraryVideo /></ProtectedRoute>} />
<Route path="/image-generator" element={<ProtectedRoute><CreateStudio /></ProtectedRoute>} />
<Route path="/image-editor" element={<ProtectedRoute><EditStudio /></ProtectedRoute>} />
<Route path="/image-upscale" element={<ProtectedRoute><UpscaleStudio /></ProtectedRoute>} />
<Route path="/image-control" element={<ProtectedRoute><ControlStudio /></ProtectedRoute>} />
<Route path="/image-studio/face-swap" element={<ProtectedRoute><FaceSwapStudio /></ProtectedRoute>} />
<Route path="/image-studio/compress" element={<ProtectedRoute><CompressionStudio /></ProtectedRoute>} />
<Route path="/image-studio/processing" element={<ProtectedRoute><ImageProcessingStudio /></ProtectedRoute>} />
<Route path="/image-studio/social-optimizer" element={<ProtectedRoute><SocialOptimizer /></ProtectedRoute>} />
<Route path="/asset-library" element={<ProtectedRoute><AssetLibrary /></ProtectedRoute>} />
<Route path="/campaign-creator" element={<ProtectedRoute><ProductMarketingDashboard /></ProtectedRoute>} />
<Route path="/campaign-creator/photoshoot" element={<ProtectedRoute><ProductPhotoshootStudio /></ProtectedRoute>} />
<Route path="/campaign-creator/animation" element={<ProtectedRoute><ProductAnimationStudio /></ProtectedRoute>} />
<Route path="/campaign-creator/video" element={<ProtectedRoute><ProductVideoStudio /></ProtectedRoute>} />
<Route path="/campaign-creator/avatar" element={<ProtectedRoute><ProductAvatarStudio /></ProtectedRoute>} />
<Route path="/product-marketing" element={<Navigate to="/campaign-creator" replace />} />
<Route path="/scheduler-dashboard" element={<ProtectedRoute><SchedulerDashboard /></ProtectedRoute>} />
<Route path="/billing" element={<ProtectedRoute><BillingPage /></ProtectedRoute>} />
<Route path="/approvals" element={<ProtectedRoute><ApprovalsPage /></ProtectedRoute>} />
<Route path="/pricing" element={<PricingPage />} />
<Route path="/research-test" element={<ResearchDashboard />} />
<Route path="/research-dashboard" element={<ResearchDashboard />} />
<Route path="/alwrity-researcher" element={<ResearchDashboard />} />
<Route path="/intent-research" element={<IntentResearchTest />} />
<Route path="/wix-test" element={<WixTestPage />} />
<Route path="/wix-test-direct" element={<WixTestPage />} />
<Route path="/wix/callback" element={<WixCallbackPage />} />
<Route path="/wp/callback" element={<WordPressCallbackPage />} />
<Route path="/gsc/callback" element={<GSCAuthCallback />} />
<Route path="/bing/callback" element={<BingCallbackPage />} />
<Route path="/bing-analytics-storage" element={<ProtectedRoute><BingAnalyticsStorage /></ProtectedRoute>} />
</Routes>
</ConditionalCopilotKit>
<AuthenticatedCopilotWrapper apiKey={copilotApiKey}>
<ConditionalCopilotKit>
<TokenInstaller />
<Routes>
<Route path="/" element={<RootRoute />} />
<Route
path="/onboarding"
element={
<ErrorBoundary context="Onboarding Wizard" showDetails>
<Wizard />
</ErrorBoundary>
}
/>
{/* Error Boundary Testing - Development Only */}
{process.env.NODE_ENV === 'development' && (
<Route path="/error-test" element={<ErrorBoundaryTest />} />
)}
<Route path="/dashboard" element={<ProtectedRoute><MainDashboard /></ProtectedRoute>} />
<Route path="/seo" element={<ProtectedRoute><SEODashboard /></ProtectedRoute>} />
<Route path="/seo-dashboard" element={<ProtectedRoute><SEODashboard /></ProtectedRoute>} />
<Route path="/content-planning" element={<ProtectedRoute><ContentPlanningDashboard /></ProtectedRoute>} />
<Route path="/facebook-writer" element={<ProtectedRoute><FacebookWriter /></ProtectedRoute>} />
<Route path="/linkedin-writer" element={<ProtectedRoute><LinkedInWriter /></ProtectedRoute>} />
<Route path="/blog-writer" element={<ProtectedRoute><BlogWriter /></ProtectedRoute>} />
<Route path="/story-writer" element={<ProtectedRoute><StoryWriter /></ProtectedRoute>} />
<Route path="/youtube-creator" element={<ProtectedRoute><YouTubeCreator /></ProtectedRoute>} />
<Route path="/podcast-maker" element={<ProtectedRoute><PodcastDashboard /></ProtectedRoute>} />
<Route path="/image-studio" element={<ProtectedRoute><ImageStudioDashboard /></ProtectedRoute>} />
<Route path="/video-studio" element={<ProtectedRoute><VideoStudioDashboard /></ProtectedRoute>} />
<Route path="/video-studio/create" element={<ProtectedRoute><CreateVideo /></ProtectedRoute>} />
<Route path="/video-studio/avatar" element={<ProtectedRoute><AvatarVideo /></ProtectedRoute>} />
<Route path="/video-studio/enhance" element={<ProtectedRoute><EnhanceVideo /></ProtectedRoute>} />
<Route path="/video-studio/extend" element={<ProtectedRoute><ExtendVideo /></ProtectedRoute>} />
<Route path="/video-studio/edit" element={<ProtectedRoute><EditVideo /></ProtectedRoute>} />
<Route path="/video-studio/transform" element={<ProtectedRoute><TransformVideo /></ProtectedRoute>} />
<Route path="/video-studio/social" element={<ProtectedRoute><SocialVideo /></ProtectedRoute>} />
<Route path="/video-studio/face-swap" element={<ProtectedRoute><FaceSwap /></ProtectedRoute>} />
<Route path="/video-studio/video-translate" element={<ProtectedRoute><VideoTranslate /></ProtectedRoute>} />
<Route path="/video-studio/video-background-remover" element={<ProtectedRoute><VideoBackgroundRemover /></ProtectedRoute>} />
<Route path="/video-studio/add-audio-to-video" element={<ProtectedRoute><AddAudioToVideo /></ProtectedRoute>} />
<Route path="/video-studio/library" element={<ProtectedRoute><LibraryVideo /></ProtectedRoute>} />
<Route path="/image-generator" element={<ProtectedRoute><CreateStudio /></ProtectedRoute>} />
<Route path="/image-editor" element={<ProtectedRoute><EditStudio /></ProtectedRoute>} />
<Route path="/image-upscale" element={<ProtectedRoute><UpscaleStudio /></ProtectedRoute>} />
<Route path="/image-control" element={<ProtectedRoute><ControlStudio /></ProtectedRoute>} />
<Route path="/image-studio/face-swap" element={<ProtectedRoute><FaceSwapStudio /></ProtectedRoute>} />
<Route path="/image-studio/compress" element={<ProtectedRoute><CompressionStudio /></ProtectedRoute>} />
<Route path="/image-studio/processing" element={<ProtectedRoute><ImageProcessingStudio /></ProtectedRoute>} />
<Route path="/image-studio/social-optimizer" element={<ProtectedRoute><SocialOptimizer /></ProtectedRoute>} />
<Route path="/asset-library" element={<ProtectedRoute><AssetLibrary /></ProtectedRoute>} />
<Route path="/campaign-creator" element={<ProtectedRoute><ProductMarketingDashboard /></ProtectedRoute>} />
<Route path="/campaign-creator/photoshoot" element={<ProtectedRoute><ProductPhotoshootStudio /></ProtectedRoute>} />
<Route path="/campaign-creator/animation" element={<ProtectedRoute><ProductAnimationStudio /></ProtectedRoute>} />
<Route path="/campaign-creator/video" element={<ProtectedRoute><ProductVideoStudio /></ProtectedRoute>} />
<Route path="/campaign-creator/avatar" element={<ProtectedRoute><ProductAvatarStudio /></ProtectedRoute>} />
<Route path="/product-marketing" element={<Navigate to="/campaign-creator" replace />} />
<Route path="/scheduler-dashboard" element={<ProtectedRoute><SchedulerDashboard /></ProtectedRoute>} />
<Route path="/billing" element={<ProtectedRoute><BillingPage /></ProtectedRoute>} />
<Route path="/approvals" element={<ProtectedRoute><ApprovalsPage /></ProtectedRoute>} />
<Route path="/pricing" element={<PricingPage />} />
<Route path="/research-test" element={<ResearchDashboard />} />
<Route path="/research-dashboard" element={<ResearchDashboard />} />
<Route path="/alwrity-researcher" element={<ResearchDashboard />} />
<Route path="/intent-research" element={<IntentResearchTest />} />
<Route path="/wix-test" element={<WixTestPage />} />
<Route path="/wix-test-direct" element={<WixTestPage />} />
<Route path="/wix/callback" element={<WixCallbackPage />} />
<Route path="/wp/callback" element={<WordPressCallbackPage />} />
<Route path="/gsc/callback" element={<GSCAuthCallback />} />
<Route path="/bing/callback" element={<BingCallbackPage />} />
<Route path="/bing-analytics-storage" element={<ProtectedRoute><BingAnalyticsStorage /></ProtectedRoute>} />
</Routes>
</ConditionalCopilotKit>
</AuthenticatedCopilotWrapper>
</Router>
);
// Only wrap with CopilotKit if we have a valid key
if (copilotApiKey && copilotApiKey.trim()) {
// Enhanced error handler that updates health context
const handleCopilotKitError = (e: any) => {
console.error("CopilotKit Error:", e);
// Try to get health context if available
// We'll use a custom event to notify health context since we can't access it directly here
const errorMessage = e?.error?.message || e?.message || 'CopilotKit error occurred';
const errorType = errorMessage.toLowerCase();
// Differentiate between fatal and transient errors
const isFatalError =
errorType.includes('cors') ||
errorType.includes('ssl') ||
errorType.includes('certificate') ||
errorType.includes('403') ||
errorType.includes('forbidden') ||
errorType.includes('ERR_CERT_COMMON_NAME_INVALID');
// Dispatch event for health context to listen to
window.dispatchEvent(new CustomEvent('copilotkit-error', {
detail: {
error: e,
errorMessage,
isFatal: isFatalError,
}
}));
};
return (
<ErrorBoundary
context="CopilotKit"
showDetails={process.env.NODE_ENV === 'development'}
fallback={
<Box sx={{ p: 3, textAlign: 'center' }}>
<Typography variant="h6" color="warning" gutterBottom>
Chat Unavailable
</Typography>
<Typography variant="body2" color="textSecondary">
CopilotKit encountered an error. The app continues to work with manual controls.
</Typography>
</Box>
}
>
<CopilotKit
publicApiKey={copilotApiKey}
showDevConsole={false}
onError={handleCopilotKitError}
>
{appContent}
</CopilotKit>
</ErrorBoundary>
);
}
// Return app without CopilotKit if no key available
return appContent;
};
// Determine initial health status based on whether CopilotKit key is available
const hasCopilotKitKey = copilotApiKey && copilotApiKey.trim();
return (
<ErrorBoundary
context="Application Root"
@@ -626,10 +652,7 @@ const App: React.FC = () => {
<ClerkProvider publishableKey={clerkPublishableKey}>
<SubscriptionProvider>
<OnboardingProvider>
<CopilotKitHealthProvider initialHealthStatus={!!hasCopilotKitKey}>
<CopilotKitDegradedBanner />
{renderApp()}
</CopilotKitHealthProvider>
{renderApp()}
</OnboardingProvider>
</SubscriptionProvider>
</ClerkProvider>

View File

@@ -5,6 +5,7 @@ export interface AssetResponse {
image_url?: string;
image_base64?: string;
optimized_prompt?: string;
prompt?: string;
asset_id?: number;
message?: string;
error?: string;
@@ -19,16 +20,39 @@ export interface VoiceCloneResponse {
error?: string;
}
export const getLatestBrandAvatar = async (): Promise<AssetResponse> => {
try {
const response = await apiClient.get('/onboarding/assets/latest-avatar');
return response.data;
} catch (error: any) {
// 404 is expected if no avatar exists
if (error.response?.status === 404) {
return { success: false, message: 'No avatar found' };
}
console.error('Failed to fetch latest avatar:', error);
return {
success: false,
error: error.response?.data?.detail || 'Failed to fetch latest avatar'
};
}
};
export const generateBrandAvatar = async (
prompt: string,
stylePreset?: string,
aspectRatio: string = "1:1"
aspectRatio: string = "1:1",
model?: string,
renderingSpeed?: string,
provider?: string
): Promise<AssetResponse> => {
try {
const response = await apiClient.post('/onboarding/assets/generate-avatar', {
prompt,
style_preset: stylePreset,
aspect_ratio: aspectRatio,
model,
rendering_speed: renderingSpeed,
provider,
user_id: "current_user" // Backend extracts actual user
});
return response.data;
@@ -61,24 +85,48 @@ export const createAvatarVariation = async (
prompt: string,
file: File
): Promise<AssetResponse> => {
// TODO: Implement backend endpoint for variation
// For now, return a mock error or handle as new generation
console.warn("createAvatarVariation not fully implemented in backend");
try {
const formData = new FormData();
formData.append('prompt', prompt);
formData.append('file', file);
formData.append('user_id', "current_user");
const response = await apiClient.post('/onboarding/assets/create-variation', formData, {
headers: {
'Content-Type': 'multipart/form-data',
},
});
return response.data;
} catch (error: any) {
console.error('Avatar variation error:', error);
return {
success: false,
error: "Feature not available yet"
success: false,
error: error.response?.data?.detail || 'Failed to create avatar variation'
};
}
};
export const enhanceBrandAvatar = async (
file: File
): Promise<AssetResponse> => {
// TODO: Implement backend endpoint for enhancement (upscaling)
console.warn("enhanceBrandAvatar not fully implemented in backend");
try {
const formData = new FormData();
formData.append('file', file);
formData.append('user_id', "current_user");
const response = await apiClient.post('/onboarding/assets/enhance-avatar', formData, {
headers: {
'Content-Type': 'multipart/form-data',
},
});
return response.data;
} catch (error: any) {
console.error('Avatar enhancement error:', error);
return {
success: false,
error: "Feature not available yet"
success: false,
error: error.response?.data?.detail || 'Failed to enhance avatar'
};
}
};
export const setBrandAvatar = async (
@@ -96,6 +144,37 @@ export const setBrandAvatar = async (
};
};
export const getLatestVoiceClone = async (): Promise<VoiceCloneResponse> => {
try {
const response = await apiClient.get('/onboarding/assets/latest-voice-clone');
return response.data;
} catch (error: any) {
if (error.response?.status === 404) {
return { success: false, message: 'No voice clone found' };
}
console.error('Failed to fetch latest voice clone:', error);
return {
success: false,
error: error.response?.data?.detail || 'Failed to fetch latest voice clone'
};
}
};
export const setBrandVoice = async (
data: {
audio_url?: string;
custom_voice_id?: string;
voice_description?: string;
}
): Promise<AssetResponse> => {
// TODO: Implement backend endpoint to set as active voice
// For now, simulate success
return {
success: true,
message: "Voice set as active brand voice"
};
};
export interface VoiceCloneParams {
audioFile: File;
engine: 'minimax' | 'qwen3';
@@ -147,3 +226,29 @@ export const createVoiceClone = async (
};
}
};
export interface VoiceDesignParams {
text: string;
voiceDescription: string;
language?: string;
}
export const createVoiceDesign = async (
params: VoiceDesignParams
): Promise<VoiceCloneResponse> => {
try {
const response = await apiClient.post('/onboarding/assets/create-voice-design', {
text: params.text,
voice_description: params.voiceDescription,
language: params.language || 'auto',
user_id: "current_user"
});
return response.data;
} catch (error: any) {
console.error('Voice design error:', error);
return {
success: false,
error: error.response?.data?.detail || 'Failed to create voice design'
};
}
};

View File

@@ -3,6 +3,8 @@
*/
import { aiApiClient } from './client';
// Import TaskStatusResponse from blogWriterApi to ensure compatibility with usePolling
import type { TaskStatusResponse } from '../services/blogWriterApi';
const API_BASE = '/api/video-studio';
@@ -18,6 +20,17 @@ export interface PromptOptimizeResponse {
success: boolean;
}
export interface CreateAvatarVideoResponse {
task_id: string;
status: string;
message: string;
error?: string;
result?: {
video_url: string;
[key: string]: any;
};
}
/**
* Optimize a prompt using WaveSpeed prompt optimizer
*/
@@ -30,3 +43,77 @@ export async function optimizePrompt(
);
return response.data;
}
/**
* Create a talking avatar video asynchronously
* Uses dedicated Video Studio endpoint for generic avatar generation
*/
export async function createAvatarVideoAsync(
imageFile: File,
audioFile: File,
resolution: '480p' | '720p' = '720p',
model: 'infinitetalk' | 'hunyuan-avatar' = 'infinitetalk'
): Promise<CreateAvatarVideoResponse> {
const formData = new FormData();
formData.append('image', imageFile);
formData.append('audio', audioFile);
formData.append('resolution', resolution);
formData.append('model', model);
try {
const response = await aiApiClient.post<CreateAvatarVideoResponse>(
`${API_BASE}/avatar/create-async`,
formData,
{
headers: {
'Content-Type': 'multipart/form-data',
},
}
);
return response.data;
} catch (error: any) {
console.error('Error creating avatar video:', error);
throw new Error(error.response?.data?.detail || 'Failed to create avatar video');
}
}
/**
* Get the status of a video generation task
*/
export async function getVideoTaskStatus(taskId: string): Promise<CreateAvatarVideoResponse> {
try {
const response = await aiApiClient.get<CreateAvatarVideoResponse>(
`${API_BASE}/task/${taskId}`
);
return response.data;
} catch (error: any) {
console.error('Error fetching video task status:', error);
throw new Error(error.response?.data?.detail || 'Failed to fetch task status');
}
}
/**
* Poll video task status compatible with usePolling hook
*/
export async function pollVideoTaskStatus(taskId: string): Promise<TaskStatusResponse<{ video_url: string; [key: string]: any }>> {
const data = await getVideoTaskStatus(taskId);
// Map CreateAvatarVideoResponse to TaskStatusResponse
// Ensure we map 'processing' to 'running' for frontend consistency
let status: 'pending' | 'running' | 'completed' | 'failed' = 'pending';
if (data.status === 'completed') status = 'completed';
else if (data.status === 'failed') status = 'failed';
else if (data.status === 'running' || data.status === 'processing') status = 'running';
else status = 'pending';
return {
task_id: data.task_id,
status: status,
progress_messages: [], // Video Studio currently doesn't return progress messages
result: data.result,
error: data.error,
// Add default values for missing fields
created_at: new Date().toISOString(),
};
}

View File

@@ -32,7 +32,8 @@ export const WixConnectModal: React.FC<WixConnectModalProps> = ({
if (!isOpen) return;
const handler = (event: MessageEvent) => {
const trusted = [window.location.origin, 'https://littery-sonny-unscrutinisingly.ngrok-free.dev'];
const ngrokOrigin = process.env.REACT_APP_NGROK_ORIGIN || 'https://littery-sonny-unscrutinisingly.ngrok-free.dev';
const trusted = [window.location.origin, ngrokOrigin];
if (!trusted.includes(event.origin)) return;
if (!event.data || typeof event.data !== 'object') return;
@@ -91,7 +92,7 @@ export const WixConnectModal: React.FC<WixConnectModalProps> = ({
// Determine the correct origin - if using ngrok, use ngrok origin; otherwise use current origin
// This ensures consistency between where OAuth starts and where callback happens
const NGROK_ORIGIN = 'https://littery-sonny-unscrutinisingly.ngrok-free.dev';
const NGROK_ORIGIN = process.env.REACT_APP_NGROK_ORIGIN || 'https://littery-sonny-unscrutinisingly.ngrok-free.dev';
const isUsingNgrok = window.location.origin.includes('localhost') ||
window.location.origin.includes('127.0.0.1') ||
window.location.origin === NGROK_ORIGIN;

View File

@@ -165,9 +165,10 @@ const SystemStatusIndicator: React.FC<SystemStatusIndicatorProps> = ({ className
// Prime cache performance occasionally even when dashboard is closed
fetchDetailedStats();
// Refresh every 30 seconds
const interval = setInterval(fetchStatus, 30000);
const cacheInterval = setInterval(fetchDetailedStats, 60000);
// Refresh every 120 seconds
const interval = setInterval(fetchStatus, 120000);
// Refresh detailed stats much less frequently in background (5 mins)
const cacheInterval = setInterval(fetchDetailedStats, 300000);
return () => {
clearInterval(interval);
clearInterval(cacheInterval);

View File

@@ -37,7 +37,7 @@ import {
import { ImageStudioLayout } from './ImageStudioLayout';
import { OperationButton } from '../shared/OperationButton';
const MotionPaper = motion(Paper);
const MotionPaper = motion.create(Paper);
const fadeEase: Easing = [0.4, 0, 0.2, 1];
const cardVariants: Variants = {

View File

@@ -30,7 +30,7 @@ import { ImageStudioLayout } from './ImageStudioLayout';
import { OperationButton } from '../shared/OperationButton';
import { EditResultViewer } from './EditResultViewer';
const MotionPaper = motion(Paper);
const MotionPaper = motion.create(Paper);
const fadeEase: Easing = [0.4, 0, 0.2, 1];
const cardVariants: Variants = {

View File

@@ -16,7 +16,7 @@ import {
} from '@mui/icons-material';
import { motion } from 'framer-motion';
const MotionPaper = motion(Paper);
const MotionPaper = motion.create(Paper);
interface CostEstimate {
provider: string;

View File

@@ -54,9 +54,9 @@ import { CostEstimator } from './CostEstimator';
import { ImageStudioLayout } from './ImageStudioLayout';
import { OperationButton } from '../shared/OperationButton';
const MotionBox = motion(Box);
const MotionPaper = motion(Paper);
const MotionCard = motion(Card);
const MotionBox = motion.create(Box);
const MotionPaper = motion.create(Paper);
const MotionCard = motion.create(Card);
// Cubic bezier easing
const easeInOut: Easing = [0.22, 0.61, 0.36, 1];

View File

@@ -31,7 +31,7 @@ import { OperationButton } from '../shared/OperationButton';
import { ImageMaskEditor } from './ImageMaskEditor';
import { ModelSelector } from './ModelSelector';
const MotionPaper = motion(Paper);
const MotionPaper = motion.create(Paper);
const fadeEase: Easing = [0.4, 0, 0.2, 1];
const cardVariants: Variants = {

View File

@@ -21,7 +21,7 @@ import { ImageStudioLayout } from './ImageStudioLayout';
import { OperationButton } from '../shared/OperationButton';
import { ModelSelector } from './ModelSelector';
const MotionPaper = motion(Paper);
const MotionPaper = motion.create(Paper);
const fadeEase: Easing = [0.4, 0, 0.2, 1];
const cardVariants: Variants = {

View File

@@ -31,8 +31,8 @@ import {
} from '@mui/icons-material';
import { motion, AnimatePresence, type Variants, type Easing } from 'framer-motion';
const MotionCard = motion(Card);
const MotionBox = motion(Box);
const MotionCard = motion.create(Card);
const MotionBox = motion.create(Box);
const galleryEase: Easing = [0.4, 0, 0.2, 1];
interface ImageResult {

View File

@@ -6,7 +6,7 @@ import type { Variants } from 'framer-motion';
import DashboardHeader from '../shared/DashboardHeader';
import type { DashboardHeaderProps } from '../shared/types';
const MotionBox = motion(Box);
const MotionBox = motion.create(Box);
const sparkleVariants: Variants = {
initial: { scale: 0, rotate: 0 },

View File

@@ -34,7 +34,7 @@ import { useImageStudio, PlatformFormat } from '../../hooks/useImageStudio';
import { ImageStudioLayout } from './ImageStudioLayout';
import { OperationButton } from '../shared/OperationButton';
const MotionPaper = motion(Paper);
const MotionPaper = motion.create(Paper);
const fadeEase: Easing = [0.4, 0, 0.2, 1];
const cardVariants: Variants = {

View File

@@ -32,7 +32,7 @@ import {
} from '@mui/icons-material';
import { motion, AnimatePresence, type Variants, type Easing } from 'framer-motion';
const MotionCard = motion(Card);
const MotionCard = motion.create(Card);
const templateCardEase: Easing = [0.4, 0, 1, 1];
interface Template {

View File

@@ -40,8 +40,8 @@ import { ImageStudioLayout } from './ImageStudioLayout';
import { OperationButton } from '../shared/OperationButton';
import { PreflightOperation } from '../../services/billingService';
const MotionPaper = motion(Paper);
const MotionCard = motion(Card);
const MotionPaper = motion.create(Paper);
const MotionCard = motion.create(Card);
const fadeEase: Easing = [0.4, 0, 0.2, 1];
const cardVariants: Variants = {

View File

@@ -10,7 +10,7 @@ import {
alpha
} from '@mui/material';
import OptimizedImage from './OptimizedImage';
import { SignInButton } from '@clerk/clerk-react';
import { SignInButton, useClerk } from '@clerk/clerk-react';
import { RocketLaunch } from '@mui/icons-material';
import { motion } from 'framer-motion';
import { ScrambleText } from '../ScrambleText';
@@ -44,6 +44,7 @@ const ScramblingText: React.FC<{ phrases: string[]; interval?: number; duration?
const EnterpriseCTA: React.FC = () => {
const theme = useTheme();
const { openSignIn } = useClerk();
// Framer Motion variants
const fadeInUp = {
@@ -119,8 +120,8 @@ const EnterpriseCTA: React.FC = () => {
</Typography>
<Stack direction={{ xs: 'column', sm: 'row' }} spacing={3} alignItems="center">
<SignInButton mode="redirect" forceRedirectUrl="/">
<Button
onClick={() => openSignIn({ forceRedirectUrl: '/' })}
variant="contained"
size="large"
startIcon={<RocketLaunch />}
@@ -146,7 +147,6 @@ const EnterpriseCTA: React.FC = () => {
interval={3500}
/>
</Button>
</SignInButton>
<Stack alignItems={{ xs: 'center', sm: 'flex-start' }} spacing={1}>
<Typography variant="body2" color="text.secondary">

View File

@@ -10,7 +10,7 @@ import {
useTheme,
alpha
} from '@mui/material';
import { SignInButton } from '@clerk/clerk-react';
import { SignInButton, useClerk } from '@clerk/clerk-react';
import {
RocketLaunch,
Lightbulb,
@@ -62,6 +62,8 @@ const ScramblingText: React.FC<{ phrases: string[]; interval?: number; duration?
const HeroSection: React.FC = () => {
const theme = useTheme();
const { openSignIn } = useClerk();
const fadeInUp = {
hidden: { opacity: 0, y: 24 },
visible: { opacity: 1, y: 0, transition: { duration: 0.6, ease: "easeOut" as const } },
@@ -272,46 +274,43 @@ const HeroSection: React.FC = () => {
<motion.div variants={fadeInUp}>
<Box sx={{ ...glassPanelSx, px: { xs: 3, md: 5 }, py: { xs: 4, md: 6 }, maxWidth: 1000, width: '100%' }}>
<Stack spacing={4} alignItems="center">
<SignInButton mode="redirect" forceRedirectUrl="/">
<Button
variant="contained"
size="large"
startIcon={<Lightbulb />}
sx={{
py: 2.5,
px: 5,
fontSize: '1.2rem',
fontWeight: 700,
borderRadius: 3,
background: 'linear-gradient(45deg, #667eea 30%, #764ba2 90%)',
backgroundImage: `
linear-gradient(120deg, transparent 0%, rgba(255,255,255,0.3) 50%, transparent 100%),
linear-gradient(45deg, #667eea 30%, #764ba2 90%)
`,
backgroundSize: '200% 100%, 100% 100%',
backgroundPosition: '200% 0, 0 0',
boxShadow: '0 10px 40px rgba(102, 126, 234, 0.4)',
'&:hover': {
boxShadow: '0 15px 50px rgba(102, 126, 234, 0.5)',
transform: 'translateY(-3px)',
backgroundPosition: '0 0, 0 0'
},
transition: 'all 0.3s ease',
animation: 'shimmer 2.5s ease-in-out infinite',
'@keyframes shimmer': {
'0%': { backgroundPosition: '200% 0, 0 0' },
'100%': { backgroundPosition: '-200% 0, 0 0' },
},
}}
>
<ScramblingText
phrases={['ALwrity For Free - BYOK', 'Start Free Today', 'Try ALwrity Free', 'Get Started Free']}
duration={600}
delay={500}
interval={4000}
/>
</Button>
</SignInButton>
<Button
onClick={() => openSignIn({ forceRedirectUrl: '/' })}
variant="contained"
size="large"
startIcon={<Lightbulb />}
sx={{
py: 2.5,
px: 5,
fontSize: '1.2rem',
fontWeight: 700,
borderRadius: 3,
background: 'linear-gradient(45deg, #667eea 30%, #764ba2 90%)',
backgroundImage: `
linear-gradient(120deg, transparent 0%, rgba(255,255,255,0.3) 50%, transparent 100%),
linear-gradient(45deg, #667eea 30%, #764ba2 90%)
`,
backgroundSize: '200% 100%, 100% 100%',
backgroundPosition: '200% 0, 0 0',
boxShadow: '0 10px 40px rgba(102, 126, 234, 0.4)',
'&:hover': {
boxShadow: '0 15px 50px rgba(102, 126, 234, 0.5)',
transform: 'translateY(-3px)',
backgroundPosition: '0 0, 0 0'
},
transition: 'all 0.3s ease',
animation: 'shimmer 2.5s ease-in-out infinite',
'@keyframes shimmer': {
'0%': { backgroundPosition: '200% 0, 0 0' },
'100%': { backgroundPosition: '-200% 0, 0 0' }
}
}}
>
<ScramblingText
phrases={['Start Free Trial', 'Get Started Now', 'Try AI Copilot', 'Boost ROI Now']}
interval={3000}
/>
</Button>
<Typography
variant="body1"

View File

@@ -12,7 +12,7 @@ import {
alpha,
Skeleton
} from '@mui/material';
import { SignInButton } from '@clerk/clerk-react';
import { SignInButton, useClerk } from '@clerk/clerk-react';
import {
RocketLaunch,
Business,
@@ -56,6 +56,7 @@ const ScramblingText: React.FC<{ phrases: string[]; interval?: number; duration?
const IntroducingAlwrity: React.FC = () => {
const theme = useTheme();
const [imageLoaded, setImageLoaded] = useState(false);
const { openSignIn } = useClerk();
// Preload the background image
useEffect(() => {
@@ -179,8 +180,8 @@ const IntroducingAlwrity: React.FC = () => {
<motion.div variants={fadeInUp}>
<Box sx={{ mt: 4 }}>
<SignInButton mode="redirect" forceRedirectUrl="/">
<Button
onClick={() => openSignIn({ forceRedirectUrl: '/' })}
variant="contained"
size="large"
startIcon={<RocketLaunch />}
@@ -206,7 +207,6 @@ const IntroducingAlwrity: React.FC = () => {
interval={3500}
/>
</Button>
</SignInButton>
</Box>
</motion.div>
</Stack>

View File

@@ -1,10 +1,20 @@
import React, { useState, useEffect, useCallback } from 'react';
import { useUser } from '@clerk/clerk-react';
import {
Box,
Fade,
Snackbar,
Typography,
Paper
Paper,
Radio,
RadioGroup,
FormControlLabel,
FormControl,
FormLabel,
Card,
CardContent,
Alert,
Chip
} from '@mui/material';
import {
// Social Media Icons
@@ -19,7 +29,11 @@ import {
Web as WordPressIcon,
Web as WixIcon,
Google as GoogleIcon,
Analytics as AnalyticsIcon
Analytics as AnalyticsIcon,
// UI Icons
Lightbulb as LightbulbIcon,
CheckCircle as CheckCircleIcon,
Error as ErrorIcon
} from '@mui/icons-material';
// Import refactored components
@@ -28,6 +42,7 @@ import PlatformSection from './common/PlatformSection';
import BenefitsSummary from './common/BenefitsSummary';
import ComingSoonSection from './common/ComingSoonSection';
import { useWordPressOAuth } from '../../hooks/useWordPressOAuth';
import { useWixConnection } from '../../hooks/useWixConnection';
import { useBingOAuth } from '../../hooks/useBingOAuth';
import { useGSCConnection } from './common/useGSCConnection';
import { usePlatformConnections } from './common/usePlatformConnections';
@@ -37,6 +52,7 @@ import { cachedAnalyticsAPI } from '../../api/cachedAnalytics';
interface IntegrationsStepProps {
onContinue: () => void;
updateHeaderContent: (content: { title: string; description: string }) => void;
onValidationChange?: (isValid: boolean) => void;
}
interface IntegrationPlatform {
@@ -52,7 +68,8 @@ interface IntegrationPlatform {
isEnabled: boolean;
}
const IntegrationsStep: React.FC<IntegrationsStepProps> = ({ onContinue, updateHeaderContent }) => {
const IntegrationsStep: React.FC<IntegrationsStepProps> = ({ onContinue, updateHeaderContent, onValidationChange }) => {
const { user } = useUser();
const [email, setEmail] = useState<string>('');
// Use custom hooks
@@ -60,13 +77,11 @@ const IntegrationsStep: React.FC<IntegrationsStepProps> = ({ onContinue, updateH
// Invalidate analytics cache when platform connections change
const invalidateAnalyticsCache = useCallback(() => {
console.log('🔄 IntegrationsStep: Invalidating analytics cache due to connection change');
cachedAnalyticsAPI.invalidateAll();
}, []);
// Force refresh analytics data (bypass cache)
const forceRefreshAnalytics = useCallback(async () => {
console.log('🔄 IntegrationsStep: Force refreshing analytics data (bypassing cache)');
try {
// Clear all cache first
cachedAnalyticsAPI.clearCache();
@@ -77,9 +92,8 @@ const IntegrationsStep: React.FC<IntegrationsStepProps> = ({ onContinue, updateH
// Force refresh analytics data
await cachedAnalyticsAPI.forceRefreshAnalyticsData(['bing', 'gsc']);
console.log('✅ IntegrationsStep: Analytics data force refreshed successfully');
} catch (error) {
console.error('IntegrationsStep: Error force refreshing analytics:', error);
console.error('IntegrationsStep: Error force refreshing analytics:', error);
}
}, []);
const { isLoading, showToast, setShowToast, toastMessage, handleConnect } = usePlatformConnections();
@@ -89,7 +103,6 @@ const IntegrationsStep: React.FC<IntegrationsStepProps> = ({ onContinue, updateH
// Bing OAuth hook
const { connected: bingConnected, sites: bingSites, connect: connectBing } = useBingOAuth();
console.log('Bing OAuth hook initialized:', { bingConnected, connectBing: typeof connectBing });
// Initialize integrations data
const [integrations] = useState<IntegrationPlatform[]>([
@@ -231,59 +244,30 @@ const IntegrationsStep: React.FC<IntegrationsStepProps> = ({ onContinue, updateH
// Handle WordPress connection status changes
useEffect(() => {
console.log('IntegrationsStep: WordPress status changed:', {
wordpressConnected,
wordpressSitesCount: wordpressSites.length,
connectedPlatforms,
currentPlatforms: connectedPlatforms
});
if (wordpressConnected && wordpressSites.length > 0) {
// WordPress is connected, add to connected platforms
if (!connectedPlatforms.includes('wordpress')) {
console.log('IntegrationsStep: Adding WordPress to connected platforms');
setConnectedPlatforms([...connectedPlatforms, 'wordpress']);
console.log('WordPress connection detected:', wordpressSites);
invalidateAnalyticsCache();
} else {
console.log('IntegrationsStep: WordPress already in connected platforms');
}
} else if (!wordpressConnected && connectedPlatforms.includes('wordpress')) {
// WordPress is disconnected, remove from connected platforms
console.log('IntegrationsStep: Removing WordPress from connected platforms');
setConnectedPlatforms(connectedPlatforms.filter(platform => platform !== 'wordpress'));
console.log('WordPress disconnection detected');
invalidateAnalyticsCache();
} else {
console.log('IntegrationsStep: No WordPress status change needed');
}
}, [wordpressConnected, wordpressSites, connectedPlatforms, setConnectedPlatforms, invalidateAnalyticsCache]);
// Handle Bing connection status changes
useEffect(() => {
console.log('IntegrationsStep: Bing status changed:', {
bingConnected,
bingSitesCount: bingSites.length,
connectedPlatforms,
currentPlatforms: connectedPlatforms
});
if (bingConnected && bingSites.length > 0) {
if (!connectedPlatforms.includes('bing')) {
console.log('IntegrationsStep: Adding Bing to connected platforms');
setConnectedPlatforms([...connectedPlatforms, 'bing']);
console.log('Bing connection detected:', bingSites);
invalidateAnalyticsCache();
} else {
console.log('IntegrationsStep: Bing already in connected platforms');
}
} else if (!bingConnected && connectedPlatforms.includes('bing')) {
console.log('IntegrationsStep: Removing Bing from connected platforms');
setConnectedPlatforms(connectedPlatforms.filter(platform => platform !== 'bing'));
console.log('Bing disconnection detected');
invalidateAnalyticsCache();
} else {
console.log('IntegrationsStep: No Bing status change needed');
}
}, [bingConnected, bingSites, connectedPlatforms, setConnectedPlatforms, invalidateAnalyticsCache]);
@@ -299,7 +283,6 @@ const IntegrationsStep: React.FC<IntegrationsStepProps> = ({ onContinue, updateH
setConnectedPlatforms([...connectedPlatforms, 'wordpress']);
// Remove query parameters from URL
window.history.replaceState({}, document.title, window.location.pathname);
console.log('WordPress OAuth connection successful:', blogUrl);
} else if (error) {
// WordPress OAuth failed
console.error('WordPress OAuth error:', error);
@@ -311,75 +294,28 @@ const IntegrationsStep: React.FC<IntegrationsStepProps> = ({ onContinue, updateH
// Get user email from Clerk
useEffect(() => {
const getUserEmail = () => {
if (typeof window !== 'undefined') {
const clerkUser = (window as any).__clerk_user;
if (clerkUser?.emailAddresses?.[0]?.emailAddress) {
return clerkUser.emailAddresses[0].emailAddress;
}
const clerkSession = localStorage.getItem('__clerk_session');
if (clerkSession) {
try {
const sessionData = JSON.parse(clerkSession);
if (sessionData?.user?.emailAddresses?.[0]?.emailAddress) {
return sessionData.user.emailAddresses[0].emailAddress;
}
} catch (e) {
// Ignore parsing errors
}
}
const userData = localStorage.getItem('user_data');
if (userData) {
try {
const data = JSON.parse(userData);
if (data.email) return data.email;
} catch (e) {
// Ignore parsing errors
}
}
const currentUserEmail = 'ajay.calsoft@gmail.com';
if (currentUserEmail && currentUserEmail.includes('@')) {
return currentUserEmail;
}
}
if (user) {
const primaryEmail = user.primaryEmailAddress?.emailAddress;
const firstEmail = user.emailAddresses?.[0]?.emailAddress;
const resolvedEmail = primaryEmail || firstEmail || '';
return 'user@example.com';
};
const userEmail = getUserEmail();
setEmail(userEmail);
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
if (resolvedEmail) {
setEmail(resolvedEmail);
}
}
}, [user]);
const handlePlatformConnect = async (platformId: string) => {
console.log('🚀 INTEGRATIONS_STEP: handlePlatformConnect called with platformId:', platformId);
console.log('🚀 INTEGRATIONS_STEP: platformId type:', typeof platformId);
console.log('🚀 INTEGRATIONS_STEP: platformId length:', platformId.length);
console.log('🚀 INTEGRATIONS_STEP: platformId === "bing":', platformId === 'bing');
console.log('🚀 INTEGRATIONS_STEP: platformId === "gsc":', platformId === 'gsc');
console.log('🚀 INTEGRATIONS_STEP: connectBing function type:', typeof connectBing);
console.log('🚀 INTEGRATIONS_STEP: connectBing function:', connectBing);
console.log('🚀 INTEGRATIONS_STEP: Stack trace:', new Error().stack);
if (platformId === 'gsc') {
console.log('🚀 INTEGRATIONS_STEP: Handling GSC connection');
await handleGSCConnect();
} else if (platformId === 'bing') {
console.log('🚀 INTEGRATIONS_STEP: Handling Bing connection - about to call connectBing');
// Use the Bing OAuth hook for connection
try {
console.log('🚀 INTEGRATIONS_STEP: Calling connectBing()...');
await connectBing();
console.log('🚀 INTEGRATIONS_STEP: Bing connection initiated successfully');
} catch (error) {
console.error('🚀 INTEGRATIONS_STEP: Bing connection failed:', error);
console.error('Bing connection failed:', error);
}
} else {
console.log('🚀 INTEGRATIONS_STEP: Handling other platform connection:', platformId);
console.log('🚀 INTEGRATIONS_STEP: This should NOT happen for Bing!');
await handleConnect(platformId);
}
};
@@ -390,6 +326,59 @@ const IntegrationsStep: React.FC<IntegrationsStepProps> = ({ onContinue, updateH
const socialPlatforms = integrations.filter(p => p.category === 'social');
// Primary Site Selection State
const [primarySite, setPrimarySite] = useState<string>('');
// Get sites from hooks for the memo
const { sites: wixSites, connected: wixConnected } = useWixConnection();
const availableSites = React.useMemo(() => {
const sites: { url: string; source: string; name: string }[] = [];
if (wixConnected && wixSites.length > 0) {
sites.push(...wixSites.map(s => ({
url: s.blog_url,
source: 'Wix',
name: 'Wix Site'
})));
}
if (wordpressConnected && wordpressSites.length > 0) {
sites.push(...wordpressSites.map(s => ({
url: s.blog_url,
source: 'WordPress',
name: 'WordPress Site'
})));
}
return sites;
}, [wixConnected, wixSites, wordpressConnected, wordpressSites]);
// Default to first site
useEffect(() => {
if (availableSites.length > 0 && !primarySite) {
setPrimarySite(availableSites[0].url);
}
}, [availableSites, primarySite]);
// Save primary site when selected
useEffect(() => {
if (primarySite) {
localStorage.setItem('primary_website', primarySite);
}
}, [primarySite]);
// Validation Effect
useEffect(() => {
if (onValidationChange) {
// Valid if:
// 1. No sites available (user can proceed without site)
// 2. Sites available AND primarySite selected
const isValid = availableSites.length === 0 || !!primarySite;
onValidationChange(isValid);
}
}, [availableSites.length, primarySite, onValidationChange]);
return (
<Box sx={{ width: '100%', maxWidth: '100%', p: { xs: 1, sm: 2, md: 3 } }}>
{/* Email Address Section */}
@@ -404,7 +393,7 @@ const IntegrationsStep: React.FC<IntegrationsStepProps> = ({ onContinue, updateH
platforms={websitePlatforms}
connectedPlatforms={connectedPlatforms}
gscSites={null}
isLoading={isLoading}
isLoading={isLoading}
onConnect={handlePlatformConnect}
onDisconnect={(platformId) => {
setConnectedPlatforms(connectedPlatforms.filter(p => p !== platformId));
@@ -414,6 +403,118 @@ const IntegrationsStep: React.FC<IntegrationsStepProps> = ({ onContinue, updateH
</div>
</Fade>
{/* Primary Site Selection */}
<Fade in timeout={900}>
<Box sx={{ mt: 3 }}>
<Paper
elevation={2}
sx={{
p: 3,
borderRadius: 2,
background: 'linear-gradient(135deg, #f8fafc 0%, #ffffff 100%)',
border: '1px solid',
borderColor: primarySite ? '#86efac' : '#e2e8f0'
}}
>
<Box sx={{ display: 'flex', alignItems: 'center', mb: 2, justifyContent: 'space-between' }}>
<Box sx={{ display: 'flex', alignItems: 'center' }}>
<Box
sx={{
width: 40,
height: 40,
borderRadius: '50%',
bgcolor: primarySite ? '#dcfce7' : '#f1f5f9',
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
mr: 2
}}
>
<LightbulbIcon sx={{ color: primarySite ? '#22c55e' : '#94a3b8' }} />
</Box>
<Box>
<Typography variant="h6" sx={{ fontWeight: 600, color: '#1e293b' }}>
Primary Website Selection
</Typography>
<Typography variant="body2" sx={{ color: '#64748b' }}>
Select your primary website for content publishing
</Typography>
</Box>
</Box>
{/* Green/Red Indicator */}
<Box sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>
<Box
sx={{
width: 12,
height: 12,
borderRadius: '50%',
bgcolor: primarySite ? '#22c55e' : '#ef4444',
boxShadow: primarySite ? '0 0 0 4px #dcfce7' : '0 0 0 4px #fee2e2'
}}
/>
<Typography variant="caption" sx={{ fontWeight: 600, color: primarySite ? '#15803d' : '#b91c1c' }}>
{primarySite ? 'Primary Set' : 'Selection Required'}
</Typography>
</Box>
</Box>
{availableSites.length > 0 ? (
<FormControl component="fieldset" sx={{ width: '100%', mt: 1 }}>
<RadioGroup
value={primarySite}
onChange={(e) => setPrimarySite(e.target.value)}
>
{availableSites.map((site, index) => (
<Card
key={index}
variant="outlined"
sx={{
mb: 1.5,
borderColor: primarySite === site.url ? '#22c55e' : '#e2e8f0',
bgcolor: primarySite === site.url ? '#f0fdf4' : '#ffffff',
transition: 'all 0.2s',
'&:hover': { borderColor: '#22c55e' }
}}
>
<CardContent sx={{ p: '12px !important', '&:last-child': { pb: '12px !important' } }}>
<FormControlLabel
value={site.url}
control={<Radio size="small" sx={{ color: primarySite === site.url ? '#22c55e' : undefined, '&.Mui-checked': { color: '#22c55e' } }} />}
label={
<Box sx={{ display: 'flex', alignItems: 'center', gap: 1.5 }}>
<Typography variant="body2" sx={{ fontWeight: 600, color: '#334155' }}>
{site.url ? site.url.replace(/^https?:\/\//, '') : 'No URL'}
</Typography>
<Chip
label={site.source}
size="small"
sx={{
height: 20,
fontSize: '0.65rem',
fontWeight: 600,
bgcolor: site.source === 'Wix' ? '#000000' : '#21759b',
color: '#ffffff'
}}
/>
</Box>
}
sx={{ width: '100%', m: 0 }}
/>
</CardContent>
</Card>
))}
</RadioGroup>
</FormControl>
) : (
<Alert severity="warning" sx={{ mt: 1, borderRadius: 2 }}>
No connected websites found. Please connect Wix or WordPress to continue.
</Alert>
)}
</Paper>
</Box>
</Fade>
{/* Analytics Platforms */}
<Fade in timeout={1000}>
<div>
@@ -453,16 +554,14 @@ const IntegrationsStep: React.FC<IntegrationsStepProps> = ({ onContinue, updateH
</Typography>
<PlatformAnalytics
platforms={connectedPlatforms}
platforms={connectedPlatforms.filter(p => ['gsc', 'bing'].includes(p))}
showSummary={true}
refreshInterval={0}
onDataLoaded={(data: any) => {
console.log('Analytics data loaded:', data);
refreshInterval={connectedPlatforms.some(p => ['gsc', 'bing'].includes(p)) ? 300000 : 0} // 5 minutes, only if connected
onDataLoaded={(data) => {
// Data loaded silently
}}
onRefreshReady={(refreshFn) => {
console.log('🔄 PlatformAnalytics refresh function ready');
// Store the refresh function for potential use
(window as any).refreshAnalytics = refreshFn;
// Store refresh function if needed
}}
/>
</Paper>

View File

@@ -26,10 +26,12 @@ import {
interface ComingSoonSectionProps {
contentCalendar?: any[];
onTestPersona?: () => void;
}
export const ComingSoonSection: React.FC<ComingSoonSectionProps> = ({
contentCalendar = []
contentCalendar = [],
onTestPersona
}) => {
const [openModal, setOpenModal] = useState(false);
const [selectedFeature, setSelectedFeature] = useState<string | null>(null);
@@ -40,8 +42,8 @@ export const ComingSoonSection: React.FC<ComingSoonSectionProps> = ({
title: 'Test Your Persona',
description: 'Generate content with different personas to see the difference',
icon: <PsychologyIcon />,
status: 'Coming Soon',
color: '#3b82f6',
status: 'Available',
color: '#10b981', // Green for available
details: [
'Compare content generated with and without your persona',
'Test Brand, Blog, and LinkedIn brand voices side-by-side',
@@ -90,15 +92,23 @@ export const ComingSoonSection: React.FC<ComingSoonSectionProps> = ({
return (
<>
<Box sx={{ mt: 4, mb: 2 }}>
<Typography variant="h4" sx={{ fontWeight: 700, color: '#1e293b', mb: 1.5 }}>
🚀 Coming Soon
<Box sx={{ mt: 6, mb: 4 }}>
<Typography
variant="h6"
sx={{
mb: 3,
fontWeight: 700,
background: 'linear-gradient(45deg, #1e293b 30%, #334155 90%)',
WebkitBackgroundClip: 'text',
WebkitTextFillColor: 'transparent',
display: 'flex',
alignItems: 'center',
gap: 1
}}
>
🚀 Advanced Features & Roadmap
</Typography>
<Typography variant="body1" sx={{ color: '#64748b', mb: 4, fontSize: '1.1rem' }}>
Exciting features in development to make your AI writing even more powerful
</Typography>
<Grid container spacing={2}>
<Grid container spacing={3}>
{features.map((feature) => (
<Grid item xs={12} md={4} key={feature.id}>
<Card
@@ -118,7 +128,13 @@ export const ComingSoonSection: React.FC<ComingSoonSectionProps> = ({
}
}
}}
onClick={() => handleFeatureClick(feature.id)}
onClick={() => {
if (feature.id === 'test-persona' && onTestPersona) {
onTestPersona();
} else {
handleFeatureClick(feature.id);
}
}}
>
<CardContent sx={{ p: 3 }}>
<Box sx={{ display: 'flex', alignItems: 'center', mb: 2 }}>
@@ -164,24 +180,25 @@ export const ComingSoonSection: React.FC<ComingSoonSectionProps> = ({
</Typography>
<Button
variant="outlined"
variant={feature.id === 'test-persona' ? 'contained' : 'outlined'}
size="medium"
sx={{
borderColor: feature.color,
color: feature.color,
color: feature.id === 'test-persona' ? '#ffffff' : feature.color,
backgroundColor: feature.id === 'test-persona' ? feature.color : 'transparent',
fontWeight: 600,
px: 3,
py: 1,
borderRadius: 2,
textTransform: 'none',
'&:hover': {
backgroundColor: `${feature.color}15`,
backgroundColor: feature.id === 'test-persona' ? `${feature.color}cc` : `${feature.color}15`,
borderColor: feature.color,
transform: 'translateY(-1px)'
}
}}
>
Learn More
{feature.id === 'test-persona' ? 'Try Now' : 'Learn More'}
</Button>
</CardContent>
</Card>
@@ -318,7 +335,14 @@ export const ComingSoonSection: React.FC<ComingSoonSectionProps> = ({
Close
</Button>
<Button
onClick={() => setOpenModal(false)}
onClick={() => {
if (selectedFeatureData?.id === 'test-persona' && onTestPersona) {
onTestPersona();
setOpenModal(false);
} else {
setOpenModal(false);
}
}}
variant="contained"
sx={{
backgroundColor: selectedFeatureData?.color || '#3b82f6',
@@ -328,7 +352,7 @@ export const ComingSoonSection: React.FC<ComingSoonSectionProps> = ({
}
}}
>
Notify Me When Ready
{selectedFeatureData?.id === 'test-persona' ? 'Try Now' : 'Notify Me When Ready'}
</Button>
</DialogActions>
</Dialog>

View File

@@ -16,11 +16,13 @@ import {
InfoOutlined,
Psychology as PsychologyIcon,
AutoAwesome as AutoAwesomeIcon,
Assessment as AssessmentIcon
Assessment as AssessmentIcon,
Lightbulb
} from '@mui/icons-material';
import {
getPersonalizationConfigurationOptions,
} from '../../api/componentLogic';
import { getLatestBrandAvatar, getLatestVoiceClone } from '../../api/brandAssets';
import { usePersonaPolling } from '../../hooks/usePersonaPolling';
import { apiClient } from '../../api/client';
import { type GenerationStep } from './PersonaStep/PersonaGenerationProgress';
@@ -31,11 +33,13 @@ import { PersonaLoadingState } from './PersonaStep/PersonaLoadingState';
import { ComingSoonSection } from './PersonaStep/ComingSoonSection';
import { BrandAvatarStudio } from './PersonalizationStep/components/BrandAvatarStudio';
import { VoiceAvatarPlaceholder } from './PersonalizationStep/components/VoiceAvatarPlaceholder';
import { TestPersonaModal } from './PersonalizationStep/components/TestPersonaModal';
interface PersonalizationStepProps {
onContinue: (data?: any) => void;
updateHeaderContent: (content: { title: string; description: string }) => void;
onValidationChange?: (isValid: boolean) => void;
onDataChange?: (data: any) => void;
onboardingData?: {
websiteAnalysis?: any;
competitorResearch?: any;
@@ -66,6 +70,7 @@ const PersonalizationStep: React.FC<PersonalizationStepProps> = ({
onContinue,
updateHeaderContent,
onValidationChange,
onDataChange,
onboardingData = {},
stepData
}) => {
@@ -92,6 +97,123 @@ const PersonalizationStep: React.FC<PersonalizationStepProps> = ({
const [hasCheckedCache, setHasCheckedCache] = useState(false);
const [configurationOptions, setConfigurationOptions] = useState<any>(null);
// Asset Status State
const [brandAvatarSet, setBrandAvatarSet] = useState(false);
const [voiceCloneSet, setVoiceCloneSet] = useState(false);
const [avatarUrl, setAvatarUrl] = useState<string>('');
const [voiceUrl, setVoiceUrl] = useState<string>('');
const [introVideoUrl, setIntroVideoUrl] = useState<string>('');
// Modal State
const [showTestPersonaModal, setShowTestPersonaModal] = useState(false);
const [hasTriggeredModal, setHasTriggeredModal] = useState(false);
const checkAssetStatus = useCallback(async () => {
try {
const avatarResp = await getLatestBrandAvatar();
let isAvatarSet = avatarResp.success;
let avatarDisplayUrl = '';
if (avatarResp.success) {
// Prefer base64 if available (immediate), else URL
avatarDisplayUrl = avatarResp.image_base64
? (avatarResp.image_base64.startsWith('data:') ? avatarResp.image_base64 : `data:image/png;base64,${avatarResp.image_base64}`)
: avatarResp.image_url || '';
} else {
// Fallback to local storage
try {
const localAvatar = localStorage.getItem('brand_avatar_selection');
if (localAvatar) {
const parsed = JSON.parse(localAvatar);
if (parsed.set) {
isAvatarSet = true;
// Try to recover image from Studio storage
const studioImage = localStorage.getItem('brand_avatar_result');
if (studioImage) {
avatarDisplayUrl = studioImage.startsWith('http') ? studioImage :
(studioImage.startsWith('data:') ? studioImage : `data:image/png;base64,${studioImage}`);
}
}
}
} catch (e) {}
}
setBrandAvatarSet(isAvatarSet);
if (avatarDisplayUrl) setAvatarUrl(avatarDisplayUrl);
const voiceResp = await getLatestVoiceClone();
let isVoiceSet = voiceResp.success;
let voiceDisplayUrl = '';
if (voiceResp.success && voiceResp.preview_audio_url) {
voiceDisplayUrl = voiceResp.preview_audio_url;
} else {
// Fallback to local storage
try {
const localVoice = localStorage.getItem('brand_voice_selection');
if (localVoice) {
const parsed = JSON.parse(localVoice);
if (parsed.set) {
isVoiceSet = true;
// Try to recover audio from Studio storage
const studioVoice = localStorage.getItem('voice_clone_result_url');
if (studioVoice) {
voiceDisplayUrl = studioVoice;
}
}
}
} catch (e) {}
}
setVoiceCloneSet(isVoiceSet);
if (voiceDisplayUrl) setVoiceUrl(voiceDisplayUrl);
} catch (e) {
console.error("Failed to check asset status", e);
}
}, []);
useEffect(() => {
checkAssetStatus();
}, [checkAssetStatus]);
// Sync data to parent Wizard
useEffect(() => {
if (onDataChange) {
const personaData = {
corePersona,
platformPersonas,
qualityMetrics,
selectedPlatforms,
brandAvatar: {
set: brandAvatarSet,
url: avatarUrl
},
voiceClone: {
set: voiceCloneSet,
url: voiceUrl
},
introVideo: {
set: !!introVideoUrl,
url: introVideoUrl
},
stepType: 'personalization',
completedAt: new Date().toISOString()
};
onDataChange(personaData);
}
}, [
corePersona,
platformPersonas,
qualityMetrics,
selectedPlatforms,
brandAvatarSet,
avatarUrl,
voiceCloneSet,
voiceUrl,
introVideoUrl,
onDataChange
]);
// Generation steps (Ported from PersonaStep)
const generationSteps: GenerationStep[] = [
{
@@ -264,22 +386,27 @@ const PersonalizationStep: React.FC<PersonalizationStepProps> = ({
useEffect(() => {
if (initRef.current) return;
initRef.current = true;
initialize();
async function loadConfigurationOptions() {
const initSequence = async () => {
// Set initial header
updateHeaderContent({
title: 'Define Your Brand Persona',
description: 'Go beyond text. Define how your brand sounds, looks, and speaks. Configure your brand voice, generate an AI avatar, and prepare for voice cloning.'
});
// Load configuration options first (lightweight)
try {
const options = await getPersonalizationConfigurationOptions();
setConfigurationOptions(options.options);
} catch (e) {
console.error('Failed to load configuration options:', e);
}
}
loadConfigurationOptions();
updateHeaderContent({
title: 'Define Your Brand Persona',
description: 'Go beyond text. Define how your brand sounds, looks, and speaks. Configure your brand voice, generate an AI avatar, and prepare for voice cloning.'
});
// Then initialize persona generation (potentially heavy)
await initialize();
};
initSequence();
}, [updateHeaderContent, initialize]);
const handleRegenerate = () => {
@@ -292,6 +419,10 @@ const PersonalizationStep: React.FC<PersonalizationStepProps> = ({
const handleContinue = useCallback(() => {
if (corePersona && platformPersonas && qualityMetrics) {
if (!brandAvatarSet || !voiceCloneSet) {
setError('Please generate and set your Brand Avatar and Voice Clone before continuing.');
return;
}
const personaData = {
corePersona,
platformPersonas,
@@ -304,15 +435,22 @@ const PersonalizationStep: React.FC<PersonalizationStepProps> = ({
} else {
setError('Missing persona data. Please generate your brand voice first.');
}
}, [corePersona, platformPersonas, qualityMetrics, selectedPlatforms, onContinue]);
}, [corePersona, platformPersonas, qualityMetrics, selectedPlatforms, onContinue, brandAvatarSet, voiceCloneSet]);
useEffect(() => {
const hasValidData = !!(corePersona && platformPersonas && Object.keys(platformPersonas).length > 0 && qualityMetrics);
const isComplete = !isGenerating && hasValidData && generationStep === 'preview';
const isComplete = !isGenerating && hasValidData && generationStep === 'preview' && brandAvatarSet && voiceCloneSet;
if (onValidationChange) {
onValidationChange(isComplete);
}
}, [corePersona, platformPersonas, qualityMetrics, isGenerating, generationStep, onValidationChange]);
// Trigger Test Persona Modal when all requirements are met
if (isComplete && !hasTriggeredModal && !showTestPersonaModal) {
setHasTriggeredModal(true);
setShowTestPersonaModal(true);
}
}, [corePersona, platformPersonas, qualityMetrics, isGenerating, generationStep, onValidationChange, brandAvatarSet, voiceCloneSet, hasTriggeredModal, showTestPersonaModal]);
if (!configurationOptions) {
return (
@@ -394,7 +532,23 @@ const PersonalizationStep: React.FC<PersonalizationStepProps> = ({
}
}}
>
{tab.label}
<Box sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>
{tab.label}
<Lightbulb
sx={{
fontSize: 18,
color: (
(tab.id === 'text' && corePersona) ||
(tab.id === 'image' && brandAvatarSet) ||
(tab.id === 'audio' && voiceCloneSet)
)
? (activeTab === tab.id ? '#A7F3D0' : '#10B981') // Light green on active, Green on inactive
: (activeTab === tab.id ? '#FCA5A5' : '#EF4444'), // Light red on active, Red on inactive
filter: 'drop-shadow(0 0 2px currentColor)',
transition: 'color 0.3s ease'
}}
/>
</Box>
</Button>
</Tooltip>
))}
@@ -433,16 +587,28 @@ const PersonalizationStep: React.FC<PersonalizationStepProps> = ({
handleRegenerate={handleRegenerate}
/>
<ComingSoonSection />
<ComingSoonSection onTestPersona={() => setShowTestPersonaModal(true)} />
</Box>
)}
{activeTab === 'image' && (
<BrandAvatarStudio domainName={domainName} />
<BrandAvatarStudio
domainName={domainName}
onAvatarSet={() => {
setBrandAvatarSet(true);
checkAssetStatus();
}}
/>
)}
{activeTab === 'audio' && (
<VoiceAvatarPlaceholder domainName={domainName} />
<VoiceAvatarPlaceholder
domainName={domainName}
onVoiceSet={() => {
setVoiceCloneSet(true);
checkAssetStatus();
}}
/>
)}
</Box>
@@ -453,7 +619,7 @@ const PersonalizationStep: React.FC<PersonalizationStepProps> = ({
<Box sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>
<InfoOutlined color="action" fontSize="small" />
<Typography variant="caption" color="text.secondary">
Changes to Brand Identity are required to continue. Avatar and Voice are optional.
All steps (Identity, Avatar, and Voice) are required to complete your brand personalization.
</Typography>
</Box>
@@ -461,24 +627,20 @@ const PersonalizationStep: React.FC<PersonalizationStepProps> = ({
{error && <Alert severity="error" sx={{ mb: 2 }}>{error}</Alert>}
{success && <Alert severity="success" sx={{ mb: 2 }}>{success}</Alert>}
<Button
variant="contained"
color="primary"
onClick={handleContinue}
disabled={loading}
sx={{
px: 6,
py: 1.5,
borderRadius: 2,
fontWeight: 'bold',
boxShadow: '0 4px 14px 0 rgba(0,118,255,0.39)'
}}
>
{loading ? 'Saving Settings...' : 'Save & Continue'}
</Button>
{/* 'Save & Continue' button removed as per requirements.
Navigation is now handled by the main Wizard button (2). */}
</Box>
</Box>
)}
{/* Test Persona Modal */}
<TestPersonaModal
open={showTestPersonaModal}
onClose={() => setShowTestPersonaModal(false)}
avatarUrl={avatarUrl}
voiceUrl={voiceUrl}
onVideoGenerated={(url) => setIntroVideoUrl(url || '')}
/>
</Box>
);
};

Some files were not shown because too many files have changed in this diff Show More