AI Researcher and Video Studio implementation complete
This commit is contained in:
@@ -11,6 +11,7 @@ from typing import Any, Dict, Literal, Optional
|
||||
from PIL import Image
|
||||
|
||||
from services.llm_providers.main_image_editing import edit_image as huggingface_edit_image
|
||||
from services.llm_providers.main_image_generation import generate_image_edit
|
||||
from services.stability_service import StabilityAIService
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
@@ -213,6 +214,249 @@ class EditStudioService:
|
||||
def list_operations(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Expose supported operations for UI rendering."""
|
||||
return self.SUPPORTED_OPERATIONS
|
||||
|
||||
def get_available_models(
|
||||
self,
|
||||
operation: Optional[str] = None,
|
||||
tier: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Get available WaveSpeed editing models.
|
||||
|
||||
Args:
|
||||
operation: Filter by operation type (e.g., "general_edit")
|
||||
tier: Filter by tier ("budget", "mid", "premium")
|
||||
|
||||
Returns:
|
||||
Dictionary with models and metadata
|
||||
"""
|
||||
from services.llm_providers.image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
|
||||
|
||||
provider = WaveSpeedEditProvider()
|
||||
all_models = provider.get_available_models()
|
||||
|
||||
# Filter by operation if specified
|
||||
if operation:
|
||||
filtered = provider.get_models_by_operation(operation)
|
||||
all_models = {k: v for k, v in all_models.items() if k in filtered}
|
||||
|
||||
# Filter by tier if specified
|
||||
if tier:
|
||||
filtered = provider.get_models_by_tier(tier)
|
||||
all_models = {k: v for k, v in all_models.items() if k in filtered}
|
||||
|
||||
# Format for API response
|
||||
models_list = []
|
||||
for model_id, model_info in all_models.items():
|
||||
models_list.append({
|
||||
"id": model_id,
|
||||
"name": model_info.get("name", model_id),
|
||||
"description": model_info.get("description", ""),
|
||||
"cost": model_info.get("cost", 0.02),
|
||||
"cost_8k": model_info.get("cost_8k"), # Optional
|
||||
"tier": model_info.get("tier", "mid"),
|
||||
"max_resolution": model_info.get("max_resolution", [2048, 2048]),
|
||||
"capabilities": model_info.get("capabilities", []),
|
||||
"use_cases": self._get_use_cases_for_model(model_id, model_info),
|
||||
"features": self._get_features_for_model(model_info),
|
||||
"supports_multi_image": model_info.get("supports_multi_image", False),
|
||||
"supports_controlnet": model_info.get("supports_controlnet", False),
|
||||
"languages": model_info.get("languages", ["en"]),
|
||||
})
|
||||
|
||||
return {
|
||||
"models": models_list,
|
||||
"total": len(models_list),
|
||||
}
|
||||
|
||||
def recommend_model(
|
||||
self,
|
||||
operation: str,
|
||||
image_resolution: Optional[Dict[str, int]] = None,
|
||||
user_tier: Optional[str] = None,
|
||||
preferences: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Recommend best model for given operation and context.
|
||||
|
||||
Args:
|
||||
operation: Operation type (e.g., "general_edit")
|
||||
image_resolution: Dict with "width" and "height"
|
||||
user_tier: User subscription tier ("free", "pro", "enterprise")
|
||||
preferences: Dict with "prioritize_cost" or "prioritize_quality"
|
||||
|
||||
Returns:
|
||||
Dictionary with recommended model and alternatives
|
||||
"""
|
||||
from services.llm_providers.image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
|
||||
|
||||
provider = WaveSpeedEditProvider()
|
||||
available_models = provider.get_models_by_operation(operation)
|
||||
|
||||
if not available_models:
|
||||
# Fallback to all models if operation doesn't match
|
||||
available_models = provider.get_available_models()
|
||||
|
||||
# Filter by resolution if provided
|
||||
if image_resolution:
|
||||
width = image_resolution.get("width", 0)
|
||||
height = image_resolution.get("height", 0)
|
||||
max_dimension = max(width, height)
|
||||
|
||||
# Filter models that support this resolution
|
||||
filtered = {}
|
||||
for model_id, model_info in available_models.items():
|
||||
max_res = model_info.get("max_resolution", (2048, 2048))
|
||||
max_supported = max(max_res[0], max_res[1])
|
||||
if max_dimension <= max_supported:
|
||||
filtered[model_id] = model_info
|
||||
available_models = filtered
|
||||
|
||||
if not available_models:
|
||||
# No models match, return first available
|
||||
all_models = provider.get_available_models()
|
||||
if all_models:
|
||||
first_model_id = list(all_models.keys())[0]
|
||||
return {
|
||||
"recommended_model": first_model_id,
|
||||
"reason": "No specific match found, using default model",
|
||||
"alternatives": [],
|
||||
}
|
||||
else:
|
||||
raise ValueError("No models available")
|
||||
|
||||
# Apply preferences
|
||||
prioritize_cost = preferences and preferences.get("prioritize_cost", False)
|
||||
prioritize_quality = preferences and preferences.get("prioritize_quality", False)
|
||||
|
||||
# Score models
|
||||
scored_models = []
|
||||
for model_id, model_info in available_models.items():
|
||||
score = 0
|
||||
cost = model_info.get("cost", 0.02)
|
||||
tier = model_info.get("tier", "mid")
|
||||
max_res = model_info.get("max_resolution", (2048, 2048))
|
||||
max_resolution = max(max_res[0], max_res[1])
|
||||
|
||||
# Cost scoring (lower is better)
|
||||
if prioritize_cost:
|
||||
score += (1.0 / cost) * 100 # Invert cost for scoring
|
||||
else:
|
||||
score += (1.0 / cost) * 50 # Less weight if not prioritizing
|
||||
|
||||
# Quality scoring (higher resolution = better)
|
||||
if prioritize_quality:
|
||||
score += max_resolution / 10 # Higher weight for quality
|
||||
else:
|
||||
score += max_resolution / 20 # Lower weight
|
||||
|
||||
# Tier preference based on user tier
|
||||
if user_tier == "free":
|
||||
if tier == "budget":
|
||||
score += 50
|
||||
elif tier == "mid":
|
||||
score += 20
|
||||
elif user_tier in ["pro", "enterprise"]:
|
||||
if tier == "premium":
|
||||
score += 50
|
||||
elif tier == "mid":
|
||||
score += 30
|
||||
|
||||
scored_models.append((model_id, model_info, score))
|
||||
|
||||
# Sort by score (highest first)
|
||||
scored_models.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
# Get recommended model
|
||||
recommended_id, recommended_info, recommended_score = scored_models[0]
|
||||
|
||||
# Build reason
|
||||
reasons = []
|
||||
if prioritize_cost:
|
||||
reasons.append("Lowest cost option")
|
||||
if prioritize_quality:
|
||||
reasons.append("Best quality")
|
||||
if image_resolution:
|
||||
reasons.append(f"Supports {image_resolution.get('width')}×{image_resolution.get('height')} resolution")
|
||||
if user_tier == "free" and recommended_info.get("tier") == "budget":
|
||||
reasons.append("Budget-friendly for free tier")
|
||||
|
||||
reason = ", ".join(reasons) if reasons else "Best match for your requirements"
|
||||
|
||||
# Get alternatives (top 2-3)
|
||||
alternatives = []
|
||||
for model_id, model_info, score in scored_models[1:4]:
|
||||
alt_reason = f"Alternative: {model_info.get('tier', 'mid').title()} tier"
|
||||
if model_info.get("cost", 0) < recommended_info.get("cost", 0):
|
||||
alt_reason += ", lower cost"
|
||||
elif model_info.get("cost", 0) > recommended_info.get("cost", 0):
|
||||
alt_reason += ", higher quality"
|
||||
alternatives.append({
|
||||
"model_id": model_id,
|
||||
"name": model_info.get("name", model_id),
|
||||
"cost": model_info.get("cost", 0.02),
|
||||
"reason": alt_reason,
|
||||
})
|
||||
|
||||
return {
|
||||
"recommended_model": recommended_id,
|
||||
"reason": reason,
|
||||
"alternatives": alternatives,
|
||||
}
|
||||
|
||||
def _get_use_cases_for_model(self, model_id: str, model_info: Dict[str, Any]) -> list:
|
||||
"""Get use cases for a model based on its capabilities."""
|
||||
use_cases_map = {
|
||||
"general_edit": ["Quick edits", "Style changes", "Background replacement"],
|
||||
"style_transfer": ["Apply artistic styles", "Style transformations"],
|
||||
"text_edit": ["Add text to images", "Edit text in images"],
|
||||
"multi_image": ["Batch editing", "Consistent character work"],
|
||||
"high_res": ["Professional work", "Print materials", "4K/8K editing"],
|
||||
"professional": ["Marketing campaigns", "Brand assets"],
|
||||
"typography": ["Text-heavy edits", "Typography generation"],
|
||||
"portrait_retouching": ["Portrait edits", "Beauty retouching"],
|
||||
"fashion_edit": ["Fashion photography", "Outfit changes"],
|
||||
"product_edit": ["E-commerce", "Product photography"],
|
||||
}
|
||||
|
||||
capabilities = model_info.get("capabilities", [])
|
||||
use_cases = []
|
||||
for cap in capabilities:
|
||||
if cap in use_cases_map:
|
||||
use_cases.extend(use_cases_map[cap])
|
||||
|
||||
# Remove duplicates
|
||||
return list(set(use_cases)) if use_cases else ["General image editing"]
|
||||
|
||||
def _get_features_for_model(self, model_info: Dict[str, Any]) -> list:
|
||||
"""Get feature list for a model."""
|
||||
features = []
|
||||
|
||||
if model_info.get("supports_multi_image"):
|
||||
max_images = model_info.get("api_params", {}).get("max_images", 0)
|
||||
if max_images:
|
||||
features.append(f"Multi-image ({max_images} images)")
|
||||
else:
|
||||
features.append("Multi-image support")
|
||||
|
||||
if model_info.get("supports_controlnet"):
|
||||
features.append("ControlNet support")
|
||||
|
||||
languages = model_info.get("languages", [])
|
||||
if len(languages) > 1:
|
||||
features.append(f"Multilingual ({', '.join(languages)})")
|
||||
elif "multilingual" in languages:
|
||||
features.append("Multilingual support")
|
||||
|
||||
max_res = model_info.get("max_resolution", (2048, 2048))
|
||||
if max(max_res) >= 4096:
|
||||
features.append("4K/8K support")
|
||||
elif max(max_res) >= 2048:
|
||||
features.append("2K support")
|
||||
|
||||
api_params = model_info.get("api_params", {})
|
||||
if api_params.get("supports_guidance_scale"):
|
||||
features.append("Guidance scale control")
|
||||
|
||||
return features if features else ["Standard editing"]
|
||||
|
||||
async def process_edit(
|
||||
self,
|
||||
@@ -221,6 +465,9 @@ class EditStudioService:
|
||||
) -> Dict[str, Any]:
|
||||
"""Process edit request and return normalized response."""
|
||||
|
||||
# Pre-flight validation: Use specific validator for editing operations
|
||||
# Note: Editing uses validate_image_editing_operations (different from generation)
|
||||
# This is intentional as editing may have different subscription limits
|
||||
if user_id:
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
@@ -386,29 +633,109 @@ class EditStudioService:
|
||||
mask_bytes: Optional[bytes],
|
||||
user_id: Optional[str],
|
||||
) -> bytes:
|
||||
"""Execute Hugging Face powered general editing (synchronous API)."""
|
||||
"""Execute general editing - routes to WaveSpeed (unified entry) or HuggingFace (legacy).
|
||||
|
||||
If model is a WaveSpeed model (qwen-edit-plus, nano-banana-pro-edit-ultra, seedream-v4.5-edit),
|
||||
uses unified entry point. Otherwise falls back to HuggingFace for backward compatibility.
|
||||
"""
|
||||
if not request.prompt:
|
||||
raise ValueError("Prompt is required for general edits")
|
||||
|
||||
options = {
|
||||
"provider": request.provider or "huggingface",
|
||||
"model": request.model,
|
||||
"guidance_scale": request.guidance_scale,
|
||||
"steps": request.steps,
|
||||
"seed": request.seed,
|
||||
}
|
||||
|
||||
# huggingface edit is synchronous - run in thread
|
||||
result = await asyncio.to_thread(
|
||||
huggingface_edit_image,
|
||||
image_bytes,
|
||||
request.prompt,
|
||||
options,
|
||||
user_id,
|
||||
mask_bytes, # Optional mask for selective editing
|
||||
# Check if model is a WaveSpeed editing model
|
||||
from services.llm_providers.image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
|
||||
provider = WaveSpeedEditProvider()
|
||||
wavespeed_models = set(provider.get_available_models().keys())
|
||||
|
||||
# Also check if provider is explicitly set to "wavespeed"
|
||||
is_wavespeed = (
|
||||
request.provider == "wavespeed" or
|
||||
(request.model and request.model in wavespeed_models)
|
||||
)
|
||||
|
||||
# Auto-detect: If no model specified and operation is general_edit, recommend one
|
||||
if not request.model and not is_wavespeed and request.operation == "general_edit":
|
||||
# Auto-select recommended model
|
||||
try:
|
||||
# Get image dimensions for recommendation
|
||||
with Image.open(io.BytesIO(image_bytes)) as img:
|
||||
image_resolution = {"width": img.width, "height": img.height}
|
||||
|
||||
recommendation = self.recommend_model(
|
||||
operation=request.operation,
|
||||
image_resolution=image_resolution,
|
||||
preferences={"prioritize_cost": True}, # Default to cost-optimized
|
||||
)
|
||||
recommended_model = recommendation.get("recommended_model")
|
||||
if recommended_model and recommended_model in wavespeed_models:
|
||||
logger.info(f"[Edit Studio] Auto-selected model: {recommended_model} (reason: {recommendation.get('reason')})")
|
||||
request.model = recommended_model
|
||||
is_wavespeed = True
|
||||
except Exception as e:
|
||||
logger.warning(f"[Edit Studio] Auto-detection failed: {e}, falling back to HuggingFace")
|
||||
|
||||
if is_wavespeed:
|
||||
# Use unified entry point for WaveSpeed models
|
||||
logger.info(f"[Edit Studio] Using WaveSpeed unified entry for model={request.model}")
|
||||
|
||||
# Convert image bytes to base64
|
||||
import base64
|
||||
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
|
||||
# Prepare options for unified entry point
|
||||
edit_options = {
|
||||
"mask_base64": None,
|
||||
"negative_prompt": request.negative_prompt,
|
||||
"width": None, # Will be determined from image if needed
|
||||
"height": None,
|
||||
"guidance_scale": request.guidance_scale,
|
||||
"steps": request.steps,
|
||||
"seed": request.seed,
|
||||
}
|
||||
|
||||
# Add mask if provided
|
||||
if mask_bytes:
|
||||
edit_options["mask_base64"] = base64.b64encode(mask_bytes).decode("utf-8")
|
||||
|
||||
# Extract dimensions from image if needed
|
||||
with Image.open(io.BytesIO(image_bytes)) as img:
|
||||
edit_options["width"] = img.width
|
||||
edit_options["height"] = img.height
|
||||
|
||||
# Call unified entry point (synchronous, so run in thread)
|
||||
result = await asyncio.to_thread(
|
||||
generate_image_edit,
|
||||
image_base64=image_base64,
|
||||
prompt=request.prompt,
|
||||
operation=request.operation or "general_edit",
|
||||
model=request.model, # Will auto-select if None
|
||||
options=edit_options,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
return result.image_bytes
|
||||
else:
|
||||
# Fall back to HuggingFace for backward compatibility
|
||||
logger.info("[Edit Studio] Using HuggingFace (legacy) for general edit")
|
||||
|
||||
options = {
|
||||
"provider": request.provider or "huggingface",
|
||||
"model": request.model,
|
||||
"guidance_scale": request.guidance_scale,
|
||||
"steps": request.steps,
|
||||
"seed": request.seed,
|
||||
}
|
||||
|
||||
return result.image_bytes
|
||||
# huggingface edit is synchronous - run in thread
|
||||
result = await asyncio.to_thread(
|
||||
huggingface_edit_image,
|
||||
image_bytes,
|
||||
request.prompt,
|
||||
options,
|
||||
user_id,
|
||||
mask_bytes, # Optional mask for selective editing
|
||||
)
|
||||
|
||||
return result.image_bytes
|
||||
|
||||
@staticmethod
|
||||
def _extract_image_bytes(result: Any) -> bytes:
|
||||
|
||||
Reference in New Issue
Block a user