- Added new router for content assets - Added new service for content assets - Added new model for content assets - Added new utils for content assets - Added new docs for content assets - Added new tests for content assets - Added new examples for content assets - Added new guides for content assets
278 lines
10 KiB
Python
278 lines
10 KiB
Python
"""Control Studio service for AI-powered controlled image generation."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import io
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, Literal, Optional
|
|
|
|
from PIL import Image
|
|
|
|
from services.stability_service import StabilityAIService
|
|
from utils.logger_utils import get_service_logger
|
|
|
|
|
|
logger = get_service_logger("image_studio.control")
|
|
|
|
|
|
ControlOperationType = Literal[
|
|
"sketch",
|
|
"structure",
|
|
"style",
|
|
"style_transfer",
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class ControlStudioRequest:
|
|
"""Normalized request payload for Control Studio operations."""
|
|
|
|
operation: ControlOperationType
|
|
prompt: str
|
|
control_image_base64: str # Sketch, structure, or style reference
|
|
style_image_base64: Optional[str] = None # For style_transfer only
|
|
negative_prompt: Optional[str] = None
|
|
control_strength: Optional[float] = None # For sketch/structure
|
|
fidelity: Optional[float] = None # For style
|
|
style_strength: Optional[float] = None # For style_transfer
|
|
composition_fidelity: Optional[float] = None # For style_transfer
|
|
change_strength: Optional[float] = None # For style_transfer
|
|
aspect_ratio: Optional[str] = None # For style
|
|
style_preset: Optional[str] = None
|
|
seed: Optional[int] = None
|
|
output_format: str = "png"
|
|
|
|
|
|
class ControlStudioService:
|
|
"""Service layer orchestrating Control Studio operations."""
|
|
|
|
SUPPORTED_OPERATIONS: Dict[ControlOperationType, Dict[str, Any]] = {
|
|
"sketch": {
|
|
"label": "Sketch to Image",
|
|
"description": "Transform sketches into refined images with precise control.",
|
|
"provider": "stability",
|
|
"fields": {
|
|
"control_image": True,
|
|
"style_image": False,
|
|
"control_strength": True,
|
|
"fidelity": False,
|
|
"style_strength": False,
|
|
"aspect_ratio": False,
|
|
},
|
|
},
|
|
"structure": {
|
|
"label": "Structure Control",
|
|
"description": "Generate images maintaining the structure of an input image.",
|
|
"provider": "stability",
|
|
"fields": {
|
|
"control_image": True,
|
|
"style_image": False,
|
|
"control_strength": True,
|
|
"fidelity": False,
|
|
"style_strength": False,
|
|
"aspect_ratio": False,
|
|
},
|
|
},
|
|
"style": {
|
|
"label": "Style Control",
|
|
"description": "Generate images using style from a reference image.",
|
|
"provider": "stability",
|
|
"fields": {
|
|
"control_image": True,
|
|
"style_image": False,
|
|
"control_strength": False,
|
|
"fidelity": True,
|
|
"style_strength": False,
|
|
"aspect_ratio": True,
|
|
},
|
|
},
|
|
"style_transfer": {
|
|
"label": "Style Transfer",
|
|
"description": "Apply visual characteristics from a style image to a target image.",
|
|
"provider": "stability",
|
|
"fields": {
|
|
"control_image": True, # init_image
|
|
"style_image": True,
|
|
"control_strength": False,
|
|
"fidelity": False,
|
|
"style_strength": True,
|
|
"aspect_ratio": False,
|
|
},
|
|
},
|
|
}
|
|
|
|
def __init__(self):
|
|
logger.info("[Control Studio] Initialized control service")
|
|
|
|
@staticmethod
|
|
def _decode_base64_image(value: Optional[str]) -> Optional[bytes]:
|
|
"""Decode a base64 (or data URL) string to bytes."""
|
|
if not value:
|
|
return None
|
|
|
|
try:
|
|
# Handle data URLs (data:image/png;base64,...)
|
|
if value.startswith("data:"):
|
|
_, b64data = value.split(",", 1)
|
|
else:
|
|
b64data = value
|
|
|
|
return base64.b64decode(b64data)
|
|
except Exception as exc:
|
|
logger.error(f"[Control Studio] Failed to decode base64 image: {exc}")
|
|
raise ValueError("Invalid base64 image payload") from exc
|
|
|
|
@staticmethod
|
|
def _image_bytes_to_metadata(image_bytes: bytes) -> Dict[str, Any]:
|
|
"""Extract width/height metadata from image bytes."""
|
|
with Image.open(io.BytesIO(image_bytes)) as img:
|
|
return {
|
|
"width": img.width,
|
|
"height": img.height,
|
|
}
|
|
|
|
@staticmethod
|
|
def _bytes_to_base64(image_bytes: bytes, output_format: str = "png") -> str:
|
|
"""Convert raw bytes to base64 data URL."""
|
|
b64 = base64.b64encode(image_bytes).decode("utf-8")
|
|
return f"data:image/{output_format};base64,{b64}"
|
|
|
|
def list_operations(self) -> Dict[str, Dict[str, Any]]:
|
|
"""Expose supported operations for UI rendering."""
|
|
return self.SUPPORTED_OPERATIONS
|
|
|
|
async def process_control(
|
|
self,
|
|
request: ControlStudioRequest,
|
|
user_id: Optional[str] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Process control request and return normalized response."""
|
|
|
|
if user_id:
|
|
from services.database import get_db
|
|
from services.subscription import PricingService
|
|
from services.subscription.preflight_validator import validate_image_control_operations
|
|
from fastapi import HTTPException
|
|
|
|
db = next(get_db())
|
|
try:
|
|
pricing_service = PricingService(db)
|
|
logger.info(f"[Control Studio] 🛂 Running pre-flight validation for user {user_id}")
|
|
validate_image_control_operations(
|
|
pricing_service=pricing_service,
|
|
user_id=user_id,
|
|
num_images=1,
|
|
)
|
|
logger.info("[Control Studio] ✅ Pre-flight validation passed")
|
|
except HTTPException:
|
|
logger.error("[Control Studio] ❌ Pre-flight validation failed")
|
|
raise
|
|
finally:
|
|
db.close()
|
|
else:
|
|
logger.warning("[Control Studio] ⚠️ No user_id provided - skipping pre-flight validation")
|
|
|
|
control_image_bytes = self._decode_base64_image(request.control_image_base64)
|
|
if not control_image_bytes:
|
|
raise ValueError("Control image payload is required")
|
|
|
|
style_image_bytes = self._decode_base64_image(request.style_image_base64)
|
|
|
|
operation = request.operation
|
|
logger.info("[Control Studio] Processing operation='%s' for user=%s", operation, user_id)
|
|
|
|
if operation not in self.SUPPORTED_OPERATIONS:
|
|
raise ValueError(f"Unsupported control operation: {operation}")
|
|
|
|
stability_service = StabilityAIService()
|
|
async with stability_service:
|
|
if operation == "sketch":
|
|
result = await stability_service.control_sketch(
|
|
image=control_image_bytes,
|
|
prompt=request.prompt,
|
|
control_strength=request.control_strength or 0.7,
|
|
negative_prompt=request.negative_prompt,
|
|
seed=request.seed,
|
|
output_format=request.output_format,
|
|
style_preset=request.style_preset,
|
|
)
|
|
elif operation == "structure":
|
|
result = await stability_service.control_structure(
|
|
image=control_image_bytes,
|
|
prompt=request.prompt,
|
|
control_strength=request.control_strength or 0.7,
|
|
negative_prompt=request.negative_prompt,
|
|
seed=request.seed,
|
|
output_format=request.output_format,
|
|
style_preset=request.style_preset,
|
|
)
|
|
elif operation == "style":
|
|
result = await stability_service.control_style(
|
|
image=control_image_bytes,
|
|
prompt=request.prompt,
|
|
negative_prompt=request.negative_prompt,
|
|
aspect_ratio=request.aspect_ratio or "1:1",
|
|
fidelity=request.fidelity or 0.5,
|
|
seed=request.seed,
|
|
output_format=request.output_format,
|
|
style_preset=request.style_preset,
|
|
)
|
|
elif operation == "style_transfer":
|
|
if not style_image_bytes:
|
|
raise ValueError("Style image is required for style transfer")
|
|
result = await stability_service.control_style_transfer(
|
|
init_image=control_image_bytes,
|
|
style_image=style_image_bytes,
|
|
prompt=request.prompt or "",
|
|
negative_prompt=request.negative_prompt,
|
|
style_strength=request.style_strength or 1.0,
|
|
composition_fidelity=request.composition_fidelity or 0.9,
|
|
change_strength=request.change_strength or 0.9,
|
|
seed=request.seed,
|
|
output_format=request.output_format,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported control operation: {operation}")
|
|
|
|
image_bytes = self._extract_image_bytes(result)
|
|
metadata = self._image_bytes_to_metadata(image_bytes)
|
|
metadata.update(
|
|
{
|
|
"operation": operation,
|
|
"style_preset": request.style_preset,
|
|
"provider": self.SUPPORTED_OPERATIONS[operation]["provider"],
|
|
}
|
|
)
|
|
|
|
response = {
|
|
"success": True,
|
|
"operation": operation,
|
|
"provider": metadata["provider"],
|
|
"image_base64": self._bytes_to_base64(image_bytes, request.output_format),
|
|
"width": metadata["width"],
|
|
"height": metadata["height"],
|
|
"metadata": metadata,
|
|
}
|
|
|
|
logger.info("[Control Studio] ✅ Operation '%s' completed", operation)
|
|
return response
|
|
|
|
@staticmethod
|
|
def _extract_image_bytes(result: Any) -> bytes:
|
|
"""Normalize Stability responses into raw image bytes."""
|
|
if isinstance(result, bytes):
|
|
return result
|
|
|
|
if isinstance(result, dict):
|
|
artifacts = result.get("artifacts") or result.get("data") or result.get("images") or []
|
|
for artifact in artifacts:
|
|
if isinstance(artifact, dict):
|
|
if artifact.get("base64"):
|
|
return base64.b64decode(artifact["base64"])
|
|
if artifact.get("b64_json"):
|
|
return base64.b64decode(artifact["b64_json"])
|
|
|
|
raise RuntimeError("Unable to extract image bytes from provider response")
|
|
|