Files
ALwrity/backend/services/image_studio/upscale_service.py
ajaysi ca725b77e7 refactor(phase2): add provider-aware tracking and fill missing subscription usage tracking
Changes:
1. helpers.py (_track_image_operation_usage): Map provider name to DB columns
   dynamically (stability→stability_calls, wavespeed→wavespeed_calls, etc.)
   instead of hardcoding stability_calls/stability_cost.

2. upscale_service.py: Added _track_image_operation_usage() call after
   successful Stability upscale completion.

3. control_service.py: Added _track_image_operation_usage() call after
   successful Stability control operation completion.

4. edit_service.py: Added _track_image_operation_usage() call after
   successful Stability edit operation (remove_background, inpaint,
   outpaint, search_replace, search_recolor, relight).

Previously only Create Studio and Face Swap tracked usage. Now all five
studios correctly decrement subscription limits.
2026-05-14 09:11:51 +05:30

167 lines
6.0 KiB
Python

import base64
import io
from dataclasses import dataclass
from typing import Literal, Optional, Dict, Any
from fastapi import HTTPException
from PIL import Image
from services.stability_service import StabilityAIService
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.upscale")
UpscaleMode = Literal["fast", "conservative", "creative", "auto"]
@dataclass
class UpscaleStudioRequest:
image_base64: str
mode: UpscaleMode = "auto"
target_width: Optional[int] = None
target_height: Optional[int] = None
preset: Optional[str] = None # e.g., web/print/social
prompt: Optional[str] = None # used for conservative/creative modes
class UpscaleStudioService:
"""Handles image upscaling workflows."""
def __init__(self):
logger.info("[Upscale Studio] Service initialized")
async def process_upscale(
self,
request: UpscaleStudioRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
# Pre-flight validation: Reuse unified helper
# Note: Using image-generation validation since upscaling uses same subscription limits
if user_id:
from services.llm_providers.main_image_generation import _validate_image_operation
_validate_image_operation(
user_id=user_id,
operation_type="image-upscale",
num_operations=1,
log_prefix="[Upscale Studio]"
)
image_bytes = self._decode_base64(request.image_base64)
if not image_bytes:
raise ValueError("Primary image is required for upscaling")
mode = self._resolve_mode(request)
async with StabilityAIService() as stability_service:
logger.info("[Upscale Studio] Running '%s' upscale for user=%s", mode, user_id)
params = {
"target_width": request.target_width,
"target_height": request.target_height,
}
# remove None values
params = {k: v for k, v in params.items() if v is not None}
if mode == "fast":
result = await stability_service.upscale_fast(
image=image_bytes,
**params,
)
elif mode == "conservative":
prompt = request.prompt or "High fidelity upscale preserving original details"
result = await stability_service.upscale_conservative(
image=image_bytes,
prompt=prompt,
**params,
)
elif mode == "creative":
prompt = request.prompt or "Creative upscale with enhanced artistic details"
result = await stability_service.upscale_creative(
image=image_bytes,
prompt=prompt,
**params,
)
else:
raise ValueError(f"Unsupported upscale mode: {mode}")
image_bytes = self._extract_image_bytes(result)
metadata = self._image_metadata(image_bytes)
# Track usage
if user_id:
from services.llm_providers.main_image_generation import _track_image_operation_usage
_track_image_operation_usage(
user_id=user_id,
provider="stability",
model=f"upscale-{mode}",
operation_type="image-upscale",
result_bytes=image_bytes,
cost=0.04,
endpoint="/image-studio/upscale",
log_prefix="[Upscale Studio]"
)
return {
"success": True,
"mode": mode,
"image_base64": self._to_base64(image_bytes),
"width": metadata["width"],
"height": metadata["height"],
"metadata": {
"preset": request.preset,
"target_width": request.target_width,
"target_height": request.target_height,
"prompt": request.prompt,
},
}
@staticmethod
def _decode_base64(value: Optional[str]) -> Optional[bytes]:
if not value:
return None
try:
if value.startswith("data:"):
_, b64data = value.split(",", 1)
else:
b64data = value
return base64.b64decode(b64data)
except Exception as exc:
logger.error("[Upscale Studio] Failed to decode base64 image: %s", exc)
raise ValueError("Invalid base64 image payload") from exc
@staticmethod
def _to_base64(image_bytes: bytes) -> str:
return f"data:image/png;base64,{base64.b64encode(image_bytes).decode('utf-8')}"
@staticmethod
def _image_metadata(image_bytes: bytes) -> Dict[str, int]:
with Image.open(io.BytesIO(image_bytes)) as img:
return {"width": img.width, "height": img.height}
@staticmethod
def _extract_image_bytes(result: Any) -> bytes:
if isinstance(result, bytes):
return result
if isinstance(result, dict):
artifacts = result.get("artifacts") or result.get("data") or result.get("images") or []
for artifact in artifacts:
if isinstance(artifact, dict):
if artifact.get("base64"):
return base64.b64decode(artifact["base64"])
if artifact.get("b64_json"):
return base64.b64decode(artifact["b64_json"])
raise HTTPException(status_code=502, detail="Unable to extract image from provider response")
@staticmethod
def _resolve_mode(request: UpscaleStudioRequest) -> UpscaleMode:
if request.mode != "auto":
return request.mode
# simple heuristic: if target >= 3000px, use conservative, else fast
if (request.target_width and request.target_width >= 3000) or (
request.target_height and request.target_height >= 3000
):
return "conservative"
return "fast"