AI Analysis and Content Strategy fixes. Enhanced Strategy Routes refactoring.
This commit is contained in:
@@ -30,6 +30,13 @@ class WaveSpeedImageProvider(ImageGenerationProvider):
|
||||
"cost_per_image": 0.05, # Estimated, adjust based on actual pricing
|
||||
"max_resolution": (1024, 1024),
|
||||
"default_steps": 15,
|
||||
},
|
||||
"flux-kontext-pro": {
|
||||
"name": "FLUX Kontext Pro",
|
||||
"description": "Professional typography and text rendering with improved prompt adherence",
|
||||
"cost_per_image": 0.04, # $0.04 per image
|
||||
"max_resolution": (1024, 1024),
|
||||
"default_steps": 20,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -177,6 +184,55 @@ class WaveSpeedImageProvider(ImageGenerationProvider):
|
||||
logger.error("[Qwen Image] ❌ Error generating image: %s", str(e), exc_info=True)
|
||||
raise RuntimeError(f"Qwen Image generation failed: {str(e)}")
|
||||
|
||||
def _generate_flux_kontext_pro(self, options: ImageGenerationOptions) -> bytes:
|
||||
"""Generate image using FLUX Kontext Pro.
|
||||
|
||||
Args:
|
||||
options: Image generation options
|
||||
|
||||
Returns:
|
||||
Image bytes
|
||||
"""
|
||||
logger.info("[FLUX Kontext Pro] Starting image generation: %s", options.prompt[:100])
|
||||
|
||||
try:
|
||||
# Prepare parameters for WaveSpeed FLUX Kontext Pro API
|
||||
params = {
|
||||
"model": "flux-kontext-pro",
|
||||
"prompt": options.prompt,
|
||||
"width": options.width,
|
||||
"height": options.height,
|
||||
"num_inference_steps": options.steps or self.SUPPORTED_MODELS["flux-kontext-pro"]["default_steps"],
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if options.negative_prompt:
|
||||
params["negative_prompt"] = options.negative_prompt
|
||||
|
||||
if options.guidance_scale:
|
||||
params["guidance_scale"] = options.guidance_scale
|
||||
|
||||
if options.seed:
|
||||
params["seed"] = options.seed
|
||||
|
||||
# Call WaveSpeed API
|
||||
result = self.client.generate_image(**params)
|
||||
|
||||
# Extract image bytes from result
|
||||
if isinstance(result, bytes):
|
||||
image_bytes = result
|
||||
elif isinstance(result, dict) and "image" in result:
|
||||
image_bytes = result["image"]
|
||||
else:
|
||||
raise ValueError(f"Unexpected response format from WaveSpeed API: {type(result)}")
|
||||
|
||||
logger.info("[FLUX Kontext Pro] ✅ Successfully generated image: %d bytes", len(image_bytes))
|
||||
return image_bytes
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[FLUX Kontext Pro] ❌ Error generating image: %s", str(e), exc_info=True)
|
||||
raise RuntimeError(f"FLUX Kontext Pro generation failed: {str(e)}")
|
||||
|
||||
def generate(self, options: ImageGenerationOptions) -> ImageGenerationResult:
|
||||
"""Generate image using WaveSpeed AI models.
|
||||
|
||||
@@ -201,6 +257,8 @@ class WaveSpeedImageProvider(ImageGenerationProvider):
|
||||
image_bytes = self._generate_ideogram_v3(options)
|
||||
elif model == "qwen-image":
|
||||
image_bytes = self._generate_qwen_image(options)
|
||||
elif model == "flux-kontext-pro":
|
||||
image_bytes = self._generate_flux_kontext_pro(options)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model: {model}")
|
||||
|
||||
|
||||
@@ -144,6 +144,9 @@ def generate_audio(
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
logger.info(f"[audio_gen] Filtered kwargs (removed None values): {filtered_kwargs}")
|
||||
|
||||
# Track response time
|
||||
import time
|
||||
start_time = time.time()
|
||||
client = WaveSpeedClient()
|
||||
audio_bytes = client.generate_speech(
|
||||
text=text,
|
||||
@@ -155,8 +158,9 @@ def generate_audio(
|
||||
enable_sync_mode=enable_sync_mode,
|
||||
**filtered_kwargs
|
||||
)
|
||||
response_time = time.time() - start_time
|
||||
|
||||
logger.info(f"[audio_gen] ✅ API call successful, generated {len(audio_bytes)} bytes")
|
||||
logger.info(f"[audio_gen] ✅ API call successful, generated {len(audio_bytes)} bytes in {response_time:.2f}s")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -228,19 +232,29 @@ def generate_audio(
|
||||
# Create usage log
|
||||
# Store the text parameter in a local variable before any imports to prevent shadowing
|
||||
text_param = text # Capture function parameter before any potential shadowing
|
||||
|
||||
# Detect actual provider name (WaveSpeed, Google, OpenAI, etc.)
|
||||
from services.subscription.provider_detection import detect_actual_provider
|
||||
actual_provider = detect_actual_provider(
|
||||
provider_enum=APIProvider.AUDIO,
|
||||
model_name="minimax/speech-02-hd",
|
||||
endpoint="/audio-generation/wavespeed"
|
||||
)
|
||||
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.AUDIO,
|
||||
endpoint="/audio-generation/wavespeed",
|
||||
method="POST",
|
||||
model_used="minimax/speech-02-hd",
|
||||
actual_provider_name=actual_provider, # Track actual provider (WaveSpeed, etc.)
|
||||
tokens_input=character_count,
|
||||
tokens_output=0,
|
||||
tokens_total=character_count,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=estimated_cost,
|
||||
response_time=0.0,
|
||||
response_time=response_time, # Use actual response time
|
||||
status_code=200,
|
||||
request_size=len(text_param.encode("utf-8")), # Use captured parameter
|
||||
response_size=len(audio_bytes),
|
||||
|
||||
@@ -138,7 +138,8 @@ def _track_image_operation_usage(
|
||||
prompt: Optional[str] = None,
|
||||
endpoint: str = "/image-generation",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
log_prefix: str = "[Image Generation]"
|
||||
log_prefix: str = "[Image Generation]",
|
||||
response_time: float = 0.0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Reusable usage tracking helper for all image operations.
|
||||
@@ -165,6 +166,7 @@ def _track_image_operation_usage(
|
||||
db_track = next(get_db_track())
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription.provider_detection import detect_actual_provider
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
@@ -215,6 +217,13 @@ def _track_image_operation_usage(
|
||||
# Determine API provider based on actual provider
|
||||
api_provider = APIProvider.STABILITY # Default for image generation
|
||||
|
||||
# Detect actual provider name (WaveSpeed, Stability, HuggingFace, etc.)
|
||||
actual_provider = detect_actual_provider(
|
||||
provider_enum=api_provider,
|
||||
model_name=model,
|
||||
endpoint=endpoint
|
||||
)
|
||||
|
||||
# Create usage log
|
||||
request_size = len(prompt.encode("utf-8")) if prompt else 0
|
||||
usage_log = APIUsageLog(
|
||||
@@ -223,13 +232,14 @@ def _track_image_operation_usage(
|
||||
endpoint=endpoint,
|
||||
method="POST",
|
||||
model_used=model or "unknown",
|
||||
actual_provider_name=actual_provider, # Track actual provider (WaveSpeed, Stability, etc.)
|
||||
tokens_input=0,
|
||||
tokens_output=0,
|
||||
tokens_total=0,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=cost,
|
||||
response_time=0.0,
|
||||
response_time=response_time, # Use actual response time
|
||||
status_code=200,
|
||||
request_size=request_size,
|
||||
response_size=len(result_bytes),
|
||||
@@ -327,21 +337,39 @@ def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_i
|
||||
|
||||
# Normalize obvious model/provider mismatches
|
||||
model_lower = (image_options.model or "").lower()
|
||||
|
||||
# Detect Wavespeed models and remap provider if needed
|
||||
wavespeed_models = ["qwen-image", "ideogram-v3-turbo", "flux-kontext-pro"]
|
||||
if model_lower in wavespeed_models and provider_name != "wavespeed":
|
||||
logger.info("Remapping provider to wavespeed for model=%s", image_options.model)
|
||||
provider_name = "wavespeed"
|
||||
|
||||
# Detect HuggingFace models and remap provider if needed
|
||||
if provider_name == "stability" and (model_lower.startswith("black-forest-labs/") or model_lower.startswith("runwayml/") or model_lower.startswith("stabilityai/flux")):
|
||||
logger.info("Remapping provider to huggingface for model=%s", image_options.model)
|
||||
provider_name = "huggingface"
|
||||
|
||||
# Detect HuggingFace models when provider is not explicitly set
|
||||
if not opts.get("provider") and (model_lower.startswith("black-forest-labs/") or model_lower.startswith("runwayml/") or model_lower.startswith("stabilityai/flux")):
|
||||
logger.info("Auto-detecting provider as huggingface for model=%s", image_options.model)
|
||||
provider_name = "huggingface"
|
||||
|
||||
if provider_name == "huggingface" and not image_options.model:
|
||||
# Provide a sensible default HF model if none specified
|
||||
image_options.model = "black-forest-labs/FLUX.1-Krea-dev"
|
||||
|
||||
if provider_name == "wavespeed" and not image_options.model:
|
||||
# Provide a sensible default WaveSpeed model if none specified
|
||||
image_options.model = "ideogram-v3-turbo"
|
||||
# Default to cost-effective model: Qwen Image ($0.05/image, optimized for blog images)
|
||||
image_options.model = "qwen-image"
|
||||
|
||||
logger.info("Generating image via provider=%s model=%s", provider_name, image_options.model)
|
||||
provider = _get_provider(provider_name)
|
||||
|
||||
# Track response time
|
||||
import time
|
||||
start_time = time.time()
|
||||
result = provider.generate(image_options)
|
||||
response_time = time.time() - start_time
|
||||
|
||||
# TRACK USAGE after successful API call - Reuse extracted helper
|
||||
if user_id and result and result.image_bytes:
|
||||
@@ -352,12 +380,14 @@ def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_i
|
||||
if result.metadata and "estimated_cost" in result.metadata:
|
||||
estimated_cost = float(result.metadata["estimated_cost"])
|
||||
else:
|
||||
# Fallback: estimate based on provider/model
|
||||
# Fallback: estimate based on provider/model (OSS-focused pricing)
|
||||
if provider_name == "wavespeed":
|
||||
if result.model and "qwen" in result.model.lower():
|
||||
estimated_cost = 0.05
|
||||
estimated_cost = 0.05 # Qwen Image: $0.05/image
|
||||
elif result.model and "ideogram" in result.model.lower():
|
||||
estimated_cost = 0.10 # Ideogram V3 Turbo: $0.10/image
|
||||
else:
|
||||
estimated_cost = 0.10 # ideogram-v3-turbo default
|
||||
estimated_cost = 0.05 # Default to Qwen Image pricing
|
||||
elif provider_name == "stability":
|
||||
estimated_cost = 0.04
|
||||
else:
|
||||
@@ -374,7 +404,8 @@ def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_i
|
||||
prompt=prompt,
|
||||
endpoint="/image-generation",
|
||||
metadata=result.metadata,
|
||||
log_prefix="[Image Generation]"
|
||||
log_prefix="[Image Generation]",
|
||||
response_time=response_time
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Image Generation] ⚠️ Skipping usage tracking: user_id={user_id}, image_bytes={len(result.image_bytes) if result.image_bytes else 0} bytes")
|
||||
|
||||
@@ -27,6 +27,7 @@ except ImportError:
|
||||
|
||||
from ..onboarding.api_key_manager import APIKeyManager
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.provider_detection import detect_actual_provider
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("video_generation_service")
|
||||
@@ -508,6 +509,11 @@ async def ai_video_generate(
|
||||
|
||||
# Generate video based on operation type
|
||||
model_name = kwargs.get("model", _get_default_model(operation_type, provider))
|
||||
|
||||
# Track response time for video generation
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
if operation_type == "text-to-video":
|
||||
if provider == "huggingface":
|
||||
@@ -620,6 +626,7 @@ async def ai_video_generate(
|
||||
|
||||
# Track usage (same pattern as text generation)
|
||||
# Use cost from result_dict if available, otherwise calculate
|
||||
response_time = time.time() - start_time
|
||||
cost_override = result_dict.get("cost") if operation_type == "image-to-video" else kwargs.get("cost_override")
|
||||
track_video_usage(
|
||||
user_id=user_id,
|
||||
@@ -628,6 +635,7 @@ async def ai_video_generate(
|
||||
prompt=result_dict.get("prompt", prompt or ""),
|
||||
video_bytes=video_bytes,
|
||||
cost_override=cost_override,
|
||||
response_time=response_time,
|
||||
)
|
||||
|
||||
# Progress callback: Complete
|
||||
@@ -662,6 +670,7 @@ def track_video_usage(
|
||||
prompt: str,
|
||||
video_bytes: bytes,
|
||||
cost_override: Optional[float] = None,
|
||||
response_time: float = 0.0,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Track subscription usage for any video generation (text-to-video or image-to-video).
|
||||
@@ -732,19 +741,27 @@ def track_video_usage(
|
||||
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
|
||||
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
# Detect actual provider name (WaveSpeed, HuggingFace, Google, etc.)
|
||||
actual_provider = detect_actual_provider(
|
||||
provider_enum=APIProvider.VIDEO,
|
||||
model_name=model_name,
|
||||
endpoint=f"/video-generation/{provider}"
|
||||
)
|
||||
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.VIDEO,
|
||||
endpoint=f"/video-generation/{provider}",
|
||||
method="POST",
|
||||
model_used=model_name,
|
||||
actual_provider_name=actual_provider, # Track actual provider (WaveSpeed, HuggingFace, etc.)
|
||||
tokens_input=0,
|
||||
tokens_output=0,
|
||||
tokens_total=0,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=cost_per_video,
|
||||
response_time=0.0,
|
||||
response_time=response_time, # Use actual response time
|
||||
status_code=200,
|
||||
request_size=len((prompt or "").encode("utf-8")),
|
||||
response_size=len(video_bytes),
|
||||
|
||||
Reference in New Issue
Block a user