Changes: 1. helpers.py (_track_image_operation_usage): Map provider name to DB columns dynamically (stability→stability_calls, wavespeed→wavespeed_calls, etc.) instead of hardcoding stability_calls/stability_cost. 2. upscale_service.py: Added _track_image_operation_usage() call after successful Stability upscale completion. 3. control_service.py: Added _track_image_operation_usage() call after successful Stability control operation completion. 4. edit_service.py: Added _track_image_operation_usage() call after successful Stability edit operation (remove_background, inpaint, outpaint, search_replace, search_recolor, relight). Previously only Create Studio and Face Swap tracked usage. Now all five studios correctly decrement subscription limits.
293 lines
11 KiB
Python
293 lines
11 KiB
Python
"""Control Studio service for AI-powered controlled image generation."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import io
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, Literal, Optional
|
|
|
|
from PIL import Image
|
|
|
|
from services.stability_service import StabilityAIService
|
|
from utils.logger_utils import get_service_logger
|
|
|
|
|
|
logger = get_service_logger("image_studio.control")
|
|
|
|
|
|
ControlOperationType = Literal[
|
|
"sketch",
|
|
"structure",
|
|
"style",
|
|
"style_transfer",
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class ControlStudioRequest:
|
|
"""Normalized request payload for Control Studio operations."""
|
|
|
|
operation: ControlOperationType
|
|
prompt: str
|
|
control_image_base64: str # Sketch, structure, or style reference
|
|
style_image_base64: Optional[str] = None # For style_transfer only
|
|
negative_prompt: Optional[str] = None
|
|
control_strength: Optional[float] = None # For sketch/structure
|
|
fidelity: Optional[float] = None # For style
|
|
style_strength: Optional[float] = None # For style_transfer
|
|
composition_fidelity: Optional[float] = None # For style_transfer
|
|
change_strength: Optional[float] = None # For style_transfer
|
|
aspect_ratio: Optional[str] = None # For style
|
|
style_preset: Optional[str] = None
|
|
seed: Optional[int] = None
|
|
output_format: str = "png"
|
|
|
|
|
|
class ControlStudioService:
|
|
"""Service layer orchestrating Control Studio operations."""
|
|
|
|
SUPPORTED_OPERATIONS: Dict[ControlOperationType, Dict[str, Any]] = {
|
|
"sketch": {
|
|
"label": "Sketch to Image",
|
|
"description": "Transform sketches into refined images with precise control.",
|
|
"provider": "stability",
|
|
"fields": {
|
|
"control_image": True,
|
|
"style_image": False,
|
|
"control_strength": True,
|
|
"fidelity": False,
|
|
"style_strength": False,
|
|
"aspect_ratio": False,
|
|
},
|
|
},
|
|
"structure": {
|
|
"label": "Structure Control",
|
|
"description": "Generate images maintaining the structure of an input image.",
|
|
"provider": "stability",
|
|
"fields": {
|
|
"control_image": True,
|
|
"style_image": False,
|
|
"control_strength": True,
|
|
"fidelity": False,
|
|
"style_strength": False,
|
|
"aspect_ratio": False,
|
|
},
|
|
},
|
|
"style": {
|
|
"label": "Style Control",
|
|
"description": "Generate images using style from a reference image.",
|
|
"provider": "stability",
|
|
"fields": {
|
|
"control_image": True,
|
|
"style_image": False,
|
|
"control_strength": False,
|
|
"fidelity": True,
|
|
"style_strength": False,
|
|
"aspect_ratio": True,
|
|
},
|
|
},
|
|
"style_transfer": {
|
|
"label": "Style Transfer",
|
|
"description": "Apply visual characteristics from a style image to a target image.",
|
|
"provider": "stability",
|
|
"fields": {
|
|
"control_image": True, # init_image
|
|
"style_image": True,
|
|
"control_strength": False,
|
|
"fidelity": False,
|
|
"style_strength": True,
|
|
"aspect_ratio": False,
|
|
},
|
|
},
|
|
}
|
|
|
|
def __init__(self):
|
|
logger.info("[Control Studio] Initialized control service")
|
|
|
|
@staticmethod
|
|
def _decode_base64_image(value: Optional[str]) -> Optional[bytes]:
|
|
"""Decode a base64 (or data URL) string to bytes."""
|
|
if not value:
|
|
return None
|
|
|
|
try:
|
|
# Handle data URLs (data:image/png;base64,...)
|
|
if value.startswith("data:"):
|
|
_, b64data = value.split(",", 1)
|
|
else:
|
|
b64data = value
|
|
|
|
return base64.b64decode(b64data)
|
|
except Exception as exc:
|
|
logger.error(f"[Control Studio] Failed to decode base64 image: {exc}")
|
|
raise ValueError("Invalid base64 image payload") from exc
|
|
|
|
@staticmethod
|
|
def _image_bytes_to_metadata(image_bytes: bytes) -> Dict[str, Any]:
|
|
"""Extract width/height metadata from image bytes."""
|
|
with Image.open(io.BytesIO(image_bytes)) as img:
|
|
return {
|
|
"width": img.width,
|
|
"height": img.height,
|
|
}
|
|
|
|
@staticmethod
|
|
def _bytes_to_base64(image_bytes: bytes, output_format: str = "png") -> str:
|
|
"""Convert raw bytes to base64 data URL."""
|
|
b64 = base64.b64encode(image_bytes).decode("utf-8")
|
|
return f"data:image/{output_format};base64,{b64}"
|
|
|
|
def list_operations(self) -> Dict[str, Dict[str, Any]]:
|
|
"""Expose supported operations for UI rendering."""
|
|
return self.SUPPORTED_OPERATIONS
|
|
|
|
async def process_control(
|
|
self,
|
|
request: ControlStudioRequest,
|
|
user_id: Optional[str] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Process control request and return normalized response."""
|
|
|
|
if user_id:
|
|
from services.database import get_db
|
|
from services.subscription import PricingService
|
|
from services.subscription.preflight_validator import validate_image_control_operations
|
|
from fastapi import HTTPException
|
|
|
|
db = next(get_db())
|
|
try:
|
|
pricing_service = PricingService(db)
|
|
logger.info(f"[Control Studio] 🛂 Running pre-flight validation for user {user_id}")
|
|
validate_image_control_operations(
|
|
pricing_service=pricing_service,
|
|
user_id=user_id,
|
|
num_images=1,
|
|
)
|
|
logger.info("[Control Studio] ✅ Pre-flight validation passed")
|
|
except HTTPException:
|
|
logger.error("[Control Studio] ❌ Pre-flight validation failed")
|
|
raise
|
|
finally:
|
|
db.close()
|
|
else:
|
|
logger.warning("[Control Studio] ⚠️ No user_id provided - skipping pre-flight validation")
|
|
|
|
control_image_bytes = self._decode_base64_image(request.control_image_base64)
|
|
if not control_image_bytes:
|
|
raise ValueError("Control image payload is required")
|
|
|
|
style_image_bytes = self._decode_base64_image(request.style_image_base64)
|
|
|
|
operation = request.operation
|
|
logger.info("[Control Studio] Processing operation='%s' for user=%s", operation, user_id)
|
|
|
|
if operation not in self.SUPPORTED_OPERATIONS:
|
|
raise ValueError(f"Unsupported control operation: {operation}")
|
|
|
|
stability_service = StabilityAIService()
|
|
async with stability_service:
|
|
if operation == "sketch":
|
|
result = await stability_service.control_sketch(
|
|
image=control_image_bytes,
|
|
prompt=request.prompt,
|
|
control_strength=request.control_strength or 0.7,
|
|
negative_prompt=request.negative_prompt,
|
|
seed=request.seed,
|
|
output_format=request.output_format,
|
|
style_preset=request.style_preset,
|
|
)
|
|
elif operation == "structure":
|
|
result = await stability_service.control_structure(
|
|
image=control_image_bytes,
|
|
prompt=request.prompt,
|
|
control_strength=request.control_strength or 0.7,
|
|
negative_prompt=request.negative_prompt,
|
|
seed=request.seed,
|
|
output_format=request.output_format,
|
|
style_preset=request.style_preset,
|
|
)
|
|
elif operation == "style":
|
|
result = await stability_service.control_style(
|
|
image=control_image_bytes,
|
|
prompt=request.prompt,
|
|
negative_prompt=request.negative_prompt,
|
|
aspect_ratio=request.aspect_ratio or "1:1",
|
|
fidelity=request.fidelity or 0.5,
|
|
seed=request.seed,
|
|
output_format=request.output_format,
|
|
style_preset=request.style_preset,
|
|
)
|
|
elif operation == "style_transfer":
|
|
if not style_image_bytes:
|
|
raise ValueError("Style image is required for style transfer")
|
|
result = await stability_service.control_style_transfer(
|
|
init_image=control_image_bytes,
|
|
style_image=style_image_bytes,
|
|
prompt=request.prompt or "",
|
|
negative_prompt=request.negative_prompt,
|
|
style_strength=request.style_strength or 1.0,
|
|
composition_fidelity=request.composition_fidelity or 0.9,
|
|
change_strength=request.change_strength or 0.9,
|
|
seed=request.seed,
|
|
output_format=request.output_format,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported control operation: {operation}")
|
|
|
|
image_bytes = self._extract_image_bytes(result)
|
|
metadata = self._image_bytes_to_metadata(image_bytes)
|
|
|
|
# Track usage
|
|
if user_id:
|
|
from services.llm_providers.main_image_generation import _track_image_operation_usage
|
|
_track_image_operation_usage(
|
|
user_id=user_id,
|
|
provider="stability",
|
|
model=f"control-{operation}",
|
|
operation_type="image-control",
|
|
result_bytes=image_bytes,
|
|
cost=0.04,
|
|
endpoint="/image-studio/control/process",
|
|
log_prefix="[Control Studio]"
|
|
)
|
|
|
|
metadata.update(
|
|
{
|
|
"operation": operation,
|
|
"style_preset": request.style_preset,
|
|
"provider": self.SUPPORTED_OPERATIONS[operation]["provider"],
|
|
}
|
|
)
|
|
|
|
response = {
|
|
"success": True,
|
|
"operation": operation,
|
|
"provider": metadata["provider"],
|
|
"image_base64": self._bytes_to_base64(image_bytes, request.output_format),
|
|
"width": metadata["width"],
|
|
"height": metadata["height"],
|
|
"metadata": metadata,
|
|
}
|
|
|
|
logger.info("[Control Studio] ✅ Operation '%s' completed", operation)
|
|
return response
|
|
|
|
@staticmethod
|
|
def _extract_image_bytes(result: Any) -> bytes:
|
|
"""Normalize Stability responses into raw image bytes."""
|
|
if isinstance(result, bytes):
|
|
return result
|
|
|
|
if isinstance(result, dict):
|
|
artifacts = result.get("artifacts") or result.get("data") or result.get("images") or []
|
|
for artifact in artifacts:
|
|
if isinstance(artifact, dict):
|
|
if artifact.get("base64"):
|
|
return base64.b64decode(artifact["base64"])
|
|
if artifact.get("b64_json"):
|
|
return base64.b64decode(artifact["b64_json"])
|
|
|
|
raise RuntimeError("Unable to extract image bytes from provider response")
|
|
|