Podcast Maker: Fix progress modals, research JSON, header stepper, voice/podcastMode chips
This commit is contained in:
@@ -250,10 +250,6 @@ def huggingface_text_response(
|
||||
|
||||
logger.info("🚀 Making Hugging Face API call (chat completion)...")
|
||||
|
||||
# Add rate limiting to prevent expensive API calls
|
||||
import time
|
||||
time.sleep(1) # 1 second delay between API calls
|
||||
|
||||
response = None
|
||||
last_error = None
|
||||
for candidate_model in _fallback_model_sequence(model):
|
||||
@@ -403,10 +399,6 @@ def huggingface_structured_json_response(
|
||||
json_schema_str = json.dumps(schema, indent=2)
|
||||
messages[-1]["content"] += f"\n\nJSON Schema:\n{json_schema_str}"
|
||||
|
||||
# Add rate limiting to prevent expensive API calls
|
||||
import time
|
||||
time.sleep(1) # 1 second delay between API calls
|
||||
|
||||
try:
|
||||
response = None
|
||||
last_error = None
|
||||
|
||||
@@ -6,6 +6,7 @@ migrated from the legacy lib/gpt_providers/text_generation/main_text_generation.
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
@@ -211,7 +212,7 @@ def llm_text_gen(
|
||||
provider_enum = APIProvider.MISTRAL # HuggingFace maps to Mistral enum for usage tracking
|
||||
actual_provider_name = "huggingface" # Keep actual provider name for logs
|
||||
elif gpt_provider == "wavespeed":
|
||||
provider_enum = APIProvider.OPENAI # Map to OpenAI for tracking purposes
|
||||
provider_enum = APIProvider.WAVESPEED
|
||||
actual_provider_name = "wavespeed"
|
||||
elif gpt_provider == "openai":
|
||||
provider_enum = APIProvider.OPENAI
|
||||
@@ -225,6 +226,8 @@ def llm_text_gen(
|
||||
if not user_id:
|
||||
raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
|
||||
sub_check_start = time.time()
|
||||
logger.warning(f"[llm_text_gen][{flow_tag}] Subscription check START for user {user_id}")
|
||||
try:
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import UsageTrackingService, PricingService
|
||||
@@ -286,6 +289,8 @@ def llm_text_gen(
|
||||
logger.info(f"[llm_text_gen] Subscription check passed for user {user_id}: provider={actual_provider_name or gpt_provider}, tokens_requested={estimated_total_tokens}, new_user_no_usage_record")
|
||||
|
||||
finally:
|
||||
sub_check_ms = (time.time() - sub_check_start) * 1000
|
||||
logger.warning(f"[llm_text_gen][{flow_tag}] Subscription check took {sub_check_ms:.0f}ms for user {user_id}")
|
||||
db.close()
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions (e.g., 429 subscription limit) - preserve error details
|
||||
@@ -295,7 +300,8 @@ def llm_text_gen(
|
||||
raise
|
||||
except Exception as sub_error:
|
||||
# STRICT: Fail on subscription check errors
|
||||
logger.error(f"[llm_text_gen] Subscription check failed for user {user_id}: {sub_error}")
|
||||
sub_check_ms = (time.time() - sub_check_start) * 1000
|
||||
logger.error(f"[llm_text_gen][{flow_tag}] Subscription check FAILED after {sub_check_ms:.0f}ms for user {user_id}: {sub_error}")
|
||||
raise RuntimeError(f"Subscription check failed: {str(sub_error)}")
|
||||
|
||||
# Construct the system prompt if not provided
|
||||
@@ -366,6 +372,7 @@ def llm_text_gen(
|
||||
)
|
||||
elif gpt_provider == "wavespeed":
|
||||
from services.llm_providers.wavespeed_provider import wavespeed_text_response
|
||||
llm_start = time.time()
|
||||
response_text = wavespeed_text_response(
|
||||
prompt=prompt,
|
||||
model=model or "openai/gpt-oss-120b",
|
||||
@@ -374,6 +381,8 @@ def llm_text_gen(
|
||||
top_p=top_p,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
llm_ms = (time.time() - llm_start) * 1000
|
||||
logger.warning(f"[llm_text_gen][{flow_tag}] LLM API call took {llm_ms:.0f}ms for user {user_id} (wavespeed)")
|
||||
else:
|
||||
logger.error(f"[llm_text_gen] Unknown provider: {gpt_provider}")
|
||||
raise RuntimeError(f"Unknown LLM provider: {gpt_provider}. Supported providers: google, huggingface, wavespeed")
|
||||
|
||||
@@ -274,10 +274,6 @@ def wavespeed_text_response(
|
||||
|
||||
logger.info("🚀 Making WaveSpeed API call (chat completion)...")
|
||||
|
||||
# Add rate limiting to prevent expensive API calls
|
||||
import time
|
||||
time.sleep(1) # 1 second delay between API calls
|
||||
|
||||
# Call exactly the requested model; no retries, no fallbacks, no variants
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
@@ -426,10 +422,6 @@ def wavespeed_structured_json_response(
|
||||
json_schema_str = json.dumps(schema, indent=2)
|
||||
messages[-1]["content"] += f"\n\nJSON Schema:\n{json_schema_str}"
|
||||
|
||||
# Add rate limiting to prevent expensive API calls
|
||||
import time
|
||||
time.sleep(1) # 1 second delay between API calls
|
||||
|
||||
try:
|
||||
response = None
|
||||
last_error = None
|
||||
|
||||
623
backend/services/podcast/broll_composer.py
Normal file
623
backend/services/podcast/broll_composer.py
Normal file
@@ -0,0 +1,623 @@
|
||||
"""
|
||||
Programmatic B-Roll Composer
|
||||
Layered composition pipeline: Background + Chart + Avatar Circle + Text Overlays
|
||||
"""
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from moviepy import (
|
||||
VideoFileClip, ImageClip, CompositeVideoClip,
|
||||
concatenate_videoclips,
|
||||
)
|
||||
import moviepy.video.fx as vfx
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Crossfade concat (Option 1: crossfadein + negative padding)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def crossfade_concat(scenes: list, fade_dur: float = 0.5):
|
||||
"""
|
||||
Concatenate scenes with a dissolve transition between each pair.
|
||||
|
||||
Each clip (except the first) gets a crossfadein effect.
|
||||
padding=-fade_dur overlaps consecutive clips so the fade actually fires
|
||||
instead of creating a black gap. set_duration on every scene is
|
||||
mandatory — CompositeVideoClip.duration can be ambiguous without it,
|
||||
which makes the overlap math wrong.
|
||||
"""
|
||||
faded = []
|
||||
for i, clip in enumerate(scenes):
|
||||
c = clip
|
||||
if i > 0:
|
||||
c = c.fx(vfx.CrossFadeIn, fade_dur)
|
||||
faded.append(c)
|
||||
return concatenate_videoclips(faded, padding=-int(fade_dur), method="compose")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data structures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class Insight:
|
||||
key_insight: str
|
||||
supporting_stat: str
|
||||
visual_cue: str # bar_chart_comparison | line_trend | bullet_points | full_avatar
|
||||
audio_tone: str
|
||||
chart_data: dict = field(default_factory=dict)
|
||||
duration: float = 10.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class SceneAssets:
|
||||
background_img: str
|
||||
chart_img: Optional[str] = None
|
||||
avatar_video: Optional[str] = None
|
||||
bullet_img: Optional[str] = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Chart generator (Matplotlib → PNG with transparency)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CHART_STYLE = {
|
||||
"bg": "#0D0D0D",
|
||||
"bar_before": "#2E4057",
|
||||
"bar_after": "#E63946",
|
||||
"text": "#F1F1EF",
|
||||
"grid": "#2A2A2A",
|
||||
"accent": "#E63946",
|
||||
"pie_colors": ["#E63946", "#2E4057", "#457B9D", "#A8DADC", "#F4A261", "#2A9D8F"],
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Chart generators (Matplotlib → PNG with transparency)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_bar_chart(data: dict, out_path: str, title: str = "",
|
||||
show_legend: bool = True, value_suffix: str = "%",
|
||||
subtitle: str = "") -> str:
|
||||
"""Render a side-by-side comparison bar chart. Returns output path."""
|
||||
labels = data.get("labels", [])
|
||||
before = data.get("before", [])
|
||||
after = data.get("after", [])
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 4.5), facecolor="none")
|
||||
ax.set_facecolor("none")
|
||||
|
||||
x = np.arange(len(labels))
|
||||
w = 0.35
|
||||
bars_b = ax.bar(x - w / 2, before, w, color=CHART_STYLE["bar_before"],
|
||||
label="Before", zorder=3, edgecolor="none")
|
||||
bars_a = ax.bar(x + w / 2, after, w, color=CHART_STYLE["bar_after"],
|
||||
label="After", zorder=3, edgecolor="none")
|
||||
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(labels, color=CHART_STYLE["text"], fontsize=11)
|
||||
ax.tick_params(axis="y", colors=CHART_STYLE["text"])
|
||||
ax.spines[:].set_visible(False)
|
||||
ax.yaxis.grid(True, color=CHART_STYLE["grid"], linewidth=0.6, zorder=0)
|
||||
ax.set_axisbelow(True)
|
||||
|
||||
for bar in [*bars_b, *bars_a]:
|
||||
h = bar.get_height()
|
||||
ax.text(bar.get_x() + bar.get_width() / 2, h + 0.5, f"{h:.0f}{value_suffix}",
|
||||
ha="center", va="bottom", color=CHART_STYLE["text"], fontsize=9,
|
||||
fontweight="bold")
|
||||
|
||||
if show_legend:
|
||||
legend = ax.legend(frameon=False, labelcolor=CHART_STYLE["text"],
|
||||
fontsize=10, loc="upper left")
|
||||
|
||||
# Add title and optional subtitle
|
||||
if title:
|
||||
ax.set_title(title, color=CHART_STYLE["text"], fontsize=13,
|
||||
fontweight="bold", pad=12)
|
||||
if subtitle:
|
||||
fig.text(0.5, 0.02, subtitle, ha='center', color=CHART_STYLE["text"],
|
||||
fontsize=10, style='italic')
|
||||
|
||||
fig.tight_layout(pad=0.5, rect=(0, 0.03 if subtitle else 0, 1, 1))
|
||||
fig.savefig(out_path, dpi=150, transparent=True, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
return out_path
|
||||
|
||||
|
||||
def make_horizontal_bar(data: dict, out_path: str, title: str = "",
|
||||
value_suffix: str = "%", bar_color: str = None) -> str:
|
||||
"""Render a horizontal bar chart (good for rankings/lists)."""
|
||||
labels = data.get("labels", [])
|
||||
values = data.get("values", data.get("y", []))
|
||||
|
||||
if not values:
|
||||
return ""
|
||||
|
||||
bar_color = bar_color or CHART_STYLE["bar_after"]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 4.5), facecolor="none")
|
||||
ax.set_facecolor("none")
|
||||
|
||||
y_pos = np.arange(len(labels))
|
||||
bars = ax.barh(y_pos, values, color=bar_color, zorder=3, edgecolor="none", height=0.6)
|
||||
|
||||
ax.set_yticks(y_pos)
|
||||
ax.set_yticklabels(labels, color=CHART_STYLE["text"], fontsize=11)
|
||||
ax.tick_params(axis="x", colors=CHART_STYLE["text"])
|
||||
ax.spines[:].set_visible(False)
|
||||
ax.xaxis.grid(True, color=CHART_STYLE["grid"], linewidth=0.6, zorder=0)
|
||||
ax.set_axisbelow(True)
|
||||
ax.invert_yaxis()
|
||||
|
||||
for i, bar in enumerate(bars):
|
||||
width = bar.get_width()
|
||||
ax.text(width + 0.5, bar.get_y() + bar.get_height()/2, f"{width:.0f}{value_suffix}",
|
||||
ha="left", va="center", color=CHART_STYLE["text"], fontsize=10,
|
||||
fontweight="bold")
|
||||
|
||||
if title:
|
||||
ax.set_title(title, color=CHART_STYLE["text"], fontsize=13,
|
||||
fontweight="bold", pad=12)
|
||||
|
||||
fig.tight_layout(pad=0.5)
|
||||
fig.savefig(out_path, dpi=150, transparent=True, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
return out_path
|
||||
|
||||
|
||||
def make_line_trend(data: dict, out_path: str, title: str = "",
|
||||
show_area: bool = True, show_markers: bool = True) -> str:
|
||||
"""Render a trend line chart."""
|
||||
x_vals = data.get("x", [])
|
||||
y_vals = data.get("y", [])
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 4.5), facecolor="none")
|
||||
ax.set_facecolor("none")
|
||||
|
||||
line_style = data.get("line_style", "-")
|
||||
line_width = data.get("line_width", 2.5)
|
||||
|
||||
ax.plot(x_vals, y_vals, color=CHART_STYLE["accent"],
|
||||
linewidth=line_width, linestyle=line_style,
|
||||
marker="o" if show_markers else None, markersize=7, zorder=3)
|
||||
|
||||
if show_area:
|
||||
ax.fill_between(x_vals, y_vals, alpha=0.12, color=CHART_STYLE["accent"])
|
||||
|
||||
ax.spines[:].set_visible(False)
|
||||
ax.tick_params(colors=CHART_STYLE["text"])
|
||||
ax.yaxis.grid(True, color=CHART_STYLE["grid"], linewidth=0.6, zorder=0)
|
||||
|
||||
if title:
|
||||
ax.set_title(title, color=CHART_STYLE["text"], fontsize=13,
|
||||
fontweight="bold", pad=12)
|
||||
|
||||
fig.tight_layout(pad=0.5)
|
||||
fig.savefig(out_path, dpi=150, transparent=True, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
return out_path
|
||||
|
||||
|
||||
def make_pie_chart(data: dict, out_path: str, title: str = "",
|
||||
show_labels: bool = True, show_percent: bool = True,
|
||||
donut: bool = False) -> str:
|
||||
"""Render a pie chart."""
|
||||
labels = data.get("labels", [])
|
||||
values = data.get("values", data.get("y", []))
|
||||
|
||||
if not values:
|
||||
return ""
|
||||
|
||||
colors = CHART_STYLE["pie_colors"][:len(values)]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 4.5), facecolor="none")
|
||||
ax.set_facecolor("none")
|
||||
|
||||
if donut:
|
||||
wedges, texts, autotexts = ax.pie(
|
||||
values, labels=labels if show_labels else None,
|
||||
colors=colors, autopct=lambda p: f'{p:.1f}%' if show_percent else '',
|
||||
startangle=90, pctdistance=0.75,
|
||||
wedgeprops=dict(width=0.5, edgecolor="none")
|
||||
)
|
||||
else:
|
||||
wedges, texts, autotexts = ax.pie(
|
||||
values, labels=labels if show_labels else None,
|
||||
colors=colors, autopct=lambda p: f'{p:.1f}%' if show_percent else '',
|
||||
startangle=90, pctdistance=0.8
|
||||
)
|
||||
|
||||
for text in texts:
|
||||
text.set_color(CHART_STYLE["text"])
|
||||
text.set_fontsize(10)
|
||||
|
||||
for autotext in autotexts:
|
||||
autotext.set_color(CHART_STYLE["text"])
|
||||
autotext.set_fontsize(9)
|
||||
autotext.set_fontweight("bold")
|
||||
|
||||
if title:
|
||||
ax.set_title(title, color=CHART_STYLE["text"], fontsize=13,
|
||||
fontweight="bold", pad=12)
|
||||
|
||||
fig.tight_layout(pad=0.5)
|
||||
fig.savefig(out_path, dpi=150, transparent=True, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
return out_path
|
||||
|
||||
|
||||
def make_stacked_bar(data: dict, out_path: str, title: str = "",
|
||||
stack_labels: list = None) -> str:
|
||||
"""Render a stacked bar chart."""
|
||||
labels = data.get("labels", [])
|
||||
stacks = data.get("stacks", []) # List of lists, each inner list is a stack
|
||||
|
||||
if not stacks or len(stacks) < 2:
|
||||
return ""
|
||||
|
||||
stack_labels = stack_labels or [f"Series {i+1}" for i in range(len(stacks))]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 4.5), facecolor="none")
|
||||
ax.set_facecolor("none")
|
||||
|
||||
x = np.arange(len(labels))
|
||||
bottom = np.zeros(len(labels))
|
||||
colors = CHART_STYLE["pie_colors"][:len(stacks)]
|
||||
|
||||
for i, stack in enumerate(stacks):
|
||||
bars = ax.bar(x, stack, 0.6, bottom=bottom, color=colors[i],
|
||||
label=stack_labels[i], zorder=3, edgecolor="none")
|
||||
|
||||
for j, bar in enumerate(bars):
|
||||
height = bar.get_height()
|
||||
if height > 5: # Only show label if segment is big enough
|
||||
ax.text(bar.get_x() + bar.get_width()/2,
|
||||
bottom[j] + height/2,
|
||||
f"{height:.0f}", ha="center", va="center",
|
||||
color=CHART_STYLE["text"], fontsize=8, fontweight="bold")
|
||||
|
||||
bottom = bottom + np.array(stack)
|
||||
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(labels, color=CHART_STYLE["text"], fontsize=11)
|
||||
ax.tick_params(axis="y", colors=CHART_STYLE["text"])
|
||||
ax.spines[:].set_visible(False)
|
||||
ax.legend(frameon=False, labelcolor=CHART_STYLE["text"], fontsize=9, loc="upper left")
|
||||
|
||||
if title:
|
||||
ax.set_title(title, color=CHART_STYLE["text"], fontsize=13,
|
||||
fontweight="bold", pad=12)
|
||||
|
||||
fig.tight_layout(pad=0.5)
|
||||
fig.savefig(out_path, dpi=150, transparent=True, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
return out_path
|
||||
|
||||
|
||||
def make_line_trend(data: dict, out_path: str, title: str = "") -> str:
|
||||
"""Render a trend line chart. Returns output path."""
|
||||
x_vals = data.get("x", [])
|
||||
y_vals = data.get("y", [])
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 4.5), facecolor="none")
|
||||
ax.set_facecolor("none")
|
||||
ax.plot(x_vals, y_vals, color=CHART_STYLE["accent"],
|
||||
linewidth=2.5, marker="o", markersize=7, zorder=3)
|
||||
ax.fill_between(x_vals, y_vals, alpha=0.12, color=CHART_STYLE["accent"])
|
||||
ax.spines[:].set_visible(False)
|
||||
ax.tick_params(colors=CHART_STYLE["text"])
|
||||
ax.yaxis.grid(True, color=CHART_STYLE["grid"], linewidth=0.6, zorder=0)
|
||||
if title:
|
||||
ax.set_title(title, color=CHART_STYLE["text"], fontsize=13,
|
||||
fontweight="bold", pad=12)
|
||||
fig.tight_layout(pad=0.5)
|
||||
fig.savefig(out_path, dpi=150, transparent=True, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
return out_path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Text / Bullet overlay (Pillow → PNG)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_bullet_overlay(lines: list[str], out_path: str,
|
||||
width: int = 900, font_size: int = 32) -> str:
|
||||
"""Render bullet points on a semi-transparent dark pill. Returns path."""
|
||||
padding = 32
|
||||
line_h = font_size + 16
|
||||
img_h = padding * 2 + len(lines) * line_h + 12
|
||||
img = Image.new("RGBA", (width, img_h), (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
draw.rounded_rectangle([0, 0, width - 1, img_h - 1],
|
||||
radius=18, fill=(10, 10, 10, 185))
|
||||
|
||||
try:
|
||||
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
|
||||
font_size)
|
||||
except OSError:
|
||||
font = ImageFont.load_default()
|
||||
|
||||
y = padding
|
||||
for line in lines:
|
||||
draw.text((padding + 18, y), f"• {line}", font=font, fill=(241, 241, 239, 255))
|
||||
y += line_h
|
||||
|
||||
img.save(out_path, format="PNG")
|
||||
return out_path
|
||||
|
||||
|
||||
def make_insight_card(insight: str, stat: str, out_path: str,
|
||||
width: int = 960, height: int = 200) -> str:
|
||||
"""Render a bold insight card (headline + supporting stat). Returns path."""
|
||||
img = Image.new("RGBA", (width, height), (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(img)
|
||||
draw.rounded_rectangle([0, 0, width - 1, height - 1],
|
||||
radius=14, fill=(10, 10, 10, 200))
|
||||
|
||||
draw.rectangle([28, 24, 36, height - 24], fill=(230, 57, 70, 255))
|
||||
|
||||
try:
|
||||
font_lg = ImageFont.truetype(
|
||||
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 34)
|
||||
font_sm = ImageFont.truetype(
|
||||
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 20)
|
||||
except OSError:
|
||||
font_lg = font_sm = ImageFont.load_default()
|
||||
|
||||
draw.text((58, 36), insight, font=font_lg, fill=(241, 241, 239, 255))
|
||||
draw.text((58, 90), stat, font=font_sm, fill=(180, 180, 178, 230))
|
||||
|
||||
img.save(out_path, format="PNG")
|
||||
return out_path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Circular avatar mask
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def apply_circle_mask(clip: VideoFileClip, diameter: int) -> VideoFileClip:
|
||||
"""Resize clip and apply a circular alpha mask."""
|
||||
clip = clip.resize(height=diameter)
|
||||
w, h = clip.size
|
||||
|
||||
Y, X = np.ogrid[:h, :w]
|
||||
cx, cy = w / 2, h / 2
|
||||
mask_arr = ((X - cx) ** 2 + (Y - cy) ** 2 <= (min(w, h) / 2) ** 2).astype(float)
|
||||
|
||||
mask_clip = ImageClip(mask_arr, ismask=True).set_duration(clip.duration)
|
||||
return clip.set_mask(mask_clip)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ken Burns zoom effect
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def ken_burns(clip: ImageClip, zoom_ratio: float = 0.08) -> ImageClip:
|
||||
"""Apply a slow zoom-in over the clip duration."""
|
||||
def zoom_frame(get_frame, t):
|
||||
frame = get_frame(t)
|
||||
frac = 1 + zoom_ratio * (t / clip.duration)
|
||||
h, w = frame.shape[:2]
|
||||
new_h, new_w = int(h / frac), int(w / frac)
|
||||
y1 = (h - new_h) // 2
|
||||
x1 = (w - new_w) // 2
|
||||
cropped = frame[y1:y1 + new_h, x1:x1 + new_w]
|
||||
return np.array(Image.fromarray(cropped).resize((w, h), Image.LANCZOS))
|
||||
|
||||
return clip.fl(zoom_frame, apply_to=["mask"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scene builders (one per visual_cue type)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def build_data_scene(assets: SceneAssets, insight: Insight) -> CompositeVideoClip:
|
||||
"""
|
||||
Layout: Background (Ken Burns) + Chart (fade-in) + Avatar circle (corner) + Insight card
|
||||
"""
|
||||
d = insight.duration
|
||||
layers = []
|
||||
|
||||
bg = (ImageClip(assets.background_img)
|
||||
.set_duration(d)
|
||||
.resize(height=1080))
|
||||
bg = ken_burns(bg)
|
||||
bg = bg.fx(vfx.lum_contrast, 0, -40)
|
||||
layers.append(bg)
|
||||
|
||||
if assets.chart_img:
|
||||
chart = (ImageClip(assets.chart_img)
|
||||
.set_duration(d - 1.5)
|
||||
.set_start(0.5)
|
||||
.resize(width=700)
|
||||
.set_position(("center", 180))
|
||||
.fx(vfx.fadein, 0.6)
|
||||
.fx(vfx.fadeout, 0.4))
|
||||
layers.append(chart)
|
||||
|
||||
card_path = "/tmp/insight_card.png"
|
||||
make_insight_card(insight.key_insight, insight.supporting_stat, card_path)
|
||||
card = (ImageClip(card_path)
|
||||
.set_duration(d - 1)
|
||||
.set_start(0.5)
|
||||
.set_position(("center", 820))
|
||||
.fx(vfx.fadein, 0.5))
|
||||
layers.append(card)
|
||||
|
||||
if assets.avatar_video:
|
||||
avatar_raw = VideoFileClip(assets.avatar_video).subclip(0, d)
|
||||
avatar = apply_circle_mask(avatar_raw, diameter=240)
|
||||
avatar = avatar.set_position((bg.w - 280, bg.h - 280))
|
||||
layers.append(avatar)
|
||||
|
||||
return CompositeVideoClip(layers, size=bg.size).set_duration(d)
|
||||
|
||||
|
||||
def build_bullet_scene(assets: SceneAssets, insight: Insight,
|
||||
bullets: list[str]) -> CompositeVideoClip:
|
||||
"""
|
||||
Layout: AI image (Ken Burns) + Bullet overlay + Avatar circle
|
||||
"""
|
||||
d = insight.duration
|
||||
layers = []
|
||||
|
||||
bg = (ImageClip(assets.background_img)
|
||||
.set_duration(d)
|
||||
.resize(height=1080))
|
||||
bg = ken_burns(bg, zoom_ratio=0.05)
|
||||
bg = bg.fx(vfx.lum_contrast, 0, -50)
|
||||
layers.append(bg)
|
||||
|
||||
bullet_path = "/tmp/bullets.png"
|
||||
make_bullet_overlay(bullets, bullet_path, width=860)
|
||||
bullets_clip = (ImageClip(bullet_path)
|
||||
.set_duration(d - 1)
|
||||
.set_start(0.5)
|
||||
.set_position(("center", "center"))
|
||||
.fx(vfx.fadein, 0.7))
|
||||
layers.append(bullets_clip)
|
||||
|
||||
if assets.avatar_video:
|
||||
avatar_raw = VideoFileClip(assets.avatar_video).subclip(0, d)
|
||||
avatar = apply_circle_mask(avatar_raw, diameter=200)
|
||||
avatar = avatar.set_position((bg.w - 240, bg.h - 240))
|
||||
layers.append(avatar)
|
||||
|
||||
return CompositeVideoClip(layers, size=bg.size).set_duration(d)
|
||||
|
||||
|
||||
def build_full_avatar_scene(assets: SceneAssets, insight: Insight) -> VideoFileClip:
|
||||
"""Full-screen avatar — the expensive 'Hook' scene. No overlay."""
|
||||
d = insight.duration
|
||||
avatar = VideoFileClip(assets.avatar_video).subclip(0, d)
|
||||
return avatar.resize(height=1080).set_duration(d)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scene dispatcher — maps visual_cue → builder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def dispatch_scene(insight: Insight, assets: SceneAssets,
|
||||
bullet_lines: Optional[list[str]] = None):
|
||||
"""Dispatch scene based on visual_cue type."""
|
||||
cue = insight.visual_cue
|
||||
|
||||
if cue == "full_avatar":
|
||||
return build_full_avatar_scene(assets, insight)
|
||||
|
||||
elif cue in ("bar_chart_comparison", "line_trend"):
|
||||
chart_path = "/tmp/chart.png"
|
||||
if cue == "bar_chart_comparison":
|
||||
make_bar_chart(insight.chart_data, chart_path,
|
||||
title=insight.key_insight)
|
||||
else:
|
||||
make_line_trend(insight.chart_data, chart_path,
|
||||
title=insight.key_insight)
|
||||
assets.chart_img = chart_path
|
||||
return build_data_scene(assets, insight)
|
||||
|
||||
elif cue == "bullet_points":
|
||||
lines = bullet_lines or [insight.key_insight, insight.supporting_stat]
|
||||
return build_bullet_scene(assets, insight, lines)
|
||||
|
||||
else:
|
||||
return build_data_scene(assets, insight)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Master compositor — assembles all scenes into one video
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def compose_video(scenes: list, output_path: str = "output.mp4",
|
||||
fps: int = 24, fade_dur: float = 0.5) -> str:
|
||||
"""Concatenate scenes with crossfade transitions and write final video file."""
|
||||
final = crossfade_concat(scenes, fade_dur=fade_dur)
|
||||
final.write_videofile(
|
||||
output_path,
|
||||
fps=fps,
|
||||
codec="libx264",
|
||||
audio_codec="aac",
|
||||
threads=4,
|
||||
preset="fast",
|
||||
logger=None,
|
||||
)
|
||||
return output_path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JSON bridge — LLM insight → assets + scene
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def pipeline_from_json(insight_json: str,
|
||||
background_img: str,
|
||||
avatar_video: Optional[str] = None) -> str:
|
||||
"""
|
||||
Full pipeline:
|
||||
1. Parse LLM insight JSON
|
||||
2. Generate chart / overlay assets
|
||||
3. Build scene
|
||||
4. Write video
|
||||
Returns path to output video.
|
||||
"""
|
||||
data = json.loads(insight_json)
|
||||
insight = Insight(**{k: data[k] for k in Insight.__dataclass_fields__ if k in data})
|
||||
assets = SceneAssets(background_img=background_img, avatar_video=avatar_video)
|
||||
scene = dispatch_scene(insight, assets,
|
||||
bullet_lines=data.get("bullet_lines"))
|
||||
out = f"/tmp/scene_{insight.visual_cue}.mp4"
|
||||
compose_video([scene], output_path=out)
|
||||
return out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Demo / smoke-test (no real media files needed for chart generation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
sample_bar_data = {
|
||||
"labels": ["Content Velocity", "CTR", "Engagement", "Cost/Lead"],
|
||||
"before": [30, 22, 18, 60],
|
||||
"after": [72, 34, 41, 38],
|
||||
}
|
||||
chart_out = make_bar_chart(
|
||||
sample_bar_data,
|
||||
"/tmp/demo_chart.png",
|
||||
title="AI Tools Impact: Before vs After (2025)",
|
||||
)
|
||||
print(f"Chart saved → {chart_out}")
|
||||
|
||||
bullets = [
|
||||
"AI reduced content cycles by 40% in 2025",
|
||||
"HubSpot: 12% lift in CTR with AI-assisted copy",
|
||||
"Video production cost down 3x with hybrid pipeline",
|
||||
]
|
||||
bullet_out = make_bullet_overlay(bullets, "/tmp/demo_bullets.png")
|
||||
print(f"Bullets saved → {bullet_out}")
|
||||
|
||||
card_out = make_insight_card(
|
||||
"AI tools reduced content cycles by 40%",
|
||||
"HubSpot 2026 report — 12% lift in CTR",
|
||||
"/tmp/demo_card.png",
|
||||
)
|
||||
print(f"Insight card saved → {card_out}")
|
||||
|
||||
sample_json = json.dumps({
|
||||
"key_insight": "AI reduced production time by 40%",
|
||||
"supporting_stat": "HubSpot 2026: 12% CTR lift",
|
||||
"visual_cue": "bar_chart_comparison",
|
||||
"audio_tone": "authoritative_and_surprising",
|
||||
"duration": 8.0,
|
||||
"chart_data": sample_bar_data,
|
||||
})
|
||||
print("\nSample Insight JSON:\n", sample_json)
|
||||
print("\nAll asset generation tests passed.")
|
||||
print("To run full video composition, supply real background_img and avatar_video paths.")
|
||||
253
backend/services/podcast/broll_service.py
Normal file
253
backend/services/podcast/broll_service.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""
|
||||
B-Roll Service - Orchestrator for programmatic B-roll video composition.
|
||||
|
||||
This service handles:
|
||||
- Chart data extraction from research
|
||||
- Individual scene B-roll video generation
|
||||
- Final video composition from multiple B-roll scenes
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, List
|
||||
from loguru import logger
|
||||
|
||||
# Import chart generators directly
|
||||
from services.podcast.broll_composer import (
|
||||
make_bar_chart,
|
||||
make_horizontal_bar,
|
||||
make_line_trend,
|
||||
make_pie_chart,
|
||||
make_stacked_bar,
|
||||
make_bullet_overlay,
|
||||
make_insight_card,
|
||||
)
|
||||
|
||||
|
||||
class BrollService:
|
||||
"""Orchestrates B-roll composition for podcast scenes."""
|
||||
|
||||
def __init__(self, output_dir: Optional[str] = None):
|
||||
"""
|
||||
Initialize B-roll service.
|
||||
|
||||
Args:
|
||||
output_dir: Base directory for B-roll output. Defaults to temp directory.
|
||||
"""
|
||||
if output_dir:
|
||||
self.output_dir = Path(output_dir)
|
||||
else:
|
||||
self.output_dir = Path(tempfile.gettempdir()) / "broll_output"
|
||||
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"[BrollService] Initialized with output directory: {self.output_dir}")
|
||||
|
||||
def get_output_path(self, filename: str) -> Path:
|
||||
"""Get output path for a file."""
|
||||
return self.output_dir / filename
|
||||
|
||||
def generate_chart_preview(
|
||||
self,
|
||||
chart_data: Dict[str, Any],
|
||||
chart_type: str = "bar_comparison",
|
||||
title: str = "",
|
||||
subtitle: str = "",
|
||||
) -> str:
|
||||
"""
|
||||
Generate a chart PNG preview (static, for Write phase).
|
||||
|
||||
Args:
|
||||
chart_data: Chart data dict with labels, before/after, etc.
|
||||
chart_type: Type of chart (bar_comparison, bar_horizontal, line_trend, pie, stacked_bar, bullet)
|
||||
title: Title for the chart
|
||||
subtitle: Optional subtitle at bottom
|
||||
|
||||
Returns:
|
||||
Path to generated PNG file
|
||||
"""
|
||||
chart_id = uuid.uuid4().hex[:8]
|
||||
out_path = str(self.get_output_path(f"chart_preview_{chart_id}.png"))
|
||||
|
||||
try:
|
||||
if chart_type == "bar_comparison":
|
||||
make_bar_chart(chart_data, out_path, title, subtitle=subtitle)
|
||||
elif chart_type == "bar_horizontal":
|
||||
make_horizontal_bar(chart_data, out_path, title)
|
||||
elif chart_type == "line_trend":
|
||||
make_line_trend(chart_data, out_path, title)
|
||||
elif chart_type == "pie":
|
||||
make_pie_chart(chart_data, out_path, title)
|
||||
elif chart_type == "pie":
|
||||
make_pie_chart(chart_data, out_path, title)
|
||||
elif chart_type == "stacked_bar":
|
||||
make_stacked_bar(chart_data, out_path, title)
|
||||
elif chart_type == "bullet":
|
||||
bullet_points = chart_data.get("bullet_points", [])
|
||||
if bullet_points:
|
||||
make_bullet_overlay(bullet_points, out_path)
|
||||
else:
|
||||
logger.warning("[BrollService] No bullet points provided")
|
||||
return ""
|
||||
else:
|
||||
logger.warning(f"[BrollService] Unknown chart type: {chart_type}")
|
||||
return ""
|
||||
|
||||
logger.info(f"[BrollService] Chart preview generated: {out_path}")
|
||||
return out_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[BrollService] Failed to generate chart preview: {e}")
|
||||
return ""
|
||||
|
||||
def generate_scene_broll(
|
||||
self,
|
||||
scene_id: str,
|
||||
key_insight: str,
|
||||
supporting_stat: str,
|
||||
chart_data: Optional[Dict[str, Any]],
|
||||
visual_cue: str, # bar_chart_comparison, bullet_points, full_avatar
|
||||
duration: float,
|
||||
background_img_path: str,
|
||||
avatar_video_path: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a B-roll video for a single scene.
|
||||
|
||||
Args:
|
||||
scene_id: Scene identifier
|
||||
key_insight: Main insight text for overlay
|
||||
supporting_stat: Supporting statistic text
|
||||
chart_data: Chart data dict (optional)
|
||||
visual_cue: Type of scene to build
|
||||
duration: Scene duration in seconds
|
||||
background_img_path: Path to background image
|
||||
avatar_video_path: Path to avatar video (optional)
|
||||
|
||||
Returns:
|
||||
Path to generated video file
|
||||
"""
|
||||
scene_id_safe = scene_id.replace(" ", "_").replace("/", "_")
|
||||
out_path = str(self.get_output_path(f"broll_{scene_id_safe}.mp4"))
|
||||
|
||||
try:
|
||||
insight = Insight(
|
||||
key_insight=key_insight,
|
||||
supporting_stat=supporting_stat,
|
||||
visual_cue=visual_cue,
|
||||
audio_tone="neutral",
|
||||
chart_data=chart_data or {},
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
assets = SceneAssets(
|
||||
background_img=background_img_path,
|
||||
avatar_video=avatar_video_path,
|
||||
)
|
||||
|
||||
# Generate the scene
|
||||
scene = dispatch_scene(insight, assets)
|
||||
|
||||
# Write video
|
||||
compose_video([scene], output_path=out_path)
|
||||
|
||||
logger.info(f"[BrollService] B-roll scene generated: {out_path}")
|
||||
return out_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[BrollService] Failed to generate B-roll scene: {e}")
|
||||
raise
|
||||
|
||||
def compose_final_video(
|
||||
self,
|
||||
video_paths: List[str],
|
||||
output_filename: str,
|
||||
fade_dur: float = 0.5,
|
||||
fps: int = 24,
|
||||
) -> str:
|
||||
"""
|
||||
Compose multiple B-roll scene videos into final video.
|
||||
|
||||
Args:
|
||||
video_paths: List of video file paths to compose
|
||||
output_filename: Output filename
|
||||
fade_dur: Crossfade duration between scenes
|
||||
fps: Output FPS
|
||||
|
||||
Returns:
|
||||
Path to final composed video
|
||||
"""
|
||||
out_path = str(self.get_output_path(output_filename))
|
||||
|
||||
try:
|
||||
scenes = []
|
||||
for video_path in video_paths:
|
||||
from moviepy import VideoFileClip
|
||||
clip = VideoFileClip(video_path)
|
||||
scenes.append(clip)
|
||||
|
||||
if not scenes:
|
||||
raise ValueError("No video clips provided")
|
||||
|
||||
# Use crossfade_concat from broll_composer
|
||||
from services.podcast.broll_composer import crossfade_concat
|
||||
|
||||
final = crossfade_concat(scenes, fade_dur=fade_dur)
|
||||
|
||||
final.write_videofile(
|
||||
out_path,
|
||||
fps=fps,
|
||||
codec="libx264",
|
||||
audio_codec="aac",
|
||||
threads=4,
|
||||
preset="fast",
|
||||
logger=None,
|
||||
)
|
||||
|
||||
# Close clips
|
||||
for clip in scenes:
|
||||
clip.close()
|
||||
|
||||
logger.info(f"[BrollService] Final video composed: {out_path}")
|
||||
return out_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[BrollService] Failed to compose final video: {e}")
|
||||
raise
|
||||
|
||||
def cleanup(self, file_paths: List[str] = None):
|
||||
"""
|
||||
Clean up temporary B-roll files.
|
||||
|
||||
Args:
|
||||
file_paths: Specific files to delete. If None, cleans output directory.
|
||||
"""
|
||||
if file_paths:
|
||||
for path in file_paths:
|
||||
try:
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
logger.debug(f"[BrollService] Removed: {path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[BrollService] Failed to remove {path}: {e}")
|
||||
else:
|
||||
# Clean entire output directory
|
||||
for file in self.output_dir.glob("*"):
|
||||
try:
|
||||
file.unlink()
|
||||
except Exception as e:
|
||||
logger.warning(f"[BrollService] Failed to remove {file}: {e}")
|
||||
|
||||
|
||||
# Singleton instance for reuse
|
||||
_broll_service_instance: Optional[BrollService] = None
|
||||
|
||||
|
||||
def get_broll_service(output_dir: Optional[str] = None) -> BrollService:
|
||||
"""Get or create B-roll service singleton."""
|
||||
global _broll_service_instance
|
||||
if _broll_service_instance is None:
|
||||
_broll_service_instance = BrollService(output_dir=output_dir)
|
||||
return _broll_service_instance
|
||||
@@ -1,4 +1,6 @@
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
import time
|
||||
from loguru import logger
|
||||
from services.product_marketing.personalization_service import PersonalizationService
|
||||
from models.podcast_bible_models import (
|
||||
@@ -11,9 +13,14 @@ from models.podcast_bible_models import (
|
||||
ShowRules
|
||||
)
|
||||
|
||||
_BIBLE_CACHE_TTL_SECONDS = 120
|
||||
|
||||
|
||||
class PodcastBibleService:
|
||||
"""Service for generating and managing the Podcast Bible."""
|
||||
|
||||
_bible_cache: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
from services.product_marketing.personalization_service import PersonalizationService
|
||||
@@ -22,19 +29,40 @@ class PodcastBibleService:
|
||||
logger.warning(f"Failed to initialize PersonalizationService: {e}")
|
||||
self.personalization_service = None
|
||||
|
||||
@classmethod
|
||||
def clear_user_cache(cls, user_id: str) -> int:
|
||||
"""Clear cached Bible data for a specific user. Returns number of entries cleared."""
|
||||
keys_to_remove = [key for key in cls._bible_cache if key.startswith(f"{user_id}:")]
|
||||
for key in keys_to_remove:
|
||||
del cls._bible_cache[key]
|
||||
if keys_to_remove:
|
||||
logger.info(f"[BibleCache] Cleared {len(keys_to_remove)} cache entries for user {user_id}")
|
||||
return len(keys_to_remove)
|
||||
|
||||
def generate_bible(self, user_id: str, project_id: str) -> PodcastBible:
|
||||
"""Generate a Podcast Bible from onboarding data."""
|
||||
bible_start = time.time()
|
||||
|
||||
cache_key = f"{user_id}:{project_id}"
|
||||
cached = self._bible_cache.get(cache_key)
|
||||
if cached and cached.get('expires_at') and cached['expires_at'] > datetime.utcnow():
|
||||
elapsed_ms = (time.time() - bible_start) * 1000
|
||||
logger.warning(f"[BibleCache] HIT for {user_id} — saved 7 DB queries, overhead {elapsed_ms:.0f}ms")
|
||||
return cached['bible']
|
||||
|
||||
logger.info(f"Generating Podcast Bible for user {user_id}")
|
||||
|
||||
try:
|
||||
if not self.personalization_service:
|
||||
logger.warning("PersonalizationService not available, using default bible")
|
||||
elapsed_ms = (time.time() - bible_start) * 1000
|
||||
logger.warning(f"[BibleCache] MISS (fallback) for {user_id} — PersonalizationService unavailable, {elapsed_ms:.0f}ms")
|
||||
return self._get_default_bible(project_id)
|
||||
|
||||
try:
|
||||
preferences = self.personalization_service.get_user_preferences(user_id)
|
||||
except Exception as pref_err:
|
||||
logger.warning(f"Failed to get user preferences: {pref_err}, using defaults")
|
||||
elapsed_ms = (time.time() - bible_start) * 1000
|
||||
logger.warning(f"[BibleCache] MISS (fallback) for {user_id} — get_user_preferences failed ({pref_err}), {elapsed_ms:.0f}ms")
|
||||
return self._get_default_bible(project_id)
|
||||
|
||||
if not preferences:
|
||||
@@ -131,6 +159,12 @@ class PodcastBibleService:
|
||||
)
|
||||
|
||||
logger.info(f"Podcast Bible generated successfully for project {project_id}")
|
||||
elapsed_ms = (time.time() - bible_start) * 1000
|
||||
logger.warning(f"[BibleCache] MISS — generated in {elapsed_ms:.0f}ms (7 DB queries), cached for {_BIBLE_CACHE_TTL_SECONDS}s")
|
||||
self._bible_cache[cache_key] = {
|
||||
'bible': bible,
|
||||
'expires_at': datetime.utcnow() + timedelta(seconds=_BIBLE_CACHE_TTL_SECONDS),
|
||||
}
|
||||
return bible
|
||||
|
||||
except Exception as e:
|
||||
@@ -176,8 +210,12 @@ class PodcastBibleService:
|
||||
)
|
||||
|
||||
def serialize_bible(self, bible: PodcastBible) -> str:
|
||||
"""Serialize the Bible into a prompt-friendly text block."""
|
||||
return f"""
|
||||
"""Serialize the Bible into a prompt-friendly text block. Results are cached by project_id."""
|
||||
cache_key = f"serialized:{bible.project_id}"
|
||||
cached = self._bible_cache.get(cache_key)
|
||||
if cached and cached.get('expires_at') and cached['expires_at'] > datetime.utcnow() and isinstance(cached.get('serialized'), str):
|
||||
return cached['serialized']
|
||||
serialized = f"""
|
||||
<podcast_bible>
|
||||
HOST PERSONA:
|
||||
- Name: {bible.host.name}
|
||||
@@ -212,3 +250,8 @@ SHOW RULES & STRUCTURE:
|
||||
- Constraints: {', '.join(bible.show_rules.constraints)}
|
||||
</podcast_bible>
|
||||
"""
|
||||
self._bible_cache[cache_key] = {
|
||||
'serialized': serialized,
|
||||
'expires_at': datetime.utcnow() + timedelta(seconds=_BIBLE_CACHE_TTL_SECONDS),
|
||||
}
|
||||
return serialized
|
||||
|
||||
@@ -4,11 +4,11 @@ Podcast Service
|
||||
Service layer for managing podcast project persistence.
|
||||
"""
|
||||
|
||||
import os
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc, and_, or_
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from models.podcast_models import PodcastProject
|
||||
from services.podcast_bible_service import PodcastBibleService
|
||||
@@ -32,8 +32,14 @@ class PodcastService:
|
||||
**kwargs
|
||||
) -> PodcastProject:
|
||||
"""Create a new podcast project."""
|
||||
# Generate Podcast Bible automatically from onboarding data
|
||||
bible = self.bible_service.generate_bible(user_id, project_id)
|
||||
# Generate Podcast Bible in full mode only — skip in podcast-only mode
|
||||
bible_data = None
|
||||
if os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() != "podcast":
|
||||
try:
|
||||
bible = self.bible_service.generate_bible(user_id, project_id)
|
||||
bible_data = bible.model_dump() if bible else None
|
||||
except Exception:
|
||||
pass # Bible is optional, project creation continues regardless
|
||||
|
||||
project = PodcastProject(
|
||||
project_id=project_id,
|
||||
@@ -42,7 +48,7 @@ class PodcastService:
|
||||
duration=duration,
|
||||
speakers=speakers,
|
||||
budget_cap=budget_cap,
|
||||
bible=bible.model_dump() if bible else None,
|
||||
bible=bible_data,
|
||||
status="draft",
|
||||
current_step="create",
|
||||
**kwargs
|
||||
|
||||
@@ -4,6 +4,7 @@ Handles subscription limit checking and validation logic.
|
||||
Extracted from pricing_service.py for better modularity.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, Any, Optional, List, Tuple, TYPE_CHECKING
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy import text
|
||||
@@ -32,9 +33,11 @@ class LimitValidator:
|
||||
self.db = pricing_service.db
|
||||
|
||||
def check_usage_limits(self, user_id: str, provider: APIProvider,
|
||||
tokens_requested: int = 0, actual_provider_name: Optional[str] = None) -> Tuple[bool, str, Dict[str, Any]]:
|
||||
tokens_requested: int = 0, actual_provider_name: Optional[str] = None) -> Tuple[bool, str, Dict[str, Any]]:
|
||||
"""Check if user can make an API call within their limits.
|
||||
|
||||
Delegates to LimitValidator for actual validation logic.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
provider: APIProvider enum (may be MISTRAL for HuggingFace)
|
||||
@@ -44,6 +47,7 @@ class LimitValidator:
|
||||
Returns:
|
||||
(can_proceed, error_message, usage_info)
|
||||
"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
# Use actual_provider_name if provided, otherwise use enum value
|
||||
# This fixes cases where HuggingFace maps to MISTRAL enum but should show as "huggingface" in errors
|
||||
@@ -51,12 +55,14 @@ class LimitValidator:
|
||||
|
||||
logger.debug(f"[Subscription Check] Starting limit check for user {user_id}, provider {display_provider_name}, tokens {tokens_requested}")
|
||||
|
||||
logger.warning(f"[Subscription Check] START for user {user_id}, provider {provider.value}")
|
||||
# Short TTL cache to reduce DB reads under sustained traffic
|
||||
cache_key = f"{user_id}:{provider.value}"
|
||||
now = datetime.utcnow()
|
||||
cached = self.pricing_service._limits_cache.get(cache_key)
|
||||
if cached and cached.get('expires_at') and cached['expires_at'] > now:
|
||||
logger.debug(f"[Subscription Check] Using cached result for {user_id}:{provider.value}")
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
logger.warning(f"[Subscription Check] Cache hit for {user_id}:{provider.value} — completed in {elapsed_ms:.0f}ms")
|
||||
return tuple(cached['result']) # type: ignore
|
||||
|
||||
# Get user subscription first to check expiration
|
||||
@@ -139,12 +145,15 @@ class LimitValidator:
|
||||
return False, "No subscription plan found. Please subscribe to a plan.", {}
|
||||
|
||||
# Get current usage for this billing period with error handling
|
||||
# CRITICAL: Use fresh queries to avoid SQLAlchemy cache after renewal
|
||||
# Use targeted expiry instead of expire_all() to avoid nuking the entire session cache
|
||||
try:
|
||||
current_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Expire all objects to force fresh read from DB (critical after renewal)
|
||||
self.db.expire_all()
|
||||
# Only expire specific objects that might have changed after renewal
|
||||
# (subscription was already checked above; plan was expired above)
|
||||
# The usage record is the main object we need fresh, and we query it directly below
|
||||
if subscription:
|
||||
self.db.expire(subscription)
|
||||
|
||||
# Use raw SQL query first to bypass ORM cache, fallback to ORM if SQL fails
|
||||
usage = None
|
||||
@@ -367,14 +376,18 @@ class LimitValidator:
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
logger.warning(f"[Subscription Check] Completed in {elapsed_ms:.0f}ms for user {user_id}, provider {display_provider_name} — within limits (calls: {current_call_count}/{call_limit_value})")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating usage percentages: {e}")
|
||||
# Return basic success
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
logger.warning(f"[Subscription Check] Completed in {elapsed_ms:.0f}ms for user {user_id}, provider {display_provider_name} — within limits (basic check)")
|
||||
return True, "Within limits", {}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in check_usage_limits for {user_id}: {e}")
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
logger.error(f"[Subscription Check] Failed for user {user_id} after {elapsed_ms:.0f}ms: {e}")
|
||||
# STRICT: Fail closed - deny requests if subscription system fails
|
||||
return False, f"Subscription check error: {str(e)}", {}
|
||||
|
||||
@@ -417,9 +430,7 @@ class LimitValidator:
|
||||
except Exception as schema_err:
|
||||
logger.warning(f"Schema check failed, will retry on query error: {schema_err}")
|
||||
|
||||
# Explicitly expire any cached objects and refresh from DB to ensure fresh data
|
||||
self.db.expire_all()
|
||||
|
||||
# Explicitly refresh usage from DB to ensure fresh data (targeted instead of expire_all)
|
||||
try:
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
@@ -438,7 +449,12 @@ class LimitValidator:
|
||||
schema_utils._checked_usage_summaries_columns = False
|
||||
from services.subscription.schema_utils import ensure_usage_summaries_columns
|
||||
ensure_usage_summaries_columns(self.db)
|
||||
self.db.expire_all()
|
||||
# After schema migration, only expire UsageSummary to force re-query
|
||||
# (no need to expire the entire session)
|
||||
for obj in self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id
|
||||
).all():
|
||||
self.db.expire(obj)
|
||||
# Retry the query
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
@@ -594,8 +610,9 @@ class LimitValidator:
|
||||
# Method 2: Fallback to fresh ORM query if raw SQL fails
|
||||
if not query_succeeded:
|
||||
try:
|
||||
# Expire all cached objects and do fresh query
|
||||
self.db.expire_all()
|
||||
# Only refresh usage object, don't expire entire session
|
||||
if usage:
|
||||
self.db.refresh(usage)
|
||||
fresh_usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
@@ -792,7 +809,11 @@ class LimitValidator:
|
||||
schema_utils._checked_usage_summaries_columns = False
|
||||
from services.subscription.schema_utils import ensure_usage_summaries_columns
|
||||
ensure_usage_summaries_columns(self.db)
|
||||
self.db.expire_all()
|
||||
# Only expire UsageSummary after schema migration, not entire session
|
||||
for obj in self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id
|
||||
).all():
|
||||
self.db.expire(obj)
|
||||
|
||||
# Retry the query
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
|
||||
Reference in New Issue
Block a user