Files
ALwrity/backend/api/images.py

319 lines
15 KiB
Python

from __future__ import annotations
import base64
import os
from typing import Optional, Dict, Any
from datetime import datetime
from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel, Field
from services.llm_providers.main_image_generation import generate_image
from services.llm_providers.main_text_generation import llm_text_gen
from utils.logger_utils import get_service_logger
from middleware.auth_middleware import get_current_user
from services.database import get_db
from services.subscription import UsageTrackingService, PricingService
from models.subscription_models import APIProvider, UsageSummary
router = APIRouter(prefix="/api/images", tags=["images"])
logger = get_service_logger("api.images")
class ImageGenerateRequest(BaseModel):
prompt: str
negative_prompt: Optional[str] = None
provider: Optional[str] = Field(None, pattern="^(gemini|huggingface|stability)$")
model: Optional[str] = None
width: Optional[int] = Field(default=1024, ge=64, le=2048)
height: Optional[int] = Field(default=1024, ge=64, le=2048)
guidance_scale: Optional[float] = None
steps: Optional[int] = None
seed: Optional[int] = None
class ImageGenerateResponse(BaseModel):
success: bool = True
image_base64: str
width: int
height: int
provider: str
model: Optional[str] = None
seed: Optional[int] = None
@router.post("/generate", response_model=ImageGenerateResponse)
def generate(
req: ImageGenerateRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> ImageGenerateResponse:
"""Generate image with subscription checking."""
try:
# Extract Clerk user ID (required)
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
# Validation is now handled inside generate_image function
last_error: Optional[Exception] = None
result = None
for attempt in range(2): # simple single retry
try:
result = generate_image(
prompt=req.prompt,
options={
"negative_prompt": req.negative_prompt,
"provider": req.provider,
"model": req.model,
"width": req.width,
"height": req.height,
"guidance_scale": req.guidance_scale,
"steps": req.steps,
"seed": req.seed,
},
user_id=user_id, # Pass user_id for validation inside generate_image
)
image_b64 = base64.b64encode(result.image_bytes).decode("utf-8")
# TRACK USAGE after successful image generation
if result:
logger.info(f"[images.generate] ✅ Image generation successful, tracking usage for user {user_id}")
try:
db_track = next(get_db())
try:
# Get or create usage summary
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
logger.debug(f"[images.generate] Looking for usage summary: user_id={user_id}, period={current_period}")
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
logger.info(f"[images.generate] Creating new usage summary for user {user_id}, period {current_period}")
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush() # Ensure summary is persisted before updating
# Get "before" state for unified log
current_calls_before = getattr(summary, "stability_calls", 0) or 0
# Update provider-specific counters (stability for image generation)
# Note: All image generation goes through STABILITY provider enum regardless of actual provider
new_calls = current_calls_before + 1
setattr(summary, "stability_calls", new_calls)
logger.debug(f"[images.generate] Updated stability_calls: {current_calls_before} -> {new_calls}")
# Update totals
old_total_calls = summary.total_calls or 0
summary.total_calls = old_total_calls + 1
logger.debug(f"[images.generate] Updated totals: calls {old_total_calls} -> {summary.total_calls}")
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
call_limit = limits['limits'].get("stability_calls", 0) if limits else 0
db_track.commit()
logger.info(f"[images.generate] ✅ Successfully tracked usage: user {user_id} -> stability -> {new_calls} calls")
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
print(f"""
[SUBSCRIPTION] Image Generation
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: stability
├─ Actual Provider: {result.provider}
├─ Model: {result.model or 'default'}
├─ Calls: {current_calls_before}{new_calls} / {call_limit if call_limit > 0 else ''}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:
logger.error(f"[images.generate] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
# Non-blocking: log error but don't fail the request
logger.error(f"[images.generate] ❌ Failed to track usage: {usage_error}", exc_info=True)
return ImageGenerateResponse(
image_base64=image_b64,
width=result.width,
height=result.height,
provider=result.provider,
model=result.model,
seed=result.seed,
)
except Exception as inner:
last_error = inner
logger.error(f"Image generation attempt {attempt+1} failed: {inner}")
# On first failure, try provider auto-remap by clearing provider to let facade decide
if attempt == 0 and req.provider:
req.provider = None
continue
break
raise last_error or RuntimeError("Unknown image generation error")
except Exception as e:
logger.error(f"Image generation failed: {e}")
# Provide a clean, actionable message to the client
raise HTTPException(
status_code=500,
detail="Image generation service is temporarily unavailable or the connection was reset. Please try again."
)
class PromptSuggestion(BaseModel):
prompt: str
negative_prompt: Optional[str] = None
width: Optional[int] = None
height: Optional[int] = None
overlay_text: Optional[str] = None
class ImagePromptSuggestRequest(BaseModel):
provider: Optional[str] = Field(None, pattern="^(gemini|huggingface|stability)$")
title: Optional[str] = None
section: Optional[Dict[str, Any]] = None
research: Optional[Dict[str, Any]] = None
persona: Optional[Dict[str, Any]] = None
include_overlay: Optional[bool] = True
class ImagePromptSuggestResponse(BaseModel):
suggestions: list[PromptSuggestion]
@router.post("/suggest-prompts", response_model=ImagePromptSuggestResponse)
def suggest_prompts(
req: ImagePromptSuggestRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> ImagePromptSuggestResponse:
try:
provider = (req.provider or ("gemini" if (os.getenv("GPT_PROVIDER") or "").lower().startswith("gemini") else "huggingface")).lower()
section = req.section or {}
title = (req.title or section.get("heading") or "").strip()
subheads = section.get("subheadings", []) or []
key_points = section.get("key_points", []) or []
keywords = section.get("keywords", []) or []
if not keywords and req.research:
keywords = (
req.research.get("keywords", {}).get("primary_keywords")
or req.research.get("keywords", {}).get("primary")
or []
)
persona = req.persona or {}
audience = persona.get("audience", "content creators and digital marketers")
industry = persona.get("industry", req.research.get("domain") if req.research else "your industry")
tone = persona.get("tone", "professional, trustworthy")
schema = {
"type": "object",
"properties": {
"suggestions": {
"type": "array",
"items": {
"type": "object",
"properties": {
"prompt": {"type": "string"},
"negative_prompt": {"type": "string"},
"width": {"type": "number"},
"height": {"type": "number"},
"overlay_text": {"type": "string"},
},
"required": ["prompt"]
},
"minItems": 3,
"maxItems": 5
}
},
"required": ["suggestions"]
}
system = (
"You are an expert image prompt engineer for text-to-image models. "
"Given blog section context, craft 3-5 hyper-personalized prompts optimized for the specified provider. "
"Return STRICT JSON matching the provided schema, no extra text."
)
provider_guidance = {
"huggingface": "Photorealistic Flux 1 Krea Dev; include camera/lighting cues (e.g., 50mm, f/2.8, rim light).",
"gemini": "Editorial, brand-safe, crisp edges, balanced lighting; avoid artifacts.",
"stability": "SDXL coherent details, sharp focus, cinematic contrast; readable text if present."
}.get(provider, "")
best_practices = (
"Best Practices: one clear focal subject; clean, uncluttered background; rule-of-thirds or center-weighted composition; "
"text-safe margins if overlay text is included; neutral lighting if unsure; realistic skin tones; avoid busy patterns; "
"no brand logos or watermarks; no copyrighted characters; avoid low-res, blur, noise, banding, oversaturation, over-sharpening; "
"ensure hands and text are coherent if present; prefer 1024px+ on shortest side for quality."
)
# Harvest a few concise facts from research if available
facts: list[str] = []
try:
if req.research:
# try common shapes used in research service
top_stats = req.research.get("key_facts") or req.research.get("highlights") or []
if isinstance(top_stats, list):
facts = [str(x) for x in top_stats[:3]]
elif isinstance(top_stats, dict):
facts = [f"{k}: {v}" for k, v in list(top_stats.items())[:3]]
except Exception:
facts = []
facts_line = ", ".join(facts) if facts else ""
overlay_hint = "Include an on-image short title or fact if it improves communication; ensure clean, high-contrast safe area for text." if (req.include_overlay is None or req.include_overlay) else "Do not include on-image text."
prompt = f"""
Provider: {provider}
Title: {title}
Subheadings: {', '.join(subheads[:5])}
Key Points: {', '.join(key_points[:5])}
Keywords: {', '.join([str(k) for k in keywords[:8]])}
Research Facts: {facts_line}
Audience: {audience}
Industry: {industry}
Tone: {tone}
Craft prompts that visually reflect this exact section (not generic blog topic). {provider_guidance}
{best_practices}
{overlay_hint}
Include a suitable negative_prompt where helpful. Suggest width/height when relevant (e.g., 1024x1024 or 1920x1080).
If including on-image text, return it in overlay_text (short: <= 8 words).
"""
# Get user_id for llm_text_gen subscription check (required)
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id_for_llm = str(current_user.get('id', ''))
if not user_id_for_llm:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
raw = llm_text_gen(prompt=prompt, system_prompt=system, json_struct=schema, user_id=user_id_for_llm)
data = raw if isinstance(raw, dict) else {}
suggestions = data.get("suggestions") or []
# basic fallback if provider returns string
if not suggestions and isinstance(raw, str):
suggestions = [{"prompt": raw}]
return ImagePromptSuggestResponse(suggestions=[PromptSuggestion(**s) for s in suggestions])
except Exception as e:
logger.error(f"Prompt suggestion failed: {e}")
raise HTTPException(status_code=500, detail=str(e))