AI Image Studio Phase 1
This commit is contained in:
@@ -312,6 +312,175 @@ class WaveSpeedClient:
|
||||
logger.info(f"[WaveSpeed] Prompt optimized successfully (length: {len(optimized_prompt)} chars)")
|
||||
return optimized_prompt
|
||||
|
||||
def generate_image(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
width: int = 1024,
|
||||
height: int = 1024,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
guidance_scale: Optional[float] = None,
|
||||
negative_prompt: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
enable_sync_mode: bool = True,
|
||||
timeout: int = 120,
|
||||
**kwargs
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate image using WaveSpeed AI models (Ideogram V3 or Qwen Image).
|
||||
|
||||
Args:
|
||||
model: Model to use ("ideogram-v3-turbo" or "qwen-image")
|
||||
prompt: Text prompt for image generation
|
||||
width: Image width (default: 1024)
|
||||
height: Image height (default: 1024)
|
||||
num_inference_steps: Number of inference steps
|
||||
guidance_scale: Guidance scale for generation
|
||||
negative_prompt: Negative prompt (what to avoid)
|
||||
seed: Random seed for reproducibility
|
||||
enable_sync_mode: If True, wait for result and return it directly (default: True)
|
||||
timeout: Request timeout in seconds (default: 120)
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
bytes: Generated image bytes
|
||||
"""
|
||||
# Map model names to WaveSpeed API paths
|
||||
model_paths = {
|
||||
"ideogram-v3-turbo": "ideogram-ai/ideogram-v3-turbo",
|
||||
"qwen-image": "wavespeed-ai/qwen-image/text-to-image",
|
||||
}
|
||||
|
||||
model_path = model_paths.get(model)
|
||||
if not model_path:
|
||||
raise ValueError(f"Unsupported image model: {model}. Supported: {list(model_paths.keys())}")
|
||||
|
||||
url = f"{self.BASE_URL}/{model_path}"
|
||||
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"enable_sync_mode": enable_sync_mode,
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if num_inference_steps is not None:
|
||||
payload["num_inference_steps"] = num_inference_steps
|
||||
if guidance_scale is not None:
|
||||
payload["guidance_scale"] = guidance_scale
|
||||
if negative_prompt:
|
||||
payload["negative_prompt"] = negative_prompt
|
||||
if seed is not None:
|
||||
payload["seed"] = seed
|
||||
|
||||
# Add any extra parameters
|
||||
for key, value in kwargs.items():
|
||||
if key not in payload:
|
||||
payload[key] = value
|
||||
|
||||
logger.info(f"[WaveSpeed] Generating image via {url} (model={model}, prompt_length={len(prompt)})")
|
||||
response = requests.post(url, headers=self._headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Image generation failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed image generation failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
# Handle sync mode - result should be directly in outputs
|
||||
if enable_sync_mode:
|
||||
outputs = data.get("outputs") or []
|
||||
if not outputs:
|
||||
logger.error(f"[WaveSpeed] No outputs in sync mode response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed image generator returned no outputs",
|
||||
)
|
||||
|
||||
# Extract image URL from outputs
|
||||
image_url = None
|
||||
if isinstance(outputs, list) and len(outputs) > 0:
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, str):
|
||||
image_url = first_output
|
||||
elif isinstance(first_output, dict):
|
||||
image_url = first_output.get("url") or first_output.get("output")
|
||||
|
||||
if not image_url or not (image_url.startswith("http://") or image_url.startswith("https://")):
|
||||
logger.error(f"[WaveSpeed] Invalid image URL in outputs: {outputs}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed image generator output format not recognized",
|
||||
)
|
||||
|
||||
# Fetch image bytes from URL
|
||||
logger.info(f"[WaveSpeed] Fetching image from URL: {image_url}")
|
||||
image_response = requests.get(image_url, timeout=timeout)
|
||||
if image_response.status_code == 200:
|
||||
image_bytes = image_response.content
|
||||
logger.info(f"[WaveSpeed] Image generated successfully (size: {len(image_bytes)} bytes)")
|
||||
return image_bytes
|
||||
else:
|
||||
logger.error(f"[WaveSpeed] Failed to fetch image from URL: {image_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to fetch generated image from WaveSpeed URL",
|
||||
)
|
||||
|
||||
# Async mode - poll for result
|
||||
prediction_id = data.get("id")
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed] No prediction ID in async response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed response missing prediction id for async mode",
|
||||
)
|
||||
|
||||
# Poll for result
|
||||
result = self.poll_until_complete(prediction_id, timeout_seconds=240, interval_seconds=1.0)
|
||||
outputs = result.get("outputs") or []
|
||||
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed image generator returned no outputs")
|
||||
|
||||
# Extract image URL and fetch
|
||||
image_url = None
|
||||
if isinstance(outputs, list) and len(outputs) > 0:
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, str):
|
||||
image_url = first_output
|
||||
elif isinstance(first_output, dict):
|
||||
image_url = first_output.get("url") or first_output.get("output")
|
||||
|
||||
if not image_url or not (image_url.startswith("http://") or image_url.startswith("https://")):
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed image generator output format not recognized",
|
||||
)
|
||||
|
||||
# Fetch image bytes
|
||||
logger.info(f"[WaveSpeed] Fetching image from URL: {image_url}")
|
||||
image_response = requests.get(image_url, timeout=timeout)
|
||||
if image_response.status_code == 200:
|
||||
image_bytes = image_response.content
|
||||
logger.info(f"[WaveSpeed] Image generated successfully (size: {len(image_bytes)} bytes)")
|
||||
return image_bytes
|
||||
else:
|
||||
logger.error(f"[WaveSpeed] Failed to fetch image from URL: {image_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to fetch generated image from WaveSpeed URL",
|
||||
)
|
||||
|
||||
def generate_speech(
|
||||
self,
|
||||
text: str,
|
||||
|
||||
Reference in New Issue
Block a user