Files
ALwrity/backend/services/image_studio/control_service.py
ajaysi 77d7c0cde6 AI Image Studio Progress Review
- Added new router for content assets
- Added new service for content assets
- Added new model for content assets
- Added new utils for content assets
- Added new docs for content assets
- Added new tests for content assets
- Added new examples for content assets
- Added new guides for content assets
2025-11-23 09:21:11 +05:30

278 lines
10 KiB
Python

"""Control Studio service for AI-powered controlled image generation."""
from __future__ import annotations
import base64
import io
from dataclasses import dataclass
from typing import Any, Dict, Literal, Optional
from PIL import Image
from services.stability_service import StabilityAIService
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.control")
ControlOperationType = Literal[
"sketch",
"structure",
"style",
"style_transfer",
]
@dataclass
class ControlStudioRequest:
"""Normalized request payload for Control Studio operations."""
operation: ControlOperationType
prompt: str
control_image_base64: str # Sketch, structure, or style reference
style_image_base64: Optional[str] = None # For style_transfer only
negative_prompt: Optional[str] = None
control_strength: Optional[float] = None # For sketch/structure
fidelity: Optional[float] = None # For style
style_strength: Optional[float] = None # For style_transfer
composition_fidelity: Optional[float] = None # For style_transfer
change_strength: Optional[float] = None # For style_transfer
aspect_ratio: Optional[str] = None # For style
style_preset: Optional[str] = None
seed: Optional[int] = None
output_format: str = "png"
class ControlStudioService:
"""Service layer orchestrating Control Studio operations."""
SUPPORTED_OPERATIONS: Dict[ControlOperationType, Dict[str, Any]] = {
"sketch": {
"label": "Sketch to Image",
"description": "Transform sketches into refined images with precise control.",
"provider": "stability",
"fields": {
"control_image": True,
"style_image": False,
"control_strength": True,
"fidelity": False,
"style_strength": False,
"aspect_ratio": False,
},
},
"structure": {
"label": "Structure Control",
"description": "Generate images maintaining the structure of an input image.",
"provider": "stability",
"fields": {
"control_image": True,
"style_image": False,
"control_strength": True,
"fidelity": False,
"style_strength": False,
"aspect_ratio": False,
},
},
"style": {
"label": "Style Control",
"description": "Generate images using style from a reference image.",
"provider": "stability",
"fields": {
"control_image": True,
"style_image": False,
"control_strength": False,
"fidelity": True,
"style_strength": False,
"aspect_ratio": True,
},
},
"style_transfer": {
"label": "Style Transfer",
"description": "Apply visual characteristics from a style image to a target image.",
"provider": "stability",
"fields": {
"control_image": True, # init_image
"style_image": True,
"control_strength": False,
"fidelity": False,
"style_strength": True,
"aspect_ratio": False,
},
},
}
def __init__(self):
logger.info("[Control Studio] Initialized control service")
@staticmethod
def _decode_base64_image(value: Optional[str]) -> Optional[bytes]:
"""Decode a base64 (or data URL) string to bytes."""
if not value:
return None
try:
# Handle data URLs (data:image/png;base64,...)
if value.startswith("data:"):
_, b64data = value.split(",", 1)
else:
b64data = value
return base64.b64decode(b64data)
except Exception as exc:
logger.error(f"[Control Studio] Failed to decode base64 image: {exc}")
raise ValueError("Invalid base64 image payload") from exc
@staticmethod
def _image_bytes_to_metadata(image_bytes: bytes) -> Dict[str, Any]:
"""Extract width/height metadata from image bytes."""
with Image.open(io.BytesIO(image_bytes)) as img:
return {
"width": img.width,
"height": img.height,
}
@staticmethod
def _bytes_to_base64(image_bytes: bytes, output_format: str = "png") -> str:
"""Convert raw bytes to base64 data URL."""
b64 = base64.b64encode(image_bytes).decode("utf-8")
return f"data:image/{output_format};base64,{b64}"
def list_operations(self) -> Dict[str, Dict[str, Any]]:
"""Expose supported operations for UI rendering."""
return self.SUPPORTED_OPERATIONS
async def process_control(
self,
request: ControlStudioRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Process control request and return normalized response."""
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_control_operations
from fastapi import HTTPException
db = next(get_db())
try:
pricing_service = PricingService(db)
logger.info(f"[Control Studio] 🛂 Running pre-flight validation for user {user_id}")
validate_image_control_operations(
pricing_service=pricing_service,
user_id=user_id,
num_images=1,
)
logger.info("[Control Studio] ✅ Pre-flight validation passed")
except HTTPException:
logger.error("[Control Studio] ❌ Pre-flight validation failed")
raise
finally:
db.close()
else:
logger.warning("[Control Studio] ⚠️ No user_id provided - skipping pre-flight validation")
control_image_bytes = self._decode_base64_image(request.control_image_base64)
if not control_image_bytes:
raise ValueError("Control image payload is required")
style_image_bytes = self._decode_base64_image(request.style_image_base64)
operation = request.operation
logger.info("[Control Studio] Processing operation='%s' for user=%s", operation, user_id)
if operation not in self.SUPPORTED_OPERATIONS:
raise ValueError(f"Unsupported control operation: {operation}")
stability_service = StabilityAIService()
async with stability_service:
if operation == "sketch":
result = await stability_service.control_sketch(
image=control_image_bytes,
prompt=request.prompt,
control_strength=request.control_strength or 0.7,
negative_prompt=request.negative_prompt,
seed=request.seed,
output_format=request.output_format,
style_preset=request.style_preset,
)
elif operation == "structure":
result = await stability_service.control_structure(
image=control_image_bytes,
prompt=request.prompt,
control_strength=request.control_strength or 0.7,
negative_prompt=request.negative_prompt,
seed=request.seed,
output_format=request.output_format,
style_preset=request.style_preset,
)
elif operation == "style":
result = await stability_service.control_style(
image=control_image_bytes,
prompt=request.prompt,
negative_prompt=request.negative_prompt,
aspect_ratio=request.aspect_ratio or "1:1",
fidelity=request.fidelity or 0.5,
seed=request.seed,
output_format=request.output_format,
style_preset=request.style_preset,
)
elif operation == "style_transfer":
if not style_image_bytes:
raise ValueError("Style image is required for style transfer")
result = await stability_service.control_style_transfer(
init_image=control_image_bytes,
style_image=style_image_bytes,
prompt=request.prompt or "",
negative_prompt=request.negative_prompt,
style_strength=request.style_strength or 1.0,
composition_fidelity=request.composition_fidelity or 0.9,
change_strength=request.change_strength or 0.9,
seed=request.seed,
output_format=request.output_format,
)
else:
raise ValueError(f"Unsupported control operation: {operation}")
image_bytes = self._extract_image_bytes(result)
metadata = self._image_bytes_to_metadata(image_bytes)
metadata.update(
{
"operation": operation,
"style_preset": request.style_preset,
"provider": self.SUPPORTED_OPERATIONS[operation]["provider"],
}
)
response = {
"success": True,
"operation": operation,
"provider": metadata["provider"],
"image_base64": self._bytes_to_base64(image_bytes, request.output_format),
"width": metadata["width"],
"height": metadata["height"],
"metadata": metadata,
}
logger.info("[Control Studio] ✅ Operation '%s' completed", operation)
return response
@staticmethod
def _extract_image_bytes(result: Any) -> bytes:
"""Normalize Stability responses into raw image bytes."""
if isinstance(result, bytes):
return result
if isinstance(result, dict):
artifacts = result.get("artifacts") or result.get("data") or result.get("images") or []
for artifact in artifacts:
if isinstance(artifact, dict):
if artifact.get("base64"):
return base64.b64decode(artifact["base64"])
if artifact.get("b64_json"):
return base64.b64decode(artifact["b64_json"])
raise RuntimeError("Unable to extract image bytes from provider response")