Files
moreminimore-marketing/backend/utils/stability_utils.py
Kunthawat Greethong c35fa52117 Base code
2026-01-08 22:39:53 +07:00

858 lines
26 KiB
Python

"""Utility functions for Stability AI operations."""
import base64
import io
import json
import mimetypes
import os
from typing import Dict, Any, Optional, List, Union, Tuple
from PIL import Image, ImageStat
import numpy as np
from fastapi import UploadFile, HTTPException
import aiofiles
import asyncio
from datetime import datetime
import hashlib
class ImageValidator:
"""Validator for image files and parameters."""
@staticmethod
def validate_image_file(file: UploadFile) -> Dict[str, Any]:
"""Validate uploaded image file.
Args:
file: Uploaded file
Returns:
Validation result with file info
"""
if not file.content_type or not file.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File must be an image")
# Check file extension
allowed_extensions = ['.jpg', '.jpeg', '.png', '.webp']
if file.filename:
ext = '.' + file.filename.split('.')[-1].lower()
if ext not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"Unsupported file format. Allowed: {allowed_extensions}"
)
return {
"filename": file.filename,
"content_type": file.content_type,
"is_valid": True
}
@staticmethod
async def analyze_image_content(content: bytes) -> Dict[str, Any]:
"""Analyze image content and characteristics.
Args:
content: Image bytes
Returns:
Image analysis results
"""
try:
img = Image.open(io.BytesIO(content))
# Basic info
info = {
"format": img.format,
"mode": img.mode,
"size": img.size,
"width": img.width,
"height": img.height,
"total_pixels": img.width * img.height,
"aspect_ratio": round(img.width / img.height, 3),
"file_size": len(content),
"has_alpha": img.mode in ("RGBA", "LA") or "transparency" in img.info
}
# Color analysis
if img.mode == "RGB" or img.mode == "RGBA":
img_rgb = img.convert("RGB")
stat = ImageStat.Stat(img_rgb)
info.update({
"brightness": round(sum(stat.mean) / 3, 2),
"color_variance": round(sum(stat.stddev) / 3, 2),
"dominant_colors": _extract_dominant_colors(img_rgb)
})
# Quality assessment
info["quality_assessment"] = _assess_image_quality(img)
return info
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error analyzing image: {str(e)}")
@staticmethod
def validate_dimensions(width: int, height: int, operation: str) -> None:
"""Validate image dimensions for specific operation.
Args:
width: Image width
height: Image height
operation: Operation type
"""
from config.stability_config import IMAGE_LIMITS
limits = IMAGE_LIMITS.get(operation, IMAGE_LIMITS["generate"])
total_pixels = width * height
if "min_pixels" in limits and total_pixels < limits["min_pixels"]:
raise HTTPException(
status_code=400,
detail=f"Image must have at least {limits['min_pixels']} pixels for {operation}"
)
if "max_pixels" in limits and total_pixels > limits["max_pixels"]:
raise HTTPException(
status_code=400,
detail=f"Image must have at most {limits['max_pixels']} pixels for {operation}"
)
if "min_dimension" in limits:
min_dim = limits["min_dimension"]
if width < min_dim or height < min_dim:
raise HTTPException(
status_code=400,
detail=f"Both dimensions must be at least {min_dim} pixels for {operation}"
)
class AudioValidator:
"""Validator for audio files and parameters."""
@staticmethod
def validate_audio_file(file: UploadFile) -> Dict[str, Any]:
"""Validate uploaded audio file.
Args:
file: Uploaded file
Returns:
Validation result with file info
"""
if not file.content_type or not file.content_type.startswith('audio/'):
raise HTTPException(status_code=400, detail="File must be an audio file")
# Check file extension
allowed_extensions = ['.mp3', '.wav']
if file.filename:
ext = '.' + file.filename.split('.')[-1].lower()
if ext not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"Unsupported audio format. Allowed: {allowed_extensions}"
)
return {
"filename": file.filename,
"content_type": file.content_type,
"is_valid": True
}
@staticmethod
async def analyze_audio_content(content: bytes) -> Dict[str, Any]:
"""Analyze audio content and characteristics.
Args:
content: Audio bytes
Returns:
Audio analysis results
"""
try:
# Basic info
info = {
"file_size": len(content),
"format": "unknown" # Would need audio library to detect
}
# For actual implementation, you'd use libraries like librosa or pydub
# to analyze audio characteristics like duration, sample rate, etc.
return info
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error analyzing audio: {str(e)}")
class PromptOptimizer:
"""Optimizer for text prompts."""
@staticmethod
def analyze_prompt(prompt: str) -> Dict[str, Any]:
"""Analyze prompt structure and content.
Args:
prompt: Text prompt
Returns:
Prompt analysis
"""
words = prompt.split()
analysis = {
"length": len(prompt),
"word_count": len(words),
"sentence_count": len([s for s in prompt.split('.') if s.strip()]),
"has_style_descriptors": _has_style_descriptors(prompt),
"has_quality_terms": _has_quality_terms(prompt),
"has_technical_terms": _has_technical_terms(prompt),
"complexity_score": _calculate_complexity_score(prompt)
}
return analysis
@staticmethod
def optimize_prompt(
prompt: str,
target_model: str = "ultra",
target_style: Optional[str] = None,
quality_level: str = "high"
) -> Dict[str, Any]:
"""Optimize prompt for better results.
Args:
prompt: Original prompt
target_model: Target model
target_style: Target style
quality_level: Desired quality level
Returns:
Optimization results
"""
optimizations = []
optimized_prompt = prompt.strip()
# Add style if not present
if target_style and not _has_style_descriptors(prompt):
optimized_prompt += f", {target_style} style"
optimizations.append(f"Added style: {target_style}")
# Add quality terms if needed
if quality_level == "high" and not _has_quality_terms(prompt):
optimized_prompt += ", high quality, detailed, sharp"
optimizations.append("Added quality enhancers")
# Model-specific optimizations
if target_model == "ultra":
if len(prompt.split()) < 10:
optimized_prompt += ", professional photography, detailed composition"
optimizations.append("Added detail for Ultra model")
elif target_model == "core":
# Keep concise for Core model
if len(prompt.split()) > 30:
optimizations.append("Consider shortening prompt for Core model")
return {
"original_prompt": prompt,
"optimized_prompt": optimized_prompt,
"optimizations_applied": optimizations,
"improvement_estimate": len(optimizations) * 15 # Rough percentage
}
@staticmethod
def generate_negative_prompt(
prompt: str,
style: Optional[str] = None
) -> str:
"""Generate appropriate negative prompt.
Args:
prompt: Original prompt
style: Target style
Returns:
Suggested negative prompt
"""
base_negative = "blurry, low quality, distorted, deformed, pixelated"
# Add style-specific negatives
if style:
if "photographic" in style.lower():
base_negative += ", cartoon, anime, illustration"
elif "anime" in style.lower():
base_negative += ", realistic, photographic"
elif "art" in style.lower():
base_negative += ", photograph, realistic"
# Add content-specific negatives based on prompt
if "person" in prompt.lower() or "human" in prompt.lower():
base_negative += ", extra limbs, malformed hands, duplicate"
return base_negative
class FileManager:
"""Manager for file operations and caching."""
@staticmethod
async def save_result(
content: bytes,
filename: str,
operation: str,
metadata: Optional[Dict[str, Any]] = None
) -> str:
"""Save generation result to file.
Args:
content: File content
filename: Filename
operation: Operation type
metadata: Optional metadata
Returns:
File path
"""
# Create directory structure
base_dir = "generated_content"
operation_dir = os.path.join(base_dir, operation)
date_dir = os.path.join(operation_dir, datetime.now().strftime("%Y/%m/%d"))
os.makedirs(date_dir, exist_ok=True)
# Generate unique filename
timestamp = datetime.now().strftime("%H%M%S")
file_hash = hashlib.md5(content).hexdigest()[:8]
unique_filename = f"{timestamp}_{file_hash}_{filename}"
file_path = os.path.join(date_dir, unique_filename)
# Save file
async with aiofiles.open(file_path, 'wb') as f:
await f.write(content)
# Save metadata if provided
if metadata:
metadata_path = file_path + ".json"
async with aiofiles.open(metadata_path, 'w') as f:
await f.write(json.dumps(metadata, indent=2))
return file_path
@staticmethod
def generate_cache_key(operation: str, parameters: Dict[str, Any]) -> str:
"""Generate cache key for operation and parameters.
Args:
operation: Operation type
parameters: Operation parameters
Returns:
Cache key
"""
# Create deterministic hash from operation and parameters
key_data = f"{operation}:{json.dumps(parameters, sort_keys=True)}"
return hashlib.sha256(key_data.encode()).hexdigest()
class ResponseFormatter:
"""Formatter for API responses."""
@staticmethod
def format_image_response(
content: bytes,
output_format: str,
metadata: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Format image response with metadata.
Args:
content: Image content
output_format: Output format
metadata: Optional metadata
Returns:
Formatted response
"""
response = {
"image": base64.b64encode(content).decode(),
"format": output_format,
"size": len(content),
"timestamp": datetime.utcnow().isoformat()
}
if metadata:
response["metadata"] = metadata
return response
@staticmethod
def format_audio_response(
content: bytes,
output_format: str,
metadata: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Format audio response with metadata.
Args:
content: Audio content
output_format: Output format
metadata: Optional metadata
Returns:
Formatted response
"""
response = {
"audio": base64.b64encode(content).decode(),
"format": output_format,
"size": len(content),
"timestamp": datetime.utcnow().isoformat()
}
if metadata:
response["metadata"] = metadata
return response
@staticmethod
def format_3d_response(
content: bytes,
metadata: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Format 3D model response with metadata.
Args:
content: 3D model content (GLB)
metadata: Optional metadata
Returns:
Formatted response
"""
response = {
"model": base64.b64encode(content).decode(),
"format": "glb",
"size": len(content),
"timestamp": datetime.utcnow().isoformat()
}
if metadata:
response["metadata"] = metadata
return response
class ParameterValidator:
"""Validator for operation parameters."""
@staticmethod
def validate_seed(seed: Optional[int]) -> int:
"""Validate and normalize seed parameter.
Args:
seed: Seed value
Returns:
Valid seed value
"""
if seed is None:
return 0
if not isinstance(seed, int) or seed < 0 or seed > 4294967294:
raise HTTPException(
status_code=400,
detail="Seed must be an integer between 0 and 4294967294"
)
return seed
@staticmethod
def validate_strength(strength: Optional[float], operation: str) -> Optional[float]:
"""Validate strength parameter for different operations.
Args:
strength: Strength value
operation: Operation type
Returns:
Valid strength value
"""
if strength is None:
return None
if not isinstance(strength, (int, float)) or strength < 0 or strength > 1:
raise HTTPException(
status_code=400,
detail="Strength must be a float between 0 and 1"
)
# Operation-specific validation
if operation == "audio_to_audio" and strength < 0.01:
raise HTTPException(
status_code=400,
detail="Minimum strength for audio-to-audio is 0.01"
)
return float(strength)
@staticmethod
def validate_creativity(creativity: Optional[float], operation: str) -> Optional[float]:
"""Validate creativity parameter.
Args:
creativity: Creativity value
operation: Operation type
Returns:
Valid creativity value
"""
if creativity is None:
return None
# Different operations have different creativity ranges
ranges = {
"upscale": (0.1, 0.5),
"outpaint": (0, 1),
"conservative_upscale": (0.2, 0.5)
}
min_val, max_val = ranges.get(operation, (0, 1))
if not isinstance(creativity, (int, float)) or creativity < min_val or creativity > max_val:
raise HTTPException(
status_code=400,
detail=f"Creativity for {operation} must be between {min_val} and {max_val}"
)
return float(creativity)
class WorkflowManager:
"""Manager for complex workflows and pipelines."""
@staticmethod
def validate_workflow(workflow: List[Dict[str, Any]]) -> List[str]:
"""Validate workflow steps.
Args:
workflow: List of workflow steps
Returns:
List of validation errors
"""
errors = []
supported_operations = [
"generate_ultra", "generate_core", "generate_sd3",
"upscale_fast", "upscale_conservative", "upscale_creative",
"inpaint", "outpaint", "erase", "search_and_replace",
"control_sketch", "control_structure", "control_style"
]
for i, step in enumerate(workflow):
if "operation" not in step:
errors.append(f"Step {i+1}: Missing 'operation' field")
continue
operation = step["operation"]
if operation not in supported_operations:
errors.append(f"Step {i+1}: Unsupported operation '{operation}'")
# Validate step dependencies
if i > 0 and operation.startswith("generate_") and i > 0:
errors.append(f"Step {i+1}: Generate operations should be first in workflow")
return errors
@staticmethod
def optimize_workflow(workflow: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Optimize workflow for better performance.
Args:
workflow: Original workflow
Returns:
Optimized workflow
"""
optimized = workflow.copy()
# Remove redundant operations
operations_seen = set()
filtered_workflow = []
for step in optimized:
operation = step["operation"]
if operation not in operations_seen or operation.startswith("generate_"):
filtered_workflow.append(step)
operations_seen.add(operation)
# Reorder for optimal execution
# Generation operations first, then modifications, then upscaling
order_priority = {
"generate": 0,
"control": 1,
"edit": 2,
"upscale": 3
}
def get_priority(step):
operation = step["operation"]
for key, priority in order_priority.items():
if operation.startswith(key):
return priority
return 999
filtered_workflow.sort(key=get_priority)
return filtered_workflow
# ==================== HELPER FUNCTIONS ====================
def _extract_dominant_colors(img: Image.Image, num_colors: int = 5) -> List[Tuple[int, int, int]]:
"""Extract dominant colors from image.
Args:
img: PIL Image
num_colors: Number of dominant colors to extract
Returns:
List of RGB tuples
"""
# Resize image for faster processing
img_small = img.resize((150, 150))
# Convert to numpy array
img_array = np.array(img_small)
pixels = img_array.reshape(-1, 3)
# Use k-means clustering to find dominant colors
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10)
kmeans.fit(pixels)
colors = kmeans.cluster_centers_.astype(int)
return [tuple(color) for color in colors]
def _assess_image_quality(img: Image.Image) -> Dict[str, Any]:
"""Assess image quality metrics.
Args:
img: PIL Image
Returns:
Quality assessment
"""
# Convert to grayscale for quality analysis
gray = img.convert('L')
gray_array = np.array(gray)
# Calculate sharpness using Laplacian variance
laplacian_var = np.var(np.gradient(gray_array))
sharpness_score = min(100, laplacian_var / 100)
# Calculate noise level
noise_level = np.std(gray_array)
# Overall quality score
overall_score = (sharpness_score + max(0, 100 - noise_level)) / 2
return {
"sharpness_score": round(sharpness_score, 2),
"noise_level": round(noise_level, 2),
"overall_score": round(overall_score, 2),
"needs_enhancement": overall_score < 70
}
def _has_style_descriptors(prompt: str) -> bool:
"""Check if prompt contains style descriptors."""
style_keywords = [
"photorealistic", "realistic", "anime", "cartoon", "digital art",
"oil painting", "watercolor", "sketch", "illustration", "3d render",
"cinematic", "artistic", "professional"
]
return any(keyword in prompt.lower() for keyword in style_keywords)
def _has_quality_terms(prompt: str) -> bool:
"""Check if prompt contains quality terms."""
quality_keywords = [
"high quality", "detailed", "sharp", "crisp", "clear",
"professional", "masterpiece", "award winning"
]
return any(keyword in prompt.lower() for keyword in quality_keywords)
def _has_technical_terms(prompt: str) -> bool:
"""Check if prompt contains technical photography terms."""
technical_keywords = [
"bokeh", "depth of field", "macro", "wide angle", "telephoto",
"iso", "aperture", "shutter speed", "lighting", "composition"
]
return any(keyword in prompt.lower() for keyword in technical_keywords)
def _calculate_complexity_score(prompt: str) -> float:
"""Calculate prompt complexity score.
Args:
prompt: Text prompt
Returns:
Complexity score (0-100)
"""
words = prompt.split()
# Base score from word count
base_score = min(len(words) * 2, 50)
# Add points for descriptive elements
if _has_style_descriptors(prompt):
base_score += 15
if _has_quality_terms(prompt):
base_score += 10
if _has_technical_terms(prompt):
base_score += 15
# Add points for specific details
if any(word in prompt.lower() for word in ["color", "lighting", "composition"]):
base_score += 10
return min(base_score, 100)
def create_batch_manifest(
operation: str,
files: List[UploadFile],
parameters: Dict[str, Any]
) -> Dict[str, Any]:
"""Create manifest for batch processing.
Args:
operation: Operation type
files: List of files to process
parameters: Operation parameters
Returns:
Batch manifest
"""
return {
"batch_id": f"batch_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}",
"operation": operation,
"file_count": len(files),
"files": [{"filename": f.filename, "size": f.size} for f in files],
"parameters": parameters,
"created_at": datetime.utcnow().isoformat(),
"estimated_duration": len(files) * 30, # 30 seconds per file estimate
"estimated_cost": len(files) * _get_operation_cost(operation)
}
def _get_operation_cost(operation: str) -> float:
"""Get estimated cost for operation.
Args:
operation: Operation type
Returns:
Estimated cost in credits
"""
from config.stability_config import MODEL_PRICING
# Map operation to pricing category
if operation.startswith("generate_"):
return MODEL_PRICING["generate"].get("core", 3) # Default to core
elif operation.startswith("upscale_"):
upscale_type = operation.replace("upscale_", "")
return MODEL_PRICING["upscale"].get(upscale_type, 5)
elif operation.startswith("control_"):
return MODEL_PRICING["control"].get("sketch", 5) # Default
else:
return 5 # Default cost
def validate_file_size(file: UploadFile, max_size: int = 10 * 1024 * 1024) -> None:
"""Validate file size.
Args:
file: Uploaded file
max_size: Maximum allowed size in bytes
"""
if file.size and file.size > max_size:
raise HTTPException(
status_code=413,
detail=f"File size ({file.size} bytes) exceeds maximum allowed size ({max_size} bytes)"
)
async def convert_image_format(content: bytes, target_format: str) -> bytes:
"""Convert image to target format.
Args:
content: Image content
target_format: Target format (jpeg, png, webp)
Returns:
Converted image bytes
"""
try:
img = Image.open(io.BytesIO(content))
# Convert to RGB if saving as JPEG
if target_format.lower() == "jpeg" and img.mode in ("RGBA", "LA"):
img = img.convert("RGB")
output = io.BytesIO()
img.save(output, format=target_format.upper())
return output.getvalue()
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error converting image: {str(e)}")
def estimate_processing_time(
operation: str,
file_size: int,
complexity: Optional[Dict[str, Any]] = None
) -> float:
"""Estimate processing time for operation.
Args:
operation: Operation type
file_size: File size in bytes
complexity: Optional complexity metrics
Returns:
Estimated time in seconds
"""
# Base times by operation (in seconds)
base_times = {
"generate_ultra": 15,
"generate_core": 5,
"generate_sd3": 10,
"upscale_fast": 2,
"upscale_conservative": 30,
"upscale_creative": 60,
"inpaint": 10,
"outpaint": 15,
"control_sketch": 8,
"control_structure": 8,
"control_style": 10,
"3d_fast": 10,
"3d_point_aware": 20,
"audio_text": 30,
"audio_transform": 45
}
base_time = base_times.get(operation, 10)
# Adjust for file size
size_factor = max(1, file_size / (1024 * 1024)) # Size in MB
adjusted_time = base_time * size_factor
# Adjust for complexity if provided
if complexity and complexity.get("complexity_score", 0) > 80:
adjusted_time *= 1.5
return round(adjusted_time, 1)