AI Researcher and Video Studio implementation complete
This commit is contained in:
@@ -1,17 +1,10 @@
|
||||
"""Create Studio service for AI-powered image generation."""
|
||||
|
||||
import os
|
||||
from typing import Optional, Dict, Any, List, Literal
|
||||
from dataclasses import dataclass
|
||||
|
||||
from services.llm_providers.image_generation import (
|
||||
ImageGenerationOptions,
|
||||
ImageGenerationResult,
|
||||
HuggingFaceImageProvider,
|
||||
GeminiImageProvider,
|
||||
StabilityImageProvider,
|
||||
WaveSpeedImageProvider,
|
||||
)
|
||||
from services.llm_providers.main_image_generation import generate_image
|
||||
from services.llm_providers.image_generation import ImageGenerationResult
|
||||
from .templates import TemplateManager, ImageTemplate, Platform, TemplateCategory
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
@@ -75,29 +68,8 @@ class CreateStudioService:
|
||||
self.template_manager = TemplateManager()
|
||||
logger.info("[Create Studio] Initialized with template manager")
|
||||
|
||||
def _get_provider_instance(self, provider_name: str, api_key: Optional[str] = None):
|
||||
"""Get provider instance by name.
|
||||
|
||||
Args:
|
||||
provider_name: Name of the provider
|
||||
api_key: Optional API key (uses env vars if not provided)
|
||||
|
||||
Returns:
|
||||
Provider instance
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is not supported
|
||||
"""
|
||||
if provider_name == "stability":
|
||||
return StabilityImageProvider(api_key=api_key or os.getenv("STABILITY_API_KEY"))
|
||||
elif provider_name == "wavespeed":
|
||||
return WaveSpeedImageProvider(api_key=api_key or os.getenv("WAVESPEED_API_KEY"))
|
||||
elif provider_name == "huggingface":
|
||||
return HuggingFaceImageProvider(api_token=api_key or os.getenv("HF_API_KEY"))
|
||||
elif provider_name == "gemini":
|
||||
return GeminiImageProvider(api_key=api_key or os.getenv("GEMINI_API_KEY"))
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider_name}")
|
||||
# Removed _get_provider_instance() - now using unified entry point
|
||||
# Provider selection is handled by main_image_generation.generate_image()
|
||||
|
||||
def _select_provider_and_model(
|
||||
self,
|
||||
@@ -289,30 +261,17 @@ class CreateStudioService:
|
||||
logger.info("[Create Studio] Starting generation: prompt=%s, template=%s",
|
||||
request.prompt[:100], request.template_id)
|
||||
|
||||
# Pre-flight validation: Check subscription and usage limits
|
||||
if user_id:
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_generation_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
logger.info(f"[Create Studio] 🛂 Running pre-flight validation for user {user_id}")
|
||||
validate_image_generation_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id,
|
||||
num_images=request.num_variations
|
||||
)
|
||||
logger.info(f"[Create Studio] ✅ Pre-flight validation passed - proceeding with generation")
|
||||
except HTTPException as http_ex:
|
||||
logger.error(f"[Create Studio] ❌ Pre-flight validation failed - blocking generation")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
logger.warning("[Create Studio] ⚠️ No user_id provided - skipping pre-flight validation")
|
||||
# Pre-flight validation: Reuse unified helper
|
||||
# Note: Validation for num_variations will be done per-image in generate_image()
|
||||
# We validate once upfront to fail fast if user has no credits
|
||||
if user_id and request.num_variations > 0:
|
||||
from services.llm_providers.main_image_generation import _validate_image_operation
|
||||
_validate_image_operation(
|
||||
user_id=user_id,
|
||||
operation_type="create-studio-generation",
|
||||
num_operations=request.num_variations,
|
||||
log_prefix="[Create Studio]"
|
||||
)
|
||||
|
||||
# Load template if specified
|
||||
template = None
|
||||
@@ -337,36 +296,37 @@ class CreateStudioService:
|
||||
# Select provider and model
|
||||
provider_name, model = self._select_provider_and_model(request, template)
|
||||
|
||||
# Get provider instance
|
||||
try:
|
||||
provider = self._get_provider_instance(provider_name)
|
||||
except Exception as e:
|
||||
logger.error("[Create Studio] ❌ Failed to initialize provider %s: %s",
|
||||
provider_name, str(e))
|
||||
raise RuntimeError(f"Provider initialization failed: {str(e)}")
|
||||
|
||||
# Generate images
|
||||
# Generate images using unified entry point
|
||||
# This ensures consistent validation, tracking, and error handling
|
||||
results = []
|
||||
for i in range(request.num_variations):
|
||||
logger.info("[Create Studio] Generating variation %d/%d",
|
||||
i + 1, request.num_variations)
|
||||
|
||||
try:
|
||||
# Prepare options
|
||||
options = ImageGenerationOptions(
|
||||
prompt=prompt,
|
||||
negative_prompt=request.negative_prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
guidance_scale=request.guidance_scale,
|
||||
steps=request.steps,
|
||||
seed=request.seed + i if request.seed else None,
|
||||
model=model,
|
||||
extra={"style_preset": request.style_preset} if request.style_preset else {}
|
||||
)
|
||||
# Prepare options for unified entry point
|
||||
options = {
|
||||
"provider": provider_name,
|
||||
"model": model,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"negative_prompt": request.negative_prompt,
|
||||
"guidance_scale": request.guidance_scale,
|
||||
"steps": request.steps,
|
||||
"seed": request.seed + i if request.seed else None,
|
||||
}
|
||||
|
||||
# Generate image
|
||||
result: ImageGenerationResult = provider.generate(options)
|
||||
# Add style preset to extra if specified
|
||||
if request.style_preset:
|
||||
options["extra"] = {"style_preset": request.style_preset}
|
||||
|
||||
# Generate image using unified entry point
|
||||
# This handles validation, provider selection, generation, and tracking automatically
|
||||
result: ImageGenerationResult = generate_image(
|
||||
prompt=prompt,
|
||||
options=options,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
results.append({
|
||||
"image_bytes": result.image_bytes,
|
||||
|
||||
Reference in New Issue
Block a user