Changes: 1. helpers.py (_track_image_operation_usage): Map provider name to DB columns dynamically (stability→stability_calls, wavespeed→wavespeed_calls, etc.) instead of hardcoding stability_calls/stability_cost. 2. upscale_service.py: Added _track_image_operation_usage() call after successful Stability upscale completion. 3. control_service.py: Added _track_image_operation_usage() call after successful Stability control operation completion. 4. edit_service.py: Added _track_image_operation_usage() call after successful Stability edit operation (remove_background, inpaint, outpaint, search_replace, search_recolor, relight). Previously only Create Studio and Face Swap tracked usage. Now all five studios correctly decrement subscription limits.
167 lines
6.0 KiB
Python
167 lines
6.0 KiB
Python
import base64
|
|
import io
|
|
from dataclasses import dataclass
|
|
from typing import Literal, Optional, Dict, Any
|
|
|
|
from fastapi import HTTPException
|
|
from PIL import Image
|
|
|
|
from services.stability_service import StabilityAIService
|
|
from utils.logger_utils import get_service_logger
|
|
|
|
logger = get_service_logger("image_studio.upscale")
|
|
|
|
|
|
UpscaleMode = Literal["fast", "conservative", "creative", "auto"]
|
|
|
|
|
|
@dataclass
|
|
class UpscaleStudioRequest:
|
|
image_base64: str
|
|
mode: UpscaleMode = "auto"
|
|
target_width: Optional[int] = None
|
|
target_height: Optional[int] = None
|
|
preset: Optional[str] = None # e.g., web/print/social
|
|
prompt: Optional[str] = None # used for conservative/creative modes
|
|
|
|
|
|
class UpscaleStudioService:
|
|
"""Handles image upscaling workflows."""
|
|
|
|
def __init__(self):
|
|
logger.info("[Upscale Studio] Service initialized")
|
|
|
|
async def process_upscale(
|
|
self,
|
|
request: UpscaleStudioRequest,
|
|
user_id: Optional[str] = None,
|
|
) -> Dict[str, Any]:
|
|
# Pre-flight validation: Reuse unified helper
|
|
# Note: Using image-generation validation since upscaling uses same subscription limits
|
|
if user_id:
|
|
from services.llm_providers.main_image_generation import _validate_image_operation
|
|
_validate_image_operation(
|
|
user_id=user_id,
|
|
operation_type="image-upscale",
|
|
num_operations=1,
|
|
log_prefix="[Upscale Studio]"
|
|
)
|
|
|
|
image_bytes = self._decode_base64(request.image_base64)
|
|
if not image_bytes:
|
|
raise ValueError("Primary image is required for upscaling")
|
|
|
|
mode = self._resolve_mode(request)
|
|
|
|
async with StabilityAIService() as stability_service:
|
|
logger.info("[Upscale Studio] Running '%s' upscale for user=%s", mode, user_id)
|
|
|
|
params = {
|
|
"target_width": request.target_width,
|
|
"target_height": request.target_height,
|
|
}
|
|
# remove None values
|
|
params = {k: v for k, v in params.items() if v is not None}
|
|
|
|
if mode == "fast":
|
|
result = await stability_service.upscale_fast(
|
|
image=image_bytes,
|
|
**params,
|
|
)
|
|
elif mode == "conservative":
|
|
prompt = request.prompt or "High fidelity upscale preserving original details"
|
|
result = await stability_service.upscale_conservative(
|
|
image=image_bytes,
|
|
prompt=prompt,
|
|
**params,
|
|
)
|
|
elif mode == "creative":
|
|
prompt = request.prompt or "Creative upscale with enhanced artistic details"
|
|
result = await stability_service.upscale_creative(
|
|
image=image_bytes,
|
|
prompt=prompt,
|
|
**params,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported upscale mode: {mode}")
|
|
|
|
image_bytes = self._extract_image_bytes(result)
|
|
metadata = self._image_metadata(image_bytes)
|
|
|
|
# Track usage
|
|
if user_id:
|
|
from services.llm_providers.main_image_generation import _track_image_operation_usage
|
|
_track_image_operation_usage(
|
|
user_id=user_id,
|
|
provider="stability",
|
|
model=f"upscale-{mode}",
|
|
operation_type="image-upscale",
|
|
result_bytes=image_bytes,
|
|
cost=0.04,
|
|
endpoint="/image-studio/upscale",
|
|
log_prefix="[Upscale Studio]"
|
|
)
|
|
|
|
return {
|
|
"success": True,
|
|
"mode": mode,
|
|
"image_base64": self._to_base64(image_bytes),
|
|
"width": metadata["width"],
|
|
"height": metadata["height"],
|
|
"metadata": {
|
|
"preset": request.preset,
|
|
"target_width": request.target_width,
|
|
"target_height": request.target_height,
|
|
"prompt": request.prompt,
|
|
},
|
|
}
|
|
|
|
@staticmethod
|
|
def _decode_base64(value: Optional[str]) -> Optional[bytes]:
|
|
if not value:
|
|
return None
|
|
try:
|
|
if value.startswith("data:"):
|
|
_, b64data = value.split(",", 1)
|
|
else:
|
|
b64data = value
|
|
return base64.b64decode(b64data)
|
|
except Exception as exc:
|
|
logger.error("[Upscale Studio] Failed to decode base64 image: %s", exc)
|
|
raise ValueError("Invalid base64 image payload") from exc
|
|
|
|
@staticmethod
|
|
def _to_base64(image_bytes: bytes) -> str:
|
|
return f"data:image/png;base64,{base64.b64encode(image_bytes).decode('utf-8')}"
|
|
|
|
@staticmethod
|
|
def _image_metadata(image_bytes: bytes) -> Dict[str, int]:
|
|
with Image.open(io.BytesIO(image_bytes)) as img:
|
|
return {"width": img.width, "height": img.height}
|
|
|
|
@staticmethod
|
|
def _extract_image_bytes(result: Any) -> bytes:
|
|
if isinstance(result, bytes):
|
|
return result
|
|
if isinstance(result, dict):
|
|
artifacts = result.get("artifacts") or result.get("data") or result.get("images") or []
|
|
for artifact in artifacts:
|
|
if isinstance(artifact, dict):
|
|
if artifact.get("base64"):
|
|
return base64.b64decode(artifact["base64"])
|
|
if artifact.get("b64_json"):
|
|
return base64.b64decode(artifact["b64_json"])
|
|
raise HTTPException(status_code=502, detail="Unable to extract image from provider response")
|
|
|
|
@staticmethod
|
|
def _resolve_mode(request: UpscaleStudioRequest) -> UpscaleMode:
|
|
if request.mode != "auto":
|
|
return request.mode
|
|
# simple heuristic: if target >= 3000px, use conservative, else fast
|
|
if (request.target_width and request.target_width >= 3000) or (
|
|
request.target_height and request.target_height >= 3000
|
|
):
|
|
return "conservative"
|
|
return "fast"
|
|
|