Issue #518 - Subscription not updating after checkout: - Fix stale closure in SubscriptionContext checkout polling (use subscriptionRef) - Move checkout success polling from InitialRouteHandler into SubscriptionContext - Remove redundant polling code from InitialRouteHandler - Fix plan label: 'Free' instead of 'No Plan', proper capitalization - Add plan refresh button in UserBadge - Add 'View Costing Details' to UserBadge dropdown - Rename 'ALwrity Podcast Maker' to 'Podcast Creator' across UI - Clean subscription=success URL param after verification Blog Writer WYSIWYG Editor enhancements: - Per-section preview toggle (view/edit icons) - Enhanced hover-based toolbar - Circular SVG progress stats bar with detailed tooltip - Research tool chips in stats bar footer - Per-section TTS with useTextToSpeech hook (browser native) - Full blog preview modal with print/PDF support - PlayAllTTSButton: sequential playback with progress bar - OnThisPageNav: floating sidebar with scroll tracking - Section data attributes for scroll anchoring GSC Brainstorm Topics feature: - Backend: gsc_brainstorm_service.py (rule-based + LLM recommendations) - Backend: POST /gsc/brainstorm endpoint with 3-word minimum validation - Frontend: gscBrainstorm.ts API client - Frontend: useGSCBrainstormConnection hook (popup OAuth, no /onboarding redirect) - Frontend: useGSCBrainstorm hook (connect check + brainstorm call) - Frontend: GSCBrainstormModal (3-tab results: Opportunities, Gaps, AI Recs) - Frontend: BrainstormButton (visible at 3+ words, GSC connect overlay) - Wire BrainstormButton into ManualResearchForm and ResearchAction - Add blog_writer to gsc_auth router features for ALWRITY_ENABLED_FEATURES
951 lines
36 KiB
Python
951 lines
36 KiB
Python
"""
|
|
Chart Service — Shared chart generation for Blog Writer, Podcast Maker, and future modules.
|
|
|
|
Extracts the chart rendering logic from podcast/broll_composer into a reusable service
|
|
that any module can call. Supports:
|
|
- Direct chart rendering (caller provides chart_type + chart_data)
|
|
- AI-driven chart inference (caller provides text, LLM infers chart_type + chart_data)
|
|
|
|
Chart types: bar_comparison, bar_horizontal, line_trend, pie, stacked_bar, bullet_points
|
|
"""
|
|
|
|
import uuid
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Dict, Any, Optional, List
|
|
from dataclasses import dataclass, field
|
|
from loguru import logger
|
|
|
|
import numpy as np
|
|
import matplotlib
|
|
matplotlib.use("Agg")
|
|
import matplotlib.pyplot as plt
|
|
from PIL import Image, ImageDraw, ImageFont
|
|
|
|
from services.llm_providers.main_text_generation import llm_text_gen
|
|
|
|
|
|
CHART_STYLE = {
|
|
"bg": "#0D0D0D",
|
|
"bar_before": "#2E4057",
|
|
"bar_after": "#E63946",
|
|
"text": "#F1F1EF",
|
|
"grid": "#2A2A2A",
|
|
"accent": "#E63946",
|
|
"pie_colors": ["#E63946", "#2E4057", "#457B9D", "#A8DADC", "#F4A261", "#2A9D8F"],
|
|
}
|
|
|
|
VALID_CHART_TYPES = [
|
|
"bar_comparison", "bar_chart_comparison",
|
|
"bar_horizontal", "line_trend",
|
|
"pie", "stacked_bar",
|
|
"bullet", "bullet_points",
|
|
]
|
|
|
|
CHART_INFERENCE_SYSTEM_PROMPT = """You are a data visualization expert. Given text content, determine the most appropriate chart type and extract structured data for rendering.
|
|
|
|
You MUST respond with ONLY a valid JSON object (no markdown, no explanation) with this exact structure:
|
|
{
|
|
"chart_type": "one of: bar_comparison, bar_horizontal, line_trend, pie, stacked_bar, bullet_points",
|
|
"chart_data": { ... appropriate data structure for the chart type ... },
|
|
"title": "A clear, concise chart title"
|
|
}
|
|
|
|
Chart data structures by type:
|
|
- bar_comparison: {"labels": [...], "before": [...], "after": [...]} OR {"labels": [...], "values": [...]}
|
|
- bar_horizontal: {"labels": [...], "values": [...]}
|
|
- line_trend: {"labels": [...], "values": [...]}
|
|
- pie: {"labels": [...], "values": [...]}
|
|
- stacked_bar: {"labels": [...], "stacks": [[...], [...]]}
|
|
- bullet_points: {"bullet_points": [...]}
|
|
|
|
Rules:
|
|
1. Choose the chart type that best represents the information in the text.
|
|
2. Use bar_comparison for before/after comparisons.
|
|
3. Use line_trend for time-series or sequential data.
|
|
4. Use pie for proportional breakdowns of a whole.
|
|
5. Use bar_horizontal for rankings or comparisons.
|
|
6. Use bullet_points if the text is qualitative with no strong numeric data.
|
|
7. Extract realistic numeric values from the text when available.
|
|
8. If no data is extractable, use bullet_points and list key points.
|
|
9. Keep labels short (under 20 chars)."""
|
|
|
|
|
|
CHART_INFERENCE_USER_PROMPT = """Create a chart from this text:
|
|
|
|
{text}
|
|
|
|
Return ONLY the JSON object with chart_type, chart_data, and title."""
|
|
|
|
|
|
CHART_ANALYSIS_SYSTEM_PROMPT = """You are a data visualization analyst. Given text from a blog section, your job is to:
|
|
1. Determine whether the text contains enough specific numeric data to create a meaningful chart
|
|
2. If YES: explain what data is available and suggest a chart type
|
|
3. If NO: suggest 2-3 specific search queries that would find relevant statistics/data to create a chart for this topic
|
|
|
|
You MUST respond with ONLY a valid JSON object (no markdown, no explanation):
|
|
{
|
|
"has_data": true/false,
|
|
"data_description": "brief description of what data is available or why it's insufficient",
|
|
"suggested_chart_type": "best chart type if has_data is true, otherwise null",
|
|
"search_queries": ["query1", "query2", "query3"] // Empty array if has_data is true
|
|
}
|
|
|
|
Be optimistic — if there's ANY numeric claim, percentage, comparison, or trend in the text, set has_data to true.
|
|
Only set has_data to false if the text is purely qualitative with no numbers, percentages, comparisons, or trends."""
|
|
|
|
|
|
CHART_ANALYSIS_USER_PROMPT = """Analyze this text for chart potential:
|
|
|
|
Section: {section_heading}
|
|
{key_points_section}
|
|
Text: {text}
|
|
|
|
Determine if this text contains enough data for a chart, or suggest search queries to find the data."""
|
|
|
|
|
|
CHART_SYNTHESIS_SYSTEM_PROMPT = """You are a data visualization expert. You have been given:
|
|
1. Original text from a blog section
|
|
2. Research data found from web searches
|
|
|
|
Create a chart that visualizes the most interesting insight from the combination of the original text and research data.
|
|
|
|
You MUST respond with ONLY a valid JSON object (no markdown, no explanation) with this exact structure:
|
|
{
|
|
"chart_type": "one of: bar_comparison, bar_horizontal, line_trend, pie, stacked_bar, bullet_points",
|
|
"chart_data": { ... appropriate data structure ... },
|
|
"title": "A clear, concise chart title",
|
|
"source": "Brief source attribution"
|
|
}
|
|
|
|
Chart data structures by type:
|
|
- bar_comparison: {"labels": [...], "before": [...], "after": [...]} OR {"labels": [...], "values": [...]}
|
|
- bar_horizontal: {"labels": [...], "values": [...]}
|
|
- line_trend: {"labels": [...], "values": [...]}
|
|
- pie: {"labels": [...], "values": [...]}
|
|
- stacked_bar: {"labels": [...], "stacks": [[...], [...]]}
|
|
- bullet_points: {"bullet_points": [...]}
|
|
|
|
Rules:
|
|
1. Use the research data to create accurate, fact-based charts
|
|
2. Prefer bar_comparison for before/after or categorical comparisons
|
|
3. Prefer line_trend for trends over time
|
|
4. Prefer pie for market share or proportional breakdowns
|
|
5. Keep labels short (under 20 characters)
|
|
6. Use realistic values from the research — do NOT invent numbers
|
|
7. Always include a source attribution based on where the data came from
|
|
8. If the research doesn't contain useful numeric data, fall back to bullet_points with key insights"""
|
|
|
|
|
|
CHART_SYNTHESIS_USER_PROMPT = """Original text:
|
|
{text}
|
|
|
|
Research data found:
|
|
{research}
|
|
|
|
Create a chart that visualizes the most interesting data insight from the combination above."""
|
|
|
|
|
|
def _normalize_chart_type(chart_type: str) -> str:
|
|
"""Normalize chart type aliases."""
|
|
mapping = {
|
|
"bar_chart_comparison": "bar_comparison",
|
|
"bullet": "bullet_points",
|
|
}
|
|
return mapping.get(chart_type, chart_type)
|
|
|
|
|
|
def _add_source_overlay(image_path: str, source: str) -> None:
|
|
"""Add a source attribution overlay to a chart image (in-place)."""
|
|
if not source or not os.path.exists(image_path):
|
|
return
|
|
try:
|
|
img = Image.open(image_path).convert("RGBA")
|
|
draw = ImageDraw.Draw(img)
|
|
source_text = f"Source: {source[:80]}"
|
|
try:
|
|
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 11)
|
|
except (OSError, IOError):
|
|
try:
|
|
font = ImageFont.truetype("arial.ttf", 11)
|
|
except (OSError, IOError):
|
|
font = ImageFont.load_default()
|
|
text_bbox = draw.textbbox((0, 0), source_text, font=font)
|
|
text_w = text_bbox[2] - text_bbox[0]
|
|
text_h = text_bbox[3] - text_bbox[1]
|
|
x = img.width - text_w - 12
|
|
y = img.height - text_h - 8
|
|
draw.rectangle([x - 4, y - 2, x + text_w + 4, y + text_h + 2], fill=(0, 0, 0, 140))
|
|
draw.text((x, y), source_text, fill=(200, 200, 200, 220), font=font)
|
|
img.save(image_path)
|
|
except Exception as e:
|
|
logger.warning(f"[ChartService] Source overlay failed (non-fatal): {e}")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Chart generators (Matplotlib → PNG with transparency)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def make_bar_chart(data: dict, out_path: str, title: str = "",
|
|
show_legend: bool = True, value_suffix: str = "%",
|
|
subtitle: str = "") -> str:
|
|
labels = data.get("labels", [])
|
|
before = data.get("before", [])
|
|
after = data.get("after", [])
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 4.5), facecolor="none")
|
|
ax.set_facecolor("none")
|
|
|
|
if not before and not after:
|
|
values = data.get("values", [])
|
|
if values and labels:
|
|
n = min(len(labels), len(values))
|
|
labels = labels[:n]
|
|
before = [0] * n
|
|
after = values[:n]
|
|
data = {**data, "labels": labels, "before": before, "after": after}
|
|
|
|
x = np.arange(len(labels))
|
|
w = 0.35
|
|
bars_b = ax.bar(x - w / 2, before, w, color=CHART_STYLE["bar_before"],
|
|
label="Before", zorder=3, edgecolor="none")
|
|
bars_a = ax.bar(x + w / 2, after, w, color=CHART_STYLE["bar_after"],
|
|
label="After", zorder=3, edgecolor="none")
|
|
|
|
ax.set_xticks(x)
|
|
ax.set_xticklabels(labels, color=CHART_STYLE["text"], fontsize=11)
|
|
ax.tick_params(axis="y", colors=CHART_STYLE["text"])
|
|
ax.spines[:].set_visible(False)
|
|
ax.yaxis.grid(True, color=CHART_STYLE["grid"], linewidth=0.6, zorder=0)
|
|
ax.set_axisbelow(True)
|
|
|
|
for bar in [*bars_b, *bars_a]:
|
|
h = bar.get_height()
|
|
ax.text(bar.get_x() + bar.get_width() / 2, h + 0.5, f"{h:.0f}{value_suffix}",
|
|
ha="center", va="bottom", color=CHART_STYLE["text"], fontsize=9,
|
|
fontweight="bold")
|
|
|
|
if show_legend:
|
|
ax.legend(frameon=False, labelcolor=CHART_STYLE["text"],
|
|
fontsize=10, loc="upper left")
|
|
|
|
if title:
|
|
ax.set_title(title, color=CHART_STYLE["text"], fontsize=13,
|
|
fontweight="bold", pad=12)
|
|
if subtitle:
|
|
fig.text(0.5, 0.02, subtitle, ha='center', color=CHART_STYLE["text"],
|
|
fontsize=10, style='italic')
|
|
|
|
fig.tight_layout(pad=0.5, rect=(0, 0.03 if subtitle else 0, 1, 1))
|
|
fig.savefig(out_path, dpi=150, transparent=True, bbox_inches="tight")
|
|
plt.close(fig)
|
|
return out_path
|
|
|
|
|
|
def make_horizontal_bar(data: dict, out_path: str, title: str = "",
|
|
value_suffix: str = "%", bar_color: str = None) -> str:
|
|
labels = data.get("labels", [])
|
|
values = data.get("values", data.get("y", []))
|
|
|
|
if not values:
|
|
return ""
|
|
|
|
bar_color = bar_color or CHART_STYLE["bar_after"]
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 4.5), facecolor="none")
|
|
ax.set_facecolor("none")
|
|
|
|
y_pos = np.arange(len(labels))
|
|
bars = ax.barh(y_pos, values, color=bar_color, zorder=3, edgecolor="none", height=0.6)
|
|
|
|
ax.set_yticks(y_pos)
|
|
ax.set_yticklabels(labels, color=CHART_STYLE["text"], fontsize=11)
|
|
ax.tick_params(axis="x", colors=CHART_STYLE["text"])
|
|
ax.spines[:].set_visible(False)
|
|
ax.xaxis.grid(True, color=CHART_STYLE["grid"], linewidth=0.6, zorder=0)
|
|
ax.set_axisbelow(True)
|
|
ax.invert_yaxis()
|
|
|
|
for i, bar in enumerate(bars):
|
|
width = bar.get_width()
|
|
ax.text(width + 0.5, bar.get_y() + bar.get_height()/2, f"{width:.0f}{value_suffix}",
|
|
ha="left", va="center", color=CHART_STYLE["text"], fontsize=10,
|
|
fontweight="bold")
|
|
|
|
if title:
|
|
ax.set_title(title, color=CHART_STYLE["text"], fontsize=13,
|
|
fontweight="bold", pad=12)
|
|
|
|
fig.tight_layout(pad=0.5)
|
|
fig.savefig(out_path, dpi=150, transparent=True, bbox_inches="tight")
|
|
plt.close(fig)
|
|
return out_path
|
|
|
|
|
|
def make_pie_chart(data: dict, out_path: str, title: str = "",
|
|
show_labels: bool = True, show_percent: bool = True,
|
|
donut: bool = False) -> str:
|
|
labels = data.get("labels", [])
|
|
values = data.get("values", data.get("y", []))
|
|
|
|
if not values:
|
|
return ""
|
|
|
|
colors = CHART_STYLE["pie_colors"][:len(values)]
|
|
|
|
fig, ax = plt.subplots(figsize=(6, 4.5), facecolor="none")
|
|
ax.set_facecolor("none")
|
|
|
|
if donut:
|
|
wedges, texts, autotexts = ax.pie(
|
|
values, labels=labels if show_labels else None,
|
|
colors=colors, autopct=lambda p: f'{p:.1f}%' if show_percent else '',
|
|
startangle=90, pctdistance=0.75,
|
|
wedgeprops=dict(width=0.5, edgecolor="none")
|
|
)
|
|
else:
|
|
wedges, texts, autotexts = ax.pie(
|
|
values, labels=labels if show_labels else None,
|
|
colors=colors, autopct=lambda p: f'{p:.1f}%' if show_percent else '',
|
|
startangle=90, pctdistance=0.8
|
|
)
|
|
|
|
for text in texts:
|
|
text.set_color(CHART_STYLE["text"])
|
|
text.set_fontsize(10)
|
|
|
|
for autotext in autotexts:
|
|
autotext.set_color(CHART_STYLE["text"])
|
|
autotext.set_fontsize(9)
|
|
autotext.set_fontweight("bold")
|
|
|
|
if title:
|
|
ax.set_title(title, color=CHART_STYLE["text"], fontsize=13,
|
|
fontweight="bold", pad=12)
|
|
|
|
fig.tight_layout(pad=0.5)
|
|
fig.savefig(out_path, dpi=150, transparent=True, bbox_inches="tight")
|
|
plt.close(fig)
|
|
return out_path
|
|
|
|
|
|
def make_stacked_bar(data: dict, out_path: str, title: str = "",
|
|
stack_labels: list = None) -> str:
|
|
labels = data.get("labels", [])
|
|
stacks = data.get("stacks", [])
|
|
|
|
if not stacks or len(stacks) < 2:
|
|
return ""
|
|
|
|
stack_labels = stack_labels or [f"Series {i+1}" for i in range(len(stacks))]
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 4.5), facecolor="none")
|
|
ax.set_facecolor("none")
|
|
|
|
x = np.arange(len(labels))
|
|
bottom = np.zeros(len(labels))
|
|
colors = CHART_STYLE["pie_colors"][:len(stacks)]
|
|
|
|
for i, stack in enumerate(stacks):
|
|
bars = ax.bar(x, stack, 0.6, bottom=bottom, color=colors[i],
|
|
label=stack_labels[i], zorder=3, edgecolor="none")
|
|
|
|
for j, bar in enumerate(bars):
|
|
height = bar.get_height()
|
|
if height > 5:
|
|
ax.text(bar.get_x() + bar.get_width()/2,
|
|
bottom[j] + height/2,
|
|
f"{height:.0f}", ha="center", va="center",
|
|
color=CHART_STYLE["text"], fontsize=8, fontweight="bold")
|
|
|
|
bottom = bottom + np.array(stack)
|
|
|
|
ax.set_xticks(x)
|
|
ax.set_xticklabels(labels, color=CHART_STYLE["text"], fontsize=11)
|
|
ax.tick_params(axis="y", colors=CHART_STYLE["text"])
|
|
ax.spines[:].set_visible(False)
|
|
ax.legend(frameon=False, labelcolor=CHART_STYLE["text"], fontsize=9, loc="upper left")
|
|
|
|
if title:
|
|
ax.set_title(title, color=CHART_STYLE["text"], fontsize=13,
|
|
fontweight="bold", pad=12)
|
|
|
|
fig.tight_layout(pad=0.5)
|
|
fig.savefig(out_path, dpi=150, transparent=True, bbox_inches="tight")
|
|
plt.close(fig)
|
|
return out_path
|
|
|
|
|
|
def make_line_trend(data: dict, out_path: str, title: str = "") -> str:
|
|
x_labels = data.get("labels", data.get("x", []))
|
|
y_vals = data.get("values", data.get("y", []))
|
|
|
|
if not x_labels or not y_vals:
|
|
return ""
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 4.5), facecolor="none")
|
|
ax.set_facecolor("none")
|
|
|
|
try:
|
|
x_vals = [float(v) for v in x_labels]
|
|
except (ValueError, TypeError):
|
|
x_vals = list(range(len(x_labels)))
|
|
|
|
ax.plot(x_vals, y_vals, color=CHART_STYLE["accent"],
|
|
linewidth=2.5, marker="o", markersize=7, zorder=3)
|
|
ax.fill_between(x_vals, y_vals, alpha=0.12, color=CHART_STYLE["accent"])
|
|
ax.spines[:].set_visible(False)
|
|
ax.tick_params(colors=CHART_STYLE["text"])
|
|
ax.yaxis.grid(True, color=CHART_STYLE["grid"], linewidth=0.6, zorder=0)
|
|
|
|
try:
|
|
x_labels_f = [float(v) for v in x_labels]
|
|
except (ValueError, TypeError):
|
|
ax.set_xticks(x_vals)
|
|
ax.set_xticklabels(x_labels, color=CHART_STYLE["text"], fontsize=10)
|
|
|
|
if title:
|
|
ax.set_title(title, color=CHART_STYLE["text"], fontsize=13,
|
|
fontweight="bold", pad=12)
|
|
fig.tight_layout(pad=0.5)
|
|
fig.savefig(out_path, dpi=150, transparent=True, bbox_inches="tight")
|
|
plt.close(fig)
|
|
return out_path
|
|
|
|
|
|
def make_bullet_overlay(lines: list, out_path: str,
|
|
width: int = 900, font_size: int = 32) -> str:
|
|
padding = 32
|
|
line_h = font_size + 16
|
|
img_h = padding * 2 + len(lines) * line_h + 12
|
|
img = Image.new("RGBA", (width, img_h), (0, 0, 0, 0))
|
|
draw = ImageDraw.Draw(img)
|
|
|
|
draw.rounded_rectangle([0, 0, width - 1, img_h - 1],
|
|
radius=18, fill=(10, 10, 10, 185))
|
|
|
|
try:
|
|
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
|
|
font_size)
|
|
except OSError:
|
|
font = ImageFont.load_default()
|
|
|
|
y = padding
|
|
for line in lines:
|
|
draw.text((padding + 18, y), f"\u2022 {line}", font=font, fill=(241, 241, 239, 255))
|
|
y += line_h
|
|
|
|
img.save(out_path, format="PNG")
|
|
return out_path
|
|
|
|
|
|
CHART_RENDERERS = {
|
|
"bar_comparison": make_bar_chart,
|
|
"bar_chart_comparison": make_bar_chart,
|
|
"bar_horizontal": make_horizontal_bar,
|
|
"line_trend": make_line_trend,
|
|
"pie": make_pie_chart,
|
|
"stacked_bar": make_stacked_bar,
|
|
"bullet_points": make_bullet_overlay,
|
|
"bullet": make_bullet_overlay,
|
|
}
|
|
|
|
|
|
class ChartService:
|
|
"""Shared chart generation service for all modules."""
|
|
|
|
def __init__(self, output_dir: Optional[str] = None, user_id: Optional[str] = None):
|
|
if output_dir:
|
|
self.output_dir = Path(output_dir)
|
|
else:
|
|
self.output_dir = self._default_chart_dir(user_id)
|
|
|
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
logger.info(f"[ChartService] Initialized with output directory: {self.output_dir}")
|
|
|
|
@staticmethod
|
|
def _default_chart_dir(user_id: Optional[str] = None) -> Path:
|
|
"""Get default chart directory (workspace-aware if user_id provided)."""
|
|
if user_id:
|
|
try:
|
|
from api.podcast.constants import get_podcast_media_dir
|
|
return get_podcast_media_dir("chart", user_id, ensure_exists=True)
|
|
except Exception:
|
|
pass
|
|
base = Path.home() / ".alwrity" / "charts"
|
|
base.mkdir(parents=True, exist_ok=True)
|
|
return base
|
|
|
|
def get_output_path(self, filename: str) -> Path:
|
|
return self.output_dir / filename
|
|
|
|
def get_chart_preview_path(self, chart_id: str) -> Path:
|
|
return self.get_output_path(f"chart_preview_{chart_id}.png")
|
|
|
|
def generate_chart(
|
|
self,
|
|
chart_data: Dict[str, Any],
|
|
chart_type: str = "bar_comparison",
|
|
title: str = "",
|
|
subtitle: str = "",
|
|
chart_id: Optional[str] = None,
|
|
) -> Dict[str, str]:
|
|
"""
|
|
Generate a chart PNG and return metadata.
|
|
|
|
Returns:
|
|
{"path": str, "chart_id": str, "filename": str}
|
|
Returns {"path": "", "chart_id": str, "filename": ""} on failure.
|
|
"""
|
|
resolved_id = chart_id or uuid.uuid4().hex[:8]
|
|
out_path = str(self.get_chart_preview_path(resolved_id))
|
|
normalized_type = _normalize_chart_type(chart_type)
|
|
|
|
logger.info(f"[ChartService] Generating chart: type={normalized_type}, id={resolved_id}")
|
|
|
|
try:
|
|
result_path = self._render_chart(normalized_type, chart_data, out_path, title, subtitle)
|
|
|
|
if not result_path or not os.path.exists(result_path):
|
|
logger.warning(f"[ChartService] Chart rendering returned empty path or file missing for type={normalized_type}")
|
|
return {"path": "", "chart_id": resolved_id, "filename": ""}
|
|
|
|
source = chart_data.get("source", "").strip()
|
|
if source:
|
|
_add_source_overlay(result_path, source)
|
|
|
|
filename = Path(result_path).name
|
|
logger.info(f"[ChartService] Chart generated: id={resolved_id}, path={result_path}")
|
|
return {"path": result_path, "chart_id": resolved_id, "filename": filename}
|
|
|
|
except Exception as e:
|
|
logger.error(f"[ChartService] Chart generation failed: {e}")
|
|
return {"path": "", "chart_id": resolved_id, "filename": ""}
|
|
|
|
def _render_chart(self, chart_type: str, chart_data: Dict[str, Any],
|
|
out_path: str, title: str, subtitle: str) -> str:
|
|
"""Dispatch to the appropriate chart renderer."""
|
|
|
|
if chart_type in ("bar_comparison", "bar_chart_comparison"):
|
|
labels = chart_data.get("labels", [])
|
|
before = chart_data.get("before", [])
|
|
after = chart_data.get("after", [])
|
|
if not before and not after:
|
|
values = chart_data.get("values", [])
|
|
if values and labels:
|
|
n = min(len(labels), len(values))
|
|
chart_data = {**chart_data, "labels": labels[:n], "before": [0] * n, "after": values[:n]}
|
|
return make_bar_chart(chart_data, out_path, title, subtitle=subtitle)
|
|
|
|
elif chart_type == "bar_horizontal":
|
|
return make_horizontal_bar(chart_data, out_path, title)
|
|
|
|
elif chart_type == "line_trend":
|
|
return make_line_trend(chart_data, out_path, title)
|
|
|
|
elif chart_type == "pie":
|
|
return make_pie_chart(chart_data, out_path, title)
|
|
|
|
elif chart_type == "stacked_bar":
|
|
return make_stacked_bar(chart_data, out_path, title)
|
|
|
|
elif chart_type in ("bullet", "bullet_points"):
|
|
bullet_points = chart_data.get("bullet_points", chart_data.get("labels", []))
|
|
if bullet_points:
|
|
return make_bullet_overlay(bullet_points, out_path)
|
|
return ""
|
|
|
|
else:
|
|
logger.warning(f"[ChartService] Unknown chart type: {chart_type}, falling back to bar_comparison")
|
|
return make_bar_chart(chart_data, out_path, title, subtitle=subtitle)
|
|
|
|
def infer_chart_from_text(self, text: str, user_id: Optional[str] = None) -> Dict[str, Any]:
|
|
"""
|
|
Use LLM to infer chart_type and chart_data from text.
|
|
|
|
Returns:
|
|
{"chart_type": str, "chart_data": dict, "title": str}
|
|
Falls back to bullet_points with key sentences extracted from text.
|
|
"""
|
|
try:
|
|
prompt = CHART_INFERENCE_USER_PROMPT.format(text=text[:3000])
|
|
result = llm_text_gen(
|
|
prompt=prompt,
|
|
system_prompt=CHART_INFERENCE_SYSTEM_PROMPT,
|
|
json_struct=None,
|
|
max_tokens=2000,
|
|
user_id=user_id,
|
|
)
|
|
|
|
if isinstance(result, dict) and result.get("text"):
|
|
raw = result["text"]
|
|
else:
|
|
raw = str(result) if result else ""
|
|
|
|
import json
|
|
import re
|
|
raw = raw.strip()
|
|
if raw.startswith("```"):
|
|
match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", raw, re.DOTALL)
|
|
if match:
|
|
raw = match.group(1)
|
|
|
|
parsed = json.loads(raw)
|
|
|
|
chart_type = parsed.get("chart_type", "bullet_points")
|
|
chart_data = parsed.get("chart_data", {})
|
|
title = parsed.get("title", "")
|
|
|
|
if chart_type not in VALID_CHART_TYPES:
|
|
chart_type = _normalize_chart_type(chart_type)
|
|
if chart_type not in VALID_CHART_TYPES:
|
|
chart_type = "bullet_points"
|
|
|
|
logger.info(f"[ChartService] Inferred chart: type={chart_type}, title={title}")
|
|
return {"chart_type": chart_type, "chart_data": chart_data, "title": title}
|
|
|
|
except Exception as e:
|
|
logger.error(f"[ChartService] Chart inference failed: {e}")
|
|
sentences = [s.strip() for s in text.replace(".", ". ").split(". ") if len(s.strip()) > 10][:5]
|
|
return {
|
|
"chart_type": "bullet_points",
|
|
"chart_data": {"bullet_points": sentences or ["No data extracted"]},
|
|
"title": "Key Points",
|
|
}
|
|
|
|
async def _analyze_chart_potential(
|
|
self,
|
|
text: str,
|
|
section_heading: Optional[str] = None,
|
|
section_key_points: Optional[List[str]] = None,
|
|
user_id: Optional[str] = None,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Stage 1: Analyze whether text has enough data for a chart.
|
|
If not, suggest Exa search queries to find relevant data.
|
|
|
|
Returns:
|
|
{"has_data": bool, "data_description": str, "suggested_chart_type": str|null, "search_queries": [...]}
|
|
"""
|
|
key_points_text = ""
|
|
if section_key_points:
|
|
key_points_text = f"\n\nKey points:\n" + "\n".join(f"- {p}" for p in section_key_points[:5])
|
|
|
|
prompt = CHART_ANALYSIS_USER_PROMPT.format(
|
|
section_heading=section_heading or "Blog Section",
|
|
key_points_section=key_points_text,
|
|
text=text[:3000],
|
|
)
|
|
|
|
try:
|
|
result = llm_text_gen(
|
|
prompt=prompt,
|
|
system_prompt=CHART_ANALYSIS_SYSTEM_PROMPT,
|
|
json_struct=None,
|
|
max_tokens=1500,
|
|
user_id=user_id,
|
|
)
|
|
|
|
raw = result.get("text", "") if isinstance(result, dict) else str(result) if result else ""
|
|
|
|
import json
|
|
import re
|
|
raw = raw.strip()
|
|
if raw.startswith("```"):
|
|
match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", raw, re.DOTALL)
|
|
if match:
|
|
raw = match.group(1)
|
|
|
|
parsed = json.loads(raw)
|
|
|
|
has_data = parsed.get("has_data", False)
|
|
data_description = parsed.get("data_description", "")
|
|
suggested_chart_type = parsed.get("suggested_chart_type")
|
|
search_queries = parsed.get("search_queries", [])
|
|
|
|
if suggested_chart_type and suggested_chart_type not in VALID_CHART_TYPES:
|
|
suggested_chart_type = _normalize_chart_type(suggested_chart_type)
|
|
if suggested_chart_type not in VALID_CHART_TYPES:
|
|
suggested_chart_type = None
|
|
|
|
logger.info(f"[ChartService] Chart analysis: has_data={has_data}, queries={search_queries}")
|
|
return {
|
|
"has_data": has_data,
|
|
"data_description": data_description,
|
|
"suggested_chart_type": suggested_chart_type,
|
|
"search_queries": search_queries,
|
|
"warnings": [],
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"[ChartService] Chart analysis failed: {e}")
|
|
heading = section_heading or ""
|
|
words = text.split()[:10]
|
|
fallback_queries = [
|
|
f"{heading} statistics data",
|
|
f"{heading} trends report",
|
|
f"{' '.join(words)} statistics",
|
|
] if heading.strip() or text.strip() else []
|
|
return {
|
|
"has_data": False,
|
|
"data_description": f"Analysis failed: {e}",
|
|
"suggested_chart_type": None,
|
|
"search_queries": fallback_queries,
|
|
"warnings": [f"Chart analysis LLM call failed: {e}"],
|
|
}
|
|
|
|
async def _search_for_chart_data(
|
|
self,
|
|
queries: List[str],
|
|
section_heading: Optional[str] = None,
|
|
user_id: Optional[str] = None,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Stage 2: Use Exa search to find relevant statistics and data for chart creation.
|
|
|
|
Returns:
|
|
{"research": str, "warnings": list[str]}
|
|
"""
|
|
if not queries:
|
|
return {"research": "", "warnings": []}
|
|
|
|
warnings = []
|
|
try:
|
|
from services.blog_writer.research.exa_provider import ExaResearchProvider
|
|
|
|
provider = ExaResearchProvider()
|
|
all_results = []
|
|
search_errors = 0
|
|
|
|
for query in queries[:3]:
|
|
try:
|
|
results = await provider.simple_search(
|
|
query=query,
|
|
num_results=3,
|
|
user_id=user_id,
|
|
)
|
|
all_results.extend(results)
|
|
except Exception as e:
|
|
search_errors += 1
|
|
logger.warning(f"[ChartService] Exa search for '{query}' failed: {e}")
|
|
continue
|
|
|
|
if search_errors == len(queries[:3]):
|
|
warnings.append("All Exa search queries failed — external data search unavailable. Chart may lack supporting data.")
|
|
|
|
if not all_results:
|
|
return {"research": "", "warnings": warnings}
|
|
|
|
research_parts = []
|
|
seen_urls = set()
|
|
for r in all_results:
|
|
url = r.get("url", "")
|
|
if url in seen_urls:
|
|
continue
|
|
seen_urls.add(url)
|
|
title = r.get("title", "Untitled")
|
|
text = r.get("text", "")[:500]
|
|
if text:
|
|
research_parts.append(f"- {title} ({url}): {text}")
|
|
|
|
if not research_parts:
|
|
return {"research": "", "warnings": warnings}
|
|
|
|
return {"research": "\n".join(research_parts), "warnings": warnings}
|
|
|
|
except ImportError:
|
|
msg = "Exa provider not available — skipping external data search."
|
|
logger.warning(f"[ChartService] {msg}")
|
|
warnings.append(msg)
|
|
return {"research": "", "warnings": warnings}
|
|
except Exception as e:
|
|
msg = f"Chart data search failed: {e}"
|
|
logger.error(f"[ChartService] {msg}")
|
|
warnings.append(msg)
|
|
return {"research": "", "warnings": warnings}
|
|
|
|
async def _synthesize_chart_from_research(
|
|
self,
|
|
text: str,
|
|
research: str,
|
|
section_heading: Optional[str] = None,
|
|
user_id: Optional[str] = None,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Stage 3: Generate chart spec from text + research data using LLM.
|
|
|
|
Returns:
|
|
{"chart_type": str, "chart_data": dict, "title": str, "source": str}
|
|
"""
|
|
try:
|
|
prompt = CHART_SYNTHESIS_USER_PROMPT.format(
|
|
text=text[:2000],
|
|
research=research[:3000],
|
|
)
|
|
|
|
result = llm_text_gen(
|
|
prompt=prompt,
|
|
system_prompt=CHART_SYNTHESIS_SYSTEM_PROMPT,
|
|
json_struct=None,
|
|
max_tokens=2000,
|
|
user_id=user_id,
|
|
)
|
|
|
|
raw = result.get("text", "") if isinstance(result, dict) else str(result) if result else ""
|
|
|
|
import json
|
|
import re
|
|
raw = raw.strip()
|
|
if raw.startswith("```"):
|
|
match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", raw, re.DOTALL)
|
|
if match:
|
|
raw = match.group(1)
|
|
|
|
parsed = json.loads(raw)
|
|
|
|
chart_type = parsed.get("chart_type", "bullet_points")
|
|
chart_data = parsed.get("chart_data", {})
|
|
title = parsed.get("title", "")
|
|
source = parsed.get("source", "")
|
|
|
|
if chart_type not in VALID_CHART_TYPES:
|
|
chart_type = _normalize_chart_type(chart_type)
|
|
if chart_type not in VALID_CHART_TYPES:
|
|
chart_type = "bullet_points"
|
|
|
|
if source and isinstance(chart_data, dict):
|
|
chart_data["source"] = source
|
|
|
|
logger.info(f"[ChartService] Synthesized chart: type={chart_type}, title={title}")
|
|
return {"chart_type": chart_type, "chart_data": chart_data, "title": title}
|
|
|
|
except Exception as e:
|
|
logger.error(f"[ChartService] Chart synthesis failed: {e}")
|
|
sentences = [s.strip() for s in text.replace(".", ". ").split(". ") if len(s.strip()) > 10][:5]
|
|
return {
|
|
"chart_type": "bullet_points",
|
|
"chart_data": {"bullet_points": sentences or ["No data available"]},
|
|
"title": section_heading or "Key Points",
|
|
}
|
|
|
|
async def infer_chart_with_research(
|
|
self,
|
|
text: str,
|
|
section_heading: Optional[str] = None,
|
|
section_key_points: Optional[List[str]] = None,
|
|
user_id: Optional[str] = None,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
3-stage chart inference pipeline:
|
|
1. Analyze text for chart potential — does it have data? If not, what to search for?
|
|
2. If no data, search Exa for relevant statistics.
|
|
3. Synthesize chart spec from text + research data.
|
|
|
|
Returns:
|
|
{"chart_type": str, "chart_data": dict, "title": str, "warnings": list[str]}
|
|
"""
|
|
warnings = []
|
|
logger.info(f"[ChartService] infer_chart_with_research: heading={section_heading}, text_len={len(text)}, user={user_id}")
|
|
|
|
# Stage 1: Analyze
|
|
analysis = await self._analyze_chart_potential(
|
|
text=text,
|
|
section_heading=section_heading,
|
|
section_key_points=section_key_points,
|
|
user_id=user_id,
|
|
)
|
|
warnings.extend(analysis.get("warnings", []))
|
|
|
|
if analysis.get("has_data") and analysis.get("suggested_chart_type"):
|
|
# Text has enough data — do direct inference
|
|
logger.info("[ChartService] Text has sufficient data, using direct inference")
|
|
result = self.infer_chart_from_text(text, user_id=user_id)
|
|
if analysis.get("suggested_chart_type") and result.get("chart_type") == "bullet_points":
|
|
result["chart_type"] = analysis["suggested_chart_type"]
|
|
result["warnings"] = warnings
|
|
return result
|
|
|
|
# Stage 2: Search for data
|
|
search_queries = analysis.get("search_queries", [])
|
|
if not search_queries:
|
|
# Build queries from section heading + text keywords
|
|
heading = section_heading or ""
|
|
words = text.split()[:10]
|
|
search_queries = [
|
|
f"{heading} statistics data",
|
|
f"{heading} trends report",
|
|
f"{' '.join(words)} statistics",
|
|
]
|
|
|
|
logger.info(f"[ChartService] Searching Exa for chart data, queries: {search_queries}")
|
|
search_result = await self._search_for_chart_data(
|
|
queries=search_queries,
|
|
section_heading=section_heading,
|
|
user_id=user_id,
|
|
)
|
|
research = search_result.get("research", "")
|
|
warnings.extend(search_result.get("warnings", []))
|
|
|
|
if not research:
|
|
logger.warning("[ChartService] No research data found, falling back to text-only inference")
|
|
result = self.infer_chart_from_text(text, user_id=user_id)
|
|
result["warnings"] = warnings
|
|
return result
|
|
|
|
# Stage 3: Synthesize chart from text + research
|
|
logger.info("[ChartService] Synthesizing chart from text + research data")
|
|
result = await self._synthesize_chart_from_research(
|
|
text=text,
|
|
research=research,
|
|
section_heading=section_heading,
|
|
user_id=user_id,
|
|
)
|
|
result["warnings"] = warnings
|
|
return result
|
|
|
|
async def generate_chart_from_text(
|
|
self,
|
|
text: str,
|
|
user_id: Optional[str] = None,
|
|
chart_id: Optional[str] = None,
|
|
section_heading: Optional[str] = None,
|
|
section_key_points: Optional[List[str]] = None,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
End-to-end: analyze text, optionally research data, then infer and render chart.
|
|
|
|
Uses the 3-stage pipeline (analyze → search → synthesize) for richer charts
|
|
with real data from Exa when the original text lacks statistics.
|
|
|
|
Returns:
|
|
{"path": str, "chart_id": str, "filename": str, "chart_type": str, "chart_data": dict, "title": str}
|
|
"""
|
|
inference = await self.infer_chart_with_research(
|
|
text=text,
|
|
section_heading=section_heading,
|
|
section_key_points=section_key_points,
|
|
user_id=user_id,
|
|
)
|
|
result = self.generate_chart(
|
|
chart_data=inference["chart_data"],
|
|
chart_type=inference["chart_type"],
|
|
title=inference["title"],
|
|
chart_id=chart_id,
|
|
)
|
|
result["chart_type"] = inference["chart_type"]
|
|
result["chart_data"] = inference["chart_data"]
|
|
result["title"] = inference["title"]
|
|
result["warnings"] = inference.get("warnings", [])
|
|
return result
|
|
|
|
|
|
# Per-user service instances
|
|
_chart_service_instances: Dict[str, ChartService] = {}
|
|
|
|
|
|
def get_chart_service(output_dir: Optional[str] = None, user_id: Optional[str] = None) -> ChartService:
|
|
"""Get or create ChartService for the given user."""
|
|
cache_key = output_dir or user_id or "default"
|
|
if cache_key not in _chart_service_instances:
|
|
_chart_service_instances[cache_key] = ChartService(output_dir=output_dir, user_id=user_id)
|
|
return _chart_service_instances[cache_key] |