Add Stability AI integration with comprehensive endpoints and features
Co-authored-by: ajay.calsoft <ajay.calsoft@gmail.com>
This commit is contained in:
1166
backend/routers/stability.py
Normal file
1166
backend/routers/stability.py
Normal file
File diff suppressed because it is too large
Load Diff
737
backend/routers/stability_admin.py
Normal file
737
backend/routers/stability_admin.py
Normal file
@@ -0,0 +1,737 @@
|
||||
"""Admin endpoints for Stability AI service management."""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
|
||||
from services.stability_service import get_stability_service, StabilityAIService
|
||||
from middleware.stability_middleware import get_middleware_stats
|
||||
from config.stability_config import (
|
||||
MODEL_PRICING, IMAGE_LIMITS, AUDIO_LIMITS, WORKFLOW_TEMPLATES,
|
||||
get_stability_config, get_model_recommendations, calculate_estimated_cost
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/stability/admin", tags=["Stability AI Admin"])
|
||||
|
||||
|
||||
# ==================== MONITORING ENDPOINTS ====================
|
||||
|
||||
@router.get("/stats", summary="Get Service Statistics")
|
||||
async def get_service_stats():
|
||||
"""Get comprehensive statistics about Stability AI service usage."""
|
||||
return {
|
||||
"service_info": {
|
||||
"name": "Stability AI Integration",
|
||||
"version": "1.0.0",
|
||||
"uptime": "N/A", # Would track actual uptime
|
||||
"last_restart": datetime.utcnow().isoformat()
|
||||
},
|
||||
"middleware_stats": get_middleware_stats(),
|
||||
"pricing_info": MODEL_PRICING,
|
||||
"limits": {
|
||||
"image": IMAGE_LIMITS,
|
||||
"audio": AUDIO_LIMITS
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.get("/health/detailed", summary="Detailed Health Check")
|
||||
async def detailed_health_check(
|
||||
stability_service: StabilityAIService = Depends(get_stability_service)
|
||||
):
|
||||
"""Perform detailed health check of Stability AI service."""
|
||||
health_status = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"overall_status": "healthy",
|
||||
"checks": {}
|
||||
}
|
||||
|
||||
try:
|
||||
# Test API connectivity
|
||||
async with stability_service:
|
||||
account_info = await stability_service.get_account_details()
|
||||
health_status["checks"]["api_connectivity"] = {
|
||||
"status": "healthy",
|
||||
"response_time": "N/A",
|
||||
"account_id": account_info.get("id", "unknown")
|
||||
}
|
||||
except Exception as e:
|
||||
health_status["checks"]["api_connectivity"] = {
|
||||
"status": "unhealthy",
|
||||
"error": str(e)
|
||||
}
|
||||
health_status["overall_status"] = "degraded"
|
||||
|
||||
try:
|
||||
# Test account balance
|
||||
async with stability_service:
|
||||
balance_info = await stability_service.get_account_balance()
|
||||
credits = balance_info.get("credits", 0)
|
||||
|
||||
health_status["checks"]["account_balance"] = {
|
||||
"status": "healthy" if credits > 10 else "warning",
|
||||
"credits": credits,
|
||||
"warning": "Low credit balance" if credits < 10 else None
|
||||
}
|
||||
except Exception as e:
|
||||
health_status["checks"]["account_balance"] = {
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
# Check configuration
|
||||
try:
|
||||
config = get_stability_config()
|
||||
health_status["checks"]["configuration"] = {
|
||||
"status": "healthy",
|
||||
"api_key_configured": bool(config.api_key),
|
||||
"base_url": config.base_url
|
||||
}
|
||||
except Exception as e:
|
||||
health_status["checks"]["configuration"] = {
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
health_status["overall_status"] = "unhealthy"
|
||||
|
||||
return health_status
|
||||
|
||||
|
||||
@router.get("/usage/summary", summary="Get Usage Summary")
|
||||
async def get_usage_summary(
|
||||
days: Optional[int] = Query(7, description="Number of days to analyze")
|
||||
):
|
||||
"""Get usage summary for the specified time period."""
|
||||
# In a real implementation, this would query a database
|
||||
# For now, return mock data
|
||||
|
||||
end_date = datetime.utcnow()
|
||||
start_date = end_date - timedelta(days=days)
|
||||
|
||||
return {
|
||||
"period": {
|
||||
"start": start_date.isoformat(),
|
||||
"end": end_date.isoformat(),
|
||||
"days": days
|
||||
},
|
||||
"usage_summary": {
|
||||
"total_requests": 156,
|
||||
"successful_requests": 148,
|
||||
"failed_requests": 8,
|
||||
"success_rate": 94.87,
|
||||
"total_credits_used": 450.5,
|
||||
"average_credits_per_request": 2.89
|
||||
},
|
||||
"operation_breakdown": {
|
||||
"generate_ultra": {"requests": 25, "credits": 200},
|
||||
"generate_core": {"requests": 45, "credits": 135},
|
||||
"upscale_fast": {"requests": 30, "credits": 60},
|
||||
"inpaint": {"requests": 20, "credits": 100},
|
||||
"control_sketch": {"requests": 15, "credits": 75}
|
||||
},
|
||||
"daily_usage": [
|
||||
{"date": (end_date - timedelta(days=i)).strftime("%Y-%m-%d"),
|
||||
"requests": 20 + i * 2,
|
||||
"credits": 50 + i * 5}
|
||||
for i in range(days)
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/costs/estimate", summary="Estimate Operation Costs")
|
||||
async def estimate_operation_costs(
|
||||
operations: str = Query(..., description="JSON array of operations to estimate"),
|
||||
model_preferences: Optional[str] = Query(None, description="JSON object of model preferences")
|
||||
):
|
||||
"""Estimate costs for a list of operations."""
|
||||
try:
|
||||
ops_list = json.loads(operations)
|
||||
preferences = json.loads(model_preferences) if model_preferences else {}
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON in parameters")
|
||||
|
||||
estimates = []
|
||||
total_cost = 0
|
||||
|
||||
for op in ops_list:
|
||||
operation = op.get("operation")
|
||||
model = preferences.get(operation) or op.get("model")
|
||||
steps = op.get("steps")
|
||||
|
||||
cost = calculate_estimated_cost(operation, model, steps)
|
||||
total_cost += cost
|
||||
|
||||
estimates.append({
|
||||
"operation": operation,
|
||||
"model": model,
|
||||
"estimated_credits": cost,
|
||||
"description": f"Estimated cost for {operation}"
|
||||
})
|
||||
|
||||
return {
|
||||
"estimates": estimates,
|
||||
"total_estimated_credits": total_cost,
|
||||
"currency_equivalent": f"${total_cost * 0.01:.2f}", # Assuming $0.01 per credit
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
# ==================== CONFIGURATION ENDPOINTS ====================
|
||||
|
||||
@router.get("/config", summary="Get Current Configuration")
|
||||
async def get_current_config():
|
||||
"""Get current Stability AI service configuration."""
|
||||
try:
|
||||
config = get_stability_config()
|
||||
return {
|
||||
"base_url": config.base_url,
|
||||
"timeout": config.timeout,
|
||||
"max_retries": config.max_retries,
|
||||
"max_file_size": config.max_file_size,
|
||||
"supported_image_formats": config.supported_image_formats,
|
||||
"supported_audio_formats": config.supported_audio_formats,
|
||||
"api_key_configured": bool(config.api_key),
|
||||
"api_key_preview": f"{config.api_key[:8]}..." if config.api_key else None
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Configuration error: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/models/recommendations", summary="Get Model Recommendations")
|
||||
async def get_model_recommendations_endpoint(
|
||||
use_case: str = Query(..., description="Use case (portrait, landscape, art, product, concept)"),
|
||||
quality_preference: str = Query("standard", description="Quality preference (draft, standard, premium)"),
|
||||
speed_preference: str = Query("balanced", description="Speed preference (fast, balanced, quality)")
|
||||
):
|
||||
"""Get model recommendations based on use case and preferences."""
|
||||
recommendations = get_model_recommendations(use_case, quality_preference, speed_preference)
|
||||
|
||||
# Add detailed information
|
||||
recommendations["use_case_info"] = {
|
||||
"description": f"Recommendations optimized for {use_case} use case",
|
||||
"quality_level": quality_preference,
|
||||
"speed_priority": speed_preference
|
||||
}
|
||||
|
||||
# Add cost information
|
||||
primary_cost = calculate_estimated_cost("generate", recommendations["primary"])
|
||||
alternative_cost = calculate_estimated_cost("generate", recommendations["alternative"])
|
||||
|
||||
recommendations["cost_comparison"] = {
|
||||
"primary_model_cost": primary_cost,
|
||||
"alternative_model_cost": alternative_cost,
|
||||
"cost_difference": abs(primary_cost - alternative_cost)
|
||||
}
|
||||
|
||||
return recommendations
|
||||
|
||||
|
||||
@router.get("/workflows/templates", summary="Get Workflow Templates")
|
||||
async def get_workflow_templates():
|
||||
"""Get available workflow templates."""
|
||||
return {
|
||||
"templates": WORKFLOW_TEMPLATES,
|
||||
"template_count": len(WORKFLOW_TEMPLATES),
|
||||
"categories": list(set(
|
||||
template["description"].split()[0].lower()
|
||||
for template in WORKFLOW_TEMPLATES.values()
|
||||
))
|
||||
}
|
||||
|
||||
|
||||
@router.post("/workflows/validate", summary="Validate Custom Workflow")
|
||||
async def validate_custom_workflow(
|
||||
workflow: dict
|
||||
):
|
||||
"""Validate a custom workflow configuration."""
|
||||
from utils.stability_utils import WorkflowManager
|
||||
|
||||
steps = workflow.get("steps", [])
|
||||
|
||||
if not steps:
|
||||
raise HTTPException(status_code=400, detail="Workflow must contain at least one step")
|
||||
|
||||
# Validate workflow
|
||||
errors = WorkflowManager.validate_workflow(steps)
|
||||
|
||||
if errors:
|
||||
return {
|
||||
"is_valid": False,
|
||||
"errors": errors,
|
||||
"workflow": workflow
|
||||
}
|
||||
|
||||
# Calculate estimated cost and time
|
||||
total_cost = sum(calculate_estimated_cost(step.get("operation", "unknown")) for step in steps)
|
||||
estimated_time = len(steps) * 30 # Rough estimate
|
||||
|
||||
# Optimize workflow
|
||||
optimized_steps = WorkflowManager.optimize_workflow(steps)
|
||||
|
||||
return {
|
||||
"is_valid": True,
|
||||
"original_workflow": workflow,
|
||||
"optimized_workflow": {"steps": optimized_steps},
|
||||
"estimates": {
|
||||
"total_credits": total_cost,
|
||||
"estimated_time_seconds": estimated_time,
|
||||
"step_count": len(steps)
|
||||
},
|
||||
"optimizations_applied": len(steps) != len(optimized_steps)
|
||||
}
|
||||
|
||||
|
||||
# ==================== CACHE MANAGEMENT ====================
|
||||
|
||||
@router.post("/cache/clear", summary="Clear Service Cache")
|
||||
async def clear_cache():
|
||||
"""Clear all cached data."""
|
||||
from middleware.stability_middleware import caching
|
||||
|
||||
caching.clear_cache()
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Cache cleared successfully",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/cache/stats", summary="Get Cache Statistics")
|
||||
async def get_cache_stats():
|
||||
"""Get cache usage statistics."""
|
||||
from middleware.stability_middleware import caching
|
||||
|
||||
return {
|
||||
"cache_stats": caching.get_cache_stats(),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
# ==================== RATE LIMITING MANAGEMENT ====================
|
||||
|
||||
@router.get("/rate-limit/status", summary="Get Rate Limit Status")
|
||||
async def get_rate_limit_status():
|
||||
"""Get current rate limiting status."""
|
||||
from middleware.stability_middleware import rate_limiter
|
||||
|
||||
return {
|
||||
"rate_limit_config": {
|
||||
"requests_per_window": rate_limiter.requests_per_window,
|
||||
"window_seconds": rate_limiter.window_seconds
|
||||
},
|
||||
"current_blocks": len(rate_limiter.blocked_until),
|
||||
"active_clients": len(rate_limiter.request_times),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.post("/rate-limit/reset", summary="Reset Rate Limits")
|
||||
async def reset_rate_limits():
|
||||
"""Reset rate limiting for all clients (admin only)."""
|
||||
from middleware.stability_middleware import rate_limiter
|
||||
|
||||
# Clear all rate limiting data
|
||||
rate_limiter.request_times.clear()
|
||||
rate_limiter.blocked_until.clear()
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Rate limits reset for all clients",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
# ==================== ACCOUNT MANAGEMENT ====================
|
||||
|
||||
@router.get("/account/detailed", summary="Get Detailed Account Information")
|
||||
async def get_detailed_account_info(
|
||||
stability_service: StabilityAIService = Depends(get_stability_service)
|
||||
):
|
||||
"""Get detailed account information including usage and limits."""
|
||||
async with stability_service:
|
||||
account_info = await stability_service.get_account_details()
|
||||
balance_info = await stability_service.get_account_balance()
|
||||
engines_info = await stability_service.list_engines()
|
||||
|
||||
return {
|
||||
"account": account_info,
|
||||
"balance": balance_info,
|
||||
"available_engines": engines_info,
|
||||
"service_limits": {
|
||||
"rate_limit": "150 requests per 10 seconds",
|
||||
"max_file_size": "10MB for images, 50MB for audio",
|
||||
"result_storage": "24 hours for async generations"
|
||||
},
|
||||
"pricing": MODEL_PRICING,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
# ==================== DEBUGGING ENDPOINTS ====================
|
||||
|
||||
@router.post("/debug/test-connection", summary="Test API Connection")
|
||||
async def test_api_connection(
|
||||
stability_service: StabilityAIService = Depends(get_stability_service)
|
||||
):
|
||||
"""Test connection to Stability AI API."""
|
||||
test_results = {}
|
||||
|
||||
try:
|
||||
async with stability_service:
|
||||
# Test account endpoint
|
||||
start_time = datetime.utcnow()
|
||||
account_info = await stability_service.get_account_details()
|
||||
end_time = datetime.utcnow()
|
||||
|
||||
test_results["account_test"] = {
|
||||
"status": "success",
|
||||
"response_time_ms": (end_time - start_time).total_seconds() * 1000,
|
||||
"account_id": account_info.get("id")
|
||||
}
|
||||
except Exception as e:
|
||||
test_results["account_test"] = {
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
try:
|
||||
async with stability_service:
|
||||
# Test engines endpoint
|
||||
start_time = datetime.utcnow()
|
||||
engines = await stability_service.list_engines()
|
||||
end_time = datetime.utcnow()
|
||||
|
||||
test_results["engines_test"] = {
|
||||
"status": "success",
|
||||
"response_time_ms": (end_time - start_time).total_seconds() * 1000,
|
||||
"engine_count": len(engines)
|
||||
}
|
||||
except Exception as e:
|
||||
test_results["engines_test"] = {
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
overall_status = "healthy" if all(
|
||||
test["status"] == "success"
|
||||
for test in test_results.values()
|
||||
) else "unhealthy"
|
||||
|
||||
return {
|
||||
"overall_status": overall_status,
|
||||
"tests": test_results,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/debug/request-logs", summary="Get Recent Request Logs")
|
||||
async def get_request_logs(
|
||||
limit: int = Query(50, description="Maximum number of log entries to return"),
|
||||
operation_filter: Optional[str] = Query(None, description="Filter by operation type")
|
||||
):
|
||||
"""Get recent request logs for debugging."""
|
||||
from middleware.stability_middleware import request_logging
|
||||
|
||||
logs = request_logging.get_recent_logs(limit)
|
||||
|
||||
if operation_filter:
|
||||
logs = [
|
||||
log for log in logs
|
||||
if operation_filter in log.get("path", "")
|
||||
]
|
||||
|
||||
return {
|
||||
"logs": logs,
|
||||
"total_entries": len(logs),
|
||||
"filter_applied": operation_filter,
|
||||
"summary": request_logging.get_log_summary()
|
||||
}
|
||||
|
||||
|
||||
# ==================== MAINTENANCE ENDPOINTS ====================
|
||||
|
||||
@router.post("/maintenance/cleanup", summary="Cleanup Service Resources")
|
||||
async def cleanup_service_resources():
|
||||
"""Cleanup service resources and temporary files."""
|
||||
cleanup_results = {}
|
||||
|
||||
try:
|
||||
# Clear caches
|
||||
from middleware.stability_middleware import caching
|
||||
caching.clear_cache()
|
||||
cleanup_results["cache_cleanup"] = "success"
|
||||
except Exception as e:
|
||||
cleanup_results["cache_cleanup"] = f"error: {str(e)}"
|
||||
|
||||
try:
|
||||
# Clean up temporary files (if any)
|
||||
import os
|
||||
import glob
|
||||
|
||||
temp_files = glob.glob("/tmp/stability_*")
|
||||
removed_count = 0
|
||||
|
||||
for temp_file in temp_files:
|
||||
try:
|
||||
os.remove(temp_file)
|
||||
removed_count += 1
|
||||
except:
|
||||
pass
|
||||
|
||||
cleanup_results["temp_file_cleanup"] = f"removed {removed_count} files"
|
||||
except Exception as e:
|
||||
cleanup_results["temp_file_cleanup"] = f"error: {str(e)}"
|
||||
|
||||
return {
|
||||
"cleanup_results": cleanup_results,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.post("/maintenance/optimize", summary="Optimize Service Performance")
|
||||
async def optimize_service_performance():
|
||||
"""Optimize service performance by adjusting configurations."""
|
||||
optimizations = []
|
||||
|
||||
# Check and optimize cache settings
|
||||
from middleware.stability_middleware import caching
|
||||
cache_stats = caching.get_cache_stats()
|
||||
|
||||
if cache_stats["total_entries"] > 100:
|
||||
caching.clear_cache()
|
||||
optimizations.append("Cleared large cache to free memory")
|
||||
|
||||
# Check rate limiting efficiency
|
||||
from middleware.stability_middleware import rate_limiter
|
||||
if len(rate_limiter.blocked_until) > 10:
|
||||
# Reset old blocks
|
||||
import time
|
||||
current_time = time.time()
|
||||
expired_blocks = [
|
||||
client_id for client_id, block_time in rate_limiter.blocked_until.items()
|
||||
if current_time > block_time
|
||||
]
|
||||
|
||||
for client_id in expired_blocks:
|
||||
del rate_limiter.blocked_until[client_id]
|
||||
|
||||
optimizations.append(f"Cleared {len(expired_blocks)} expired rate limit blocks")
|
||||
|
||||
return {
|
||||
"optimizations_applied": optimizations,
|
||||
"optimization_count": len(optimizations),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
# ==================== FEATURE FLAGS ====================
|
||||
|
||||
@router.get("/features", summary="Get Feature Flags")
|
||||
async def get_feature_flags():
|
||||
"""Get current feature flag status."""
|
||||
from config.stability_config import FEATURE_FLAGS
|
||||
|
||||
return {
|
||||
"features": FEATURE_FLAGS,
|
||||
"enabled_count": sum(1 for enabled in FEATURE_FLAGS.values() if enabled),
|
||||
"total_features": len(FEATURE_FLAGS)
|
||||
}
|
||||
|
||||
|
||||
@router.post("/features/{feature_name}/toggle", summary="Toggle Feature Flag")
|
||||
async def toggle_feature_flag(feature_name: str):
|
||||
"""Toggle a feature flag on/off."""
|
||||
from config.stability_config import FEATURE_FLAGS
|
||||
|
||||
if feature_name not in FEATURE_FLAGS:
|
||||
raise HTTPException(status_code=404, detail=f"Feature '{feature_name}' not found")
|
||||
|
||||
# Toggle the feature
|
||||
FEATURE_FLAGS[feature_name] = not FEATURE_FLAGS[feature_name]
|
||||
|
||||
return {
|
||||
"feature": feature_name,
|
||||
"new_status": FEATURE_FLAGS[feature_name],
|
||||
"message": f"Feature '{feature_name}' {'enabled' if FEATURE_FLAGS[feature_name] else 'disabled'}",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
# ==================== EXPORT ENDPOINTS ====================
|
||||
|
||||
@router.get("/export/config", summary="Export Configuration")
|
||||
async def export_configuration():
|
||||
"""Export current service configuration."""
|
||||
config = get_stability_config()
|
||||
|
||||
export_data = {
|
||||
"service_config": {
|
||||
"base_url": config.base_url,
|
||||
"timeout": config.timeout,
|
||||
"max_retries": config.max_retries,
|
||||
"max_file_size": config.max_file_size
|
||||
},
|
||||
"pricing": MODEL_PRICING,
|
||||
"limits": {
|
||||
"image": IMAGE_LIMITS,
|
||||
"audio": AUDIO_LIMITS
|
||||
},
|
||||
"workflows": WORKFLOW_TEMPLATES,
|
||||
"export_timestamp": datetime.utcnow().isoformat(),
|
||||
"version": "1.0.0"
|
||||
}
|
||||
|
||||
return export_data
|
||||
|
||||
|
||||
@router.get("/export/usage-report", summary="Export Usage Report")
|
||||
async def export_usage_report(
|
||||
format_type: str = Query("json", description="Export format (json, csv)"),
|
||||
days: int = Query(30, description="Number of days to include")
|
||||
):
|
||||
"""Export detailed usage report."""
|
||||
# In a real implementation, this would query actual usage data
|
||||
|
||||
usage_data = {
|
||||
"report_info": {
|
||||
"generated_at": datetime.utcnow().isoformat(),
|
||||
"period_days": days,
|
||||
"format": format_type
|
||||
},
|
||||
"summary": {
|
||||
"total_requests": 500,
|
||||
"total_credits_used": 1250,
|
||||
"average_daily_usage": 41.67,
|
||||
"most_used_operation": "generate_core"
|
||||
},
|
||||
"detailed_usage": [
|
||||
{
|
||||
"date": (datetime.utcnow() - timedelta(days=i)).strftime("%Y-%m-%d"),
|
||||
"requests": 15 + (i % 5),
|
||||
"credits": 37.5 + (i % 5) * 2.5,
|
||||
"top_operation": "generate_core"
|
||||
}
|
||||
for i in range(days)
|
||||
]
|
||||
}
|
||||
|
||||
if format_type == "csv":
|
||||
# Convert to CSV format
|
||||
import csv
|
||||
import io
|
||||
|
||||
output = io.StringIO()
|
||||
writer = csv.DictWriter(output, fieldnames=["date", "requests", "credits", "top_operation"])
|
||||
writer.writeheader()
|
||||
writer.writerows(usage_data["detailed_usage"])
|
||||
|
||||
return Response(
|
||||
content=output.getvalue(),
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": f"attachment; filename=stability_usage_{days}days.csv"}
|
||||
)
|
||||
|
||||
return usage_data
|
||||
|
||||
|
||||
# ==================== SYSTEM INFO ENDPOINTS ====================
|
||||
|
||||
@router.get("/system/info", summary="Get System Information")
|
||||
async def get_system_info():
|
||||
"""Get comprehensive system information."""
|
||||
import sys
|
||||
import platform
|
||||
import psutil
|
||||
|
||||
return {
|
||||
"system": {
|
||||
"platform": platform.platform(),
|
||||
"python_version": sys.version,
|
||||
"cpu_count": psutil.cpu_count(),
|
||||
"memory_total_gb": round(psutil.virtual_memory().total / (1024**3), 2),
|
||||
"memory_available_gb": round(psutil.virtual_memory().available / (1024**3), 2)
|
||||
},
|
||||
"service": {
|
||||
"name": "Stability AI Integration",
|
||||
"version": "1.0.0",
|
||||
"uptime": "N/A", # Would track actual uptime
|
||||
"active_connections": "N/A"
|
||||
},
|
||||
"api_info": {
|
||||
"base_url": "https://api.stability.ai",
|
||||
"supported_versions": ["v2beta", "v1"],
|
||||
"rate_limit": "150 requests per 10 seconds"
|
||||
},
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/system/dependencies", summary="Get Service Dependencies")
|
||||
async def get_service_dependencies():
|
||||
"""Get information about service dependencies."""
|
||||
dependencies = {
|
||||
"required": {
|
||||
"fastapi": "Web framework",
|
||||
"aiohttp": "HTTP client for API calls",
|
||||
"pydantic": "Data validation",
|
||||
"pillow": "Image processing",
|
||||
"loguru": "Logging"
|
||||
},
|
||||
"optional": {
|
||||
"scikit-learn": "Color analysis",
|
||||
"numpy": "Numerical operations",
|
||||
"psutil": "System monitoring"
|
||||
},
|
||||
"external_services": {
|
||||
"stability_ai_api": {
|
||||
"url": "https://api.stability.ai",
|
||||
"status": "unknown", # Would check actual status
|
||||
"description": "Stability AI REST API"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return dependencies
|
||||
|
||||
|
||||
# ==================== WEBHOOK MANAGEMENT ====================
|
||||
|
||||
@router.get("/webhooks/config", summary="Get Webhook Configuration")
|
||||
async def get_webhook_config():
|
||||
"""Get current webhook configuration."""
|
||||
return {
|
||||
"webhooks_enabled": True,
|
||||
"supported_events": [
|
||||
"generation.completed",
|
||||
"generation.failed",
|
||||
"upscale.completed",
|
||||
"edit.completed"
|
||||
],
|
||||
"webhook_url": "/api/stability/webhook/generation-complete",
|
||||
"retry_policy": {
|
||||
"max_retries": 3,
|
||||
"retry_delay_seconds": 5
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.post("/webhooks/test", summary="Test Webhook Delivery")
|
||||
async def test_webhook_delivery():
|
||||
"""Test webhook delivery mechanism."""
|
||||
test_payload = {
|
||||
"event": "generation.completed",
|
||||
"generation_id": "test_generation_id",
|
||||
"status": "success",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
# In a real implementation, this would send to configured webhook URLs
|
||||
|
||||
return {
|
||||
"test_status": "success",
|
||||
"payload_sent": test_payload,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
817
backend/routers/stability_advanced.py
Normal file
817
backend/routers/stability_advanced.py
Normal file
@@ -0,0 +1,817 @@
|
||||
"""Advanced Stability AI endpoints with specialized features."""
|
||||
|
||||
from fastapi import APIRouter, UploadFile, File, Form, Depends, HTTPException, BackgroundTasks
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from typing import Optional, List, Dict, Any
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from services.stability_service import get_stability_service, StabilityAIService
|
||||
|
||||
router = APIRouter(prefix="/api/stability/advanced", tags=["Stability AI Advanced"])
|
||||
|
||||
|
||||
# ==================== ADVANCED GENERATION WORKFLOWS ====================
|
||||
|
||||
@router.post("/workflow/image-enhancement", summary="Complete Image Enhancement Workflow")
|
||||
async def image_enhancement_workflow(
|
||||
image: UploadFile = File(..., description="Image to enhance"),
|
||||
enhancement_type: str = Form("auto", description="Enhancement type: auto, upscale, denoise, sharpen"),
|
||||
prompt: Optional[str] = Form(None, description="Optional prompt for guided enhancement"),
|
||||
target_resolution: Optional[str] = Form("4k", description="Target resolution: 4k, 2k, hd"),
|
||||
preserve_style: Optional[bool] = Form(True, description="Preserve original style"),
|
||||
background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||
stability_service: StabilityAIService = Depends(get_stability_service)
|
||||
):
|
||||
"""Complete image enhancement workflow with automatic optimization.
|
||||
|
||||
This workflow automatically determines the best enhancement approach based on
|
||||
the input image characteristics and user preferences.
|
||||
"""
|
||||
async with stability_service:
|
||||
# Analyze image first
|
||||
content = await image.read()
|
||||
img_info = await _analyze_image(content)
|
||||
|
||||
# Reset file pointer
|
||||
await image.seek(0)
|
||||
|
||||
# Determine enhancement strategy
|
||||
strategy = _determine_enhancement_strategy(img_info, enhancement_type, target_resolution)
|
||||
|
||||
# Execute enhancement workflow
|
||||
results = []
|
||||
|
||||
for step in strategy["steps"]:
|
||||
if step["operation"] == "upscale_fast":
|
||||
result = await stability_service.upscale_fast(image=image)
|
||||
elif step["operation"] == "upscale_conservative":
|
||||
result = await stability_service.upscale_conservative(
|
||||
image=image,
|
||||
prompt=prompt or step["default_prompt"]
|
||||
)
|
||||
elif step["operation"] == "upscale_creative":
|
||||
result = await stability_service.upscale_creative(
|
||||
image=image,
|
||||
prompt=prompt or step["default_prompt"]
|
||||
)
|
||||
|
||||
results.append({
|
||||
"step": step["name"],
|
||||
"operation": step["operation"],
|
||||
"status": "completed",
|
||||
"result_size": len(result) if isinstance(result, bytes) else None
|
||||
})
|
||||
|
||||
# Use result as input for next step if needed
|
||||
if isinstance(result, bytes) and len(strategy["steps"]) > 1:
|
||||
# Convert bytes back to UploadFile-like object for next step
|
||||
image = _bytes_to_upload_file(result, image.filename)
|
||||
|
||||
# Return final result
|
||||
if isinstance(result, bytes):
|
||||
return Response(
|
||||
content=result,
|
||||
media_type="image/png",
|
||||
headers={
|
||||
"X-Enhancement-Strategy": json.dumps(strategy),
|
||||
"X-Processing-Steps": str(len(results))
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"strategy": strategy,
|
||||
"steps_completed": results,
|
||||
"generation_id": result.get("id") if isinstance(result, dict) else None
|
||||
}
|
||||
|
||||
|
||||
@router.post("/workflow/creative-suite", summary="Creative Suite Multi-Step Workflow")
|
||||
async def creative_suite_workflow(
|
||||
base_image: Optional[UploadFile] = File(None, description="Base image (optional for text-to-image)"),
|
||||
prompt: str = Form(..., description="Main creative prompt"),
|
||||
style_reference: Optional[UploadFile] = File(None, description="Style reference image"),
|
||||
workflow_steps: str = Form(..., description="JSON array of workflow steps"),
|
||||
output_format: Optional[str] = Form("png", description="Output format"),
|
||||
stability_service: StabilityAIService = Depends(get_stability_service)
|
||||
):
|
||||
"""Execute a multi-step creative workflow combining various Stability AI services.
|
||||
|
||||
This endpoint allows you to chain multiple operations together for complex
|
||||
creative workflows.
|
||||
"""
|
||||
try:
|
||||
steps = json.loads(workflow_steps)
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON in workflow_steps")
|
||||
|
||||
async with stability_service:
|
||||
current_image = base_image
|
||||
results = []
|
||||
|
||||
for i, step in enumerate(steps):
|
||||
operation = step.get("operation")
|
||||
params = step.get("parameters", {})
|
||||
|
||||
try:
|
||||
if operation == "generate_core" and not current_image:
|
||||
result = await stability_service.generate_core(prompt=prompt, **params)
|
||||
elif operation == "control_style" and style_reference:
|
||||
result = await stability_service.control_style(
|
||||
image=style_reference, prompt=prompt, **params
|
||||
)
|
||||
elif operation == "inpaint" and current_image:
|
||||
result = await stability_service.inpaint(
|
||||
image=current_image, prompt=prompt, **params
|
||||
)
|
||||
elif operation == "upscale_fast" and current_image:
|
||||
result = await stability_service.upscale_fast(image=current_image, **params)
|
||||
else:
|
||||
raise ValueError(f"Unsupported operation or missing requirements: {operation}")
|
||||
|
||||
# Convert result to next step input if needed
|
||||
if isinstance(result, bytes):
|
||||
current_image = _bytes_to_upload_file(result, f"step_{i}_output.png")
|
||||
|
||||
results.append({
|
||||
"step": i + 1,
|
||||
"operation": operation,
|
||||
"status": "completed",
|
||||
"result_type": "image" if isinstance(result, bytes) else "json"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
results.append({
|
||||
"step": i + 1,
|
||||
"operation": operation,
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
})
|
||||
break
|
||||
|
||||
# Return final result
|
||||
if isinstance(result, bytes):
|
||||
return Response(
|
||||
content=result,
|
||||
media_type=f"image/{output_format}",
|
||||
headers={"X-Workflow-Steps": json.dumps(results)}
|
||||
)
|
||||
|
||||
return {"workflow_results": results, "final_result": result}
|
||||
|
||||
|
||||
# ==================== COMPARISON ENDPOINTS ====================
|
||||
|
||||
@router.post("/compare/models", summary="Compare Different Models")
|
||||
async def compare_models(
|
||||
prompt: str = Form(..., description="Text prompt for comparison"),
|
||||
models: str = Form(..., description="JSON array of models to compare"),
|
||||
seed: Optional[int] = Form(42, description="Seed for consistent comparison"),
|
||||
stability_service: StabilityAIService = Depends(get_stability_service)
|
||||
):
|
||||
"""Generate images using different models for comparison.
|
||||
|
||||
This endpoint generates the same prompt using different Stability AI models
|
||||
to help you compare quality and style differences.
|
||||
"""
|
||||
try:
|
||||
model_list = json.loads(models)
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON in models")
|
||||
|
||||
async with stability_service:
|
||||
results = {}
|
||||
|
||||
for model in model_list:
|
||||
try:
|
||||
if model == "ultra":
|
||||
result = await stability_service.generate_ultra(
|
||||
prompt=prompt, seed=seed, output_format="webp"
|
||||
)
|
||||
elif model == "core":
|
||||
result = await stability_service.generate_core(
|
||||
prompt=prompt, seed=seed, output_format="webp"
|
||||
)
|
||||
elif model.startswith("sd3"):
|
||||
result = await stability_service.generate_sd3(
|
||||
prompt=prompt, model=model, seed=seed, output_format="webp"
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(result, bytes):
|
||||
results[model] = {
|
||||
"status": "success",
|
||||
"image": base64.b64encode(result).decode(),
|
||||
"size": len(result)
|
||||
}
|
||||
else:
|
||||
results[model] = {"status": "async", "generation_id": result.get("id")}
|
||||
|
||||
except Exception as e:
|
||||
results[model] = {"status": "error", "error": str(e)}
|
||||
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"seed": seed,
|
||||
"comparison_results": results,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
# ==================== STYLE TRANSFER WORKFLOWS ====================
|
||||
|
||||
@router.post("/style/multi-style-transfer", summary="Multi-Style Transfer")
|
||||
async def multi_style_transfer(
|
||||
content_image: UploadFile = File(..., description="Content image"),
|
||||
style_images: List[UploadFile] = File(..., description="Multiple style reference images"),
|
||||
blend_weights: Optional[str] = Form(None, description="JSON array of blend weights"),
|
||||
output_format: Optional[str] = Form("png", description="Output format"),
|
||||
stability_service: StabilityAIService = Depends(get_stability_service)
|
||||
):
|
||||
"""Apply multiple styles to a single content image with blending.
|
||||
|
||||
This endpoint applies multiple style references to a content image,
|
||||
optionally with specified blend weights.
|
||||
"""
|
||||
weights = None
|
||||
if blend_weights:
|
||||
try:
|
||||
weights = json.loads(blend_weights)
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON in blend_weights")
|
||||
|
||||
if weights and len(weights) != len(style_images):
|
||||
raise HTTPException(status_code=400, detail="Number of weights must match number of style images")
|
||||
|
||||
async with stability_service:
|
||||
results = []
|
||||
|
||||
for i, style_image in enumerate(style_images):
|
||||
weight = weights[i] if weights else 1.0
|
||||
|
||||
result = await stability_service.control_style_transfer(
|
||||
init_image=content_image,
|
||||
style_image=style_image,
|
||||
style_strength=weight,
|
||||
output_format=output_format
|
||||
)
|
||||
|
||||
if isinstance(result, bytes):
|
||||
results.append({
|
||||
"style_index": i,
|
||||
"weight": weight,
|
||||
"image": base64.b64encode(result).decode(),
|
||||
"size": len(result)
|
||||
})
|
||||
|
||||
# Reset content image file pointer for next iteration
|
||||
await content_image.seek(0)
|
||||
|
||||
return {
|
||||
"content_image": content_image.filename,
|
||||
"style_count": len(style_images),
|
||||
"results": results
|
||||
}
|
||||
|
||||
|
||||
# ==================== ANIMATION & SEQUENCE ENDPOINTS ====================
|
||||
|
||||
@router.post("/animation/image-sequence", summary="Generate Image Sequence")
|
||||
async def generate_image_sequence(
|
||||
base_prompt: str = Form(..., description="Base prompt for sequence"),
|
||||
sequence_prompts: str = Form(..., description="JSON array of sequence variations"),
|
||||
seed_start: Optional[int] = Form(42, description="Starting seed"),
|
||||
seed_increment: Optional[int] = Form(1, description="Seed increment per frame"),
|
||||
output_format: Optional[str] = Form("png", description="Output format"),
|
||||
stability_service: StabilityAIService = Depends(get_stability_service)
|
||||
):
|
||||
"""Generate a sequence of related images for animation or storytelling.
|
||||
|
||||
This endpoint generates a series of images with slight variations to create
|
||||
animation frames or story sequences.
|
||||
"""
|
||||
try:
|
||||
prompts = json.loads(sequence_prompts)
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON in sequence_prompts")
|
||||
|
||||
async with stability_service:
|
||||
sequence_results = []
|
||||
current_seed = seed_start
|
||||
|
||||
for i, variation in enumerate(prompts):
|
||||
full_prompt = f"{base_prompt}, {variation}"
|
||||
|
||||
result = await stability_service.generate_core(
|
||||
prompt=full_prompt,
|
||||
seed=current_seed,
|
||||
output_format=output_format
|
||||
)
|
||||
|
||||
if isinstance(result, bytes):
|
||||
sequence_results.append({
|
||||
"frame": i + 1,
|
||||
"prompt": full_prompt,
|
||||
"seed": current_seed,
|
||||
"image": base64.b64encode(result).decode(),
|
||||
"size": len(result)
|
||||
})
|
||||
|
||||
current_seed += seed_increment
|
||||
|
||||
return {
|
||||
"base_prompt": base_prompt,
|
||||
"frame_count": len(sequence_results),
|
||||
"sequence": sequence_results
|
||||
}
|
||||
|
||||
|
||||
# ==================== QUALITY ANALYSIS ENDPOINTS ====================
|
||||
|
||||
@router.post("/analysis/generation-quality", summary="Analyze Generation Quality")
|
||||
async def analyze_generation_quality(
|
||||
image: UploadFile = File(..., description="Generated image to analyze"),
|
||||
original_prompt: str = Form(..., description="Original generation prompt"),
|
||||
model_used: str = Form(..., description="Model used for generation")
|
||||
):
|
||||
"""Analyze the quality and characteristics of a generated image.
|
||||
|
||||
This endpoint provides detailed analysis of generated images including
|
||||
quality metrics, style adherence, and improvement suggestions.
|
||||
"""
|
||||
from PIL import Image, ImageStat
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
content = await image.read()
|
||||
img = Image.open(io.BytesIO(content))
|
||||
|
||||
# Basic image statistics
|
||||
stat = ImageStat.Stat(img)
|
||||
|
||||
# Convert to RGB if needed for analysis
|
||||
if img.mode != "RGB":
|
||||
img = img.convert("RGB")
|
||||
|
||||
# Calculate quality metrics
|
||||
img_array = np.array(img)
|
||||
|
||||
# Brightness analysis
|
||||
brightness = np.mean(img_array)
|
||||
|
||||
# Contrast analysis
|
||||
contrast = np.std(img_array)
|
||||
|
||||
# Color distribution
|
||||
color_channels = np.mean(img_array, axis=(0, 1))
|
||||
|
||||
# Sharpness estimation (using Laplacian variance)
|
||||
gray = img.convert('L')
|
||||
gray_array = np.array(gray)
|
||||
laplacian_var = np.var(np.gradient(gray_array))
|
||||
|
||||
quality_score = min(100, (contrast / 50) * (laplacian_var / 1000) * 100)
|
||||
|
||||
analysis = {
|
||||
"image_info": {
|
||||
"dimensions": f"{img.width}x{img.height}",
|
||||
"format": img.format,
|
||||
"mode": img.mode,
|
||||
"file_size": len(content)
|
||||
},
|
||||
"quality_metrics": {
|
||||
"overall_score": round(quality_score, 2),
|
||||
"brightness": round(brightness, 2),
|
||||
"contrast": round(contrast, 2),
|
||||
"sharpness": round(laplacian_var, 2)
|
||||
},
|
||||
"color_analysis": {
|
||||
"red_channel": round(float(color_channels[0]), 2),
|
||||
"green_channel": round(float(color_channels[1]), 2),
|
||||
"blue_channel": round(float(color_channels[2]), 2),
|
||||
"color_balance": "balanced" if max(color_channels) - min(color_channels) < 30 else "imbalanced"
|
||||
},
|
||||
"generation_info": {
|
||||
"original_prompt": original_prompt,
|
||||
"model_used": model_used,
|
||||
"analysis_timestamp": datetime.utcnow().isoformat()
|
||||
},
|
||||
"recommendations": _generate_quality_recommendations(quality_score, brightness, contrast)
|
||||
}
|
||||
|
||||
return analysis
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Error analyzing image: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/analysis/prompt-optimization", summary="Optimize Text Prompts")
|
||||
async def optimize_prompt(
|
||||
prompt: str = Form(..., description="Original prompt to optimize"),
|
||||
target_style: Optional[str] = Form(None, description="Target style"),
|
||||
target_quality: Optional[str] = Form("high", description="Target quality level"),
|
||||
model: Optional[str] = Form("ultra", description="Target model"),
|
||||
include_negative: Optional[bool] = Form(True, description="Include negative prompt suggestions")
|
||||
):
|
||||
"""Analyze and optimize text prompts for better generation results.
|
||||
|
||||
This endpoint analyzes your prompt and provides suggestions for improvement
|
||||
based on best practices and model-specific optimizations.
|
||||
"""
|
||||
analysis = {
|
||||
"original_prompt": prompt,
|
||||
"prompt_length": len(prompt),
|
||||
"word_count": len(prompt.split()),
|
||||
"optimization_suggestions": []
|
||||
}
|
||||
|
||||
# Analyze prompt structure
|
||||
suggestions = []
|
||||
|
||||
# Check for style descriptors
|
||||
style_keywords = ["photorealistic", "digital art", "oil painting", "watercolor", "sketch"]
|
||||
has_style = any(keyword in prompt.lower() for keyword in style_keywords)
|
||||
if not has_style and target_style:
|
||||
suggestions.append(f"Add style descriptor: {target_style}")
|
||||
|
||||
# Check for quality enhancers
|
||||
quality_keywords = ["high quality", "detailed", "sharp", "crisp", "professional"]
|
||||
has_quality = any(keyword in prompt.lower() for keyword in quality_keywords)
|
||||
if not has_quality and target_quality == "high":
|
||||
suggestions.append("Add quality enhancers: 'high quality, detailed, sharp'")
|
||||
|
||||
# Check for composition elements
|
||||
composition_keywords = ["composition", "lighting", "perspective", "framing"]
|
||||
has_composition = any(keyword in prompt.lower() for keyword in composition_keywords)
|
||||
if not has_composition:
|
||||
suggestions.append("Consider adding composition details: lighting, perspective, framing")
|
||||
|
||||
# Model-specific optimizations
|
||||
if model == "ultra":
|
||||
suggestions.append("For Ultra model: Use detailed, specific descriptions")
|
||||
elif model == "core":
|
||||
suggestions.append("For Core model: Keep prompts concise but descriptive")
|
||||
|
||||
# Generate optimized prompt
|
||||
optimized_prompt = prompt
|
||||
if suggestions:
|
||||
optimized_prompt = _apply_prompt_optimizations(prompt, suggestions, target_style)
|
||||
|
||||
# Generate negative prompt suggestions
|
||||
negative_suggestions = []
|
||||
if include_negative:
|
||||
negative_suggestions = _generate_negative_prompt_suggestions(prompt, target_style)
|
||||
|
||||
analysis.update({
|
||||
"optimization_suggestions": suggestions,
|
||||
"optimized_prompt": optimized_prompt,
|
||||
"negative_prompt_suggestions": negative_suggestions,
|
||||
"estimated_improvement": len(suggestions) * 10, # Rough estimate
|
||||
"model_compatibility": _check_model_compatibility(optimized_prompt, model)
|
||||
})
|
||||
|
||||
return analysis
|
||||
|
||||
|
||||
# ==================== BATCH PROCESSING ENDPOINTS ====================
|
||||
|
||||
@router.post("/batch/process-folder", summary="Process Multiple Images")
|
||||
async def batch_process_folder(
|
||||
images: List[UploadFile] = File(..., description="Multiple images to process"),
|
||||
operation: str = Form(..., description="Operation to perform on all images"),
|
||||
operation_params: str = Form("{}", description="JSON parameters for operation"),
|
||||
background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||
stability_service: StabilityAIService = Depends(get_stability_service)
|
||||
):
|
||||
"""Process multiple images with the same operation in batch.
|
||||
|
||||
This endpoint allows you to apply the same operation to multiple images
|
||||
efficiently.
|
||||
"""
|
||||
try:
|
||||
params = json.loads(operation_params)
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON in operation_params")
|
||||
|
||||
# Validate operation
|
||||
supported_operations = [
|
||||
"upscale_fast", "remove_background", "erase", "generate_ultra", "generate_core"
|
||||
]
|
||||
if operation not in supported_operations:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported operation. Supported: {supported_operations}"
|
||||
)
|
||||
|
||||
# Start batch processing in background
|
||||
batch_id = f"batch_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
|
||||
|
||||
background_tasks.add_task(
|
||||
_process_batch_images,
|
||||
batch_id,
|
||||
images,
|
||||
operation,
|
||||
params,
|
||||
stability_service
|
||||
)
|
||||
|
||||
return {
|
||||
"batch_id": batch_id,
|
||||
"status": "started",
|
||||
"image_count": len(images),
|
||||
"operation": operation,
|
||||
"estimated_completion": (datetime.utcnow() + timedelta(minutes=len(images) * 2)).isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/batch/{batch_id}/status", summary="Get Batch Processing Status")
|
||||
async def get_batch_status(batch_id: str):
|
||||
"""Get the status of a batch processing operation.
|
||||
|
||||
Returns the current status and progress of a batch operation.
|
||||
"""
|
||||
# In a real implementation, you'd store batch status in a database
|
||||
# For now, return a mock response
|
||||
return {
|
||||
"batch_id": batch_id,
|
||||
"status": "processing",
|
||||
"progress": {
|
||||
"completed": 2,
|
||||
"total": 5,
|
||||
"percentage": 40
|
||||
},
|
||||
"estimated_completion": (datetime.utcnow() + timedelta(minutes=5)).isoformat()
|
||||
}
|
||||
|
||||
|
||||
# ==================== HELPER FUNCTIONS ====================
|
||||
|
||||
async def _analyze_image(content: bytes) -> Dict[str, Any]:
|
||||
"""Analyze image characteristics."""
|
||||
from PIL import Image
|
||||
|
||||
img = Image.open(io.BytesIO(content))
|
||||
total_pixels = img.width * img.height
|
||||
|
||||
return {
|
||||
"width": img.width,
|
||||
"height": img.height,
|
||||
"total_pixels": total_pixels,
|
||||
"aspect_ratio": img.width / img.height,
|
||||
"format": img.format,
|
||||
"mode": img.mode,
|
||||
"is_low_res": total_pixels < 500000, # Less than 0.5MP
|
||||
"is_high_res": total_pixels > 2000000, # More than 2MP
|
||||
"needs_upscaling": total_pixels < 1000000 # Less than 1MP
|
||||
}
|
||||
|
||||
|
||||
def _determine_enhancement_strategy(img_info: Dict[str, Any], enhancement_type: str, target_resolution: str) -> Dict[str, Any]:
|
||||
"""Determine the best enhancement strategy based on image characteristics."""
|
||||
strategy = {"steps": []}
|
||||
|
||||
if enhancement_type == "auto":
|
||||
if img_info["is_low_res"]:
|
||||
if img_info["total_pixels"] < 100000: # Very low res
|
||||
strategy["steps"].append({
|
||||
"name": "Creative Upscale",
|
||||
"operation": "upscale_creative",
|
||||
"default_prompt": "high quality, detailed, sharp"
|
||||
})
|
||||
else:
|
||||
strategy["steps"].append({
|
||||
"name": "Conservative Upscale",
|
||||
"operation": "upscale_conservative",
|
||||
"default_prompt": "enhance quality, preserve details"
|
||||
})
|
||||
else:
|
||||
strategy["steps"].append({
|
||||
"name": "Fast Upscale",
|
||||
"operation": "upscale_fast",
|
||||
"default_prompt": ""
|
||||
})
|
||||
elif enhancement_type == "upscale":
|
||||
if target_resolution == "4k":
|
||||
strategy["steps"].append({
|
||||
"name": "Conservative Upscale to 4K",
|
||||
"operation": "upscale_conservative",
|
||||
"default_prompt": "4K resolution, high quality"
|
||||
})
|
||||
else:
|
||||
strategy["steps"].append({
|
||||
"name": "Fast Upscale",
|
||||
"operation": "upscale_fast",
|
||||
"default_prompt": ""
|
||||
})
|
||||
|
||||
return strategy
|
||||
|
||||
|
||||
def _bytes_to_upload_file(content: bytes, filename: str):
|
||||
"""Convert bytes to UploadFile-like object."""
|
||||
from fastapi import UploadFile
|
||||
from io import BytesIO
|
||||
|
||||
file_obj = BytesIO(content)
|
||||
file_obj.seek(0)
|
||||
|
||||
# Create a mock UploadFile
|
||||
class MockUploadFile:
|
||||
def __init__(self, file_obj, filename):
|
||||
self.file = file_obj
|
||||
self.filename = filename
|
||||
self.content_type = "image/png"
|
||||
|
||||
async def read(self):
|
||||
return self.file.read()
|
||||
|
||||
async def seek(self, position):
|
||||
self.file.seek(position)
|
||||
|
||||
return MockUploadFile(file_obj, filename)
|
||||
|
||||
|
||||
def _generate_quality_recommendations(quality_score: float, brightness: float, contrast: float) -> List[str]:
|
||||
"""Generate quality improvement recommendations."""
|
||||
recommendations = []
|
||||
|
||||
if quality_score < 50:
|
||||
recommendations.append("Consider using a higher quality model like Ultra")
|
||||
|
||||
if brightness < 100:
|
||||
recommendations.append("Image appears dark, consider adjusting lighting in prompt")
|
||||
elif brightness > 200:
|
||||
recommendations.append("Image appears bright, consider reducing exposure in prompt")
|
||||
|
||||
if contrast < 30:
|
||||
recommendations.append("Low contrast detected, add 'high contrast' to prompt")
|
||||
|
||||
if not recommendations:
|
||||
recommendations.append("Image quality looks good!")
|
||||
|
||||
return recommendations
|
||||
|
||||
|
||||
def _apply_prompt_optimizations(prompt: str, suggestions: List[str], target_style: Optional[str]) -> str:
|
||||
"""Apply optimization suggestions to prompt."""
|
||||
optimized = prompt
|
||||
|
||||
# Add style if suggested
|
||||
if target_style and f"Add style descriptor: {target_style}" in suggestions:
|
||||
optimized = f"{optimized}, {target_style} style"
|
||||
|
||||
# Add quality enhancers if suggested
|
||||
if any("quality enhancer" in s for s in suggestions):
|
||||
optimized = f"{optimized}, high quality, detailed, sharp"
|
||||
|
||||
return optimized.strip()
|
||||
|
||||
|
||||
def _generate_negative_prompt_suggestions(prompt: str, target_style: Optional[str]) -> List[str]:
|
||||
"""Generate negative prompt suggestions based on prompt analysis."""
|
||||
suggestions = []
|
||||
|
||||
# Common negative prompts
|
||||
suggestions.extend([
|
||||
"blurry, low quality, pixelated",
|
||||
"distorted, deformed, malformed",
|
||||
"oversaturated, undersaturated"
|
||||
])
|
||||
|
||||
# Style-specific negative prompts
|
||||
if target_style:
|
||||
if "photorealistic" in target_style.lower():
|
||||
suggestions.append("cartoon, anime, illustration")
|
||||
elif "anime" in target_style.lower():
|
||||
suggestions.append("realistic, photographic")
|
||||
|
||||
return suggestions
|
||||
|
||||
|
||||
def _check_model_compatibility(prompt: str, model: str) -> Dict[str, Any]:
|
||||
"""Check prompt compatibility with specific models."""
|
||||
compatibility = {"score": 100, "notes": []}
|
||||
|
||||
if model == "ultra":
|
||||
if len(prompt.split()) < 5:
|
||||
compatibility["score"] -= 20
|
||||
compatibility["notes"].append("Ultra model works best with detailed prompts")
|
||||
elif model == "core":
|
||||
if len(prompt) > 500:
|
||||
compatibility["score"] -= 10
|
||||
compatibility["notes"].append("Core model works well with concise prompts")
|
||||
|
||||
return compatibility
|
||||
|
||||
|
||||
async def _process_batch_images(
|
||||
batch_id: str,
|
||||
images: List[UploadFile],
|
||||
operation: str,
|
||||
params: Dict[str, Any],
|
||||
stability_service: StabilityAIService
|
||||
):
|
||||
"""Background task for processing multiple images."""
|
||||
# In a real implementation, you'd store progress in a database
|
||||
# This is a simplified version for demonstration
|
||||
|
||||
async with stability_service:
|
||||
for i, image in enumerate(images):
|
||||
try:
|
||||
if operation == "upscale_fast":
|
||||
await stability_service.upscale_fast(image=image, **params)
|
||||
elif operation == "remove_background":
|
||||
await stability_service.remove_background(image=image, **params)
|
||||
# Add other operations as needed
|
||||
|
||||
# Log progress (in real implementation, update database)
|
||||
logger.info(f"Batch {batch_id}: Completed image {i+1}/{len(images)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Batch {batch_id}: Error processing image {i+1}: {str(e)}")
|
||||
|
||||
|
||||
# ==================== EXPERIMENTAL ENDPOINTS ====================
|
||||
|
||||
@router.post("/experimental/ai-director", summary="AI Director Mode")
|
||||
async def ai_director_mode(
|
||||
concept: str = Form(..., description="High-level creative concept"),
|
||||
target_audience: Optional[str] = Form(None, description="Target audience"),
|
||||
mood: Optional[str] = Form(None, description="Desired mood"),
|
||||
color_palette: Optional[str] = Form(None, description="Preferred color palette"),
|
||||
iterations: Optional[int] = Form(3, description="Number of iterations"),
|
||||
stability_service: StabilityAIService = Depends(get_stability_service)
|
||||
):
|
||||
"""AI Director mode for automated creative decision making.
|
||||
|
||||
This experimental endpoint acts as an AI creative director, making
|
||||
intelligent decisions about style, composition, and execution based on
|
||||
high-level creative concepts.
|
||||
"""
|
||||
# Generate detailed prompts based on concept
|
||||
director_prompts = _generate_director_prompts(concept, target_audience, mood, color_palette)
|
||||
|
||||
async with stability_service:
|
||||
iterations_results = []
|
||||
|
||||
for i in range(iterations):
|
||||
prompt = director_prompts[i % len(director_prompts)]
|
||||
|
||||
result = await stability_service.generate_ultra(
|
||||
prompt=prompt,
|
||||
output_format="webp"
|
||||
)
|
||||
|
||||
if isinstance(result, bytes):
|
||||
iterations_results.append({
|
||||
"iteration": i + 1,
|
||||
"prompt": prompt,
|
||||
"image": base64.b64encode(result).decode(),
|
||||
"size": len(result)
|
||||
})
|
||||
|
||||
return {
|
||||
"concept": concept,
|
||||
"director_analysis": {
|
||||
"target_audience": target_audience,
|
||||
"mood": mood,
|
||||
"color_palette": color_palette
|
||||
},
|
||||
"generated_prompts": director_prompts,
|
||||
"iterations": iterations_results
|
||||
}
|
||||
|
||||
|
||||
def _generate_director_prompts(concept: str, audience: Optional[str], mood: Optional[str], colors: Optional[str]) -> List[str]:
|
||||
"""Generate creative prompts based on director inputs."""
|
||||
base_prompt = concept
|
||||
|
||||
# Add audience-specific elements
|
||||
if audience:
|
||||
if "professional" in audience.lower():
|
||||
base_prompt += ", professional, clean, sophisticated"
|
||||
elif "creative" in audience.lower():
|
||||
base_prompt += ", artistic, innovative, expressive"
|
||||
elif "casual" in audience.lower():
|
||||
base_prompt += ", friendly, approachable, relaxed"
|
||||
|
||||
# Add mood elements
|
||||
if mood:
|
||||
base_prompt += f", {mood} mood"
|
||||
|
||||
# Add color palette
|
||||
if colors:
|
||||
base_prompt += f", {colors} color palette"
|
||||
|
||||
# Generate variations
|
||||
variations = [
|
||||
f"{base_prompt}, high quality, detailed",
|
||||
f"{base_prompt}, cinematic lighting, professional photography",
|
||||
f"{base_prompt}, artistic composition, creative perspective"
|
||||
]
|
||||
|
||||
return variations
|
||||
Reference in New Issue
Block a user