"""Stability AI service for handling API interactions.""" import aiohttp import asyncio from typing import Dict, Any, Optional, Union, Tuple, List import os from loguru import logger import json import base64 from fastapi import HTTPException, UploadFile class StabilityAIService: """Service class for interacting with Stability AI API.""" def __init__(self, api_key: Optional[str] = None): """Initialize the Stability AI service. Args: api_key: Stability AI API key. If not provided, will try to get from environment. """ self.api_key = api_key or os.getenv("STABILITY_API_KEY") if not self.api_key: raise ValueError("Stability AI API key is required. Set STABILITY_API_KEY environment variable or pass api_key parameter.") self.base_url = "https://api.stability.ai" self.session: Optional[aiohttp.ClientSession] = None async def __aenter__(self): """Async context manager entry.""" self.session = aiohttp.ClientSession() return self async def __aexit__(self, exc_type, exc_val, exc_tb): """Async context manager exit.""" if self.session: await self.session.close() def _get_headers(self, accept_type: str = "image/*") -> Dict[str, str]: """Get common headers for API requests. Args: accept_type: Accept header value Returns: Headers dictionary """ return { "Authorization": f"Bearer {self.api_key}", "Accept": accept_type, "User-Agent": "ALwrity-Backend/1.0" } async def _make_request( self, method: str, endpoint: str, data: Optional[Dict[str, Any]] = None, files: Optional[Dict[str, Any]] = None, accept_type: str = "image/*", timeout: int = 300 ) -> Union[bytes, Dict[str, Any]]: """Make HTTP request to Stability AI API. Args: method: HTTP method endpoint: API endpoint data: Form data files: File data accept_type: Accept header value timeout: Request timeout in seconds Returns: Response data (bytes for images/audio, dict for JSON) """ if not self.session: self.session = aiohttp.ClientSession() url = f"{self.base_url}{endpoint}" headers = self._get_headers(accept_type) # Remove content-type header to let aiohttp set it automatically for multipart if files: headers.pop("Content-Type", None) try: # Prepare multipart data form_data = aiohttp.FormData() # Add files if files: for key, file_data in files.items(): if isinstance(file_data, UploadFile): content = await file_data.read() form_data.add_field(key, content, filename=file_data.filename or "file", content_type=file_data.content_type) elif isinstance(file_data, bytes): form_data.add_field(key, file_data, filename="file") else: form_data.add_field(key, file_data) # Add form data if data: for key, value in data.items(): if value is not None: form_data.add_field(key, str(value)) timeout_config = aiohttp.ClientTimeout(total=timeout) async with self.session.request( method=method, url=url, headers=headers, data=form_data, timeout=timeout_config ) as response: # Handle different response types content_type = response.headers.get('Content-Type', '') if response.status == 200: if 'application/json' in content_type: return await response.json() else: return await response.read() elif response.status == 202: # Async generation started return await response.json() else: # Error response try: error_data = await response.json() logger.error(f"Stability AI API error: {error_data}") raise HTTPException( status_code=response.status, detail=error_data ) except: error_text = await response.text() logger.error(f"Stability AI API error: {error_text}") raise HTTPException( status_code=response.status, detail={"error": error_text} ) except asyncio.TimeoutError: logger.error(f"Timeout error for {endpoint}") raise HTTPException(status_code=504, detail="Request timeout") except Exception as e: logger.error(f"Request error for {endpoint}: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) async def _prepare_image_file(self, image: Union[UploadFile, bytes, str]) -> bytes: """Prepare image file for upload. Args: image: Image data in various formats Returns: Image bytes """ if isinstance(image, UploadFile): return await image.read() elif isinstance(image, bytes): return image elif isinstance(image, str): # Assume base64 encoded return base64.b64decode(image) else: raise ValueError("Unsupported image format") async def _prepare_audio_file(self, audio: Union[UploadFile, bytes, str]) -> bytes: """Prepare audio file for upload. Args: audio: Audio data in various formats Returns: Audio bytes """ if isinstance(audio, UploadFile): return await audio.read() elif isinstance(audio, bytes): return audio elif isinstance(audio, str): # Assume base64 encoded return base64.b64decode(audio) else: raise ValueError("Unsupported audio format") def _validate_image_requirements(self, width: int, height: int, min_pixels: int = 4096, max_pixels: int = 9437184): """Validate image dimension requirements. Args: width: Image width height: Image height min_pixels: Minimum pixel count max_pixels: Maximum pixel count """ total_pixels = width * height if total_pixels < min_pixels: raise ValueError(f"Image must have at least {min_pixels} pixels") if total_pixels > max_pixels: raise ValueError(f"Image must have at most {max_pixels} pixels") if width < 64 or height < 64: raise ValueError("Image dimensions must be at least 64x64 pixels") def _validate_aspect_ratio(self, width: int, height: int, min_ratio: float = 0.4, max_ratio: float = 2.5): """Validate image aspect ratio. Args: width: Image width height: Image height min_ratio: Minimum aspect ratio (1:2.5) max_ratio: Maximum aspect ratio (2.5:1) """ aspect_ratio = width / height if aspect_ratio < min_ratio or aspect_ratio > max_ratio: raise ValueError(f"Aspect ratio must be between {min_ratio}:1 and {max_ratio}:1") # ==================== GENERATE METHODS ==================== async def generate_ultra( self, prompt: str, image: Optional[Union[UploadFile, bytes]] = None, **kwargs ) -> Union[bytes, Dict[str, Any]]: """Generate image using Stable Image Ultra. Args: prompt: Text prompt for generation image: Optional input image for image-to-image **kwargs: Additional parameters Returns: Generated image bytes or JSON response """ data = {"prompt": prompt} data.update({k: v for k, v in kwargs.items() if v is not None}) files = {} if image: files["image"] = await self._prepare_image_file(image) return await self._make_request( method="POST", endpoint="/v2beta/stable-image/generate/ultra", data=data, files=files if files else None ) async def generate_core( self, prompt: str, **kwargs ) -> Union[bytes, Dict[str, Any]]: """Generate image using Stable Image Core. Args: prompt: Text prompt for generation **kwargs: Additional parameters Returns: Generated image bytes or JSON response """ data = {"prompt": prompt} data.update({k: v for k, v in kwargs.items() if v is not None}) return await self._make_request( method="POST", endpoint="/v2beta/stable-image/generate/core", data=data ) async def generate_sd3( self, prompt: str, image: Optional[Union[UploadFile, bytes]] = None, **kwargs ) -> Union[bytes, Dict[str, Any]]: """Generate image using Stable Diffusion 3.5. Args: prompt: Text prompt for generation image: Optional input image for image-to-image **kwargs: Additional parameters Returns: Generated image bytes or JSON response """ data = {"prompt": prompt} data.update({k: v for k, v in kwargs.items() if v is not None}) files = {} if image: files["image"] = await self._prepare_image_file(image) return await self._make_request( method="POST", endpoint="/v2beta/stable-image/generate/sd3", data=data, files=files if files else None ) # ==================== EDIT METHODS ==================== async def erase( self, image: Union[UploadFile, bytes], mask: Optional[Union[UploadFile, bytes]] = None, **kwargs ) -> Union[bytes, Dict[str, Any]]: """Erase objects from image using mask. Args: image: Input image mask: Optional mask image **kwargs: Additional parameters Returns: Edited image bytes or JSON response """ data = {} data.update({k: v for k, v in kwargs.items() if v is not None}) files = {"image": await self._prepare_image_file(image)} if mask: files["mask"] = await self._prepare_image_file(mask) return await self._make_request( method="POST", endpoint="/v2beta/stable-image/edit/erase", data=data, files=files ) async def inpaint( self, image: Union[UploadFile, bytes], prompt: str, mask: Optional[Union[UploadFile, bytes]] = None, **kwargs ) -> Union[bytes, Dict[str, Any]]: """Inpaint image with new content. Args: image: Input image prompt: Text prompt for inpainting mask: Optional mask image **kwargs: Additional parameters Returns: Edited image bytes or JSON response """ data = {"prompt": prompt} data.update({k: v for k, v in kwargs.items() if v is not None}) files = {"image": await self._prepare_image_file(image)} if mask: files["mask"] = await self._prepare_image_file(mask) return await self._make_request( method="POST", endpoint="/v2beta/stable-image/edit/inpaint", data=data, files=files ) async def outpaint( self, image: Union[UploadFile, bytes], **kwargs ) -> Union[bytes, Dict[str, Any]]: """Outpaint image in specified directions. Args: image: Input image **kwargs: Additional parameters including left, right, up, down Returns: Edited image bytes or JSON response """ data = {} data.update({k: v for k, v in kwargs.items() if v is not None}) files = {"image": await self._prepare_image_file(image)} return await self._make_request( method="POST", endpoint="/v2beta/stable-image/edit/outpaint", data=data, files=files ) async def search_and_replace( self, image: Union[UploadFile, bytes], prompt: str, search_prompt: str, mask: Optional[Union[UploadFile, bytes]] = None, **kwargs ) -> Union[bytes, Dict[str, Any]]: """Replace objects in image using search prompt. Args: image: Input image prompt: Text prompt for replacement search_prompt: What to search for mask: Optional mask image for precise region selection **kwargs: Additional parameters Returns: Edited image bytes or JSON response """ data = {"prompt": prompt, "search_prompt": search_prompt} data.update({k: v for k, v in kwargs.items() if v is not None}) files = {"image": await self._prepare_image_file(image)} if mask: files["mask"] = await self._prepare_image_file(mask) return await self._make_request( method="POST", endpoint="/v2beta/stable-image/edit/search-and-replace", data=data, files=files ) async def search_and_recolor( self, image: Union[UploadFile, bytes], prompt: str, select_prompt: str, mask: Optional[Union[UploadFile, bytes]] = None, **kwargs ) -> Union[bytes, Dict[str, Any]]: """Recolor objects in image using select prompt. Args: image: Input image prompt: Text prompt for recoloring select_prompt: What to select for recoloring mask: Optional mask image for precise region selection **kwargs: Additional parameters Returns: Edited image bytes or JSON response """ data = {"prompt": prompt, "select_prompt": select_prompt} data.update({k: v for k, v in kwargs.items() if v is not None}) files = {"image": await self._prepare_image_file(image)} if mask: files["mask"] = await self._prepare_image_file(mask) return await self._make_request( method="POST", endpoint="/v2beta/stable-image/edit/search-and-recolor", data=data, files=files ) async def remove_background( self, image: Union[UploadFile, bytes], **kwargs ) -> Union[bytes, Dict[str, Any]]: """Remove background from image. Args: image: Input image **kwargs: Additional parameters Returns: Edited image bytes or JSON response """ data = {} data.update({k: v for k, v in kwargs.items() if v is not None}) files = {"image": await self._prepare_image_file(image)} return await self._make_request( method="POST", endpoint="/v2beta/stable-image/edit/remove-background", data=data, files=files ) async def replace_background_and_relight( self, subject_image: Union[UploadFile, bytes], background_reference: Optional[Union[UploadFile, bytes]] = None, light_reference: Optional[Union[UploadFile, bytes]] = None, **kwargs ) -> Dict[str, Any]: """Replace background and relight image (async). Args: subject_image: Subject image background_reference: Optional background reference image light_reference: Optional light reference image **kwargs: Additional parameters Returns: Generation ID for async polling """ data = {} data.update({k: v for k, v in kwargs.items() if v is not None}) files = {"subject_image": await self._prepare_image_file(subject_image)} if background_reference: files["background_reference"] = await self._prepare_image_file(background_reference) if light_reference: files["light_reference"] = await self._prepare_image_file(light_reference) return await self._make_request( method="POST", endpoint="/v2beta/stable-image/edit/replace-background-and-relight", data=data, files=files, accept_type="application/json" ) # ==================== UPSCALE METHODS ==================== async def upscale_fast( self, image: Union[UploadFile, bytes], **kwargs ) -> Union[bytes, Dict[str, Any]]: """Fast upscale image by 4x. Args: image: Input image **kwargs: Additional parameters Returns: Upscaled image bytes or JSON response """ data = {} data.update({k: v for k, v in kwargs.items() if v is not None}) files = {"image": await self._prepare_image_file(image)} return await self._make_request( method="POST", endpoint="/v2beta/stable-image/upscale/fast", data=data, files=files ) async def upscale_conservative( self, image: Union[UploadFile, bytes], prompt: str, **kwargs ) -> Union[bytes, Dict[str, Any]]: """Conservative upscale to 4K resolution. Args: image: Input image prompt: Text prompt for upscaling **kwargs: Additional parameters Returns: Upscaled image bytes or JSON response """ data = {"prompt": prompt} data.update({k: v for k, v in kwargs.items() if v is not None}) files = {"image": await self._prepare_image_file(image)} return await self._make_request( method="POST", endpoint="/v2beta/stable-image/upscale/conservative", data=data, files=files ) async def upscale_creative( self, image: Union[UploadFile, bytes], prompt: str, **kwargs ) -> Dict[str, Any]: """Creative upscale to 4K resolution (async). Args: image: Input image prompt: Text prompt for upscaling **kwargs: Additional parameters Returns: Generation ID for async polling """ data = {"prompt": prompt} data.update({k: v for k, v in kwargs.items() if v is not None}) files = {"image": await self._prepare_image_file(image)} return await self._make_request( method="POST", endpoint="/v2beta/stable-image/upscale/creative", data=data, files=files, accept_type="application/json" ) # ==================== CONTROL METHODS ==================== async def control_sketch( self, image: Union[UploadFile, bytes], prompt: str, **kwargs ) -> Union[bytes, Dict[str, Any]]: """Generate image from sketch with prompt. Args: image: Input sketch image prompt: Text prompt for generation **kwargs: Additional parameters Returns: Generated image bytes or JSON response """ data = {"prompt": prompt} data.update({k: v for k, v in kwargs.items() if v is not None}) files = {"image": await self._prepare_image_file(image)} return await self._make_request( method="POST", endpoint="/v2beta/stable-image/control/sketch", data=data, files=files ) async def control_structure( self, image: Union[UploadFile, bytes], prompt: str, **kwargs ) -> Union[bytes, Dict[str, Any]]: """Generate image maintaining structure of input. Args: image: Input structure image prompt: Text prompt for generation **kwargs: Additional parameters Returns: Generated image bytes or JSON response """ data = {"prompt": prompt} data.update({k: v for k, v in kwargs.items() if v is not None}) files = {"image": await self._prepare_image_file(image)} return await self._make_request( method="POST", endpoint="/v2beta/stable-image/control/structure", data=data, files=files ) async def control_style( self, image: Union[UploadFile, bytes], prompt: str, **kwargs ) -> Union[bytes, Dict[str, Any]]: """Generate image using style from input image. Args: image: Input style image prompt: Text prompt for generation **kwargs: Additional parameters Returns: Generated image bytes or JSON response """ data = {"prompt": prompt} data.update({k: v for k, v in kwargs.items() if v is not None}) files = {"image": await self._prepare_image_file(image)} return await self._make_request( method="POST", endpoint="/v2beta/stable-image/control/style", data=data, files=files ) async def control_style_transfer( self, init_image: Union[UploadFile, bytes], style_image: Union[UploadFile, bytes], **kwargs ) -> Union[bytes, Dict[str, Any]]: """Transfer style between images. Args: init_image: Initial image style_image: Style reference image **kwargs: Additional parameters Returns: Generated image bytes or JSON response """ data = {} data.update({k: v for k, v in kwargs.items() if v is not None}) files = { "init_image": await self._prepare_image_file(init_image), "style_image": await self._prepare_image_file(style_image) } return await self._make_request( method="POST", endpoint="/v2beta/stable-image/control/style-transfer", data=data, files=files ) # ==================== 3D METHODS ==================== async def generate_3d_fast( self, image: Union[UploadFile, bytes], **kwargs ) -> bytes: """Generate 3D model using Stable Fast 3D. Args: image: Input image **kwargs: Additional parameters Returns: 3D model binary data (GLB format) """ data = {} data.update({k: v for k, v in kwargs.items() if v is not None}) files = {"image": await self._prepare_image_file(image)} return await self._make_request( method="POST", endpoint="/v2beta/3d/stable-fast-3d", data=data, files=files, accept_type="model/gltf-binary" ) async def generate_3d_point_aware( self, image: Union[UploadFile, bytes], **kwargs ) -> bytes: """Generate 3D model using Stable Point Aware 3D. Args: image: Input image **kwargs: Additional parameters Returns: 3D model binary data (GLB format) """ data = {} data.update({k: v for k, v in kwargs.items() if v is not None}) files = {"image": await self._prepare_image_file(image)} return await self._make_request( method="POST", endpoint="/v2beta/3d/stable-point-aware-3d", data=data, files=files, accept_type="model/gltf-binary" ) # ==================== AUDIO METHODS ==================== async def generate_audio_from_text( self, prompt: str, **kwargs ) -> Union[bytes, Dict[str, Any]]: """Generate audio from text prompt. Args: prompt: Text prompt for audio generation **kwargs: Additional parameters Returns: Generated audio bytes or JSON response """ data = {"prompt": prompt} data.update({k: v for k, v in kwargs.items() if v is not None}) # Use empty files dict to trigger multipart form files = {"none": ""} return await self._make_request( method="POST", endpoint="/v2beta/audio/stable-audio-2/text-to-audio", data=data, files=files, accept_type="audio/*" ) async def generate_audio_from_audio( self, prompt: str, audio: Union[UploadFile, bytes], **kwargs ) -> Union[bytes, Dict[str, Any]]: """Generate audio from audio input. Args: prompt: Text prompt for audio generation audio: Input audio **kwargs: Additional parameters Returns: Generated audio bytes or JSON response """ data = {"prompt": prompt} data.update({k: v for k, v in kwargs.items() if v is not None}) files = {"audio": await self._prepare_audio_file(audio)} return await self._make_request( method="POST", endpoint="/v2beta/audio/stable-audio-2/audio-to-audio", data=data, files=files, accept_type="audio/*" ) async def inpaint_audio( self, prompt: str, audio: Union[UploadFile, bytes], **kwargs ) -> Union[bytes, Dict[str, Any]]: """Inpaint audio with new content. Args: prompt: Text prompt for audio inpainting audio: Input audio **kwargs: Additional parameters Returns: Generated audio bytes or JSON response """ data = {"prompt": prompt} data.update({k: v for k, v in kwargs.items() if v is not None}) files = {"audio": await self._prepare_audio_file(audio)} return await self._make_request( method="POST", endpoint="/v2beta/audio/stable-audio-2/inpaint", data=data, files=files, accept_type="audio/*" ) # ==================== RESULTS METHODS ==================== async def get_generation_result( self, generation_id: str, accept_type: str = "*/*" ) -> Union[bytes, Dict[str, Any]]: """Get result of async generation. Args: generation_id: Generation ID from async operation accept_type: Accept header value Returns: Generation result (bytes or JSON) """ return await self._make_request( method="GET", endpoint=f"/v2beta/results/{generation_id}", accept_type=accept_type ) # ==================== V1 LEGACY METHODS ==================== async def v1_text_to_image( self, engine_id: str, text_prompts: List[Dict[str, Any]], **kwargs ) -> Dict[str, Any]: """V1 text-to-image generation. Args: engine_id: Engine ID text_prompts: Text prompts list **kwargs: Additional parameters Returns: Generation response with artifacts """ data = {"text_prompts": text_prompts} data.update({k: v for k, v in kwargs.items() if v is not None}) headers = self._get_headers("application/json") headers["Content-Type"] = "application/json" async with self.session.post( f"{self.base_url}/v1/generation/{engine_id}/text-to-image", headers=headers, json=data ) as response: if response.status == 200: return await response.json() else: error_data = await response.json() raise HTTPException(status_code=response.status, detail=error_data) async def v1_image_to_image( self, engine_id: str, init_image: Union[UploadFile, bytes], text_prompts: List[Dict[str, Any]], **kwargs ) -> Dict[str, Any]: """V1 image-to-image generation. Args: engine_id: Engine ID init_image: Initial image text_prompts: Text prompts list **kwargs: Additional parameters Returns: Generation response with artifacts """ data = {} data.update({k: v for k, v in kwargs.items() if v is not None}) # Add text prompts to form data for i, prompt in enumerate(text_prompts): data[f"text_prompts[{i}][text]"] = prompt["text"] if "weight" in prompt: data[f"text_prompts[{i}][weight]"] = prompt["weight"] files = {"init_image": await self._prepare_image_file(init_image)} return await self._make_request( method="POST", endpoint=f"/v1/generation/{engine_id}/image-to-image", data=data, files=files, accept_type="application/json" ) async def v1_masking( self, engine_id: str, init_image: Union[UploadFile, bytes], mask_image: Optional[Union[UploadFile, bytes]], text_prompts: List[Dict[str, Any]], mask_source: str, **kwargs ) -> Dict[str, Any]: """V1 image masking generation. Args: engine_id: Engine ID init_image: Initial image mask_image: Optional mask image text_prompts: Text prompts list mask_source: Mask source type **kwargs: Additional parameters Returns: Generation response with artifacts """ data = {"mask_source": mask_source} data.update({k: v for k, v in kwargs.items() if v is not None}) # Add text prompts to form data for i, prompt in enumerate(text_prompts): data[f"text_prompts[{i}][text]"] = prompt["text"] if "weight" in prompt: data[f"text_prompts[{i}][weight]"] = prompt["weight"] files = {"init_image": await self._prepare_image_file(init_image)} if mask_image: files["mask_image"] = await self._prepare_image_file(mask_image) return await self._make_request( method="POST", endpoint=f"/v1/generation/{engine_id}/image-to-image/masking", data=data, files=files, accept_type="application/json" ) # ==================== USER & ACCOUNT METHODS ==================== async def get_account_details(self) -> Dict[str, Any]: """Get account details. Returns: Account information """ headers = self._get_headers("application/json") async with self.session.get( f"{self.base_url}/v1/user/account", headers=headers ) as response: if response.status == 200: return await response.json() else: error_data = await response.json() raise HTTPException(status_code=response.status, detail=error_data) async def get_account_balance(self) -> Dict[str, Any]: """Get account balance. Returns: Account balance information """ headers = self._get_headers("application/json") async with self.session.get( f"{self.base_url}/v1/user/balance", headers=headers ) as response: if response.status == 200: return await response.json() else: error_data = await response.json() raise HTTPException(status_code=response.status, detail=error_data) async def list_engines(self) -> Dict[str, Any]: """List available engines. Returns: List of available engines """ headers = self._get_headers("application/json") async with self.session.get( f"{self.base_url}/v1/engines/list", headers=headers ) as response: if response.status == 200: return await response.json() else: error_data = await response.json() raise HTTPException(status_code=response.status, detail=error_data) # Global service instance stability_service = None async def get_stability_service() -> StabilityAIService: """Get or create Stability AI service instance. Returns: Stability AI service instance """ global stability_service if stability_service is None: stability_service = StabilityAIService() return stability_service