Base code
This commit is contained in:
858
backend/utils/stability_utils.py
Normal file
858
backend/utils/stability_utils.py
Normal file
@@ -0,0 +1,858 @@
|
||||
"""Utility functions for Stability AI operations."""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import Dict, Any, Optional, List, Union, Tuple
|
||||
from PIL import Image, ImageStat
|
||||
import numpy as np
|
||||
from fastapi import UploadFile, HTTPException
|
||||
import aiofiles
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import hashlib
|
||||
|
||||
|
||||
class ImageValidator:
|
||||
"""Validator for image files and parameters."""
|
||||
|
||||
@staticmethod
|
||||
def validate_image_file(file: UploadFile) -> Dict[str, Any]:
|
||||
"""Validate uploaded image file.
|
||||
|
||||
Args:
|
||||
file: Uploaded file
|
||||
|
||||
Returns:
|
||||
Validation result with file info
|
||||
"""
|
||||
if not file.content_type or not file.content_type.startswith('image/'):
|
||||
raise HTTPException(status_code=400, detail="File must be an image")
|
||||
|
||||
# Check file extension
|
||||
allowed_extensions = ['.jpg', '.jpeg', '.png', '.webp']
|
||||
if file.filename:
|
||||
ext = '.' + file.filename.split('.')[-1].lower()
|
||||
if ext not in allowed_extensions:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported file format. Allowed: {allowed_extensions}"
|
||||
)
|
||||
|
||||
return {
|
||||
"filename": file.filename,
|
||||
"content_type": file.content_type,
|
||||
"is_valid": True
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def analyze_image_content(content: bytes) -> Dict[str, Any]:
|
||||
"""Analyze image content and characteristics.
|
||||
|
||||
Args:
|
||||
content: Image bytes
|
||||
|
||||
Returns:
|
||||
Image analysis results
|
||||
"""
|
||||
try:
|
||||
img = Image.open(io.BytesIO(content))
|
||||
|
||||
# Basic info
|
||||
info = {
|
||||
"format": img.format,
|
||||
"mode": img.mode,
|
||||
"size": img.size,
|
||||
"width": img.width,
|
||||
"height": img.height,
|
||||
"total_pixels": img.width * img.height,
|
||||
"aspect_ratio": round(img.width / img.height, 3),
|
||||
"file_size": len(content),
|
||||
"has_alpha": img.mode in ("RGBA", "LA") or "transparency" in img.info
|
||||
}
|
||||
|
||||
# Color analysis
|
||||
if img.mode == "RGB" or img.mode == "RGBA":
|
||||
img_rgb = img.convert("RGB")
|
||||
stat = ImageStat.Stat(img_rgb)
|
||||
|
||||
info.update({
|
||||
"brightness": round(sum(stat.mean) / 3, 2),
|
||||
"color_variance": round(sum(stat.stddev) / 3, 2),
|
||||
"dominant_colors": _extract_dominant_colors(img_rgb)
|
||||
})
|
||||
|
||||
# Quality assessment
|
||||
info["quality_assessment"] = _assess_image_quality(img)
|
||||
|
||||
return info
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Error analyzing image: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def validate_dimensions(width: int, height: int, operation: str) -> None:
|
||||
"""Validate image dimensions for specific operation.
|
||||
|
||||
Args:
|
||||
width: Image width
|
||||
height: Image height
|
||||
operation: Operation type
|
||||
"""
|
||||
from config.stability_config import IMAGE_LIMITS
|
||||
|
||||
limits = IMAGE_LIMITS.get(operation, IMAGE_LIMITS["generate"])
|
||||
total_pixels = width * height
|
||||
|
||||
if "min_pixels" in limits and total_pixels < limits["min_pixels"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Image must have at least {limits['min_pixels']} pixels for {operation}"
|
||||
)
|
||||
|
||||
if "max_pixels" in limits and total_pixels > limits["max_pixels"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Image must have at most {limits['max_pixels']} pixels for {operation}"
|
||||
)
|
||||
|
||||
if "min_dimension" in limits:
|
||||
min_dim = limits["min_dimension"]
|
||||
if width < min_dim or height < min_dim:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Both dimensions must be at least {min_dim} pixels for {operation}"
|
||||
)
|
||||
|
||||
|
||||
class AudioValidator:
|
||||
"""Validator for audio files and parameters."""
|
||||
|
||||
@staticmethod
|
||||
def validate_audio_file(file: UploadFile) -> Dict[str, Any]:
|
||||
"""Validate uploaded audio file.
|
||||
|
||||
Args:
|
||||
file: Uploaded file
|
||||
|
||||
Returns:
|
||||
Validation result with file info
|
||||
"""
|
||||
if not file.content_type or not file.content_type.startswith('audio/'):
|
||||
raise HTTPException(status_code=400, detail="File must be an audio file")
|
||||
|
||||
# Check file extension
|
||||
allowed_extensions = ['.mp3', '.wav']
|
||||
if file.filename:
|
||||
ext = '.' + file.filename.split('.')[-1].lower()
|
||||
if ext not in allowed_extensions:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported audio format. Allowed: {allowed_extensions}"
|
||||
)
|
||||
|
||||
return {
|
||||
"filename": file.filename,
|
||||
"content_type": file.content_type,
|
||||
"is_valid": True
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def analyze_audio_content(content: bytes) -> Dict[str, Any]:
|
||||
"""Analyze audio content and characteristics.
|
||||
|
||||
Args:
|
||||
content: Audio bytes
|
||||
|
||||
Returns:
|
||||
Audio analysis results
|
||||
"""
|
||||
try:
|
||||
# Basic info
|
||||
info = {
|
||||
"file_size": len(content),
|
||||
"format": "unknown" # Would need audio library to detect
|
||||
}
|
||||
|
||||
# For actual implementation, you'd use libraries like librosa or pydub
|
||||
# to analyze audio characteristics like duration, sample rate, etc.
|
||||
|
||||
return info
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Error analyzing audio: {str(e)}")
|
||||
|
||||
|
||||
class PromptOptimizer:
|
||||
"""Optimizer for text prompts."""
|
||||
|
||||
@staticmethod
|
||||
def analyze_prompt(prompt: str) -> Dict[str, Any]:
|
||||
"""Analyze prompt structure and content.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt
|
||||
|
||||
Returns:
|
||||
Prompt analysis
|
||||
"""
|
||||
words = prompt.split()
|
||||
|
||||
analysis = {
|
||||
"length": len(prompt),
|
||||
"word_count": len(words),
|
||||
"sentence_count": len([s for s in prompt.split('.') if s.strip()]),
|
||||
"has_style_descriptors": _has_style_descriptors(prompt),
|
||||
"has_quality_terms": _has_quality_terms(prompt),
|
||||
"has_technical_terms": _has_technical_terms(prompt),
|
||||
"complexity_score": _calculate_complexity_score(prompt)
|
||||
}
|
||||
|
||||
return analysis
|
||||
|
||||
@staticmethod
|
||||
def optimize_prompt(
|
||||
prompt: str,
|
||||
target_model: str = "ultra",
|
||||
target_style: Optional[str] = None,
|
||||
quality_level: str = "high"
|
||||
) -> Dict[str, Any]:
|
||||
"""Optimize prompt for better results.
|
||||
|
||||
Args:
|
||||
prompt: Original prompt
|
||||
target_model: Target model
|
||||
target_style: Target style
|
||||
quality_level: Desired quality level
|
||||
|
||||
Returns:
|
||||
Optimization results
|
||||
"""
|
||||
optimizations = []
|
||||
optimized_prompt = prompt.strip()
|
||||
|
||||
# Add style if not present
|
||||
if target_style and not _has_style_descriptors(prompt):
|
||||
optimized_prompt += f", {target_style} style"
|
||||
optimizations.append(f"Added style: {target_style}")
|
||||
|
||||
# Add quality terms if needed
|
||||
if quality_level == "high" and not _has_quality_terms(prompt):
|
||||
optimized_prompt += ", high quality, detailed, sharp"
|
||||
optimizations.append("Added quality enhancers")
|
||||
|
||||
# Model-specific optimizations
|
||||
if target_model == "ultra":
|
||||
if len(prompt.split()) < 10:
|
||||
optimized_prompt += ", professional photography, detailed composition"
|
||||
optimizations.append("Added detail for Ultra model")
|
||||
elif target_model == "core":
|
||||
# Keep concise for Core model
|
||||
if len(prompt.split()) > 30:
|
||||
optimizations.append("Consider shortening prompt for Core model")
|
||||
|
||||
return {
|
||||
"original_prompt": prompt,
|
||||
"optimized_prompt": optimized_prompt,
|
||||
"optimizations_applied": optimizations,
|
||||
"improvement_estimate": len(optimizations) * 15 # Rough percentage
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def generate_negative_prompt(
|
||||
prompt: str,
|
||||
style: Optional[str] = None
|
||||
) -> str:
|
||||
"""Generate appropriate negative prompt.
|
||||
|
||||
Args:
|
||||
prompt: Original prompt
|
||||
style: Target style
|
||||
|
||||
Returns:
|
||||
Suggested negative prompt
|
||||
"""
|
||||
base_negative = "blurry, low quality, distorted, deformed, pixelated"
|
||||
|
||||
# Add style-specific negatives
|
||||
if style:
|
||||
if "photographic" in style.lower():
|
||||
base_negative += ", cartoon, anime, illustration"
|
||||
elif "anime" in style.lower():
|
||||
base_negative += ", realistic, photographic"
|
||||
elif "art" in style.lower():
|
||||
base_negative += ", photograph, realistic"
|
||||
|
||||
# Add content-specific negatives based on prompt
|
||||
if "person" in prompt.lower() or "human" in prompt.lower():
|
||||
base_negative += ", extra limbs, malformed hands, duplicate"
|
||||
|
||||
return base_negative
|
||||
|
||||
|
||||
class FileManager:
|
||||
"""Manager for file operations and caching."""
|
||||
|
||||
@staticmethod
|
||||
async def save_result(
|
||||
content: bytes,
|
||||
filename: str,
|
||||
operation: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""Save generation result to file.
|
||||
|
||||
Args:
|
||||
content: File content
|
||||
filename: Filename
|
||||
operation: Operation type
|
||||
metadata: Optional metadata
|
||||
|
||||
Returns:
|
||||
File path
|
||||
"""
|
||||
# Create directory structure
|
||||
base_dir = "generated_content"
|
||||
operation_dir = os.path.join(base_dir, operation)
|
||||
date_dir = os.path.join(operation_dir, datetime.now().strftime("%Y/%m/%d"))
|
||||
|
||||
os.makedirs(date_dir, exist_ok=True)
|
||||
|
||||
# Generate unique filename
|
||||
timestamp = datetime.now().strftime("%H%M%S")
|
||||
file_hash = hashlib.md5(content).hexdigest()[:8]
|
||||
unique_filename = f"{timestamp}_{file_hash}_{filename}"
|
||||
|
||||
file_path = os.path.join(date_dir, unique_filename)
|
||||
|
||||
# Save file
|
||||
async with aiofiles.open(file_path, 'wb') as f:
|
||||
await f.write(content)
|
||||
|
||||
# Save metadata if provided
|
||||
if metadata:
|
||||
metadata_path = file_path + ".json"
|
||||
async with aiofiles.open(metadata_path, 'w') as f:
|
||||
await f.write(json.dumps(metadata, indent=2))
|
||||
|
||||
return file_path
|
||||
|
||||
@staticmethod
|
||||
def generate_cache_key(operation: str, parameters: Dict[str, Any]) -> str:
|
||||
"""Generate cache key for operation and parameters.
|
||||
|
||||
Args:
|
||||
operation: Operation type
|
||||
parameters: Operation parameters
|
||||
|
||||
Returns:
|
||||
Cache key
|
||||
"""
|
||||
# Create deterministic hash from operation and parameters
|
||||
key_data = f"{operation}:{json.dumps(parameters, sort_keys=True)}"
|
||||
return hashlib.sha256(key_data.encode()).hexdigest()
|
||||
|
||||
|
||||
class ResponseFormatter:
|
||||
"""Formatter for API responses."""
|
||||
|
||||
@staticmethod
|
||||
def format_image_response(
|
||||
content: bytes,
|
||||
output_format: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Format image response with metadata.
|
||||
|
||||
Args:
|
||||
content: Image content
|
||||
output_format: Output format
|
||||
metadata: Optional metadata
|
||||
|
||||
Returns:
|
||||
Formatted response
|
||||
"""
|
||||
response = {
|
||||
"image": base64.b64encode(content).decode(),
|
||||
"format": output_format,
|
||||
"size": len(content),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
if metadata:
|
||||
response["metadata"] = metadata
|
||||
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def format_audio_response(
|
||||
content: bytes,
|
||||
output_format: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Format audio response with metadata.
|
||||
|
||||
Args:
|
||||
content: Audio content
|
||||
output_format: Output format
|
||||
metadata: Optional metadata
|
||||
|
||||
Returns:
|
||||
Formatted response
|
||||
"""
|
||||
response = {
|
||||
"audio": base64.b64encode(content).decode(),
|
||||
"format": output_format,
|
||||
"size": len(content),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
if metadata:
|
||||
response["metadata"] = metadata
|
||||
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def format_3d_response(
|
||||
content: bytes,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Format 3D model response with metadata.
|
||||
|
||||
Args:
|
||||
content: 3D model content (GLB)
|
||||
metadata: Optional metadata
|
||||
|
||||
Returns:
|
||||
Formatted response
|
||||
"""
|
||||
response = {
|
||||
"model": base64.b64encode(content).decode(),
|
||||
"format": "glb",
|
||||
"size": len(content),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
if metadata:
|
||||
response["metadata"] = metadata
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class ParameterValidator:
|
||||
"""Validator for operation parameters."""
|
||||
|
||||
@staticmethod
|
||||
def validate_seed(seed: Optional[int]) -> int:
|
||||
"""Validate and normalize seed parameter.
|
||||
|
||||
Args:
|
||||
seed: Seed value
|
||||
|
||||
Returns:
|
||||
Valid seed value
|
||||
"""
|
||||
if seed is None:
|
||||
return 0
|
||||
|
||||
if not isinstance(seed, int) or seed < 0 or seed > 4294967294:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Seed must be an integer between 0 and 4294967294"
|
||||
)
|
||||
|
||||
return seed
|
||||
|
||||
@staticmethod
|
||||
def validate_strength(strength: Optional[float], operation: str) -> Optional[float]:
|
||||
"""Validate strength parameter for different operations.
|
||||
|
||||
Args:
|
||||
strength: Strength value
|
||||
operation: Operation type
|
||||
|
||||
Returns:
|
||||
Valid strength value
|
||||
"""
|
||||
if strength is None:
|
||||
return None
|
||||
|
||||
if not isinstance(strength, (int, float)) or strength < 0 or strength > 1:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Strength must be a float between 0 and 1"
|
||||
)
|
||||
|
||||
# Operation-specific validation
|
||||
if operation == "audio_to_audio" and strength < 0.01:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Minimum strength for audio-to-audio is 0.01"
|
||||
)
|
||||
|
||||
return float(strength)
|
||||
|
||||
@staticmethod
|
||||
def validate_creativity(creativity: Optional[float], operation: str) -> Optional[float]:
|
||||
"""Validate creativity parameter.
|
||||
|
||||
Args:
|
||||
creativity: Creativity value
|
||||
operation: Operation type
|
||||
|
||||
Returns:
|
||||
Valid creativity value
|
||||
"""
|
||||
if creativity is None:
|
||||
return None
|
||||
|
||||
# Different operations have different creativity ranges
|
||||
ranges = {
|
||||
"upscale": (0.1, 0.5),
|
||||
"outpaint": (0, 1),
|
||||
"conservative_upscale": (0.2, 0.5)
|
||||
}
|
||||
|
||||
min_val, max_val = ranges.get(operation, (0, 1))
|
||||
|
||||
if not isinstance(creativity, (int, float)) or creativity < min_val or creativity > max_val:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Creativity for {operation} must be between {min_val} and {max_val}"
|
||||
)
|
||||
|
||||
return float(creativity)
|
||||
|
||||
|
||||
class WorkflowManager:
|
||||
"""Manager for complex workflows and pipelines."""
|
||||
|
||||
@staticmethod
|
||||
def validate_workflow(workflow: List[Dict[str, Any]]) -> List[str]:
|
||||
"""Validate workflow steps.
|
||||
|
||||
Args:
|
||||
workflow: List of workflow steps
|
||||
|
||||
Returns:
|
||||
List of validation errors
|
||||
"""
|
||||
errors = []
|
||||
supported_operations = [
|
||||
"generate_ultra", "generate_core", "generate_sd3",
|
||||
"upscale_fast", "upscale_conservative", "upscale_creative",
|
||||
"inpaint", "outpaint", "erase", "search_and_replace",
|
||||
"control_sketch", "control_structure", "control_style"
|
||||
]
|
||||
|
||||
for i, step in enumerate(workflow):
|
||||
if "operation" not in step:
|
||||
errors.append(f"Step {i+1}: Missing 'operation' field")
|
||||
continue
|
||||
|
||||
operation = step["operation"]
|
||||
if operation not in supported_operations:
|
||||
errors.append(f"Step {i+1}: Unsupported operation '{operation}'")
|
||||
|
||||
# Validate step dependencies
|
||||
if i > 0 and operation.startswith("generate_") and i > 0:
|
||||
errors.append(f"Step {i+1}: Generate operations should be first in workflow")
|
||||
|
||||
return errors
|
||||
|
||||
@staticmethod
|
||||
def optimize_workflow(workflow: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Optimize workflow for better performance.
|
||||
|
||||
Args:
|
||||
workflow: Original workflow
|
||||
|
||||
Returns:
|
||||
Optimized workflow
|
||||
"""
|
||||
optimized = workflow.copy()
|
||||
|
||||
# Remove redundant operations
|
||||
operations_seen = set()
|
||||
filtered_workflow = []
|
||||
|
||||
for step in optimized:
|
||||
operation = step["operation"]
|
||||
if operation not in operations_seen or operation.startswith("generate_"):
|
||||
filtered_workflow.append(step)
|
||||
operations_seen.add(operation)
|
||||
|
||||
# Reorder for optimal execution
|
||||
# Generation operations first, then modifications, then upscaling
|
||||
order_priority = {
|
||||
"generate": 0,
|
||||
"control": 1,
|
||||
"edit": 2,
|
||||
"upscale": 3
|
||||
}
|
||||
|
||||
def get_priority(step):
|
||||
operation = step["operation"]
|
||||
for key, priority in order_priority.items():
|
||||
if operation.startswith(key):
|
||||
return priority
|
||||
return 999
|
||||
|
||||
filtered_workflow.sort(key=get_priority)
|
||||
|
||||
return filtered_workflow
|
||||
|
||||
|
||||
# ==================== HELPER FUNCTIONS ====================
|
||||
|
||||
def _extract_dominant_colors(img: Image.Image, num_colors: int = 5) -> List[Tuple[int, int, int]]:
|
||||
"""Extract dominant colors from image.
|
||||
|
||||
Args:
|
||||
img: PIL Image
|
||||
num_colors: Number of dominant colors to extract
|
||||
|
||||
Returns:
|
||||
List of RGB tuples
|
||||
"""
|
||||
# Resize image for faster processing
|
||||
img_small = img.resize((150, 150))
|
||||
|
||||
# Convert to numpy array
|
||||
img_array = np.array(img_small)
|
||||
pixels = img_array.reshape(-1, 3)
|
||||
|
||||
# Use k-means clustering to find dominant colors
|
||||
from sklearn.cluster import KMeans
|
||||
|
||||
kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10)
|
||||
kmeans.fit(pixels)
|
||||
|
||||
colors = kmeans.cluster_centers_.astype(int)
|
||||
return [tuple(color) for color in colors]
|
||||
|
||||
|
||||
def _assess_image_quality(img: Image.Image) -> Dict[str, Any]:
|
||||
"""Assess image quality metrics.
|
||||
|
||||
Args:
|
||||
img: PIL Image
|
||||
|
||||
Returns:
|
||||
Quality assessment
|
||||
"""
|
||||
# Convert to grayscale for quality analysis
|
||||
gray = img.convert('L')
|
||||
gray_array = np.array(gray)
|
||||
|
||||
# Calculate sharpness using Laplacian variance
|
||||
laplacian_var = np.var(np.gradient(gray_array))
|
||||
sharpness_score = min(100, laplacian_var / 100)
|
||||
|
||||
# Calculate noise level
|
||||
noise_level = np.std(gray_array)
|
||||
|
||||
# Overall quality score
|
||||
overall_score = (sharpness_score + max(0, 100 - noise_level)) / 2
|
||||
|
||||
return {
|
||||
"sharpness_score": round(sharpness_score, 2),
|
||||
"noise_level": round(noise_level, 2),
|
||||
"overall_score": round(overall_score, 2),
|
||||
"needs_enhancement": overall_score < 70
|
||||
}
|
||||
|
||||
|
||||
def _has_style_descriptors(prompt: str) -> bool:
|
||||
"""Check if prompt contains style descriptors."""
|
||||
style_keywords = [
|
||||
"photorealistic", "realistic", "anime", "cartoon", "digital art",
|
||||
"oil painting", "watercolor", "sketch", "illustration", "3d render",
|
||||
"cinematic", "artistic", "professional"
|
||||
]
|
||||
return any(keyword in prompt.lower() for keyword in style_keywords)
|
||||
|
||||
|
||||
def _has_quality_terms(prompt: str) -> bool:
|
||||
"""Check if prompt contains quality terms."""
|
||||
quality_keywords = [
|
||||
"high quality", "detailed", "sharp", "crisp", "clear",
|
||||
"professional", "masterpiece", "award winning"
|
||||
]
|
||||
return any(keyword in prompt.lower() for keyword in quality_keywords)
|
||||
|
||||
|
||||
def _has_technical_terms(prompt: str) -> bool:
|
||||
"""Check if prompt contains technical photography terms."""
|
||||
technical_keywords = [
|
||||
"bokeh", "depth of field", "macro", "wide angle", "telephoto",
|
||||
"iso", "aperture", "shutter speed", "lighting", "composition"
|
||||
]
|
||||
return any(keyword in prompt.lower() for keyword in technical_keywords)
|
||||
|
||||
|
||||
def _calculate_complexity_score(prompt: str) -> float:
|
||||
"""Calculate prompt complexity score.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt
|
||||
|
||||
Returns:
|
||||
Complexity score (0-100)
|
||||
"""
|
||||
words = prompt.split()
|
||||
|
||||
# Base score from word count
|
||||
base_score = min(len(words) * 2, 50)
|
||||
|
||||
# Add points for descriptive elements
|
||||
if _has_style_descriptors(prompt):
|
||||
base_score += 15
|
||||
if _has_quality_terms(prompt):
|
||||
base_score += 10
|
||||
if _has_technical_terms(prompt):
|
||||
base_score += 15
|
||||
|
||||
# Add points for specific details
|
||||
if any(word in prompt.lower() for word in ["color", "lighting", "composition"]):
|
||||
base_score += 10
|
||||
|
||||
return min(base_score, 100)
|
||||
|
||||
|
||||
def create_batch_manifest(
|
||||
operation: str,
|
||||
files: List[UploadFile],
|
||||
parameters: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Create manifest for batch processing.
|
||||
|
||||
Args:
|
||||
operation: Operation type
|
||||
files: List of files to process
|
||||
parameters: Operation parameters
|
||||
|
||||
Returns:
|
||||
Batch manifest
|
||||
"""
|
||||
return {
|
||||
"batch_id": f"batch_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}",
|
||||
"operation": operation,
|
||||
"file_count": len(files),
|
||||
"files": [{"filename": f.filename, "size": f.size} for f in files],
|
||||
"parameters": parameters,
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"estimated_duration": len(files) * 30, # 30 seconds per file estimate
|
||||
"estimated_cost": len(files) * _get_operation_cost(operation)
|
||||
}
|
||||
|
||||
|
||||
def _get_operation_cost(operation: str) -> float:
|
||||
"""Get estimated cost for operation.
|
||||
|
||||
Args:
|
||||
operation: Operation type
|
||||
|
||||
Returns:
|
||||
Estimated cost in credits
|
||||
"""
|
||||
from config.stability_config import MODEL_PRICING
|
||||
|
||||
# Map operation to pricing category
|
||||
if operation.startswith("generate_"):
|
||||
return MODEL_PRICING["generate"].get("core", 3) # Default to core
|
||||
elif operation.startswith("upscale_"):
|
||||
upscale_type = operation.replace("upscale_", "")
|
||||
return MODEL_PRICING["upscale"].get(upscale_type, 5)
|
||||
elif operation.startswith("control_"):
|
||||
return MODEL_PRICING["control"].get("sketch", 5) # Default
|
||||
else:
|
||||
return 5 # Default cost
|
||||
|
||||
|
||||
def validate_file_size(file: UploadFile, max_size: int = 10 * 1024 * 1024) -> None:
|
||||
"""Validate file size.
|
||||
|
||||
Args:
|
||||
file: Uploaded file
|
||||
max_size: Maximum allowed size in bytes
|
||||
"""
|
||||
if file.size and file.size > max_size:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File size ({file.size} bytes) exceeds maximum allowed size ({max_size} bytes)"
|
||||
)
|
||||
|
||||
|
||||
async def convert_image_format(content: bytes, target_format: str) -> bytes:
|
||||
"""Convert image to target format.
|
||||
|
||||
Args:
|
||||
content: Image content
|
||||
target_format: Target format (jpeg, png, webp)
|
||||
|
||||
Returns:
|
||||
Converted image bytes
|
||||
"""
|
||||
try:
|
||||
img = Image.open(io.BytesIO(content))
|
||||
|
||||
# Convert to RGB if saving as JPEG
|
||||
if target_format.lower() == "jpeg" and img.mode in ("RGBA", "LA"):
|
||||
img = img.convert("RGB")
|
||||
|
||||
output = io.BytesIO()
|
||||
img.save(output, format=target_format.upper())
|
||||
return output.getvalue()
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Error converting image: {str(e)}")
|
||||
|
||||
|
||||
def estimate_processing_time(
|
||||
operation: str,
|
||||
file_size: int,
|
||||
complexity: Optional[Dict[str, Any]] = None
|
||||
) -> float:
|
||||
"""Estimate processing time for operation.
|
||||
|
||||
Args:
|
||||
operation: Operation type
|
||||
file_size: File size in bytes
|
||||
complexity: Optional complexity metrics
|
||||
|
||||
Returns:
|
||||
Estimated time in seconds
|
||||
"""
|
||||
# Base times by operation (in seconds)
|
||||
base_times = {
|
||||
"generate_ultra": 15,
|
||||
"generate_core": 5,
|
||||
"generate_sd3": 10,
|
||||
"upscale_fast": 2,
|
||||
"upscale_conservative": 30,
|
||||
"upscale_creative": 60,
|
||||
"inpaint": 10,
|
||||
"outpaint": 15,
|
||||
"control_sketch": 8,
|
||||
"control_structure": 8,
|
||||
"control_style": 10,
|
||||
"3d_fast": 10,
|
||||
"3d_point_aware": 20,
|
||||
"audio_text": 30,
|
||||
"audio_transform": 45
|
||||
}
|
||||
|
||||
base_time = base_times.get(operation, 10)
|
||||
|
||||
# Adjust for file size
|
||||
size_factor = max(1, file_size / (1024 * 1024)) # Size in MB
|
||||
adjusted_time = base_time * size_factor
|
||||
|
||||
# Adjust for complexity if provided
|
||||
if complexity and complexity.get("complexity_score", 0) > 80:
|
||||
adjusted_time *= 1.5
|
||||
|
||||
return round(adjusted_time, 1)
|
||||
Reference in New Issue
Block a user