AI podcast project
This commit is contained in:
@@ -137,6 +137,9 @@ def generate_audio(
|
||||
|
||||
# Generate audio using WaveSpeed
|
||||
try:
|
||||
# Avoid passing duplicate enable_sync_mode; allow override via kwargs
|
||||
enable_sync_mode = kwargs.pop("enable_sync_mode", True)
|
||||
|
||||
client = WaveSpeedClient()
|
||||
audio_bytes = client.generate_speech(
|
||||
text=text,
|
||||
@@ -145,7 +148,7 @@ def generate_audio(
|
||||
volume=volume,
|
||||
pitch=pitch,
|
||||
emotion=emotion,
|
||||
enable_sync_mode=True,
|
||||
enable_sync_mode=enable_sync_mode,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
11
backend/services/podcast/__init__.py
Normal file
11
backend/services/podcast/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Podcast Services Module
|
||||
|
||||
Dedicated services for podcast generation functionality.
|
||||
Separate from story writer services to maintain clear separation of concerns.
|
||||
"""
|
||||
|
||||
from .video_combination_service import PodcastVideoCombinationService
|
||||
|
||||
__all__ = ["PodcastVideoCombinationService"]
|
||||
|
||||
382
backend/services/podcast/video_combination_service.py
Normal file
382
backend/services/podcast/video_combination_service.py
Normal file
@@ -0,0 +1,382 @@
|
||||
"""
|
||||
Podcast Video Combination Service
|
||||
|
||||
Dedicated service for combining podcast scene videos into final episodes.
|
||||
Separate from StoryVideoGenerationService to avoid breaking story writer functionality.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
import warnings
|
||||
import time
|
||||
import threading
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class PodcastVideoCombinationService:
|
||||
"""Service for combining podcast scene videos into final episodes."""
|
||||
|
||||
def __init__(self, output_dir: Optional[str] = None):
|
||||
"""
|
||||
Initialize the podcast video combination service.
|
||||
|
||||
Parameters:
|
||||
output_dir (str, optional): Directory to save combined videos.
|
||||
Defaults to 'backend/podcast_videos/Final_Videos' if not provided.
|
||||
"""
|
||||
if output_dir:
|
||||
self.output_dir = Path(output_dir)
|
||||
else:
|
||||
# Default to podcast_videos/Final_Videos directory
|
||||
base_dir = Path(__file__).parent.parent.parent
|
||||
self.output_dir = base_dir / "podcast_videos" / "Final_Videos"
|
||||
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"[PodcastVideoCombination] Initialized with output directory: {self.output_dir}")
|
||||
|
||||
def combine_videos(
|
||||
self,
|
||||
video_paths: List[str],
|
||||
podcast_title: str,
|
||||
fps: int = 30,
|
||||
progress_callback: Optional[callable] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Combine multiple video files into a single final podcast video.
|
||||
|
||||
This method is specifically designed for podcast videos that already have
|
||||
embedded audio. It does not require separate audio files.
|
||||
|
||||
Parameters:
|
||||
video_paths (List[str]): List of video file paths to combine.
|
||||
podcast_title (str): Title of the podcast episode.
|
||||
fps (int): Frames per second for output video (default: 30).
|
||||
progress_callback (callable, optional): Callback function for progress updates.
|
||||
Signature: callback(progress: float, message: str)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Video metadata including file path, URL, duration, and file size.
|
||||
|
||||
Raises:
|
||||
ValueError: If no valid video files are provided.
|
||||
RuntimeError: If video combination fails.
|
||||
"""
|
||||
if not video_paths:
|
||||
raise ValueError("No video paths provided")
|
||||
|
||||
# Validate all video files exist
|
||||
valid_video_paths = []
|
||||
for video_path in video_paths:
|
||||
path = Path(video_path)
|
||||
if path.exists() and path.is_file():
|
||||
valid_video_paths.append(str(path))
|
||||
else:
|
||||
logger.warning(f"[PodcastVideoCombination] Video not found: {video_path}")
|
||||
|
||||
if not valid_video_paths:
|
||||
raise ValueError("No valid video files found to combine")
|
||||
|
||||
logger.info(f"[PodcastVideoCombination] Combining {len(valid_video_paths)} videos")
|
||||
|
||||
try:
|
||||
# Import MoviePy
|
||||
try:
|
||||
from moviepy import VideoFileClip, concatenate_videoclips
|
||||
except Exception as e:
|
||||
logger.error(f"[PodcastVideoCombination] MoviePy not installed: {e}")
|
||||
raise RuntimeError("MoviePy is not installed. Please install it to combine videos.")
|
||||
|
||||
# Suppress MoviePy warnings about incomplete frames (common with some video encodings)
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="moviepy")
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(10.0, "Loading video clips...")
|
||||
|
||||
# Load all video clips
|
||||
video_clips = []
|
||||
total_duration = 0.0
|
||||
|
||||
for idx, video_path in enumerate(valid_video_paths):
|
||||
try:
|
||||
logger.info(f"[PodcastVideoCombination] Loading video {idx + 1}/{len(valid_video_paths)}: {video_path}")
|
||||
|
||||
# Load video clip with error handling for incomplete files
|
||||
# MoviePy will use the last valid frame if frames are missing at the end
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
video_clip = VideoFileClip(str(video_path))
|
||||
|
||||
# Validate clip was loaded successfully
|
||||
if video_clip.duration <= 0:
|
||||
logger.warning(f"[PodcastVideoCombination] Video {video_path} has invalid duration, skipping")
|
||||
video_clip.close()
|
||||
continue
|
||||
|
||||
# Videos already have embedded audio, no need to replace
|
||||
video_clips.append(video_clip)
|
||||
total_duration += video_clip.duration
|
||||
|
||||
if progress_callback:
|
||||
progress = 10.0 + ((idx + 1) / len(valid_video_paths)) * 60.0
|
||||
progress_callback(progress, f"Loaded video {idx + 1}/{len(valid_video_paths)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[PodcastVideoCombination] Failed to load video {video_path}: {e}")
|
||||
# Continue with other videos instead of failing completely
|
||||
continue
|
||||
|
||||
if not video_clips:
|
||||
raise RuntimeError("No valid video clips were loaded")
|
||||
|
||||
logger.info(f"[PodcastVideoCombination] Loaded {len(video_clips)} clips, total duration: {total_duration:.2f}s")
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(75.0, f"Concatenating {len(video_clips)} videos ({total_duration:.1f}s total)...")
|
||||
|
||||
# Concatenate all video clips
|
||||
logger.info(f"[PodcastVideoCombination] Concatenating {len(video_clips)} video clips (total duration: {total_duration:.2f}s)")
|
||||
final_video = concatenate_videoclips(video_clips, method="compose")
|
||||
logger.info(f"[PodcastVideoCombination] Concatenation complete, final video duration: {final_video.duration:.2f}s")
|
||||
|
||||
# Generate output filename
|
||||
video_filename = self._generate_video_filename(podcast_title)
|
||||
video_path = self.output_dir / video_filename
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(85.0, f"Rendering final video ({total_duration:.1f}s total)...")
|
||||
|
||||
# Write final video file
|
||||
logger.info(
|
||||
f"[PodcastVideoCombination] Rendering final video to: {video_path} "
|
||||
f"(duration: {total_duration:.2f}s, {len(video_clips)} clips)"
|
||||
)
|
||||
|
||||
# Use faster preset for quicker encoding (still good quality)
|
||||
# 'ultrafast' is fastest but lower quality, 'fast' is good balance
|
||||
encoding_preset = 'fast' # Faster than 'medium' but still good quality
|
||||
|
||||
# Suppress warnings during video writing as well
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
|
||||
# Write video with optimized settings
|
||||
# Note: write_videofile is blocking and can take several minutes for longer videos
|
||||
# Estimated time: ~1-2 minutes per minute of video content
|
||||
estimated_time_minutes = max(1, int(total_duration / 60) * 2)
|
||||
logger.info(
|
||||
f"[PodcastVideoCombination] Starting video encoding "
|
||||
f"(estimated time: ~{estimated_time_minutes} minutes for {total_duration:.1f}s video)..."
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Start a thread to update progress periodically during encoding
|
||||
# Since write_videofile is blocking, we'll simulate progress
|
||||
progress_thread = None
|
||||
encoding_done = threading.Event()
|
||||
|
||||
if progress_callback:
|
||||
def update_progress_periodically():
|
||||
"""Update progress every 5 seconds during encoding"""
|
||||
base_progress = 87.0
|
||||
max_progress = 98.0
|
||||
progress_range = max_progress - base_progress
|
||||
update_interval = 5.0 # Update every 5 seconds
|
||||
elapsed = 0.0
|
||||
|
||||
try:
|
||||
while not encoding_done.is_set():
|
||||
elapsed += update_interval
|
||||
# Simulate progress: start at 87%, gradually increase to 98%
|
||||
# Use logarithmic curve to slow down as we approach completion
|
||||
progress = base_progress + (progress_range * min(1.0, elapsed / (estimated_time_minutes * 60)))
|
||||
progress = min(max_progress, progress)
|
||||
|
||||
remaining_minutes = max(0, estimated_time_minutes - int(elapsed / 60))
|
||||
message = f"Encoding video... ({remaining_minutes} min remaining)"
|
||||
if remaining_minutes == 0:
|
||||
message = "Finalizing video..."
|
||||
|
||||
try:
|
||||
progress_callback(progress, message)
|
||||
except Exception as e:
|
||||
logger.warning(f"[PodcastVideoCombination] Error in progress callback: {e}")
|
||||
break
|
||||
|
||||
# Use wait with timeout instead of sleep to check event more frequently
|
||||
if encoding_done.wait(timeout=update_interval):
|
||||
break # Event was set, exit immediately
|
||||
except Exception as e:
|
||||
logger.warning(f"[PodcastVideoCombination] Error in progress thread: {e}")
|
||||
|
||||
progress_thread = threading.Thread(target=update_progress_periodically, daemon=True)
|
||||
progress_thread.start()
|
||||
|
||||
# Write video file - this is the blocking operation
|
||||
logger.info(f"[PodcastVideoCombination] Calling write_videofile...")
|
||||
try:
|
||||
final_video.write_videofile(
|
||||
str(video_path),
|
||||
fps=fps,
|
||||
codec='libx264',
|
||||
audio_codec='aac',
|
||||
preset=encoding_preset, # Faster encoding
|
||||
threads=4,
|
||||
logger=None, # Disable MoviePy's default logger
|
||||
bitrate=None, # Let encoder choose optimal bitrate
|
||||
audio_bitrate='192k', # Good quality audio
|
||||
temp_audiofile=str(video_path.with_suffix('.m4a')), # Temporary audio file
|
||||
remove_temp=True, # Clean up temp files
|
||||
write_logfile=False, # Don't write log file
|
||||
)
|
||||
logger.info(f"[PodcastVideoCombination] write_videofile completed successfully")
|
||||
except Exception as write_error:
|
||||
logger.error(f"[PodcastVideoCombination] Error in write_videofile: {write_error}")
|
||||
# Check if file was created despite error
|
||||
if video_path.exists() and video_path.stat().st_size > 0:
|
||||
logger.warning(f"[PodcastVideoCombination] Video file exists despite error, continuing...")
|
||||
else:
|
||||
raise
|
||||
finally:
|
||||
# Always signal that encoding is done - don't wait for progress thread
|
||||
if progress_thread:
|
||||
encoding_done.set()
|
||||
# Don't join - let it finish on its own (daemon thread)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"[PodcastVideoCombination] Video encoding completed in {elapsed_time:.1f} seconds "
|
||||
f"({elapsed_time/60:.1f} minutes)"
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(99.0, "Video encoding complete! Finalizing...")
|
||||
|
||||
# Verify file was created and get file size
|
||||
# Use retry logic in case file is still being written
|
||||
max_retries = 5
|
||||
file_size = 0
|
||||
for retry in range(max_retries):
|
||||
if video_path.exists():
|
||||
file_size = video_path.stat().st_size
|
||||
if file_size > 0:
|
||||
break
|
||||
if retry < max_retries - 1:
|
||||
logger.info(f"[PodcastVideoCombination] Waiting for video file to be written (retry {retry + 1}/{max_retries})...")
|
||||
time.sleep(1)
|
||||
|
||||
if not video_path.exists():
|
||||
raise RuntimeError(f"Video file was not created: {video_path}")
|
||||
|
||||
if file_size == 0:
|
||||
raise RuntimeError(f"Video file is empty: {video_path}")
|
||||
|
||||
logger.info(f"[PodcastVideoCombination] Video file verified: {video_path} ({file_size} bytes)")
|
||||
|
||||
# Clean up clips immediately but quickly - don't block
|
||||
# Close clips synchronously but with timeout protection
|
||||
try:
|
||||
final_video.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"[PodcastVideoCombination] Error closing final video clip: {e}")
|
||||
|
||||
# Close individual clips quickly
|
||||
for clip in video_clips:
|
||||
try:
|
||||
clip.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"[PodcastVideoCombination] Error closing video clip: {e}")
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(100.0, "Video combination complete!")
|
||||
|
||||
logger.info(f"[PodcastVideoCombination] Saved combined video to: {video_path} ({file_size} bytes)")
|
||||
|
||||
# Return video metadata immediately - don't wait for cleanup
|
||||
# This prevents blocking if cleanup hangs
|
||||
return {
|
||||
"video_path": str(video_path),
|
||||
"video_filename": video_filename,
|
||||
"video_url": f"/api/podcast/final-videos/{video_filename}",
|
||||
"duration": total_duration,
|
||||
"fps": fps,
|
||||
"file_size": file_size,
|
||||
"num_scenes": len(video_clips),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"[PodcastVideoCombination] Error combining videos: {e}")
|
||||
raise RuntimeError(f"Failed to combine videos: {str(e)}") from e
|
||||
|
||||
def save_scene_video(self, video_bytes: bytes, scene_number: int, user_id: str) -> Dict[str, str]:
|
||||
"""
|
||||
Save a single scene video to disk.
|
||||
|
||||
This is a utility method for saving individual scene videos before combination.
|
||||
Separate from story writer to maintain clear separation of concerns.
|
||||
|
||||
Parameters:
|
||||
video_bytes (bytes): Raw video file bytes.
|
||||
scene_number (int): Scene number for filename.
|
||||
user_id (str): User ID for filename.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: Dictionary with 'video_filename', 'video_path', 'video_url', and 'file_size'.
|
||||
"""
|
||||
import uuid
|
||||
|
||||
try:
|
||||
# Generate unique filename matching story writer format
|
||||
clean_user_id = "".join(c if c.isalnum() or c in ('-', '_') else '_' for c in user_id[:16])
|
||||
timestamp = str(uuid.uuid4())[:8]
|
||||
video_filename = f"scene_{scene_number}_{clean_user_id}_{timestamp}.mp4"
|
||||
|
||||
# Save to AI_Videos subdirectory (scene videos before combination)
|
||||
# output_dir is Final_Videos, so parent is podcast_videos, then AI_Videos
|
||||
scene_videos_dir = self.output_dir.parent / "AI_Videos"
|
||||
scene_videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
video_path = scene_videos_dir / video_filename
|
||||
|
||||
# Write video bytes to file
|
||||
with open(video_path, "wb") as f:
|
||||
f.write(video_bytes)
|
||||
|
||||
file_size = video_path.stat().st_size
|
||||
logger.info(f"[PodcastVideoCombination] Saved scene {scene_number} video: {video_filename} ({file_size} bytes)")
|
||||
|
||||
# Generate URL path (relative to /api/podcast/videos/)
|
||||
video_url = f"/api/podcast/videos/{video_filename}"
|
||||
|
||||
return {
|
||||
"video_filename": video_filename,
|
||||
"video_url": video_url,
|
||||
"video_path": str(video_path),
|
||||
"file_size": file_size,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[PodcastVideoCombination] Error saving scene video: {e}", exc_info=True)
|
||||
raise RuntimeError(f"Failed to save scene video: {str(e)}") from e
|
||||
|
||||
def _generate_video_filename(self, podcast_title: str) -> str:
|
||||
"""
|
||||
Generate a unique filename for the combined video.
|
||||
|
||||
Parameters:
|
||||
podcast_title (str): Title of the podcast episode.
|
||||
|
||||
Returns:
|
||||
str: Generated filename.
|
||||
"""
|
||||
# Sanitize title for filename
|
||||
safe_title = "".join(c for c in podcast_title if c.isalnum() or c in (' ', '-', '_')).strip()
|
||||
safe_title = safe_title.replace(' ', '_')[:50] # Limit length
|
||||
|
||||
# Add unique ID and timestamp
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
timestamp = int(Path(__file__).stat().st_mtime) # Use file modification time as simple timestamp
|
||||
|
||||
return f"podcast_{safe_title}_{unique_id}_{timestamp}.mp4"
|
||||
|
||||
@@ -301,6 +301,12 @@ class StoryAudioGenerationService:
|
||||
pitch: float = 0.0,
|
||||
emotion: str = "happy",
|
||||
english_normalization: bool = False,
|
||||
sample_rate: Optional[int] = None,
|
||||
bitrate: Optional[int] = None,
|
||||
channel: Optional[str] = None,
|
||||
format: Optional[str] = None,
|
||||
language_boost: Optional[str] = None,
|
||||
enable_sync_mode: Optional[bool] = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate AI audio for a single scene using main_audio_generation.
|
||||
@@ -340,6 +346,12 @@ class StoryAudioGenerationService:
|
||||
emotion=emotion,
|
||||
user_id=user_id,
|
||||
english_normalization=english_normalization,
|
||||
sample_rate=sample_rate,
|
||||
bitrate=bitrate,
|
||||
channel=channel,
|
||||
format=format,
|
||||
language_boost=language_boost,
|
||||
enable_sync_mode=enable_sync_mode,
|
||||
)
|
||||
|
||||
# Save audio to file
|
||||
|
||||
@@ -252,10 +252,14 @@ class StoryVideoGenerationService:
|
||||
if len(scenes) != len(audio_paths):
|
||||
raise ValueError("Number of scenes and audio paths must match")
|
||||
|
||||
video_paths = video_paths or [None] * len(scenes)
|
||||
if len(video_paths) != len(scenes):
|
||||
# Ensure video_paths is a list and matches scenes length
|
||||
if video_paths is None:
|
||||
video_paths = [None] * len(scenes)
|
||||
elif len(video_paths) != len(scenes):
|
||||
video_paths = video_paths + [None] * (len(scenes) - len(video_paths))
|
||||
|
||||
logger.debug(f"[StoryVideoGeneration] video_paths length: {len(video_paths)}, scenes length: {len(scenes)}")
|
||||
|
||||
try:
|
||||
logger.info(f"[StoryVideoGeneration] Generating story video for {len(scenes)} scenes")
|
||||
|
||||
@@ -311,49 +315,64 @@ class StoryVideoGenerationService:
|
||||
scene_title = scene.get("title", "Untitled")
|
||||
|
||||
logger.info(f"[StoryVideoGeneration] Processing scene {scene_number}/{len(scenes)}: {scene_title}")
|
||||
|
||||
audio_file = Path(audio_path)
|
||||
if not audio_file.exists():
|
||||
logger.warning(f"[StoryVideoGeneration] Audio not found: {audio_path}, skipping scene {scene_number}")
|
||||
continue
|
||||
|
||||
# Load audio
|
||||
audio_clip = AudioFileClip(str(audio_file))
|
||||
audio_duration = audio_clip.duration
|
||||
logger.debug(f"[StoryVideoGeneration] Scene {scene_number} paths - video: {video_path}, audio: {audio_path}, image: {image_path}")
|
||||
|
||||
# Prefer animated video if available
|
||||
if video_path and Path(video_path).exists():
|
||||
# Check video_path is not None and is a valid string before calling Path()
|
||||
if video_path is not None and isinstance(video_path, (str, Path)) and video_path and Path(video_path).exists():
|
||||
logger.info(f"[StoryVideoGeneration] Using animated video for scene {scene_number}: {video_path}")
|
||||
# Load animated video
|
||||
if VideoFileClip is None:
|
||||
raise RuntimeError("VideoFileClip not available - MoviePy may not be fully installed")
|
||||
video_clip = VideoFileClip(str(video_path))
|
||||
# Replace audio with the preferred audio (AI or free)
|
||||
video_clip = video_clip.with_audio(audio_clip)
|
||||
# Match duration to audio if needed
|
||||
if video_clip.duration > audio_duration:
|
||||
video_clip = video_clip.subclip(0, audio_duration)
|
||||
elif video_clip.duration < audio_duration:
|
||||
# Loop the video if it's shorter than audio
|
||||
loops_needed = int(audio_duration / video_clip.duration) + 1
|
||||
video_clip = concatenate_videoclips([video_clip] * loops_needed).subclip(0, audio_duration)
|
||||
|
||||
# Handle audio: use embedded audio if no separate audio_path provided
|
||||
if audio_path is not None and isinstance(audio_path, (str, Path)) and audio_path and Path(audio_path).exists():
|
||||
# Load separate audio file and replace video's audio
|
||||
logger.info(f"[StoryVideoGeneration] Replacing video audio with separate audio file: {audio_path}")
|
||||
audio_clip = AudioFileClip(str(audio_path))
|
||||
audio_duration = audio_clip.duration
|
||||
video_clip = video_clip.with_audio(audio_clip)
|
||||
elif image_path and Path(image_path).exists():
|
||||
# Fall back to static image
|
||||
logger.info(f"[StoryVideoGeneration] Using static image for scene {scene_number}: {image_path}")
|
||||
image_file = Path(image_path)
|
||||
# Create image clip (MoviePy v2: use with_* API)
|
||||
image_clip = ImageClip(str(image_file)).with_duration(audio_duration)
|
||||
image_clip = image_clip.with_fps(fps)
|
||||
# Set audio to image clip
|
||||
video_clip = image_clip.with_audio(audio_clip)
|
||||
# Match duration to audio if needed
|
||||
if video_clip.duration > audio_duration:
|
||||
video_clip = video_clip.subclip(0, audio_duration)
|
||||
elif video_clip.duration < audio_duration:
|
||||
# Loop the video if it's shorter than audio
|
||||
loops_needed = int(audio_duration / video_clip.duration) + 1
|
||||
video_clip = concatenate_videoclips([video_clip] * loops_needed).subclip(0, audio_duration)
|
||||
video_clip = video_clip.with_audio(audio_clip)
|
||||
else:
|
||||
# Use embedded audio from video
|
||||
logger.info(f"[StoryVideoGeneration] Using embedded audio from video for scene {scene_number}")
|
||||
audio_duration = video_clip.duration
|
||||
# Video already has audio, no need to replace
|
||||
|
||||
scene_clips.append(video_clip)
|
||||
total_duration += audio_duration
|
||||
elif audio_path is not None and isinstance(audio_path, (str, Path)) and audio_path and Path(audio_path).exists():
|
||||
# No video, but we have audio - use with image or create blank
|
||||
audio_file = Path(audio_path)
|
||||
audio_clip = AudioFileClip(str(audio_file))
|
||||
audio_duration = audio_clip.duration
|
||||
|
||||
if image_path is not None and isinstance(image_path, (str, Path)) and image_path and Path(image_path).exists():
|
||||
# Fall back to static image with audio
|
||||
logger.info(f"[StoryVideoGeneration] Using static image for scene {scene_number}: {image_path}")
|
||||
image_file = Path(image_path)
|
||||
# Create image clip (MoviePy v2: use with_* API)
|
||||
image_clip = ImageClip(str(image_file)).with_duration(audio_duration)
|
||||
image_clip = image_clip.with_fps(fps)
|
||||
# Set audio to image clip
|
||||
video_clip = image_clip.with_audio(audio_clip)
|
||||
scene_clips.append(video_clip)
|
||||
total_duration += audio_duration
|
||||
else:
|
||||
logger.warning(f"[StoryVideoGeneration] Audio provided but no video or image for scene {scene_number}, skipping")
|
||||
continue
|
||||
else:
|
||||
logger.warning(f"[StoryVideoGeneration] No video or image found for scene {scene_number}, skipping")
|
||||
logger.warning(f"[StoryVideoGeneration] No video, audio, or image found for scene {scene_number}, skipping")
|
||||
continue
|
||||
|
||||
scene_clips.append(video_clip)
|
||||
total_duration += audio_duration
|
||||
|
||||
# Call progress callback if provided
|
||||
if progress_callback:
|
||||
progress = ((idx + 1) / len(scenes)) * 90 # Reserve 10% for final composition
|
||||
@@ -362,7 +381,12 @@ class StoryVideoGenerationService:
|
||||
logger.info(f"[StoryVideoGeneration] Processed scene {idx + 1}/{len(scenes)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[StoryVideoGeneration] Failed to process scene {idx + 1}: {e}")
|
||||
logger.error(
|
||||
f"[StoryVideoGeneration] Failed to process scene {idx + 1} ({scene_number}): {e}\n"
|
||||
f" video_path: {video_path} (type: {type(video_path)})\n"
|
||||
f" audio_path: {audio_path} (type: {type(audio_path)})\n"
|
||||
f" image_path: {image_path} (type: {type(image_path)})"
|
||||
)
|
||||
# Continue with next scene instead of failing completely
|
||||
continue
|
||||
|
||||
|
||||
@@ -71,13 +71,16 @@ class WaveSpeedClient:
|
||||
logger.info(f"[WaveSpeed] Submitted request: {prediction_id}")
|
||||
return prediction_id
|
||||
|
||||
def get_prediction_result(self, prediction_id: str, timeout: int = 120) -> Dict[str, Any]:
|
||||
def get_prediction_result(self, prediction_id: str, timeout: int = 30) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch the current status/result for a prediction.
|
||||
Matches the example pattern: simple GET request, check status_code == 200, return data.
|
||||
"""
|
||||
url = f"{self.BASE_URL}/predictions/{prediction_id}/result"
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers={"Authorization": f"Bearer {self.api_key}"}, timeout=timeout)
|
||||
response = requests.get(url, headers=headers, timeout=timeout)
|
||||
except requests_exceptions.Timeout as exc:
|
||||
raise HTTPException(
|
||||
status_code=504,
|
||||
@@ -98,7 +101,15 @@ class WaveSpeedClient:
|
||||
"exception": str(exc),
|
||||
},
|
||||
) from exc
|
||||
if response.status_code != 200:
|
||||
|
||||
# Match example pattern: check status_code == 200, then get data
|
||||
if response.status_code == 200:
|
||||
result = response.json().get("data")
|
||||
if not result:
|
||||
raise HTTPException(status_code=502, detail={"error": "WaveSpeed polling response missing data"})
|
||||
return result
|
||||
else:
|
||||
# Non-200 status - log and raise error (matching example's break behavior)
|
||||
logger.error(f"[WaveSpeed] Polling failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
@@ -109,59 +120,116 @@ class WaveSpeedClient:
|
||||
},
|
||||
)
|
||||
|
||||
result = response.json().get("data")
|
||||
if not result:
|
||||
raise HTTPException(status_code=502, detail={"error": "WaveSpeed polling response missing data"})
|
||||
return result
|
||||
|
||||
def poll_until_complete(
|
||||
self,
|
||||
prediction_id: str,
|
||||
timeout_seconds: int = 240,
|
||||
timeout_seconds: Optional[int] = None,
|
||||
interval_seconds: float = 1.0,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Poll WaveSpeed until the job completes, fails, or times out.
|
||||
Poll WaveSpeed until the job completes or fails.
|
||||
Matches the example pattern: simple polling loop until status is "completed" or "failed".
|
||||
|
||||
Args:
|
||||
prediction_id: The prediction ID to poll for
|
||||
timeout_seconds: Optional timeout in seconds. If None, polls indefinitely until completion/failure.
|
||||
interval_seconds: Seconds to wait between polling attempts (default: 1.0, faster than 2.0)
|
||||
|
||||
Returns:
|
||||
Dict containing the completed result
|
||||
|
||||
Raises:
|
||||
HTTPException: If the task fails, polling fails, or times out (if timeout_seconds is set)
|
||||
"""
|
||||
start_time = time.time()
|
||||
consecutive_errors = 0
|
||||
max_consecutive_errors = 6 # safety guard for non-transient errors
|
||||
|
||||
while True:
|
||||
try:
|
||||
result = self.get_prediction_result(prediction_id)
|
||||
consecutive_errors = 0 # Reset error counter on success
|
||||
except HTTPException as exc:
|
||||
detail = exc.detail or {}
|
||||
if isinstance(detail, dict):
|
||||
detail.setdefault("prediction_id", prediction_id)
|
||||
detail.setdefault("resume_available", True)
|
||||
detail.setdefault("error", detail.get("error", "WaveSpeed polling failed"))
|
||||
raise HTTPException(status_code=exc.status_code, detail=detail) from exc
|
||||
|
||||
# Determine underlying status code (WaveSpeed vs proxy)
|
||||
status_code = detail.get("status_code", exc.status_code)
|
||||
|
||||
# Treat 5xx as transient: keep polling indefinitely with backoff
|
||||
if 500 <= int(status_code) < 600:
|
||||
consecutive_errors += 1
|
||||
backoff = min(30.0, interval_seconds * (2 ** (consecutive_errors - 1)))
|
||||
logger.warning(
|
||||
f"[WaveSpeed] Transient polling error {consecutive_errors} for {prediction_id}: "
|
||||
f"{status_code}. Backing off {backoff:.1f}s"
|
||||
)
|
||||
time.sleep(backoff)
|
||||
continue
|
||||
|
||||
# For non-transient (typically 4xx) errors, apply safety cap
|
||||
consecutive_errors += 1
|
||||
if consecutive_errors >= max_consecutive_errors:
|
||||
logger.error(
|
||||
f"[WaveSpeed] Too many polling errors ({consecutive_errors}) for {prediction_id}, "
|
||||
f"status_code={status_code}. Giving up."
|
||||
)
|
||||
raise HTTPException(status_code=exc.status_code, detail=detail) from exc
|
||||
|
||||
backoff = min(30.0, interval_seconds * (2 ** (consecutive_errors - 1)))
|
||||
logger.warning(
|
||||
f"[WaveSpeed] Polling error {consecutive_errors}/{max_consecutive_errors} for {prediction_id}: "
|
||||
f"{status_code}. Backing off {backoff:.1f}s"
|
||||
)
|
||||
time.sleep(backoff)
|
||||
continue
|
||||
|
||||
# Extract status from result (matching example pattern)
|
||||
status = result.get("status")
|
||||
|
||||
if status == "completed":
|
||||
logger.info(f"[WaveSpeed] Prediction {prediction_id} completed.")
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"[WaveSpeed] Prediction {prediction_id} completed in {elapsed:.1f}s")
|
||||
return result
|
||||
|
||||
if status == "failed":
|
||||
logger.error(f"[WaveSpeed] Prediction {prediction_id} failed: {result.get('error')}")
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
logger.error(f"[WaveSpeed] Prediction {prediction_id} failed: {error_msg}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed animation failed",
|
||||
"prediction_id": prediction_id,
|
||||
"details": result.get("error"),
|
||||
},
|
||||
)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > timeout_seconds:
|
||||
logger.error(f"[WaveSpeed] Prediction {prediction_id} timed out after {timeout_seconds}s")
|
||||
raise HTTPException(
|
||||
status_code=504,
|
||||
detail={
|
||||
"error": "WaveSpeed animation timed out",
|
||||
"error": "WaveSpeed task failed",
|
||||
"prediction_id": prediction_id,
|
||||
"message": error_msg,
|
||||
"details": result,
|
||||
},
|
||||
)
|
||||
|
||||
logger.debug(f"[WaveSpeed] Prediction {prediction_id} status={status}. Waiting...")
|
||||
# Check timeout only if specified
|
||||
if timeout_seconds is not None:
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > timeout_seconds:
|
||||
logger.error(f"[WaveSpeed] Prediction {prediction_id} timed out after {timeout_seconds}s")
|
||||
raise HTTPException(
|
||||
status_code=504,
|
||||
detail={
|
||||
"error": "WaveSpeed task timed out",
|
||||
"prediction_id": prediction_id,
|
||||
"timeout_seconds": timeout_seconds,
|
||||
"current_status": status,
|
||||
"message": f"Task did not complete within {timeout_seconds} seconds. Status: {status}",
|
||||
},
|
||||
)
|
||||
|
||||
# Log progress periodically (every 30 seconds)
|
||||
elapsed = time.time() - start_time
|
||||
if int(elapsed) % 30 == 0 and elapsed > 0:
|
||||
logger.info(f"[WaveSpeed] Polling {prediction_id}: status={status}, elapsed={elapsed:.0f}s")
|
||||
|
||||
# Poll faster (1.0s instead of 2.0s) to match example's responsiveness
|
||||
time.sleep(interval_seconds)
|
||||
|
||||
def optimize_prompt(
|
||||
@@ -469,7 +537,9 @@ class WaveSpeedClient:
|
||||
|
||||
# Fetch image bytes
|
||||
logger.info(f"[WaveSpeed] Fetching image from URL: {image_url}")
|
||||
image_response = requests.get(image_url, timeout=timeout)
|
||||
# Use reasonable timeout for downloading the final image (60s should be enough)
|
||||
# The timeout parameter is for polling, not for downloading
|
||||
image_response = requests.get(image_url, timeout=60)
|
||||
if image_response.status_code == 200:
|
||||
image_bytes = image_response.content
|
||||
logger.info(f"[WaveSpeed] Image generated successfully (size: {len(image_bytes)} bytes)")
|
||||
@@ -481,6 +551,208 @@ class WaveSpeedClient:
|
||||
detail="Failed to fetch generated image from WaveSpeed URL",
|
||||
)
|
||||
|
||||
def generate_character_image(
|
||||
self,
|
||||
prompt: str,
|
||||
reference_image_bytes: bytes,
|
||||
style: str = "Auto",
|
||||
aspect_ratio: str = "16:9",
|
||||
rendering_speed: str = "Default",
|
||||
timeout: Optional[int] = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate image using Ideogram Character API to maintain character consistency.
|
||||
Creates variations of a reference character image while respecting the base appearance.
|
||||
|
||||
Note: This API is always async and requires polling for results.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt describing the scene/context for the character
|
||||
reference_image_bytes: Reference image bytes (base avatar)
|
||||
style: Character style type ("Auto", "Fiction", or "Realistic")
|
||||
aspect_ratio: Aspect ratio ("1:1", "16:9", "9:16", "4:3", "3:4")
|
||||
rendering_speed: Rendering speed ("Default", "Turbo", "Quality")
|
||||
timeout: Total timeout in seconds for submission + polling (default: 180)
|
||||
|
||||
Returns:
|
||||
bytes: Generated image bytes with consistent character
|
||||
"""
|
||||
import base64
|
||||
|
||||
# Encode reference image to base64
|
||||
image_base64 = base64.b64encode(reference_image_bytes).decode('utf-8')
|
||||
# Add data URI prefix
|
||||
image_data_uri = f"data:image/png;base64,{image_base64}"
|
||||
|
||||
url = f"{self.BASE_URL}/ideogram-ai/ideogram-character"
|
||||
|
||||
# Note: enable_sync_mode is not a valid parameter for Ideogram Character API
|
||||
# The API is always async and requires polling
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"image": image_data_uri,
|
||||
"style": style,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"rendering_speed": rendering_speed,
|
||||
}
|
||||
|
||||
logger.info(f"[WaveSpeed] Generating character image via Ideogram Character (prompt_length={len(prompt)})")
|
||||
# POST request should return quickly with just the task ID
|
||||
# Use reasonable timeouts for the initial submission
|
||||
# Connection timeout: 30s (increased for reliability - network may be slow)
|
||||
# Read timeout: 30s (should be enough to get task ID response)
|
||||
# Retry logic for transient connection failures
|
||||
max_retries = 2
|
||||
retry_delay = 2.0 # seconds
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=self._headers(),
|
||||
json=payload,
|
||||
timeout=(30, 30) # (connect_timeout, read_timeout) - increased for network reliability
|
||||
)
|
||||
break # Success, exit retry loop
|
||||
except (requests_exceptions.ConnectTimeout, requests_exceptions.ConnectionError) as e:
|
||||
if attempt < max_retries:
|
||||
logger.warning(f"[WaveSpeed] Connection attempt {attempt + 1}/{max_retries + 1} failed, retrying in {retry_delay}s: {e}")
|
||||
time.sleep(retry_delay)
|
||||
retry_delay *= 2 # Exponential backoff
|
||||
continue
|
||||
else:
|
||||
# Final attempt failed
|
||||
error_type = "Connection timeout" if isinstance(e, requests_exceptions.ConnectTimeout) else "Connection error"
|
||||
logger.error(f"[WaveSpeed] {error_type} to Ideogram Character API after {max_retries + 1} attempts: {e}")
|
||||
raise HTTPException(
|
||||
status_code=504 if isinstance(e, requests_exceptions.ConnectTimeout) else 502,
|
||||
detail={
|
||||
"error": f"{error_type} to WaveSpeed Ideogram Character API",
|
||||
"message": "Unable to establish connection to the image generation service after multiple attempts. Please check your network connection and try again.",
|
||||
"exception": str(e),
|
||||
"retry_recommended": True,
|
||||
},
|
||||
)
|
||||
except requests_exceptions.Timeout as e:
|
||||
logger.error(f"[WaveSpeed] Request timeout to Ideogram Character API: {e}")
|
||||
raise HTTPException(
|
||||
status_code=504,
|
||||
detail={
|
||||
"error": "Request timeout to WaveSpeed Ideogram Character API",
|
||||
"message": "The image generation request took too long. Please try again.",
|
||||
"exception": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Character image generation failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed Ideogram Character generation failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
# Extract prediction ID
|
||||
prediction_id = data.get("id")
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed] No prediction ID in response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed Ideogram Character response missing prediction id",
|
||||
)
|
||||
|
||||
# Ideogram Character API is always async - check status and poll if needed
|
||||
outputs = data.get("outputs") or []
|
||||
status = data.get("status", "unknown")
|
||||
|
||||
logger.info(f"[WaveSpeed] Ideogram Character task created: prediction_id={prediction_id}, status={status}")
|
||||
|
||||
# If status is already completed, use outputs directly (unlikely but possible)
|
||||
if outputs and status == "completed":
|
||||
logger.info(f"[WaveSpeed] Got immediate results from Ideogram Character")
|
||||
else:
|
||||
# Always need to poll for results (API is async)
|
||||
logger.info(f"[WaveSpeed] Polling for Ideogram Character result (status: {status}, prediction_id: {prediction_id})")
|
||||
# Poll until complete - use timeout if provided, otherwise poll indefinitely
|
||||
# Match example pattern exactly: simple while True loop, check status, break on completed/failed
|
||||
polling_timeout = timeout if timeout else None # None means poll indefinitely
|
||||
result = self.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=polling_timeout,
|
||||
interval_seconds=0.5, # Poll every 0.5s (closer to example's 0.1s)
|
||||
)
|
||||
# Safely extract outputs and status
|
||||
if not isinstance(result, dict):
|
||||
logger.error(f"[WaveSpeed] Unexpected result type: {type(result)}, value: {result}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed Ideogram Character returned unexpected response format",
|
||||
)
|
||||
|
||||
outputs = result.get("outputs") or []
|
||||
status = result.get("status", "unknown")
|
||||
|
||||
if status != "completed":
|
||||
# Safely extract error message
|
||||
error_msg = "Unknown error"
|
||||
if isinstance(result, dict):
|
||||
error_msg = result.get("error") or result.get("message") or str(result.get("details", "Unknown error"))
|
||||
else:
|
||||
error_msg = str(result)
|
||||
|
||||
logger.error(f"[WaveSpeed] Ideogram Character task did not complete: status={status}, error={error_msg}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed Ideogram Character task failed",
|
||||
"status": status,
|
||||
"message": error_msg,
|
||||
}
|
||||
)
|
||||
|
||||
# Extract image URL from outputs
|
||||
if not outputs:
|
||||
logger.error(f"[WaveSpeed] No outputs after polling: status={status}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed Ideogram Character returned no outputs",
|
||||
)
|
||||
|
||||
image_url = None
|
||||
if isinstance(outputs, list) and len(outputs) > 0:
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, str):
|
||||
image_url = first_output
|
||||
elif isinstance(first_output, dict):
|
||||
image_url = first_output.get("url") or first_output.get("image_url")
|
||||
|
||||
if not image_url:
|
||||
logger.error(f"[WaveSpeed] No image URL in outputs: {outputs}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed Ideogram Character response missing image URL",
|
||||
)
|
||||
|
||||
# Download image
|
||||
logger.info(f"[WaveSpeed] Downloading character image from: {image_url}")
|
||||
image_response = requests.get(image_url, timeout=60)
|
||||
if image_response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Failed to download image: {image_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to download generated character image",
|
||||
)
|
||||
|
||||
image_bytes = image_response.content
|
||||
logger.info(f"[WaveSpeed] ✅ Successfully generated character image: {len(image_bytes)} bytes")
|
||||
return image_bytes
|
||||
|
||||
def generate_speech(
|
||||
self,
|
||||
text: str,
|
||||
@@ -490,7 +762,7 @@ class WaveSpeedClient:
|
||||
pitch: float = 0.0,
|
||||
emotion: str = "happy",
|
||||
enable_sync_mode: bool = True,
|
||||
timeout: int = 60,
|
||||
timeout: int = 120,
|
||||
**kwargs
|
||||
) -> bytes:
|
||||
"""
|
||||
@@ -537,7 +809,51 @@ class WaveSpeedClient:
|
||||
payload[param] = kwargs[param]
|
||||
|
||||
logger.info(f"[WaveSpeed] Generating speech via {url} (voice={voice_id}, text_length={len(text)})")
|
||||
response = requests.post(url, headers=self._headers(), json=payload, timeout=timeout)
|
||||
|
||||
# Retry on transient connection issues
|
||||
max_retries = 2
|
||||
retry_delay = 2.0
|
||||
last_error = None
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=self._headers(),
|
||||
json=payload,
|
||||
timeout=(30, 60), # connect, read
|
||||
)
|
||||
break
|
||||
except (requests_exceptions.ConnectTimeout, requests_exceptions.ConnectionError) as e:
|
||||
last_error = e
|
||||
if attempt < max_retries:
|
||||
logger.warning(
|
||||
f"[WaveSpeed] Speech connection attempt {attempt + 1}/{max_retries + 1} failed, "
|
||||
f"retrying in {retry_delay}s: {e}"
|
||||
)
|
||||
time.sleep(retry_delay)
|
||||
retry_delay *= 2
|
||||
continue
|
||||
logger.error(f"[WaveSpeed] Speech connection failed after {max_retries + 1} attempts: {e}")
|
||||
raise HTTPException(
|
||||
status_code=504,
|
||||
detail={
|
||||
"error": "Connection to WaveSpeed speech API timed out",
|
||||
"message": "Unable to reach the speech service. Please try again.",
|
||||
"exception": str(e),
|
||||
"retry_recommended": True,
|
||||
},
|
||||
)
|
||||
except requests_exceptions.Timeout as e:
|
||||
last_error = e
|
||||
logger.error(f"[WaveSpeed] Speech request timeout: {e}")
|
||||
raise HTTPException(
|
||||
status_code=504,
|
||||
detail={
|
||||
"error": "WaveSpeed speech request timed out",
|
||||
"message": "The speech generation request took too long. Please try again.",
|
||||
"exception": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Speech generation failed: {response.status_code} {response.text}")
|
||||
|
||||
@@ -8,7 +8,6 @@ from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from .client import WaveSpeedClient
|
||||
from .kling_animation import generate_animation_prompt
|
||||
|
||||
INFINITALK_MODEL_PATH = "wavespeed-ai/infinitetalk"
|
||||
INFINITALK_MODEL_NAME = "wavespeed-ai/infinitetalk"
|
||||
@@ -22,6 +21,67 @@ def _as_data_uri(content_bytes: bytes, mime_type: str) -> str:
|
||||
return f"data:{mime_type};base64,{encoded}"
|
||||
|
||||
|
||||
def _generate_simple_infinitetalk_prompt(
|
||||
scene_data: Dict[str, Any],
|
||||
story_context: Dict[str, Any],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Generate a balanced, concise prompt for InfiniteTalk.
|
||||
InfiniteTalk is audio-driven, so the prompt should describe the scene and suggest
|
||||
subtle motion, but avoid overly elaborate cinematic descriptions.
|
||||
|
||||
Returns None if no meaningful prompt can be generated.
|
||||
"""
|
||||
title = (scene_data.get("title") or "").strip()
|
||||
description = (scene_data.get("description") or "").strip()
|
||||
image_prompt = (scene_data.get("image_prompt") or "").strip()
|
||||
|
||||
# Build a balanced prompt: scene description + simple motion hint
|
||||
parts = []
|
||||
|
||||
# Start with the main subject/scene
|
||||
if title and len(title) > 5 and title.lower() not in ("scene", "podcast", "episode"):
|
||||
parts.append(title)
|
||||
elif description:
|
||||
# Take first sentence or first 60 chars
|
||||
desc_part = description.split('.')[0][:60].strip()
|
||||
if desc_part:
|
||||
parts.append(desc_part)
|
||||
elif image_prompt:
|
||||
# Take first sentence or first 60 chars
|
||||
img_part = image_prompt.split('.')[0][:60].strip()
|
||||
if img_part:
|
||||
parts.append(img_part)
|
||||
|
||||
if not parts:
|
||||
return None
|
||||
|
||||
# Add a simple, subtle motion suggestion (not elaborate camera movements)
|
||||
# Keep it natural and audio-driven
|
||||
motion_hints = [
|
||||
"with subtle movement",
|
||||
"with gentle motion",
|
||||
"with natural animation",
|
||||
]
|
||||
|
||||
# Combine scene description with subtle motion hint
|
||||
if len(parts[0]) < 80:
|
||||
# Room for a motion hint
|
||||
prompt = f"{parts[0]}, {motion_hints[0]}"
|
||||
else:
|
||||
# Just use the description if it's already long enough
|
||||
prompt = parts[0]
|
||||
|
||||
# Keep it concise - max 120 characters (allows for scene + motion hint)
|
||||
prompt = prompt[:120].strip()
|
||||
|
||||
# Clean up trailing commas or incomplete sentences
|
||||
if prompt.endswith(','):
|
||||
prompt = prompt[:-1].strip()
|
||||
|
||||
return prompt if len(prompt) >= 15 else None
|
||||
|
||||
|
||||
def animate_scene_with_voiceover(
|
||||
*,
|
||||
image_bytes: bytes,
|
||||
@@ -31,6 +91,8 @@ def animate_scene_with_voiceover(
|
||||
user_id: str,
|
||||
resolution: str = "720p",
|
||||
prompt_override: Optional[str] = None,
|
||||
mask_image_bytes: Optional[bytes] = None,
|
||||
seed: Optional[int] = -1,
|
||||
image_mime: str = "image/png",
|
||||
audio_mime: str = "audio/mpeg",
|
||||
client: Optional[WaveSpeedClient] = None,
|
||||
@@ -59,21 +121,28 @@ def animate_scene_with_voiceover(
|
||||
if resolution not in {"480p", "720p"}:
|
||||
raise HTTPException(status_code=400, detail="Resolution must be '480p' or '720p'.")
|
||||
|
||||
animation_prompt = prompt_override or generate_animation_prompt(scene_data, story_context, user_id)
|
||||
# Generate simple, concise prompt for InfiniteTalk (audio-driven, less need for elaborate descriptions)
|
||||
animation_prompt = prompt_override or _generate_simple_infinitetalk_prompt(scene_data, story_context)
|
||||
|
||||
payload = {
|
||||
payload: Dict[str, Any] = {
|
||||
"image": _as_data_uri(image_bytes, image_mime),
|
||||
"audio": _as_data_uri(audio_bytes, audio_mime),
|
||||
"resolution": resolution,
|
||||
}
|
||||
# Only include prompt if we have a meaningful one (InfiniteTalk works fine without it)
|
||||
if animation_prompt:
|
||||
payload["prompt"] = animation_prompt
|
||||
if mask_image_bytes:
|
||||
payload["mask_image"] = _as_data_uri(mask_image_bytes, image_mime)
|
||||
if seed is not None:
|
||||
payload["seed"] = seed
|
||||
|
||||
client = client or WaveSpeedClient()
|
||||
prediction_id = client.submit_image_to_video(INFINITALK_MODEL_PATH, payload, timeout=60)
|
||||
|
||||
try:
|
||||
result = client.poll_until_complete(prediction_id, timeout_seconds=600, interval_seconds=1.0)
|
||||
# Poll faster (0.5s) to mirror reference pattern; allow up to 10 minutes
|
||||
result = client.poll_until_complete(prediction_id, timeout_seconds=600, interval_seconds=0.5)
|
||||
except HTTPException as exc:
|
||||
detail = exc.detail or {}
|
||||
if isinstance(detail, dict):
|
||||
|
||||
Reference in New Issue
Block a user