AI Image Studio Phase 1
This commit is contained in:
458
backend/services/image_studio/edit_service.py
Normal file
458
backend/services/image_studio/edit_service.py
Normal file
@@ -0,0 +1,458 @@
|
||||
"""Edit Studio service for AI-powered image editing and transformations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from services.llm_providers.main_image_editing import edit_image as huggingface_edit_image
|
||||
from services.stability_service import StabilityAIService
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
logger = get_service_logger("image_studio.edit")
|
||||
|
||||
|
||||
EditOperationType = Literal[
|
||||
"remove_background",
|
||||
"inpaint",
|
||||
"outpaint",
|
||||
"search_replace",
|
||||
"search_recolor",
|
||||
"relight",
|
||||
"general_edit",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class EditStudioRequest:
|
||||
"""Normalized request payload for Edit Studio operations."""
|
||||
|
||||
image_base64: str
|
||||
operation: EditOperationType
|
||||
prompt: Optional[str] = None
|
||||
negative_prompt: Optional[str] = None
|
||||
mask_base64: Optional[str] = None
|
||||
search_prompt: Optional[str] = None
|
||||
select_prompt: Optional[str] = None
|
||||
background_image_base64: Optional[str] = None
|
||||
lighting_image_base64: Optional[str] = None
|
||||
expand_left: Optional[int] = None
|
||||
expand_right: Optional[int] = None
|
||||
expand_up: Optional[int] = None
|
||||
expand_down: Optional[int] = None
|
||||
provider: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
style_preset: Optional[str] = None
|
||||
guidance_scale: Optional[float] = None
|
||||
steps: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
output_format: str = "png"
|
||||
options: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class EditStudioService:
|
||||
"""Service layer orchestrating Edit Studio operations."""
|
||||
|
||||
SUPPORTED_OPERATIONS: Dict[EditOperationType, Dict[str, Any]] = {
|
||||
"remove_background": {
|
||||
"label": "Remove Background",
|
||||
"description": "Isolate the main subject and remove the background.",
|
||||
"provider": "stability",
|
||||
"async": False,
|
||||
"fields": {
|
||||
"prompt": False,
|
||||
"mask": False,
|
||||
"negative_prompt": False,
|
||||
"search_prompt": False,
|
||||
"select_prompt": False,
|
||||
"background": False,
|
||||
"lighting": False,
|
||||
"expansion": False,
|
||||
},
|
||||
},
|
||||
"inpaint": {
|
||||
"label": "Inpaint & Fix",
|
||||
"description": "Edit specific regions using prompts and optional masks.",
|
||||
"provider": "stability",
|
||||
"async": False,
|
||||
"fields": {
|
||||
"prompt": True,
|
||||
"mask": True,
|
||||
"negative_prompt": True,
|
||||
"search_prompt": False,
|
||||
"select_prompt": False,
|
||||
"background": False,
|
||||
"lighting": False,
|
||||
"expansion": False,
|
||||
},
|
||||
},
|
||||
"outpaint": {
|
||||
"label": "Outpaint",
|
||||
"description": "Extend the canvas in any direction with smart fill.",
|
||||
"provider": "stability",
|
||||
"async": False,
|
||||
"fields": {
|
||||
"prompt": False,
|
||||
"mask": False,
|
||||
"negative_prompt": True,
|
||||
"search_prompt": False,
|
||||
"select_prompt": False,
|
||||
"background": False,
|
||||
"lighting": False,
|
||||
"expansion": True,
|
||||
},
|
||||
},
|
||||
"search_replace": {
|
||||
"label": "Search & Replace",
|
||||
"description": "Locate objects via search prompt and replace them.",
|
||||
"provider": "stability",
|
||||
"async": False,
|
||||
"fields": {
|
||||
"prompt": True,
|
||||
"mask": False,
|
||||
"negative_prompt": False,
|
||||
"search_prompt": True,
|
||||
"select_prompt": False,
|
||||
"background": False,
|
||||
"lighting": False,
|
||||
"expansion": False,
|
||||
},
|
||||
},
|
||||
"search_recolor": {
|
||||
"label": "Search & Recolor",
|
||||
"description": "Select elements via prompt and recolor them.",
|
||||
"provider": "stability",
|
||||
"async": False,
|
||||
"fields": {
|
||||
"prompt": True,
|
||||
"mask": False,
|
||||
"negative_prompt": False,
|
||||
"search_prompt": False,
|
||||
"select_prompt": True,
|
||||
"background": False,
|
||||
"lighting": False,
|
||||
"expansion": False,
|
||||
},
|
||||
},
|
||||
"relight": {
|
||||
"label": "Replace Background & Relight",
|
||||
"description": "Swap backgrounds and relight using reference images.",
|
||||
"provider": "stability",
|
||||
"async": True,
|
||||
"fields": {
|
||||
"prompt": False,
|
||||
"mask": False,
|
||||
"negative_prompt": False,
|
||||
"search_prompt": False,
|
||||
"select_prompt": False,
|
||||
"background": True,
|
||||
"lighting": True,
|
||||
"expansion": False,
|
||||
},
|
||||
},
|
||||
"general_edit": {
|
||||
"label": "Prompt-based Edit",
|
||||
"description": "Free-form editing powered by Hugging Face image-to-image models.",
|
||||
"provider": "huggingface",
|
||||
"async": False,
|
||||
"fields": {
|
||||
"prompt": True,
|
||||
"mask": False,
|
||||
"negative_prompt": True,
|
||||
"search_prompt": False,
|
||||
"select_prompt": False,
|
||||
"background": False,
|
||||
"lighting": False,
|
||||
"expansion": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
logger.info("[Edit Studio] Initialized edit service")
|
||||
|
||||
@staticmethod
|
||||
def _decode_base64_image(value: Optional[str]) -> Optional[bytes]:
|
||||
"""Decode a base64 (or data URL) string to bytes."""
|
||||
if not value:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Handle data URLs (data:image/png;base64,...)
|
||||
if value.startswith("data:"):
|
||||
_, b64data = value.split(",", 1)
|
||||
else:
|
||||
b64data = value
|
||||
|
||||
return base64.b64decode(b64data)
|
||||
except Exception as exc:
|
||||
logger.error(f"[Edit Studio] Failed to decode base64 image: {exc}")
|
||||
raise ValueError("Invalid base64 image payload") from exc
|
||||
|
||||
@staticmethod
|
||||
def _image_bytes_to_metadata(image_bytes: bytes) -> Dict[str, Any]:
|
||||
"""Extract width/height metadata from image bytes."""
|
||||
with Image.open(io.BytesIO(image_bytes)) as img:
|
||||
return {
|
||||
"width": img.width,
|
||||
"height": img.height,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _bytes_to_base64(image_bytes: bytes, output_format: str = "png") -> str:
|
||||
"""Convert raw bytes to base64 data URL."""
|
||||
b64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
return f"data:image/{output_format};base64,{b64}"
|
||||
|
||||
def list_operations(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Expose supported operations for UI rendering."""
|
||||
return self.SUPPORTED_OPERATIONS
|
||||
|
||||
async def process_edit(
|
||||
self,
|
||||
request: EditStudioRequest,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Process edit request and return normalized response."""
|
||||
|
||||
if user_id:
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_editing_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
logger.info(f"[Edit Studio] 🛂 Running pre-flight validation for user {user_id}")
|
||||
validate_image_editing_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id,
|
||||
)
|
||||
logger.info("[Edit Studio] ✅ Pre-flight validation passed")
|
||||
except HTTPException:
|
||||
logger.error("[Edit Studio] ❌ Pre-flight validation failed")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
logger.warning("[Edit Studio] ⚠️ No user_id provided - skipping pre-flight validation")
|
||||
|
||||
image_bytes = self._decode_base64_image(request.image_base64)
|
||||
if not image_bytes:
|
||||
raise ValueError("Primary image payload is required")
|
||||
|
||||
mask_bytes = self._decode_base64_image(request.mask_base64)
|
||||
background_bytes = self._decode_base64_image(request.background_image_base64)
|
||||
lighting_bytes = self._decode_base64_image(request.lighting_image_base64)
|
||||
|
||||
operation = request.operation
|
||||
logger.info("[Edit Studio] Processing operation='%s' for user=%s", operation, user_id)
|
||||
|
||||
if operation not in self.SUPPORTED_OPERATIONS:
|
||||
raise ValueError(f"Unsupported edit operation: {operation}")
|
||||
|
||||
if operation in {"remove_background", "inpaint", "outpaint", "search_replace", "search_recolor", "relight"}:
|
||||
image_bytes = await self._handle_stability_edit(
|
||||
operation=operation,
|
||||
request=request,
|
||||
image_bytes=image_bytes,
|
||||
mask_bytes=mask_bytes,
|
||||
background_bytes=background_bytes,
|
||||
lighting_bytes=lighting_bytes,
|
||||
)
|
||||
else:
|
||||
image_bytes = await self._handle_general_edit(
|
||||
request=request,
|
||||
image_bytes=image_bytes,
|
||||
mask_bytes=mask_bytes,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
metadata = self._image_bytes_to_metadata(image_bytes)
|
||||
metadata.update(
|
||||
{
|
||||
"operation": operation,
|
||||
"style_preset": request.style_preset,
|
||||
"provider": self.SUPPORTED_OPERATIONS[operation]["provider"],
|
||||
}
|
||||
)
|
||||
|
||||
response = {
|
||||
"success": True,
|
||||
"operation": operation,
|
||||
"provider": metadata["provider"],
|
||||
"image_base64": self._bytes_to_base64(image_bytes, request.output_format),
|
||||
"width": metadata["width"],
|
||||
"height": metadata["height"],
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
logger.info("[Edit Studio] ✅ Operation '%s' completed", operation)
|
||||
return response
|
||||
|
||||
async def _handle_stability_edit(
|
||||
self,
|
||||
operation: EditOperationType,
|
||||
request: EditStudioRequest,
|
||||
image_bytes: bytes,
|
||||
mask_bytes: Optional[bytes],
|
||||
background_bytes: Optional[bytes],
|
||||
lighting_bytes: Optional[bytes],
|
||||
) -> bytes:
|
||||
"""Execute Stability AI edit workflows."""
|
||||
stability_service = StabilityAIService()
|
||||
|
||||
async with stability_service:
|
||||
if operation == "remove_background":
|
||||
result = await stability_service.remove_background(
|
||||
image=image_bytes,
|
||||
output_format=request.output_format,
|
||||
)
|
||||
elif operation == "inpaint":
|
||||
if not request.prompt:
|
||||
raise ValueError("Prompt is required for inpainting")
|
||||
result = await stability_service.inpaint(
|
||||
image=image_bytes,
|
||||
prompt=request.prompt,
|
||||
mask=mask_bytes,
|
||||
negative_prompt=request.negative_prompt,
|
||||
output_format=request.output_format,
|
||||
style_preset=request.style_preset,
|
||||
grow_mask=request.options.get("grow_mask", 5),
|
||||
)
|
||||
elif operation == "outpaint":
|
||||
result = await stability_service.outpaint(
|
||||
image=image_bytes,
|
||||
prompt=request.prompt,
|
||||
negative_prompt=request.negative_prompt,
|
||||
output_format=request.output_format,
|
||||
left=request.expand_left or 0,
|
||||
right=request.expand_right or 0,
|
||||
up=request.expand_up or 0,
|
||||
down=request.expand_down or 0,
|
||||
style_preset=request.style_preset,
|
||||
)
|
||||
elif operation == "search_replace":
|
||||
if not (request.prompt and request.search_prompt):
|
||||
raise ValueError("Both prompt and search_prompt are required for search & replace")
|
||||
result = await stability_service.search_and_replace(
|
||||
image=image_bytes,
|
||||
prompt=request.prompt,
|
||||
search_prompt=request.search_prompt,
|
||||
output_format=request.output_format,
|
||||
)
|
||||
elif operation == "search_recolor":
|
||||
if not (request.prompt and request.select_prompt):
|
||||
raise ValueError("Both prompt and select_prompt are required for search & recolor")
|
||||
result = await stability_service.search_and_recolor(
|
||||
image=image_bytes,
|
||||
prompt=request.prompt,
|
||||
select_prompt=request.select_prompt,
|
||||
output_format=request.output_format,
|
||||
)
|
||||
elif operation == "relight":
|
||||
if not background_bytes and not lighting_bytes:
|
||||
raise ValueError("At least one reference (background or lighting) is required for relight")
|
||||
result = await stability_service.replace_background_and_relight(
|
||||
subject_image=image_bytes,
|
||||
background_reference=background_bytes,
|
||||
light_reference=lighting_bytes,
|
||||
output_format=request.output_format,
|
||||
)
|
||||
if isinstance(result, dict) and result.get("id"):
|
||||
result = await self._poll_stability_result(
|
||||
stability_service,
|
||||
generation_id=result["id"],
|
||||
output_format=request.output_format,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported Stability operation: {operation}")
|
||||
|
||||
return self._extract_image_bytes(result)
|
||||
|
||||
async def _handle_general_edit(
|
||||
self,
|
||||
request: EditStudioRequest,
|
||||
image_bytes: bytes,
|
||||
mask_bytes: Optional[bytes],
|
||||
user_id: Optional[str],
|
||||
) -> bytes:
|
||||
"""Execute Hugging Face powered general editing (synchronous API)."""
|
||||
if not request.prompt:
|
||||
raise ValueError("Prompt is required for general edits")
|
||||
|
||||
options = {
|
||||
"provider": request.provider or "huggingface",
|
||||
"model": request.model,
|
||||
"guidance_scale": request.guidance_scale,
|
||||
"steps": request.steps,
|
||||
"seed": request.seed,
|
||||
}
|
||||
|
||||
# huggingface edit is synchronous - run in thread
|
||||
result = await asyncio.to_thread(
|
||||
huggingface_edit_image,
|
||||
image_bytes,
|
||||
request.prompt,
|
||||
options,
|
||||
user_id,
|
||||
)
|
||||
|
||||
return result.image_bytes
|
||||
|
||||
@staticmethod
|
||||
def _extract_image_bytes(result: Any) -> bytes:
|
||||
"""Normalize Stability responses into raw image bytes."""
|
||||
if isinstance(result, bytes):
|
||||
return result
|
||||
|
||||
if isinstance(result, dict):
|
||||
artifacts = result.get("artifacts") or result.get("data") or result.get("images") or []
|
||||
for artifact in artifacts:
|
||||
if isinstance(artifact, dict):
|
||||
if artifact.get("base64"):
|
||||
return base64.b64decode(artifact["base64"])
|
||||
if artifact.get("b64_json"):
|
||||
return base64.b64decode(artifact["b64_json"])
|
||||
|
||||
raise RuntimeError("Unable to extract image bytes from provider response")
|
||||
|
||||
async def _poll_stability_result(
|
||||
self,
|
||||
stability_service: StabilityAIService,
|
||||
generation_id: str,
|
||||
output_format: str,
|
||||
timeout_seconds: int = 240,
|
||||
interval_seconds: float = 2.0,
|
||||
) -> bytes:
|
||||
"""Poll Stability async endpoint until result is ready."""
|
||||
elapsed = 0.0
|
||||
while elapsed < timeout_seconds:
|
||||
result = await stability_service.get_generation_result(
|
||||
generation_id=generation_id,
|
||||
accept_type="*/*",
|
||||
)
|
||||
|
||||
if isinstance(result, bytes):
|
||||
return result
|
||||
|
||||
if isinstance(result, dict):
|
||||
state = (result.get("state") or result.get("status") or "").lower()
|
||||
if state in {"succeeded", "success", "ready", "completed"}:
|
||||
return self._extract_image_bytes(result)
|
||||
if state in {"failed", "error"}:
|
||||
raise RuntimeError(f"Stability generation failed: {result}")
|
||||
|
||||
await asyncio.sleep(interval_seconds)
|
||||
elapsed += interval_seconds
|
||||
|
||||
raise RuntimeError("Timed out waiting for Stability generation result")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user