AI Image Studio Progress Review
- Added new router for content assets - Added new service for content assets - Added new model for content assets - Added new utils for content assets - Added new docs for content assets - Added new tests for content assets - Added new examples for content assets - Added new guides for content assets
This commit is contained in:
277
backend/services/image_studio/control_service.py
Normal file
277
backend/services/image_studio/control_service.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""Control Studio service for AI-powered controlled image generation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from services.stability_service import StabilityAIService
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
logger = get_service_logger("image_studio.control")
|
||||
|
||||
|
||||
ControlOperationType = Literal[
|
||||
"sketch",
|
||||
"structure",
|
||||
"style",
|
||||
"style_transfer",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlStudioRequest:
|
||||
"""Normalized request payload for Control Studio operations."""
|
||||
|
||||
operation: ControlOperationType
|
||||
prompt: str
|
||||
control_image_base64: str # Sketch, structure, or style reference
|
||||
style_image_base64: Optional[str] = None # For style_transfer only
|
||||
negative_prompt: Optional[str] = None
|
||||
control_strength: Optional[float] = None # For sketch/structure
|
||||
fidelity: Optional[float] = None # For style
|
||||
style_strength: Optional[float] = None # For style_transfer
|
||||
composition_fidelity: Optional[float] = None # For style_transfer
|
||||
change_strength: Optional[float] = None # For style_transfer
|
||||
aspect_ratio: Optional[str] = None # For style
|
||||
style_preset: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
output_format: str = "png"
|
||||
|
||||
|
||||
class ControlStudioService:
|
||||
"""Service layer orchestrating Control Studio operations."""
|
||||
|
||||
SUPPORTED_OPERATIONS: Dict[ControlOperationType, Dict[str, Any]] = {
|
||||
"sketch": {
|
||||
"label": "Sketch to Image",
|
||||
"description": "Transform sketches into refined images with precise control.",
|
||||
"provider": "stability",
|
||||
"fields": {
|
||||
"control_image": True,
|
||||
"style_image": False,
|
||||
"control_strength": True,
|
||||
"fidelity": False,
|
||||
"style_strength": False,
|
||||
"aspect_ratio": False,
|
||||
},
|
||||
},
|
||||
"structure": {
|
||||
"label": "Structure Control",
|
||||
"description": "Generate images maintaining the structure of an input image.",
|
||||
"provider": "stability",
|
||||
"fields": {
|
||||
"control_image": True,
|
||||
"style_image": False,
|
||||
"control_strength": True,
|
||||
"fidelity": False,
|
||||
"style_strength": False,
|
||||
"aspect_ratio": False,
|
||||
},
|
||||
},
|
||||
"style": {
|
||||
"label": "Style Control",
|
||||
"description": "Generate images using style from a reference image.",
|
||||
"provider": "stability",
|
||||
"fields": {
|
||||
"control_image": True,
|
||||
"style_image": False,
|
||||
"control_strength": False,
|
||||
"fidelity": True,
|
||||
"style_strength": False,
|
||||
"aspect_ratio": True,
|
||||
},
|
||||
},
|
||||
"style_transfer": {
|
||||
"label": "Style Transfer",
|
||||
"description": "Apply visual characteristics from a style image to a target image.",
|
||||
"provider": "stability",
|
||||
"fields": {
|
||||
"control_image": True, # init_image
|
||||
"style_image": True,
|
||||
"control_strength": False,
|
||||
"fidelity": False,
|
||||
"style_strength": True,
|
||||
"aspect_ratio": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
logger.info("[Control Studio] Initialized control service")
|
||||
|
||||
@staticmethod
|
||||
def _decode_base64_image(value: Optional[str]) -> Optional[bytes]:
|
||||
"""Decode a base64 (or data URL) string to bytes."""
|
||||
if not value:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Handle data URLs (data:image/png;base64,...)
|
||||
if value.startswith("data:"):
|
||||
_, b64data = value.split(",", 1)
|
||||
else:
|
||||
b64data = value
|
||||
|
||||
return base64.b64decode(b64data)
|
||||
except Exception as exc:
|
||||
logger.error(f"[Control Studio] Failed to decode base64 image: {exc}")
|
||||
raise ValueError("Invalid base64 image payload") from exc
|
||||
|
||||
@staticmethod
|
||||
def _image_bytes_to_metadata(image_bytes: bytes) -> Dict[str, Any]:
|
||||
"""Extract width/height metadata from image bytes."""
|
||||
with Image.open(io.BytesIO(image_bytes)) as img:
|
||||
return {
|
||||
"width": img.width,
|
||||
"height": img.height,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _bytes_to_base64(image_bytes: bytes, output_format: str = "png") -> str:
|
||||
"""Convert raw bytes to base64 data URL."""
|
||||
b64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
return f"data:image/{output_format};base64,{b64}"
|
||||
|
||||
def list_operations(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Expose supported operations for UI rendering."""
|
||||
return self.SUPPORTED_OPERATIONS
|
||||
|
||||
async def process_control(
|
||||
self,
|
||||
request: ControlStudioRequest,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Process control request and return normalized response."""
|
||||
|
||||
if user_id:
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_control_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
logger.info(f"[Control Studio] 🛂 Running pre-flight validation for user {user_id}")
|
||||
validate_image_control_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id,
|
||||
num_images=1,
|
||||
)
|
||||
logger.info("[Control Studio] ✅ Pre-flight validation passed")
|
||||
except HTTPException:
|
||||
logger.error("[Control Studio] ❌ Pre-flight validation failed")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
logger.warning("[Control Studio] ⚠️ No user_id provided - skipping pre-flight validation")
|
||||
|
||||
control_image_bytes = self._decode_base64_image(request.control_image_base64)
|
||||
if not control_image_bytes:
|
||||
raise ValueError("Control image payload is required")
|
||||
|
||||
style_image_bytes = self._decode_base64_image(request.style_image_base64)
|
||||
|
||||
operation = request.operation
|
||||
logger.info("[Control Studio] Processing operation='%s' for user=%s", operation, user_id)
|
||||
|
||||
if operation not in self.SUPPORTED_OPERATIONS:
|
||||
raise ValueError(f"Unsupported control operation: {operation}")
|
||||
|
||||
stability_service = StabilityAIService()
|
||||
async with stability_service:
|
||||
if operation == "sketch":
|
||||
result = await stability_service.control_sketch(
|
||||
image=control_image_bytes,
|
||||
prompt=request.prompt,
|
||||
control_strength=request.control_strength or 0.7,
|
||||
negative_prompt=request.negative_prompt,
|
||||
seed=request.seed,
|
||||
output_format=request.output_format,
|
||||
style_preset=request.style_preset,
|
||||
)
|
||||
elif operation == "structure":
|
||||
result = await stability_service.control_structure(
|
||||
image=control_image_bytes,
|
||||
prompt=request.prompt,
|
||||
control_strength=request.control_strength or 0.7,
|
||||
negative_prompt=request.negative_prompt,
|
||||
seed=request.seed,
|
||||
output_format=request.output_format,
|
||||
style_preset=request.style_preset,
|
||||
)
|
||||
elif operation == "style":
|
||||
result = await stability_service.control_style(
|
||||
image=control_image_bytes,
|
||||
prompt=request.prompt,
|
||||
negative_prompt=request.negative_prompt,
|
||||
aspect_ratio=request.aspect_ratio or "1:1",
|
||||
fidelity=request.fidelity or 0.5,
|
||||
seed=request.seed,
|
||||
output_format=request.output_format,
|
||||
style_preset=request.style_preset,
|
||||
)
|
||||
elif operation == "style_transfer":
|
||||
if not style_image_bytes:
|
||||
raise ValueError("Style image is required for style transfer")
|
||||
result = await stability_service.control_style_transfer(
|
||||
init_image=control_image_bytes,
|
||||
style_image=style_image_bytes,
|
||||
prompt=request.prompt or "",
|
||||
negative_prompt=request.negative_prompt,
|
||||
style_strength=request.style_strength or 1.0,
|
||||
composition_fidelity=request.composition_fidelity or 0.9,
|
||||
change_strength=request.change_strength or 0.9,
|
||||
seed=request.seed,
|
||||
output_format=request.output_format,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported control operation: {operation}")
|
||||
|
||||
image_bytes = self._extract_image_bytes(result)
|
||||
metadata = self._image_bytes_to_metadata(image_bytes)
|
||||
metadata.update(
|
||||
{
|
||||
"operation": operation,
|
||||
"style_preset": request.style_preset,
|
||||
"provider": self.SUPPORTED_OPERATIONS[operation]["provider"],
|
||||
}
|
||||
)
|
||||
|
||||
response = {
|
||||
"success": True,
|
||||
"operation": operation,
|
||||
"provider": metadata["provider"],
|
||||
"image_base64": self._bytes_to_base64(image_bytes, request.output_format),
|
||||
"width": metadata["width"],
|
||||
"height": metadata["height"],
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
logger.info("[Control Studio] ✅ Operation '%s' completed", operation)
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def _extract_image_bytes(result: Any) -> bytes:
|
||||
"""Normalize Stability responses into raw image bytes."""
|
||||
if isinstance(result, bytes):
|
||||
return result
|
||||
|
||||
if isinstance(result, dict):
|
||||
artifacts = result.get("artifacts") or result.get("data") or result.get("images") or []
|
||||
for artifact in artifacts:
|
||||
if isinstance(artifact, dict):
|
||||
if artifact.get("base64"):
|
||||
return base64.b64decode(artifact["base64"])
|
||||
if artifact.get("b64_json"):
|
||||
return base64.b64decode(artifact["b64_json"])
|
||||
|
||||
raise RuntimeError("Unable to extract image bytes from provider response")
|
||||
|
||||
Reference in New Issue
Block a user