Files
ALwrity/backend/services/image_studio/upscale_service.py
2025-11-20 09:06:00 +05:30

155 lines
5.6 KiB
Python

import base64
import io
from dataclasses import dataclass
from typing import Literal, Optional, Dict, Any
from fastapi import HTTPException
from PIL import Image
from services.stability_service import StabilityAIService
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.upscale")
UpscaleMode = Literal["fast", "conservative", "creative", "auto"]
@dataclass
class UpscaleStudioRequest:
image_base64: str
mode: UpscaleMode = "auto"
target_width: Optional[int] = None
target_height: Optional[int] = None
preset: Optional[str] = None # e.g., web/print/social
prompt: Optional[str] = None # used for conservative/creative modes
class UpscaleStudioService:
"""Handles image upscaling workflows."""
def __init__(self):
logger.info("[Upscale Studio] Service initialized")
async def process_upscale(
self,
request: UpscaleStudioRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_upscale_operations
db = next(get_db())
try:
pricing_service = PricingService(db)
logger.info("[Upscale Studio] 🛂 Running pre-flight validation for user %s", user_id)
validate_image_upscale_operations(pricing_service=pricing_service, user_id=user_id)
finally:
db.close()
image_bytes = self._decode_base64(request.image_base64)
if not image_bytes:
raise ValueError("Primary image is required for upscaling")
mode = self._resolve_mode(request)
async with StabilityAIService() as stability_service:
logger.info("[Upscale Studio] Running '%s' upscale for user=%s", mode, user_id)
params = {
"target_width": request.target_width,
"target_height": request.target_height,
}
# remove None values
params = {k: v for k, v in params.items() if v is not None}
if mode == "fast":
result = await stability_service.upscale_fast(
image=image_bytes,
**params,
)
elif mode == "conservative":
prompt = request.prompt or "High fidelity upscale preserving original details"
result = await stability_service.upscale_conservative(
image=image_bytes,
prompt=prompt,
**params,
)
elif mode == "creative":
prompt = request.prompt or "Creative upscale with enhanced artistic details"
result = await stability_service.upscale_creative(
image=image_bytes,
prompt=prompt,
**params,
)
else:
raise ValueError(f"Unsupported upscale mode: {mode}")
image_bytes = self._extract_image_bytes(result)
metadata = self._image_metadata(image_bytes)
return {
"success": True,
"mode": mode,
"image_base64": self._to_base64(image_bytes),
"width": metadata["width"],
"height": metadata["height"],
"metadata": {
"preset": request.preset,
"target_width": request.target_width,
"target_height": request.target_height,
"prompt": request.prompt,
},
}
@staticmethod
def _decode_base64(value: Optional[str]) -> Optional[bytes]:
if not value:
return None
try:
if value.startswith("data:"):
_, b64data = value.split(",", 1)
else:
b64data = value
return base64.b64decode(b64data)
except Exception as exc:
logger.error("[Upscale Studio] Failed to decode base64 image: %s", exc)
raise ValueError("Invalid base64 image payload") from exc
@staticmethod
def _to_base64(image_bytes: bytes) -> str:
return f"data:image/png;base64,{base64.b64encode(image_bytes).decode('utf-8')}"
@staticmethod
def _image_metadata(image_bytes: bytes) -> Dict[str, int]:
with Image.open(io.BytesIO(image_bytes)) as img:
return {"width": img.width, "height": img.height}
@staticmethod
def _extract_image_bytes(result: Any) -> bytes:
if isinstance(result, bytes):
return result
if isinstance(result, dict):
artifacts = result.get("artifacts") or result.get("data") or result.get("images") or []
for artifact in artifacts:
if isinstance(artifact, dict):
if artifact.get("base64"):
return base64.b64decode(artifact["base64"])
if artifact.get("b64_json"):
return base64.b64decode(artifact["b64_json"])
raise HTTPException(status_code=502, detail="Unable to extract image from provider response")
@staticmethod
def _resolve_mode(request: UpscaleStudioRequest) -> UpscaleMode:
if request.mode != "auto":
return request.mode
# simple heuristic: if target >= 3000px, use conservative, else fast
if (request.target_width and request.target_width >= 3000) or (
request.target_height and request.target_height >= 3000
):
return "conservative"
return "fast"