Base code

This commit is contained in:
Kunthawat Greethong
2026-01-08 22:39:53 +07:00
parent 697115c61a
commit c35fa52117
2169 changed files with 626670 additions and 0 deletions

704
backend/api/images.py Normal file
View File

@@ -0,0 +1,704 @@
from __future__ import annotations
import base64
import os
import uuid
from typing import Optional, Dict, Any
from datetime import datetime
from pathlib import Path
from sqlalchemy.orm import Session
from fastapi import APIRouter, HTTPException, Depends, Request
from fastapi.responses import FileResponse
from pydantic import BaseModel, Field
from services.llm_providers.main_image_generation import generate_image
from services.llm_providers.main_image_editing import edit_image
from services.llm_providers.main_text_generation import llm_text_gen
from utils.logger_utils import get_service_logger
from middleware.auth_middleware import get_current_user
from services.database import get_db
from services.subscription import UsageTrackingService, PricingService
from models.subscription_models import APIProvider, UsageSummary
from utils.asset_tracker import save_asset_to_library
from utils.file_storage import save_file_safely, generate_unique_filename, sanitize_filename
router = APIRouter(prefix="/api/images", tags=["images"])
logger = get_service_logger("api.images")
class ImageGenerateRequest(BaseModel):
prompt: str
negative_prompt: Optional[str] = None
provider: Optional[str] = Field(None, pattern="^(gemini|huggingface|stability)$")
model: Optional[str] = None
width: Optional[int] = Field(default=1024, ge=64, le=2048)
height: Optional[int] = Field(default=1024, ge=64, le=2048)
guidance_scale: Optional[float] = None
steps: Optional[int] = None
seed: Optional[int] = None
class ImageGenerateResponse(BaseModel):
success: bool = True
image_base64: str
image_url: Optional[str] = None # URL to saved image file
width: int
height: int
provider: str
model: Optional[str] = None
seed: Optional[int] = None
@router.post("/generate", response_model=ImageGenerateResponse)
def generate(
req: ImageGenerateRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db)
) -> ImageGenerateResponse:
"""Generate image with subscription checking."""
try:
# Extract Clerk user ID (required)
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
# Validation is now handled inside generate_image function
last_error: Optional[Exception] = None
result = None
for attempt in range(2): # simple single retry
try:
result = generate_image(
prompt=req.prompt,
options={
"negative_prompt": req.negative_prompt,
"provider": req.provider,
"model": req.model,
"width": req.width,
"height": req.height,
"guidance_scale": req.guidance_scale,
"steps": req.steps,
"seed": req.seed,
},
user_id=user_id, # Pass user_id for validation inside generate_image
)
image_b64 = base64.b64encode(result.image_bytes).decode("utf-8")
# Save image to disk and track in asset library
image_url = None
image_filename = None
image_path = None
try:
# Create output directory for image studio images
base_dir = Path(__file__).parent.parent
output_dir = base_dir / "image_studio_images"
# Generate safe filename from prompt
clean_prompt = sanitize_filename(req.prompt[:50], max_length=50)
image_filename = generate_unique_filename(
prefix=f"img_{clean_prompt}",
extension=".png",
include_uuid=True
)
# Save file safely
image_path, save_error = save_file_safely(
content=result.image_bytes,
directory=output_dir,
filename=image_filename,
max_file_size=50 * 1024 * 1024 # 50MB for images
)
if image_path and not save_error:
# Generate file URL (will be served via API endpoint)
image_url = f"/api/images/image-studio/images/{image_path.name}"
logger.info(f"[images.generate] Saved image to: {image_path} ({len(result.image_bytes)} bytes)")
# Save to asset library (non-blocking)
try:
asset_id = save_asset_to_library(
db=db,
user_id=user_id,
asset_type="image",
source_module="image_studio",
filename=image_path.name,
file_url=image_url,
file_path=str(image_path),
file_size=len(result.image_bytes),
mime_type="image/png",
title=req.prompt[:100] if len(req.prompt) <= 100 else req.prompt[:97] + "...",
description=f"Generated image: {req.prompt[:200]}" if len(req.prompt) > 200 else req.prompt,
prompt=req.prompt,
tags=["image_studio", "generated", result.provider] if result.provider else ["image_studio", "generated"],
provider=result.provider,
model=result.model,
asset_metadata={
"width": result.width,
"height": result.height,
"seed": result.seed,
"status": "completed",
"negative_prompt": req.negative_prompt
}
)
if asset_id:
logger.info(f"[images.generate] ✅ Asset saved to library: ID={asset_id}, filename={image_path.name}")
else:
logger.warning(f"[images.generate] Asset tracking returned None (may have failed silently)")
except Exception as asset_error:
logger.error(f"[images.generate] Failed to save asset to library: {asset_error}", exc_info=True)
# Don't fail the request if asset tracking fails
else:
logger.warning(f"[images.generate] Failed to save image to disk: {save_error}")
# Continue without failing the request - base64 is still available
except Exception as save_error:
logger.error(f"[images.generate] Unexpected error saving image: {save_error}", exc_info=True)
# Continue without failing the request
# TRACK USAGE after successful image generation
if result:
logger.info(f"[images.generate] ✅ Image generation successful, tracking usage for user {user_id}")
try:
db_track = next(get_db())
try:
# Get or create usage summary
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
logger.debug(f"[images.generate] Looking for usage summary: user_id={user_id}, period={current_period}")
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
logger.info(f"[images.generate] Creating new usage summary for user {user_id}, period {current_period}")
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush() # Ensure summary is persisted before updating
# Get "before" state for unified log
current_calls_before = getattr(summary, "stability_calls", 0) or 0
# Update provider-specific counters (stability for image generation)
# Note: All image generation goes through STABILITY provider enum regardless of actual provider
new_calls = current_calls_before + 1
setattr(summary, "stability_calls", new_calls)
logger.debug(f"[images.generate] Updated stability_calls: {current_calls_before} -> {new_calls}")
# Update totals
old_total_calls = summary.total_calls or 0
summary.total_calls = old_total_calls + 1
logger.debug(f"[images.generate] Updated totals: calls {old_total_calls} -> {summary.total_calls}")
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
call_limit = limits['limits'].get("stability_calls", 0) if limits else 0
# Get image editing stats for unified log
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
# Get video stats for unified log
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
# Get audio stats for unified log
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else ''
db_track.commit()
logger.info(f"[images.generate] ✅ Successfully tracked usage: user {user_id} -> stability -> {new_calls} calls")
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
print(f"""
[SUBSCRIPTION] Image Generation
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: stability
├─ Actual Provider: {result.provider}
├─ Model: {result.model or 'default'}
├─ Calls: {current_calls_before}{new_calls} / {call_limit if call_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
├─ Audio: {current_audio_calls} / {audio_limit_display}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:
logger.error(f"[images.generate] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
# Non-blocking: log error but don't fail the request
logger.error(f"[images.generate] ❌ Failed to track usage: {usage_error}", exc_info=True)
return ImageGenerateResponse(
image_base64=image_b64,
image_url=image_url,
width=result.width,
height=result.height,
provider=result.provider,
model=result.model,
seed=result.seed,
)
except Exception as inner:
last_error = inner
logger.error(f"Image generation attempt {attempt+1} failed: {inner}")
# On first failure, try provider auto-remap by clearing provider to let facade decide
if attempt == 0 and req.provider:
req.provider = None
continue
break
raise last_error or RuntimeError("Unknown image generation error")
except Exception as e:
logger.error(f"Image generation failed: {e}")
# Provide a clean, actionable message to the client
raise HTTPException(
status_code=500,
detail="Image generation service is temporarily unavailable or the connection was reset. Please try again."
)
class PromptSuggestion(BaseModel):
prompt: str
negative_prompt: Optional[str] = None
width: Optional[int] = None
height: Optional[int] = None
overlay_text: Optional[str] = None
class ImagePromptSuggestRequest(BaseModel):
provider: Optional[str] = Field(None, pattern="^(gemini|huggingface|stability)$")
title: Optional[str] = None
section: Optional[Dict[str, Any]] = None
research: Optional[Dict[str, Any]] = None
persona: Optional[Dict[str, Any]] = None
include_overlay: Optional[bool] = True
class ImagePromptSuggestResponse(BaseModel):
suggestions: list[PromptSuggestion]
class ImageEditRequest(BaseModel):
image_base64: str
prompt: str
provider: Optional[str] = Field(None, pattern="^(huggingface)$")
model: Optional[str] = None
guidance_scale: Optional[float] = None
steps: Optional[int] = None
seed: Optional[int] = None
class ImageEditResponse(BaseModel):
success: bool = True
image_base64: str
image_url: Optional[str] = None # URL to saved edited image file
width: int
height: int
provider: str
model: Optional[str] = None
seed: Optional[int] = None
@router.post("/suggest-prompts", response_model=ImagePromptSuggestResponse)
def suggest_prompts(
req: ImagePromptSuggestRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> ImagePromptSuggestResponse:
try:
provider = (req.provider or ("gemini" if (os.getenv("GPT_PROVIDER") or "").lower().startswith("gemini") else "huggingface")).lower()
section = req.section or {}
title = (req.title or section.get("heading") or "").strip()
subheads = section.get("subheadings", []) or []
key_points = section.get("key_points", []) or []
keywords = section.get("keywords", []) or []
if not keywords and req.research:
keywords = (
req.research.get("keywords", {}).get("primary_keywords")
or req.research.get("keywords", {}).get("primary")
or []
)
persona = req.persona or {}
audience = persona.get("audience", "content creators and digital marketers")
industry = persona.get("industry", req.research.get("domain") if req.research else "your industry")
tone = persona.get("tone", "professional, trustworthy")
schema = {
"type": "object",
"properties": {
"suggestions": {
"type": "array",
"items": {
"type": "object",
"properties": {
"prompt": {"type": "string"},
"negative_prompt": {"type": "string"},
"width": {"type": "number"},
"height": {"type": "number"},
"overlay_text": {"type": "string"},
},
"required": ["prompt"]
},
"minItems": 3,
"maxItems": 5
}
},
"required": ["suggestions"]
}
system = (
"You are an expert image prompt engineer for text-to-image models. "
"Given blog section context, craft 3-5 hyper-personalized prompts optimized for the specified provider. "
"Return STRICT JSON matching the provided schema, no extra text."
)
provider_guidance = {
"huggingface": "Photorealistic Flux 1 Krea Dev; include camera/lighting cues (e.g., 50mm, f/2.8, rim light).",
"gemini": "Editorial, brand-safe, crisp edges, balanced lighting; avoid artifacts.",
"stability": "SDXL coherent details, sharp focus, cinematic contrast; readable text if present."
}.get(provider, "")
best_practices = (
"Best Practices: one clear focal subject; clean, uncluttered background; rule-of-thirds or center-weighted composition; "
"text-safe margins if overlay text is included; neutral lighting if unsure; realistic skin tones; avoid busy patterns; "
"no brand logos or watermarks; no copyrighted characters; avoid low-res, blur, noise, banding, oversaturation, over-sharpening; "
"ensure hands and text are coherent if present; prefer 1024px+ on shortest side for quality."
)
# Harvest a few concise facts from research if available
facts: list[str] = []
try:
if req.research:
# try common shapes used in research service
top_stats = req.research.get("key_facts") or req.research.get("highlights") or []
if isinstance(top_stats, list):
facts = [str(x) for x in top_stats[:3]]
elif isinstance(top_stats, dict):
facts = [f"{k}: {v}" for k, v in list(top_stats.items())[:3]]
except Exception:
facts = []
facts_line = ", ".join(facts) if facts else ""
overlay_hint = "Include an on-image short title or fact if it improves communication; ensure clean, high-contrast safe area for text." if (req.include_overlay is None or req.include_overlay) else "Do not include on-image text."
prompt = f"""
Provider: {provider}
Title: {title}
Subheadings: {', '.join(subheads[:5])}
Key Points: {', '.join(key_points[:5])}
Keywords: {', '.join([str(k) for k in keywords[:8]])}
Research Facts: {facts_line}
Audience: {audience}
Industry: {industry}
Tone: {tone}
Craft prompts that visually reflect this exact section (not generic blog topic). {provider_guidance}
{best_practices}
{overlay_hint}
Include a suitable negative_prompt where helpful. Suggest width/height when relevant (e.g., 1024x1024 or 1920x1080).
If including on-image text, return it in overlay_text (short: <= 8 words).
"""
# Get user_id for llm_text_gen subscription check (required)
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id_for_llm = str(current_user.get('id', ''))
if not user_id_for_llm:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
raw = llm_text_gen(prompt=prompt, system_prompt=system, json_struct=schema, user_id=user_id_for_llm)
data = raw if isinstance(raw, dict) else {}
suggestions = data.get("suggestions") or []
# basic fallback if provider returns string
if not suggestions and isinstance(raw, str):
suggestions = [{"prompt": raw}]
return ImagePromptSuggestResponse(suggestions=[PromptSuggestion(**s) for s in suggestions])
except Exception as e:
logger.error(f"Prompt suggestion failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/edit", response_model=ImageEditResponse)
def edit(
req: ImageEditRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db)
) -> ImageEditResponse:
"""Edit image with subscription checking."""
try:
# Extract Clerk user ID (required)
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
# Decode base64 image
try:
input_image_bytes = base64.b64decode(req.image_base64)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid image_base64: {str(e)}")
# Validation is now handled inside edit_image function
result = edit_image(
input_image_bytes=input_image_bytes,
prompt=req.prompt,
options={
"provider": req.provider,
"model": req.model,
"guidance_scale": req.guidance_scale,
"steps": req.steps,
"seed": req.seed,
},
user_id=user_id, # Pass user_id for validation inside edit_image
)
edited_image_b64 = base64.b64encode(result.image_bytes).decode("utf-8")
# Save edited image to disk and track in asset library
image_url = None
image_filename = None
image_path = None
try:
# Create output directory for image studio edited images
base_dir = Path(__file__).parent.parent
output_dir = base_dir / "image_studio_images" / "edited"
# Generate safe filename from prompt
clean_prompt = sanitize_filename(req.prompt[:50], max_length=50)
image_filename = generate_unique_filename(
prefix=f"edited_{clean_prompt}",
extension=".png",
include_uuid=True
)
# Save file safely
image_path, save_error = save_file_safely(
content=result.image_bytes,
directory=output_dir,
filename=image_filename,
max_file_size=50 * 1024 * 1024 # 50MB for images
)
if image_path and not save_error:
# Generate file URL
image_url = f"/api/images/image-studio/images/edited/{image_path.name}"
logger.info(f"[images.edit] Saved edited image to: {image_path} ({len(result.image_bytes)} bytes)")
# Save to asset library (non-blocking)
try:
asset_id = save_asset_to_library(
db=db,
user_id=user_id,
asset_type="image",
source_module="image_studio",
filename=image_path.name,
file_url=image_url,
file_path=str(image_path),
file_size=len(result.image_bytes),
mime_type="image/png",
title=f"Edited: {req.prompt[:100]}" if len(req.prompt) <= 100 else f"Edited: {req.prompt[:97]}...",
description=f"Edited image with prompt: {req.prompt[:200]}" if len(req.prompt) > 200 else f"Edited image with prompt: {req.prompt}",
prompt=req.prompt,
tags=["image_studio", "edited", result.provider] if result.provider else ["image_studio", "edited"],
provider=result.provider,
model=result.model,
asset_metadata={
"width": result.width,
"height": result.height,
"seed": result.seed,
"status": "completed",
"operation": "edit"
}
)
if asset_id:
logger.info(f"[images.edit] ✅ Asset saved to library: ID={asset_id}, filename={image_path.name}")
else:
logger.warning(f"[images.edit] Asset tracking returned None (may have failed silently)")
except Exception as asset_error:
logger.error(f"[images.edit] Failed to save asset to library: {asset_error}", exc_info=True)
# Don't fail the request if asset tracking fails
else:
logger.warning(f"[images.edit] Failed to save edited image to disk: {save_error}")
# Continue without failing the request - base64 is still available
except Exception as save_error:
logger.error(f"[images.edit] Unexpected error saving edited image: {save_error}", exc_info=True)
# Continue without failing the request
# TRACK USAGE after successful image editing
if result:
logger.info(f"[images.edit] ✅ Image editing successful, tracking usage for user {user_id}")
try:
db_track = next(get_db())
try:
# Get or create usage summary
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
logger.debug(f"[images.edit] Looking for usage summary: user_id={user_id}, period={current_period}")
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
logger.info(f"[images.edit] Creating new usage summary for user {user_id}, period {current_period}")
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush() # Ensure summary is persisted before updating
# Get "before" state for unified log
current_calls_before = getattr(summary, "image_edit_calls", 0) or 0
# Update image editing counters (separate from image generation)
new_calls = current_calls_before + 1
setattr(summary, "image_edit_calls", new_calls)
logger.debug(f"[images.edit] Updated image_edit_calls: {current_calls_before} -> {new_calls}")
# Update totals
old_total_calls = summary.total_calls or 0
summary.total_calls = old_total_calls + 1
logger.debug(f"[images.edit] Updated totals: calls {old_total_calls} -> {summary.total_calls}")
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
call_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
# Get image generation stats for unified log
current_image_gen_calls = getattr(summary, "stability_calls", 0) or 0
image_gen_limit = limits['limits'].get("stability_calls", 0) if limits else 0
# Get video stats for unified log
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
# Get audio stats for unified log
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else ''
db_track.commit()
logger.info(f"[images.edit] ✅ Successfully tracked usage: user {user_id} -> image_edit -> {new_calls} calls")
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
print(f"""
[SUBSCRIPTION] Image Editing
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: image_edit
├─ Actual Provider: {result.provider}
├─ Model: {result.model or 'default'}
├─ Calls: {current_calls_before}{new_calls} / {call_limit if call_limit > 0 else ''}
├─ Images: {current_image_gen_calls} / {image_gen_limit if image_gen_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
├─ Audio: {current_audio_calls} / {audio_limit_display}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:
logger.error(f"[images.edit] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
# Non-blocking: log error but don't fail the request
logger.error(f"[images.edit] ❌ Failed to track usage: {usage_error}", exc_info=True)
return ImageEditResponse(
image_base64=edited_image_b64,
image_url=image_url,
width=result.width,
height=result.height,
provider=result.provider,
model=result.model,
seed=result.seed,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Image editing failed: {e}", exc_info=True)
# Provide a clean, actionable message to the client
raise HTTPException(
status_code=500,
detail="Image editing service is temporarily unavailable or the connection was reset. Please try again."
)
# ---------------------------
# Image Serving Endpoints
# ---------------------------
@router.get("/image-studio/images/{image_filename:path}")
async def serve_image_studio_image(
image_filename: str,
current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Serve a generated or edited image from Image Studio."""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
# Determine if it's an edited image or regular image
base_dir = Path(__file__).parent.parent
image_studio_dir = (base_dir / "image_studio_images").resolve()
if image_filename.startswith("edited/"):
# Remove "edited/" prefix and serve from edited directory
actual_filename = image_filename.replace("edited/", "", 1)
image_path = (image_studio_dir / "edited" / actual_filename).resolve()
base_subdir = (image_studio_dir / "edited").resolve()
else:
image_path = (image_studio_dir / image_filename).resolve()
base_subdir = image_studio_dir
# Security: Prevent directory traversal attacks
# Ensure the resolved path is within the intended directory
try:
image_path.relative_to(base_subdir)
except ValueError:
raise HTTPException(
status_code=403,
detail="Access denied: Invalid image path"
)
if not image_path.exists():
raise HTTPException(status_code=404, detail="Image not found")
return FileResponse(
path=str(image_path),
media_type="image/png",
filename=image_path.name
)
except HTTPException:
raise
except Exception as e:
logger.error(f"[images] Failed to serve image: {e}")
raise HTTPException(status_code=500, detail=str(e))