858 lines
26 KiB
Python
858 lines
26 KiB
Python
"""Utility functions for Stability AI operations."""
|
|
|
|
import base64
|
|
import io
|
|
import json
|
|
import mimetypes
|
|
import os
|
|
from typing import Dict, Any, Optional, List, Union, Tuple
|
|
from PIL import Image, ImageStat
|
|
import numpy as np
|
|
from fastapi import UploadFile, HTTPException
|
|
import aiofiles
|
|
import asyncio
|
|
from datetime import datetime
|
|
import hashlib
|
|
|
|
|
|
class ImageValidator:
|
|
"""Validator for image files and parameters."""
|
|
|
|
@staticmethod
|
|
def validate_image_file(file: UploadFile) -> Dict[str, Any]:
|
|
"""Validate uploaded image file.
|
|
|
|
Args:
|
|
file: Uploaded file
|
|
|
|
Returns:
|
|
Validation result with file info
|
|
"""
|
|
if not file.content_type or not file.content_type.startswith('image/'):
|
|
raise HTTPException(status_code=400, detail="File must be an image")
|
|
|
|
# Check file extension
|
|
allowed_extensions = ['.jpg', '.jpeg', '.png', '.webp']
|
|
if file.filename:
|
|
ext = '.' + file.filename.split('.')[-1].lower()
|
|
if ext not in allowed_extensions:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Unsupported file format. Allowed: {allowed_extensions}"
|
|
)
|
|
|
|
return {
|
|
"filename": file.filename,
|
|
"content_type": file.content_type,
|
|
"is_valid": True
|
|
}
|
|
|
|
@staticmethod
|
|
async def analyze_image_content(content: bytes) -> Dict[str, Any]:
|
|
"""Analyze image content and characteristics.
|
|
|
|
Args:
|
|
content: Image bytes
|
|
|
|
Returns:
|
|
Image analysis results
|
|
"""
|
|
try:
|
|
img = Image.open(io.BytesIO(content))
|
|
|
|
# Basic info
|
|
info = {
|
|
"format": img.format,
|
|
"mode": img.mode,
|
|
"size": img.size,
|
|
"width": img.width,
|
|
"height": img.height,
|
|
"total_pixels": img.width * img.height,
|
|
"aspect_ratio": round(img.width / img.height, 3),
|
|
"file_size": len(content),
|
|
"has_alpha": img.mode in ("RGBA", "LA") or "transparency" in img.info
|
|
}
|
|
|
|
# Color analysis
|
|
if img.mode == "RGB" or img.mode == "RGBA":
|
|
img_rgb = img.convert("RGB")
|
|
stat = ImageStat.Stat(img_rgb)
|
|
|
|
info.update({
|
|
"brightness": round(sum(stat.mean) / 3, 2),
|
|
"color_variance": round(sum(stat.stddev) / 3, 2),
|
|
"dominant_colors": _extract_dominant_colors(img_rgb)
|
|
})
|
|
|
|
# Quality assessment
|
|
info["quality_assessment"] = _assess_image_quality(img)
|
|
|
|
return info
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=f"Error analyzing image: {str(e)}")
|
|
|
|
@staticmethod
|
|
def validate_dimensions(width: int, height: int, operation: str) -> None:
|
|
"""Validate image dimensions for specific operation.
|
|
|
|
Args:
|
|
width: Image width
|
|
height: Image height
|
|
operation: Operation type
|
|
"""
|
|
from config.stability_config import IMAGE_LIMITS
|
|
|
|
limits = IMAGE_LIMITS.get(operation, IMAGE_LIMITS["generate"])
|
|
total_pixels = width * height
|
|
|
|
if "min_pixels" in limits and total_pixels < limits["min_pixels"]:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Image must have at least {limits['min_pixels']} pixels for {operation}"
|
|
)
|
|
|
|
if "max_pixels" in limits and total_pixels > limits["max_pixels"]:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Image must have at most {limits['max_pixels']} pixels for {operation}"
|
|
)
|
|
|
|
if "min_dimension" in limits:
|
|
min_dim = limits["min_dimension"]
|
|
if width < min_dim or height < min_dim:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Both dimensions must be at least {min_dim} pixels for {operation}"
|
|
)
|
|
|
|
|
|
class AudioValidator:
|
|
"""Validator for audio files and parameters."""
|
|
|
|
@staticmethod
|
|
def validate_audio_file(file: UploadFile) -> Dict[str, Any]:
|
|
"""Validate uploaded audio file.
|
|
|
|
Args:
|
|
file: Uploaded file
|
|
|
|
Returns:
|
|
Validation result with file info
|
|
"""
|
|
if not file.content_type or not file.content_type.startswith('audio/'):
|
|
raise HTTPException(status_code=400, detail="File must be an audio file")
|
|
|
|
# Check file extension
|
|
allowed_extensions = ['.mp3', '.wav']
|
|
if file.filename:
|
|
ext = '.' + file.filename.split('.')[-1].lower()
|
|
if ext not in allowed_extensions:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Unsupported audio format. Allowed: {allowed_extensions}"
|
|
)
|
|
|
|
return {
|
|
"filename": file.filename,
|
|
"content_type": file.content_type,
|
|
"is_valid": True
|
|
}
|
|
|
|
@staticmethod
|
|
async def analyze_audio_content(content: bytes) -> Dict[str, Any]:
|
|
"""Analyze audio content and characteristics.
|
|
|
|
Args:
|
|
content: Audio bytes
|
|
|
|
Returns:
|
|
Audio analysis results
|
|
"""
|
|
try:
|
|
# Basic info
|
|
info = {
|
|
"file_size": len(content),
|
|
"format": "unknown" # Would need audio library to detect
|
|
}
|
|
|
|
# For actual implementation, you'd use libraries like librosa or pydub
|
|
# to analyze audio characteristics like duration, sample rate, etc.
|
|
|
|
return info
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=f"Error analyzing audio: {str(e)}")
|
|
|
|
|
|
class PromptOptimizer:
|
|
"""Optimizer for text prompts."""
|
|
|
|
@staticmethod
|
|
def analyze_prompt(prompt: str) -> Dict[str, Any]:
|
|
"""Analyze prompt structure and content.
|
|
|
|
Args:
|
|
prompt: Text prompt
|
|
|
|
Returns:
|
|
Prompt analysis
|
|
"""
|
|
words = prompt.split()
|
|
|
|
analysis = {
|
|
"length": len(prompt),
|
|
"word_count": len(words),
|
|
"sentence_count": len([s for s in prompt.split('.') if s.strip()]),
|
|
"has_style_descriptors": _has_style_descriptors(prompt),
|
|
"has_quality_terms": _has_quality_terms(prompt),
|
|
"has_technical_terms": _has_technical_terms(prompt),
|
|
"complexity_score": _calculate_complexity_score(prompt)
|
|
}
|
|
|
|
return analysis
|
|
|
|
@staticmethod
|
|
def optimize_prompt(
|
|
prompt: str,
|
|
target_model: str = "ultra",
|
|
target_style: Optional[str] = None,
|
|
quality_level: str = "high"
|
|
) -> Dict[str, Any]:
|
|
"""Optimize prompt for better results.
|
|
|
|
Args:
|
|
prompt: Original prompt
|
|
target_model: Target model
|
|
target_style: Target style
|
|
quality_level: Desired quality level
|
|
|
|
Returns:
|
|
Optimization results
|
|
"""
|
|
optimizations = []
|
|
optimized_prompt = prompt.strip()
|
|
|
|
# Add style if not present
|
|
if target_style and not _has_style_descriptors(prompt):
|
|
optimized_prompt += f", {target_style} style"
|
|
optimizations.append(f"Added style: {target_style}")
|
|
|
|
# Add quality terms if needed
|
|
if quality_level == "high" and not _has_quality_terms(prompt):
|
|
optimized_prompt += ", high quality, detailed, sharp"
|
|
optimizations.append("Added quality enhancers")
|
|
|
|
# Model-specific optimizations
|
|
if target_model == "ultra":
|
|
if len(prompt.split()) < 10:
|
|
optimized_prompt += ", professional photography, detailed composition"
|
|
optimizations.append("Added detail for Ultra model")
|
|
elif target_model == "core":
|
|
# Keep concise for Core model
|
|
if len(prompt.split()) > 30:
|
|
optimizations.append("Consider shortening prompt for Core model")
|
|
|
|
return {
|
|
"original_prompt": prompt,
|
|
"optimized_prompt": optimized_prompt,
|
|
"optimizations_applied": optimizations,
|
|
"improvement_estimate": len(optimizations) * 15 # Rough percentage
|
|
}
|
|
|
|
@staticmethod
|
|
def generate_negative_prompt(
|
|
prompt: str,
|
|
style: Optional[str] = None
|
|
) -> str:
|
|
"""Generate appropriate negative prompt.
|
|
|
|
Args:
|
|
prompt: Original prompt
|
|
style: Target style
|
|
|
|
Returns:
|
|
Suggested negative prompt
|
|
"""
|
|
base_negative = "blurry, low quality, distorted, deformed, pixelated"
|
|
|
|
# Add style-specific negatives
|
|
if style:
|
|
if "photographic" in style.lower():
|
|
base_negative += ", cartoon, anime, illustration"
|
|
elif "anime" in style.lower():
|
|
base_negative += ", realistic, photographic"
|
|
elif "art" in style.lower():
|
|
base_negative += ", photograph, realistic"
|
|
|
|
# Add content-specific negatives based on prompt
|
|
if "person" in prompt.lower() or "human" in prompt.lower():
|
|
base_negative += ", extra limbs, malformed hands, duplicate"
|
|
|
|
return base_negative
|
|
|
|
|
|
class FileManager:
|
|
"""Manager for file operations and caching."""
|
|
|
|
@staticmethod
|
|
async def save_result(
|
|
content: bytes,
|
|
filename: str,
|
|
operation: str,
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
) -> str:
|
|
"""Save generation result to file.
|
|
|
|
Args:
|
|
content: File content
|
|
filename: Filename
|
|
operation: Operation type
|
|
metadata: Optional metadata
|
|
|
|
Returns:
|
|
File path
|
|
"""
|
|
# Create directory structure
|
|
base_dir = "generated_content"
|
|
operation_dir = os.path.join(base_dir, operation)
|
|
date_dir = os.path.join(operation_dir, datetime.now().strftime("%Y/%m/%d"))
|
|
|
|
os.makedirs(date_dir, exist_ok=True)
|
|
|
|
# Generate unique filename
|
|
timestamp = datetime.now().strftime("%H%M%S")
|
|
file_hash = hashlib.md5(content).hexdigest()[:8]
|
|
unique_filename = f"{timestamp}_{file_hash}_{filename}"
|
|
|
|
file_path = os.path.join(date_dir, unique_filename)
|
|
|
|
# Save file
|
|
async with aiofiles.open(file_path, 'wb') as f:
|
|
await f.write(content)
|
|
|
|
# Save metadata if provided
|
|
if metadata:
|
|
metadata_path = file_path + ".json"
|
|
async with aiofiles.open(metadata_path, 'w') as f:
|
|
await f.write(json.dumps(metadata, indent=2))
|
|
|
|
return file_path
|
|
|
|
@staticmethod
|
|
def generate_cache_key(operation: str, parameters: Dict[str, Any]) -> str:
|
|
"""Generate cache key for operation and parameters.
|
|
|
|
Args:
|
|
operation: Operation type
|
|
parameters: Operation parameters
|
|
|
|
Returns:
|
|
Cache key
|
|
"""
|
|
# Create deterministic hash from operation and parameters
|
|
key_data = f"{operation}:{json.dumps(parameters, sort_keys=True)}"
|
|
return hashlib.sha256(key_data.encode()).hexdigest()
|
|
|
|
|
|
class ResponseFormatter:
|
|
"""Formatter for API responses."""
|
|
|
|
@staticmethod
|
|
def format_image_response(
|
|
content: bytes,
|
|
output_format: str,
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
) -> Dict[str, Any]:
|
|
"""Format image response with metadata.
|
|
|
|
Args:
|
|
content: Image content
|
|
output_format: Output format
|
|
metadata: Optional metadata
|
|
|
|
Returns:
|
|
Formatted response
|
|
"""
|
|
response = {
|
|
"image": base64.b64encode(content).decode(),
|
|
"format": output_format,
|
|
"size": len(content),
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
|
|
if metadata:
|
|
response["metadata"] = metadata
|
|
|
|
return response
|
|
|
|
@staticmethod
|
|
def format_audio_response(
|
|
content: bytes,
|
|
output_format: str,
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
) -> Dict[str, Any]:
|
|
"""Format audio response with metadata.
|
|
|
|
Args:
|
|
content: Audio content
|
|
output_format: Output format
|
|
metadata: Optional metadata
|
|
|
|
Returns:
|
|
Formatted response
|
|
"""
|
|
response = {
|
|
"audio": base64.b64encode(content).decode(),
|
|
"format": output_format,
|
|
"size": len(content),
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
|
|
if metadata:
|
|
response["metadata"] = metadata
|
|
|
|
return response
|
|
|
|
@staticmethod
|
|
def format_3d_response(
|
|
content: bytes,
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
) -> Dict[str, Any]:
|
|
"""Format 3D model response with metadata.
|
|
|
|
Args:
|
|
content: 3D model content (GLB)
|
|
metadata: Optional metadata
|
|
|
|
Returns:
|
|
Formatted response
|
|
"""
|
|
response = {
|
|
"model": base64.b64encode(content).decode(),
|
|
"format": "glb",
|
|
"size": len(content),
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
|
|
if metadata:
|
|
response["metadata"] = metadata
|
|
|
|
return response
|
|
|
|
|
|
class ParameterValidator:
|
|
"""Validator for operation parameters."""
|
|
|
|
@staticmethod
|
|
def validate_seed(seed: Optional[int]) -> int:
|
|
"""Validate and normalize seed parameter.
|
|
|
|
Args:
|
|
seed: Seed value
|
|
|
|
Returns:
|
|
Valid seed value
|
|
"""
|
|
if seed is None:
|
|
return 0
|
|
|
|
if not isinstance(seed, int) or seed < 0 or seed > 4294967294:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Seed must be an integer between 0 and 4294967294"
|
|
)
|
|
|
|
return seed
|
|
|
|
@staticmethod
|
|
def validate_strength(strength: Optional[float], operation: str) -> Optional[float]:
|
|
"""Validate strength parameter for different operations.
|
|
|
|
Args:
|
|
strength: Strength value
|
|
operation: Operation type
|
|
|
|
Returns:
|
|
Valid strength value
|
|
"""
|
|
if strength is None:
|
|
return None
|
|
|
|
if not isinstance(strength, (int, float)) or strength < 0 or strength > 1:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Strength must be a float between 0 and 1"
|
|
)
|
|
|
|
# Operation-specific validation
|
|
if operation == "audio_to_audio" and strength < 0.01:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Minimum strength for audio-to-audio is 0.01"
|
|
)
|
|
|
|
return float(strength)
|
|
|
|
@staticmethod
|
|
def validate_creativity(creativity: Optional[float], operation: str) -> Optional[float]:
|
|
"""Validate creativity parameter.
|
|
|
|
Args:
|
|
creativity: Creativity value
|
|
operation: Operation type
|
|
|
|
Returns:
|
|
Valid creativity value
|
|
"""
|
|
if creativity is None:
|
|
return None
|
|
|
|
# Different operations have different creativity ranges
|
|
ranges = {
|
|
"upscale": (0.1, 0.5),
|
|
"outpaint": (0, 1),
|
|
"conservative_upscale": (0.2, 0.5)
|
|
}
|
|
|
|
min_val, max_val = ranges.get(operation, (0, 1))
|
|
|
|
if not isinstance(creativity, (int, float)) or creativity < min_val or creativity > max_val:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Creativity for {operation} must be between {min_val} and {max_val}"
|
|
)
|
|
|
|
return float(creativity)
|
|
|
|
|
|
class WorkflowManager:
|
|
"""Manager for complex workflows and pipelines."""
|
|
|
|
@staticmethod
|
|
def validate_workflow(workflow: List[Dict[str, Any]]) -> List[str]:
|
|
"""Validate workflow steps.
|
|
|
|
Args:
|
|
workflow: List of workflow steps
|
|
|
|
Returns:
|
|
List of validation errors
|
|
"""
|
|
errors = []
|
|
supported_operations = [
|
|
"generate_ultra", "generate_core", "generate_sd3",
|
|
"upscale_fast", "upscale_conservative", "upscale_creative",
|
|
"inpaint", "outpaint", "erase", "search_and_replace",
|
|
"control_sketch", "control_structure", "control_style"
|
|
]
|
|
|
|
for i, step in enumerate(workflow):
|
|
if "operation" not in step:
|
|
errors.append(f"Step {i+1}: Missing 'operation' field")
|
|
continue
|
|
|
|
operation = step["operation"]
|
|
if operation not in supported_operations:
|
|
errors.append(f"Step {i+1}: Unsupported operation '{operation}'")
|
|
|
|
# Validate step dependencies
|
|
if i > 0 and operation.startswith("generate_") and i > 0:
|
|
errors.append(f"Step {i+1}: Generate operations should be first in workflow")
|
|
|
|
return errors
|
|
|
|
@staticmethod
|
|
def optimize_workflow(workflow: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
"""Optimize workflow for better performance.
|
|
|
|
Args:
|
|
workflow: Original workflow
|
|
|
|
Returns:
|
|
Optimized workflow
|
|
"""
|
|
optimized = workflow.copy()
|
|
|
|
# Remove redundant operations
|
|
operations_seen = set()
|
|
filtered_workflow = []
|
|
|
|
for step in optimized:
|
|
operation = step["operation"]
|
|
if operation not in operations_seen or operation.startswith("generate_"):
|
|
filtered_workflow.append(step)
|
|
operations_seen.add(operation)
|
|
|
|
# Reorder for optimal execution
|
|
# Generation operations first, then modifications, then upscaling
|
|
order_priority = {
|
|
"generate": 0,
|
|
"control": 1,
|
|
"edit": 2,
|
|
"upscale": 3
|
|
}
|
|
|
|
def get_priority(step):
|
|
operation = step["operation"]
|
|
for key, priority in order_priority.items():
|
|
if operation.startswith(key):
|
|
return priority
|
|
return 999
|
|
|
|
filtered_workflow.sort(key=get_priority)
|
|
|
|
return filtered_workflow
|
|
|
|
|
|
# ==================== HELPER FUNCTIONS ====================
|
|
|
|
def _extract_dominant_colors(img: Image.Image, num_colors: int = 5) -> List[Tuple[int, int, int]]:
|
|
"""Extract dominant colors from image.
|
|
|
|
Args:
|
|
img: PIL Image
|
|
num_colors: Number of dominant colors to extract
|
|
|
|
Returns:
|
|
List of RGB tuples
|
|
"""
|
|
# Resize image for faster processing
|
|
img_small = img.resize((150, 150))
|
|
|
|
# Convert to numpy array
|
|
img_array = np.array(img_small)
|
|
pixels = img_array.reshape(-1, 3)
|
|
|
|
# Use k-means clustering to find dominant colors
|
|
from sklearn.cluster import KMeans
|
|
|
|
kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10)
|
|
kmeans.fit(pixels)
|
|
|
|
colors = kmeans.cluster_centers_.astype(int)
|
|
return [tuple(color) for color in colors]
|
|
|
|
|
|
def _assess_image_quality(img: Image.Image) -> Dict[str, Any]:
|
|
"""Assess image quality metrics.
|
|
|
|
Args:
|
|
img: PIL Image
|
|
|
|
Returns:
|
|
Quality assessment
|
|
"""
|
|
# Convert to grayscale for quality analysis
|
|
gray = img.convert('L')
|
|
gray_array = np.array(gray)
|
|
|
|
# Calculate sharpness using Laplacian variance
|
|
laplacian_var = np.var(np.gradient(gray_array))
|
|
sharpness_score = min(100, laplacian_var / 100)
|
|
|
|
# Calculate noise level
|
|
noise_level = np.std(gray_array)
|
|
|
|
# Overall quality score
|
|
overall_score = (sharpness_score + max(0, 100 - noise_level)) / 2
|
|
|
|
return {
|
|
"sharpness_score": round(sharpness_score, 2),
|
|
"noise_level": round(noise_level, 2),
|
|
"overall_score": round(overall_score, 2),
|
|
"needs_enhancement": overall_score < 70
|
|
}
|
|
|
|
|
|
def _has_style_descriptors(prompt: str) -> bool:
|
|
"""Check if prompt contains style descriptors."""
|
|
style_keywords = [
|
|
"photorealistic", "realistic", "anime", "cartoon", "digital art",
|
|
"oil painting", "watercolor", "sketch", "illustration", "3d render",
|
|
"cinematic", "artistic", "professional"
|
|
]
|
|
return any(keyword in prompt.lower() for keyword in style_keywords)
|
|
|
|
|
|
def _has_quality_terms(prompt: str) -> bool:
|
|
"""Check if prompt contains quality terms."""
|
|
quality_keywords = [
|
|
"high quality", "detailed", "sharp", "crisp", "clear",
|
|
"professional", "masterpiece", "award winning"
|
|
]
|
|
return any(keyword in prompt.lower() for keyword in quality_keywords)
|
|
|
|
|
|
def _has_technical_terms(prompt: str) -> bool:
|
|
"""Check if prompt contains technical photography terms."""
|
|
technical_keywords = [
|
|
"bokeh", "depth of field", "macro", "wide angle", "telephoto",
|
|
"iso", "aperture", "shutter speed", "lighting", "composition"
|
|
]
|
|
return any(keyword in prompt.lower() for keyword in technical_keywords)
|
|
|
|
|
|
def _calculate_complexity_score(prompt: str) -> float:
|
|
"""Calculate prompt complexity score.
|
|
|
|
Args:
|
|
prompt: Text prompt
|
|
|
|
Returns:
|
|
Complexity score (0-100)
|
|
"""
|
|
words = prompt.split()
|
|
|
|
# Base score from word count
|
|
base_score = min(len(words) * 2, 50)
|
|
|
|
# Add points for descriptive elements
|
|
if _has_style_descriptors(prompt):
|
|
base_score += 15
|
|
if _has_quality_terms(prompt):
|
|
base_score += 10
|
|
if _has_technical_terms(prompt):
|
|
base_score += 15
|
|
|
|
# Add points for specific details
|
|
if any(word in prompt.lower() for word in ["color", "lighting", "composition"]):
|
|
base_score += 10
|
|
|
|
return min(base_score, 100)
|
|
|
|
|
|
def create_batch_manifest(
|
|
operation: str,
|
|
files: List[UploadFile],
|
|
parameters: Dict[str, Any]
|
|
) -> Dict[str, Any]:
|
|
"""Create manifest for batch processing.
|
|
|
|
Args:
|
|
operation: Operation type
|
|
files: List of files to process
|
|
parameters: Operation parameters
|
|
|
|
Returns:
|
|
Batch manifest
|
|
"""
|
|
return {
|
|
"batch_id": f"batch_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}",
|
|
"operation": operation,
|
|
"file_count": len(files),
|
|
"files": [{"filename": f.filename, "size": f.size} for f in files],
|
|
"parameters": parameters,
|
|
"created_at": datetime.utcnow().isoformat(),
|
|
"estimated_duration": len(files) * 30, # 30 seconds per file estimate
|
|
"estimated_cost": len(files) * _get_operation_cost(operation)
|
|
}
|
|
|
|
|
|
def _get_operation_cost(operation: str) -> float:
|
|
"""Get estimated cost for operation.
|
|
|
|
Args:
|
|
operation: Operation type
|
|
|
|
Returns:
|
|
Estimated cost in credits
|
|
"""
|
|
from config.stability_config import MODEL_PRICING
|
|
|
|
# Map operation to pricing category
|
|
if operation.startswith("generate_"):
|
|
return MODEL_PRICING["generate"].get("core", 3) # Default to core
|
|
elif operation.startswith("upscale_"):
|
|
upscale_type = operation.replace("upscale_", "")
|
|
return MODEL_PRICING["upscale"].get(upscale_type, 5)
|
|
elif operation.startswith("control_"):
|
|
return MODEL_PRICING["control"].get("sketch", 5) # Default
|
|
else:
|
|
return 5 # Default cost
|
|
|
|
|
|
def validate_file_size(file: UploadFile, max_size: int = 10 * 1024 * 1024) -> None:
|
|
"""Validate file size.
|
|
|
|
Args:
|
|
file: Uploaded file
|
|
max_size: Maximum allowed size in bytes
|
|
"""
|
|
if file.size and file.size > max_size:
|
|
raise HTTPException(
|
|
status_code=413,
|
|
detail=f"File size ({file.size} bytes) exceeds maximum allowed size ({max_size} bytes)"
|
|
)
|
|
|
|
|
|
async def convert_image_format(content: bytes, target_format: str) -> bytes:
|
|
"""Convert image to target format.
|
|
|
|
Args:
|
|
content: Image content
|
|
target_format: Target format (jpeg, png, webp)
|
|
|
|
Returns:
|
|
Converted image bytes
|
|
"""
|
|
try:
|
|
img = Image.open(io.BytesIO(content))
|
|
|
|
# Convert to RGB if saving as JPEG
|
|
if target_format.lower() == "jpeg" and img.mode in ("RGBA", "LA"):
|
|
img = img.convert("RGB")
|
|
|
|
output = io.BytesIO()
|
|
img.save(output, format=target_format.upper())
|
|
return output.getvalue()
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=f"Error converting image: {str(e)}")
|
|
|
|
|
|
def estimate_processing_time(
|
|
operation: str,
|
|
file_size: int,
|
|
complexity: Optional[Dict[str, Any]] = None
|
|
) -> float:
|
|
"""Estimate processing time for operation.
|
|
|
|
Args:
|
|
operation: Operation type
|
|
file_size: File size in bytes
|
|
complexity: Optional complexity metrics
|
|
|
|
Returns:
|
|
Estimated time in seconds
|
|
"""
|
|
# Base times by operation (in seconds)
|
|
base_times = {
|
|
"generate_ultra": 15,
|
|
"generate_core": 5,
|
|
"generate_sd3": 10,
|
|
"upscale_fast": 2,
|
|
"upscale_conservative": 30,
|
|
"upscale_creative": 60,
|
|
"inpaint": 10,
|
|
"outpaint": 15,
|
|
"control_sketch": 8,
|
|
"control_structure": 8,
|
|
"control_style": 10,
|
|
"3d_fast": 10,
|
|
"3d_point_aware": 20,
|
|
"audio_text": 30,
|
|
"audio_transform": 45
|
|
}
|
|
|
|
base_time = base_times.get(operation, 10)
|
|
|
|
# Adjust for file size
|
|
size_factor = max(1, file_size / (1024 * 1024)) # Size in MB
|
|
adjusted_time = base_time * size_factor
|
|
|
|
# Adjust for complexity if provided
|
|
if complexity and complexity.get("complexity_score", 0) > 80:
|
|
adjusted_time *= 1.5
|
|
|
|
return round(adjusted_time, 1) |