Save local changes (GSC/Bing integrations) before merging PR #354
This commit is contained in:
@@ -5,6 +5,7 @@ import sys
|
||||
import base64
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import HTTPException
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
|
||||
from .image_generation import (
|
||||
@@ -29,6 +30,11 @@ logger = get_service_logger("image_generation.facade")
|
||||
def _select_provider(explicit: Optional[str]) -> str:
|
||||
if explicit:
|
||||
return explicit
|
||||
|
||||
# User requested WaveSpeed as default provider
|
||||
if os.getenv("WAVESPEED_API_KEY"):
|
||||
return "wavespeed"
|
||||
|
||||
gpt_provider = (os.getenv("GPT_PROVIDER") or "").lower()
|
||||
if gpt_provider.startswith("gemini"):
|
||||
return "gemini"
|
||||
@@ -36,8 +42,7 @@ def _select_provider(explicit: Optional[str]) -> str:
|
||||
return "huggingface"
|
||||
if os.getenv("STABILITY_API_KEY"):
|
||||
return "stability"
|
||||
if os.getenv("WAVESPEED_API_KEY"):
|
||||
return "wavespeed"
|
||||
|
||||
# Fallback to huggingface to enable a path if configured
|
||||
return "huggingface"
|
||||
|
||||
@@ -739,18 +744,139 @@ async def generate_image_with_provider(
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in generate_image_with_provider: {e}")
|
||||
# Propagate specific error message if available
|
||||
error_detail = str(e)
|
||||
if "402" in error_detail or "Payment Required" in error_detail:
|
||||
raise HTTPException(status_code=402, detail=f"Payment Required: {error_detail}")
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
"error": error_detail
|
||||
}
|
||||
|
||||
|
||||
import time
|
||||
from services.database import get_session_for_user
|
||||
from models.onboarding import WebsiteAnalysis, OnboardingSession, CompetitorAnalysis
|
||||
|
||||
async def enhance_image_prompt(prompt: str, user_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Enhance image prompt using LLM.
|
||||
Placeholder implementation.
|
||||
Enhance image prompt using WaveSpeed's specialized prompt optimizer.
|
||||
Restructures and enriches prompts for visual clarity and cinematic detail.
|
||||
Uses Step 2 (Website Analysis) and Step 3 (Competitor Analysis) context if available.
|
||||
"""
|
||||
return prompt
|
||||
start_time = time.time()
|
||||
try:
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
|
||||
# 1. Pre-flight Validation
|
||||
if user_id:
|
||||
_validate_image_operation(
|
||||
user_id=user_id,
|
||||
operation_type="prompt-enhancement",
|
||||
num_operations=1,
|
||||
log_prefix="[Prompt Enhancement]"
|
||||
)
|
||||
|
||||
# 2. Fetch Context from Step 2 & 3
|
||||
context_instruction = ""
|
||||
if user_id:
|
||||
try:
|
||||
db_session = get_session_for_user(user_id)
|
||||
try:
|
||||
# Get Onboarding Session
|
||||
session = db_session.query(OnboardingSession).filter(
|
||||
OnboardingSession.user_id == user_id
|
||||
).first()
|
||||
|
||||
if session:
|
||||
# Step 2: Website Analysis
|
||||
website_analysis = db_session.query(WebsiteAnalysis).filter(
|
||||
WebsiteAnalysis.session_id == session.id
|
||||
).first()
|
||||
|
||||
if website_analysis:
|
||||
# Handle potential JSON or dict types
|
||||
brand_voice = website_analysis.brand_analysis
|
||||
style = website_analysis.style_guidelines
|
||||
target_audience = website_analysis.target_audience
|
||||
|
||||
context_instruction += "\n\nCONTEXT FROM WEBSITE ANALYSIS:\n"
|
||||
if target_audience:
|
||||
context_instruction += f"Target Audience: {target_audience}\n"
|
||||
|
||||
if brand_voice and isinstance(brand_voice, dict):
|
||||
context_instruction += f"Brand Voice: {brand_voice.get('voice_characteristics', '')} - {brand_voice.get('tone', '')}\n"
|
||||
|
||||
if style and isinstance(style, dict):
|
||||
context_instruction += f"Visual Style: {style.get('visual_style', '')} - {style.get('color_palette', '')}\n"
|
||||
|
||||
# Step 3: Competitor Analysis (Limit to top 3)
|
||||
competitors = db_session.query(CompetitorAnalysis).filter(
|
||||
CompetitorAnalysis.session_id == session.id
|
||||
).limit(3).all()
|
||||
|
||||
if competitors:
|
||||
context_instruction += "\nCOMPETITOR VISUAL INSIGHTS:\n"
|
||||
for comp in competitors:
|
||||
if comp.analysis_data and isinstance(comp.analysis_data, dict):
|
||||
comp_title = comp.analysis_data.get('title', 'Competitor')
|
||||
# Try to extract visual/content insights if available
|
||||
highlights = comp.analysis_data.get('highlights', [])
|
||||
if highlights:
|
||||
context_instruction += f"- {comp_title}: {', '.join(highlights[:2])}\n"
|
||||
|
||||
finally:
|
||||
db_session.close()
|
||||
except Exception as db_ex:
|
||||
logger.warning(f"Failed to fetch context for prompt enhancement: {db_ex}")
|
||||
|
||||
# Combine prompt with context
|
||||
full_input_text = prompt
|
||||
if context_instruction:
|
||||
logger.info(f"Enhancing prompt for user {user_id} with Step 2/3 context")
|
||||
# We append context as instruction for the optimizer
|
||||
full_input_text = f"Original Request: {prompt}\n\n{context_instruction}\n\nTask: Generate a hyper-personalized, detailed image generation prompt based on the Original Request and the provided Context. Ensure the visual style aligns with the Brand Voice and Visual Style."
|
||||
else:
|
||||
logger.info(f"Enhancing prompt for user {user_id} (no context found)")
|
||||
|
||||
# 3. Call WaveSpeed
|
||||
client = WaveSpeedClient()
|
||||
# Use 'image' mode for avatar/image generation workflows
|
||||
# Use 'photographic' style as requested for avatars
|
||||
optimized_prompt = client.optimize_prompt(
|
||||
text=full_input_text,
|
||||
mode="image",
|
||||
style="photographic",
|
||||
enable_sync_mode=True,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
# 4. Track Usage
|
||||
if user_id:
|
||||
duration = time.time() - start_time
|
||||
# Track as 0 cost for now unless we have specific pricing for prompt opt
|
||||
# But we track it as an operation
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider="wavespeed",
|
||||
model="wavespeed-prompt-opt",
|
||||
operation_type="prompt-enhancement",
|
||||
result_bytes=b"", # No image
|
||||
cost=0.0,
|
||||
prompt=prompt,
|
||||
endpoint="/enhance-prompt",
|
||||
metadata={"duration": duration, "context_added": bool(context_instruction)},
|
||||
log_prefix="[Prompt Enhancement]",
|
||||
response_time=duration
|
||||
)
|
||||
|
||||
return optimized_prompt
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to enhance prompt via WaveSpeed: {e}")
|
||||
# Fallback to original prompt on failure
|
||||
return prompt
|
||||
|
||||
|
||||
async def generate_image_variation(
|
||||
@@ -760,13 +886,123 @@ async def generate_image_variation(
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate variation of an existing image.
|
||||
Placeholder implementation.
|
||||
Generate variation of an existing image using image-to-image editing.
|
||||
Wrapper for step4_asset_routes.
|
||||
"""
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Not implemented yet"
|
||||
}
|
||||
try:
|
||||
# Handle image input (bytes, file, or base64)
|
||||
image_bytes = None
|
||||
if isinstance(image, bytes):
|
||||
image_bytes = image
|
||||
elif hasattr(image, "read"):
|
||||
image_bytes = await image.read()
|
||||
elif isinstance(image, str):
|
||||
# Assume base64 or path
|
||||
if os.path.exists(image):
|
||||
with open(image, "rb") as f:
|
||||
image_bytes = f.read()
|
||||
else:
|
||||
# Try base64 decode
|
||||
try:
|
||||
if "base64," in image:
|
||||
image = image.split("base64,")[1]
|
||||
image_bytes = base64.b64decode(image)
|
||||
except:
|
||||
pass
|
||||
|
||||
if not image_bytes:
|
||||
return {"success": False, "error": "Invalid image input"}
|
||||
|
||||
# Convert to base64 for internal function
|
||||
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
|
||||
|
||||
# Use generate_image_edit with "variation" intent
|
||||
# For variation, we typically use general_edit with specific prompt
|
||||
result = await run_in_threadpool(
|
||||
generate_image_edit,
|
||||
image_base64=image_base64,
|
||||
prompt=prompt,
|
||||
operation="general_edit",
|
||||
model=kwargs.get("model", "qwen-edit-plus"), # Default to capable model
|
||||
options=kwargs,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
result_base64 = base64.b64encode(result.image_bytes).decode('utf-8')
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"image_base64": result_base64,
|
||||
"metadata": result.metadata
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in generate_image_variation: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
async def generate_image_enhance(
|
||||
image: Any,
|
||||
user_id: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Enhance/Upscale an existing image.
|
||||
Wrapper for step4_asset_routes.
|
||||
"""
|
||||
try:
|
||||
# Handle image input
|
||||
image_bytes = None
|
||||
if isinstance(image, bytes):
|
||||
image_bytes = image
|
||||
elif hasattr(image, "read"):
|
||||
image_bytes = await image.read()
|
||||
elif isinstance(image, str):
|
||||
if os.path.exists(image):
|
||||
with open(image, "rb") as f:
|
||||
image_bytes = f.read()
|
||||
else:
|
||||
try:
|
||||
if "base64," in image:
|
||||
image = image.split("base64,")[1]
|
||||
image_bytes = base64.b64decode(image)
|
||||
except:
|
||||
pass
|
||||
|
||||
if not image_bytes:
|
||||
return {"success": False, "error": "Invalid image input"}
|
||||
|
||||
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
|
||||
|
||||
# Use generate_image_edit with "enhance" intent
|
||||
# Use high-res model like nano-banana-pro-edit-ultra
|
||||
result = await run_in_threadpool(
|
||||
generate_image_edit,
|
||||
image_base64=image_base64,
|
||||
prompt="enhance details, high resolution, professional quality, 4k, sharp focus",
|
||||
operation="general_edit",
|
||||
model="nano-banana-pro-edit-ultra",
|
||||
options={**kwargs, "resolution": "4k"},
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
result_base64 = base64.b64encode(result.image_bytes).decode('utf-8')
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"image_base64": result_base64,
|
||||
"metadata": result.metadata
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in generate_image_enhance: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user