AI Image Studio Phase 1
This commit is contained in:
154
backend/services/image_studio/upscale_service.py
Normal file
154
backend/services/image_studio/upscale_service.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import base64
|
||||
import io
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Optional, Dict, Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
from PIL import Image
|
||||
|
||||
from services.stability_service import StabilityAIService
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("image_studio.upscale")
|
||||
|
||||
|
||||
UpscaleMode = Literal["fast", "conservative", "creative", "auto"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpscaleStudioRequest:
|
||||
image_base64: str
|
||||
mode: UpscaleMode = "auto"
|
||||
target_width: Optional[int] = None
|
||||
target_height: Optional[int] = None
|
||||
preset: Optional[str] = None # e.g., web/print/social
|
||||
prompt: Optional[str] = None # used for conservative/creative modes
|
||||
|
||||
|
||||
class UpscaleStudioService:
|
||||
"""Handles image upscaling workflows."""
|
||||
|
||||
def __init__(self):
|
||||
logger.info("[Upscale Studio] Service initialized")
|
||||
|
||||
async def process_upscale(
|
||||
self,
|
||||
request: UpscaleStudioRequest,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
if user_id:
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_upscale_operations
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
logger.info("[Upscale Studio] 🛂 Running pre-flight validation for user %s", user_id)
|
||||
validate_image_upscale_operations(pricing_service=pricing_service, user_id=user_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
image_bytes = self._decode_base64(request.image_base64)
|
||||
if not image_bytes:
|
||||
raise ValueError("Primary image is required for upscaling")
|
||||
|
||||
mode = self._resolve_mode(request)
|
||||
|
||||
async with StabilityAIService() as stability_service:
|
||||
logger.info("[Upscale Studio] Running '%s' upscale for user=%s", mode, user_id)
|
||||
|
||||
params = {
|
||||
"target_width": request.target_width,
|
||||
"target_height": request.target_height,
|
||||
}
|
||||
# remove None values
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
if mode == "fast":
|
||||
result = await stability_service.upscale_fast(
|
||||
image=image_bytes,
|
||||
**params,
|
||||
)
|
||||
elif mode == "conservative":
|
||||
prompt = request.prompt or "High fidelity upscale preserving original details"
|
||||
result = await stability_service.upscale_conservative(
|
||||
image=image_bytes,
|
||||
prompt=prompt,
|
||||
**params,
|
||||
)
|
||||
elif mode == "creative":
|
||||
prompt = request.prompt or "Creative upscale with enhanced artistic details"
|
||||
result = await stability_service.upscale_creative(
|
||||
image=image_bytes,
|
||||
prompt=prompt,
|
||||
**params,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported upscale mode: {mode}")
|
||||
|
||||
image_bytes = self._extract_image_bytes(result)
|
||||
metadata = self._image_metadata(image_bytes)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"mode": mode,
|
||||
"image_base64": self._to_base64(image_bytes),
|
||||
"width": metadata["width"],
|
||||
"height": metadata["height"],
|
||||
"metadata": {
|
||||
"preset": request.preset,
|
||||
"target_width": request.target_width,
|
||||
"target_height": request.target_height,
|
||||
"prompt": request.prompt,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _decode_base64(value: Optional[str]) -> Optional[bytes]:
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
if value.startswith("data:"):
|
||||
_, b64data = value.split(",", 1)
|
||||
else:
|
||||
b64data = value
|
||||
return base64.b64decode(b64data)
|
||||
except Exception as exc:
|
||||
logger.error("[Upscale Studio] Failed to decode base64 image: %s", exc)
|
||||
raise ValueError("Invalid base64 image payload") from exc
|
||||
|
||||
@staticmethod
|
||||
def _to_base64(image_bytes: bytes) -> str:
|
||||
return f"data:image/png;base64,{base64.b64encode(image_bytes).decode('utf-8')}"
|
||||
|
||||
@staticmethod
|
||||
def _image_metadata(image_bytes: bytes) -> Dict[str, int]:
|
||||
with Image.open(io.BytesIO(image_bytes)) as img:
|
||||
return {"width": img.width, "height": img.height}
|
||||
|
||||
@staticmethod
|
||||
def _extract_image_bytes(result: Any) -> bytes:
|
||||
if isinstance(result, bytes):
|
||||
return result
|
||||
if isinstance(result, dict):
|
||||
artifacts = result.get("artifacts") or result.get("data") or result.get("images") or []
|
||||
for artifact in artifacts:
|
||||
if isinstance(artifact, dict):
|
||||
if artifact.get("base64"):
|
||||
return base64.b64decode(artifact["base64"])
|
||||
if artifact.get("b64_json"):
|
||||
return base64.b64decode(artifact["b64_json"])
|
||||
raise HTTPException(status_code=502, detail="Unable to extract image from provider response")
|
||||
|
||||
@staticmethod
|
||||
def _resolve_mode(request: UpscaleStudioRequest) -> UpscaleMode:
|
||||
if request.mode != "auto":
|
||||
return request.mode
|
||||
# simple heuristic: if target >= 3000px, use conservative, else fast
|
||||
if (request.target_width and request.target_width >= 3000) or (
|
||||
request.target_height and request.target_height >= 3000
|
||||
):
|
||||
return "conservative"
|
||||
return "fast"
|
||||
|
||||
Reference in New Issue
Block a user