Added image generation to blog writer
This commit is contained in:
@@ -7,6 +7,7 @@ content creation, SEO analysis, and publishing.
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from typing import Any, Dict, List
|
||||
from pydantic import BaseModel, Field
|
||||
from loguru import logger
|
||||
|
||||
from models.blog_models import (
|
||||
@@ -29,6 +30,7 @@ from models.blog_models import (
|
||||
HallucinationCheckResponse,
|
||||
)
|
||||
from services.blog_writer.blog_service import BlogWriterService
|
||||
from services.blog_writer.seo.blog_seo_recommendation_applier import BlogSEORecommendationApplier
|
||||
from .task_manager import task_manager
|
||||
from .cache_manager import cache_manager
|
||||
from models.blog_models import MediumBlogGenerateRequest
|
||||
@@ -37,6 +39,44 @@ from models.blog_models import MediumBlogGenerateRequest
|
||||
router = APIRouter(prefix="/api/blog", tags=["AI Blog Writer"])
|
||||
|
||||
service = BlogWriterService()
|
||||
recommendation_applier = BlogSEORecommendationApplier()
|
||||
# ---------------------------
|
||||
# SEO Recommendation Endpoints
|
||||
# ---------------------------
|
||||
|
||||
|
||||
class RecommendationItem(BaseModel):
|
||||
category: str = Field(..., description="Recommendation category, e.g. Structure")
|
||||
priority: str = Field(..., description="Priority level: High | Medium | Low")
|
||||
recommendation: str = Field(..., description="Action to perform")
|
||||
impact: str = Field(..., description="Expected impact or rationale")
|
||||
|
||||
|
||||
class SEOApplyRecommendationsRequest(BaseModel):
|
||||
title: str = Field(..., description="Current blog title")
|
||||
sections: List[Dict[str, Any]] = Field(..., description="Array of sections with id, heading, content")
|
||||
outline: List[Dict[str, Any]] = Field(default_factory=list, description="Outline structure for context")
|
||||
research: Dict[str, Any] = Field(default_factory=dict, description="Research data used for the blog")
|
||||
recommendations: List[RecommendationItem] = Field(..., description="Actionable recommendations to apply")
|
||||
persona: Dict[str, Any] = Field(default_factory=dict, description="Persona settings if available")
|
||||
tone: str | None = Field(default=None, description="Desired tone override")
|
||||
audience: str | None = Field(default=None, description="Target audience override")
|
||||
|
||||
|
||||
@router.post("/seo/apply-recommendations")
|
||||
async def apply_seo_recommendations(request: SEOApplyRecommendationsRequest) -> Dict[str, Any]:
|
||||
"""Apply actionable SEO recommendations and return updated content."""
|
||||
try:
|
||||
result = await recommendation_applier.apply_recommendations(request.dict())
|
||||
if not result.get("success"):
|
||||
raise HTTPException(status_code=500, detail=result.get("error", "Failed to apply recommendations"))
|
||||
return result
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to apply SEO recommendations: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
@@ -92,7 +132,7 @@ async def start_outline_generation(request: BlogOutlineRequest) -> Dict[str, Any
|
||||
async def get_outline_status(task_id: str) -> Dict[str, Any]:
|
||||
"""Get the status of an outline generation operation."""
|
||||
try:
|
||||
status = task_manager.get_task_status(task_id)
|
||||
status = await task_manager.get_task_status(task_id)
|
||||
if status is None:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
@@ -164,6 +204,50 @@ async def generate_section(request: BlogSectionRequest) -> BlogSectionResponse:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/content/start")
|
||||
async def start_content_generation(request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Start full content generation and return a task id for polling.
|
||||
|
||||
Accepts a payload compatible with MediumBlogGenerateRequest to minimize duplication.
|
||||
"""
|
||||
try:
|
||||
# Map dict to MediumBlogGenerateRequest for reuse
|
||||
from models.blog_models import MediumBlogGenerateRequest, MediumSectionOutline, PersonaInfo
|
||||
sections = [MediumSectionOutline(**s) for s in request.get("sections", [])]
|
||||
persona = None
|
||||
if request.get("persona"):
|
||||
persona = PersonaInfo(**request.get("persona"))
|
||||
req = MediumBlogGenerateRequest(
|
||||
title=request.get("title", "Untitled Blog"),
|
||||
sections=sections,
|
||||
persona=persona,
|
||||
tone=request.get("tone"),
|
||||
audience=request.get("audience"),
|
||||
globalTargetWords=request.get("globalTargetWords", 1000),
|
||||
researchKeywords=request.get("researchKeywords") or request.get("keywords"),
|
||||
)
|
||||
task_id = task_manager.start_content_generation_task(req)
|
||||
return {"task_id": task_id, "status": "started"}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start content generation: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/content/status/{task_id}")
|
||||
async def content_generation_status(task_id: str) -> Dict[str, Any]:
|
||||
"""Poll status for content generation task."""
|
||||
try:
|
||||
status = await task_manager.get_task_status(task_id)
|
||||
if status is None:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return status
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get content generation status for {task_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/section/{section_id}/continuity")
|
||||
async def get_section_continuity(section_id: str) -> Dict[str, Any]:
|
||||
"""Fetch last computed continuity metrics for a section (if available)."""
|
||||
@@ -342,7 +426,7 @@ async def start_medium_generation(request: MediumBlogGenerateRequest):
|
||||
async def medium_generation_status(task_id: str):
|
||||
"""Poll status for medium blog generation task."""
|
||||
try:
|
||||
status = task_manager.get_task_status(task_id)
|
||||
status = await task_manager.get_task_status(task_id)
|
||||
if status is None:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return status
|
||||
@@ -366,7 +450,7 @@ async def start_blog_rewrite(request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def rewrite_status(task_id: str):
|
||||
"""Poll status for blog rewrite task."""
|
||||
try:
|
||||
status = service.task_manager.get_task_status(task_id)
|
||||
status = await service.task_manager.get_task_status(task_id)
|
||||
if status is None:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return status
|
||||
|
||||
@@ -133,6 +133,16 @@ class TaskManager:
|
||||
task_id = self.create_task("medium_generation")
|
||||
asyncio.create_task(self._run_medium_generation_task(task_id, request))
|
||||
return task_id
|
||||
|
||||
def start_content_generation_task(self, request: MediumBlogGenerateRequest) -> str:
|
||||
"""Start content generation (full blog via sections) with provider parity.
|
||||
|
||||
Internally reuses medium generator pipeline for now but tracked under
|
||||
distinct task_type 'content_generation' and same polling contract.
|
||||
"""
|
||||
task_id = self.create_task("content_generation")
|
||||
asyncio.create_task(self._run_medium_generation_task(task_id, request))
|
||||
return task_id
|
||||
|
||||
async def _run_research_task(self, task_id: str, request: BlogResearchRequest):
|
||||
"""Background task to run research and update status with progress messages."""
|
||||
|
||||
@@ -4,11 +4,11 @@ from typing import Dict, Any, List
|
||||
from ..models.story_models import FacebookStoryRequest, FacebookStoryResponse
|
||||
from .base_service import FacebookWriterBaseService
|
||||
try:
|
||||
from ...services.llm_providers.text_to_image_generation.gen_gemini_images import (
|
||||
generate_gemini_images_base64,
|
||||
)
|
||||
from ...services.llm_providers.main_image_generation import generate_image
|
||||
from base64 import b64encode
|
||||
except Exception:
|
||||
generate_gemini_images_base64 = None # type: ignore
|
||||
generate_image = None # type: ignore
|
||||
b64encode = None # type: ignore
|
||||
|
||||
|
||||
class FacebookStoryService(FacebookWriterBaseService):
|
||||
@@ -50,22 +50,29 @@ class FacebookStoryService(FacebookWriterBaseService):
|
||||
# Generate visual suggestions and engagement tips
|
||||
visual_suggestions = self._generate_visual_suggestions(actual_story_type, request.visual_options)
|
||||
engagement_tips = self._generate_engagement_tips("story")
|
||||
# Optional: generate one story image (9:16) using Gemini
|
||||
# Optional: generate one story image (9:16) using unified image generation
|
||||
images_base64: List[str] = []
|
||||
try:
|
||||
if generate_gemini_images_base64 is not None:
|
||||
if generate_image is not None and b64encode is not None:
|
||||
img_prompt = request.visual_options.background_image_prompt or (
|
||||
f"Facebook story background for {request.business_type}. "
|
||||
f"Style: {actual_tone}. Type: {actual_story_type}. Vertical mobile 9:16, high contrast, legible overlay space."
|
||||
)
|
||||
images_base64 = generate_gemini_images_base64(
|
||||
img_prompt,
|
||||
enhance_prompt=False,
|
||||
aspect_ratio="9:16",
|
||||
max_retries=2,
|
||||
initial_retry_delay=1.0,
|
||||
) or []
|
||||
except Exception:
|
||||
# Generate image using unified system (9:16 aspect ratio = 1080x1920)
|
||||
result = generate_image(
|
||||
prompt=img_prompt,
|
||||
options={
|
||||
"provider": "gemini", # Facebook stories use Gemini
|
||||
"width": 1080,
|
||||
"height": 1920,
|
||||
}
|
||||
)
|
||||
if result and result.image_bytes:
|
||||
# Convert bytes to base64
|
||||
image_b64 = b64encode(result.image_bytes).decode('utf-8')
|
||||
images_base64 = [image_b64]
|
||||
except Exception as e:
|
||||
# Log error but continue without images
|
||||
images_base64 = []
|
||||
|
||||
return FacebookStoryResponse(
|
||||
|
||||
217
backend/api/images.py
Normal file
217
backend/api/images.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from services.llm_providers.main_image_generation import generate_image
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/images", tags=["images"])
|
||||
logger = get_service_logger("api.images")
|
||||
|
||||
|
||||
class ImageGenerateRequest(BaseModel):
|
||||
prompt: str
|
||||
negative_prompt: Optional[str] = None
|
||||
provider: Optional[str] = Field(None, pattern="^(gemini|huggingface|stability)$")
|
||||
model: Optional[str] = None
|
||||
width: Optional[int] = Field(default=1024, ge=64, le=2048)
|
||||
height: Optional[int] = Field(default=1024, ge=64, le=2048)
|
||||
guidance_scale: Optional[float] = None
|
||||
steps: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
|
||||
class ImageGenerateResponse(BaseModel):
|
||||
success: bool = True
|
||||
image_base64: str
|
||||
width: int
|
||||
height: int
|
||||
provider: str
|
||||
model: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
|
||||
@router.post("/generate", response_model=ImageGenerateResponse)
|
||||
def generate(req: ImageGenerateRequest) -> ImageGenerateResponse:
|
||||
try:
|
||||
last_error: Optional[Exception] = None
|
||||
for attempt in range(2): # simple single retry
|
||||
try:
|
||||
result = generate_image(
|
||||
prompt=req.prompt,
|
||||
options={
|
||||
"negative_prompt": req.negative_prompt,
|
||||
"provider": req.provider,
|
||||
"model": req.model,
|
||||
"width": req.width,
|
||||
"height": req.height,
|
||||
"guidance_scale": req.guidance_scale,
|
||||
"steps": req.steps,
|
||||
"seed": req.seed,
|
||||
},
|
||||
)
|
||||
image_b64 = base64.b64encode(result.image_bytes).decode("utf-8")
|
||||
return ImageGenerateResponse(
|
||||
image_base64=image_b64,
|
||||
width=result.width,
|
||||
height=result.height,
|
||||
provider=result.provider,
|
||||
model=result.model,
|
||||
seed=result.seed,
|
||||
)
|
||||
except Exception as inner:
|
||||
last_error = inner
|
||||
logger.error(f"Image generation attempt {attempt+1} failed: {inner}")
|
||||
# On first failure, try provider auto-remap by clearing provider to let facade decide
|
||||
if attempt == 0 and req.provider:
|
||||
req.provider = None
|
||||
continue
|
||||
break
|
||||
raise last_error or RuntimeError("Unknown image generation error")
|
||||
except Exception as e:
|
||||
logger.error(f"Image generation failed: {e}")
|
||||
# Provide a clean, actionable message to the client
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Image generation service is temporarily unavailable or the connection was reset. Please try again."
|
||||
)
|
||||
|
||||
|
||||
class PromptSuggestion(BaseModel):
|
||||
prompt: str
|
||||
negative_prompt: Optional[str] = None
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
overlay_text: Optional[str] = None
|
||||
|
||||
|
||||
class ImagePromptSuggestRequest(BaseModel):
|
||||
provider: Optional[str] = Field(None, pattern="^(gemini|huggingface|stability)$")
|
||||
title: Optional[str] = None
|
||||
section: Optional[Dict[str, Any]] = None
|
||||
research: Optional[Dict[str, Any]] = None
|
||||
persona: Optional[Dict[str, Any]] = None
|
||||
include_overlay: Optional[bool] = True
|
||||
|
||||
|
||||
class ImagePromptSuggestResponse(BaseModel):
|
||||
suggestions: list[PromptSuggestion]
|
||||
|
||||
|
||||
@router.post("/suggest-prompts", response_model=ImagePromptSuggestResponse)
|
||||
def suggest_prompts(req: ImagePromptSuggestRequest) -> ImagePromptSuggestResponse:
|
||||
try:
|
||||
provider = (req.provider or ("gemini" if (os.getenv("GPT_PROVIDER") or "").lower().startswith("gemini") else "huggingface")).lower()
|
||||
section = req.section or {}
|
||||
title = (req.title or section.get("heading") or "").strip()
|
||||
subheads = section.get("subheadings", []) or []
|
||||
key_points = section.get("key_points", []) or []
|
||||
keywords = section.get("keywords", []) or []
|
||||
if not keywords and req.research:
|
||||
keywords = (
|
||||
req.research.get("keywords", {}).get("primary_keywords")
|
||||
or req.research.get("keywords", {}).get("primary")
|
||||
or []
|
||||
)
|
||||
|
||||
persona = req.persona or {}
|
||||
audience = persona.get("audience", "content creators and digital marketers")
|
||||
industry = persona.get("industry", req.research.get("domain") if req.research else "your industry")
|
||||
tone = persona.get("tone", "professional, trustworthy")
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"suggestions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {"type": "string"},
|
||||
"negative_prompt": {"type": "string"},
|
||||
"width": {"type": "number"},
|
||||
"height": {"type": "number"},
|
||||
"overlay_text": {"type": "string"},
|
||||
},
|
||||
"required": ["prompt"]
|
||||
},
|
||||
"minItems": 3,
|
||||
"maxItems": 5
|
||||
}
|
||||
},
|
||||
"required": ["suggestions"]
|
||||
}
|
||||
|
||||
system = (
|
||||
"You are an expert image prompt engineer for text-to-image models. "
|
||||
"Given blog section context, craft 3-5 hyper-personalized prompts optimized for the specified provider. "
|
||||
"Return STRICT JSON matching the provided schema, no extra text."
|
||||
)
|
||||
|
||||
provider_guidance = {
|
||||
"huggingface": "Photorealistic Flux 1 Krea Dev; include camera/lighting cues (e.g., 50mm, f/2.8, rim light).",
|
||||
"gemini": "Editorial, brand-safe, crisp edges, balanced lighting; avoid artifacts.",
|
||||
"stability": "SDXL coherent details, sharp focus, cinematic contrast; readable text if present."
|
||||
}.get(provider, "")
|
||||
|
||||
best_practices = (
|
||||
"Best Practices: one clear focal subject; clean, uncluttered background; rule-of-thirds or center-weighted composition; "
|
||||
"text-safe margins if overlay text is included; neutral lighting if unsure; realistic skin tones; avoid busy patterns; "
|
||||
"no brand logos or watermarks; no copyrighted characters; avoid low-res, blur, noise, banding, oversaturation, over-sharpening; "
|
||||
"ensure hands and text are coherent if present; prefer 1024px+ on shortest side for quality."
|
||||
)
|
||||
|
||||
# Harvest a few concise facts from research if available
|
||||
facts: list[str] = []
|
||||
try:
|
||||
if req.research:
|
||||
# try common shapes used in research service
|
||||
top_stats = req.research.get("key_facts") or req.research.get("highlights") or []
|
||||
if isinstance(top_stats, list):
|
||||
facts = [str(x) for x in top_stats[:3]]
|
||||
elif isinstance(top_stats, dict):
|
||||
facts = [f"{k}: {v}" for k, v in list(top_stats.items())[:3]]
|
||||
except Exception:
|
||||
facts = []
|
||||
|
||||
facts_line = ", ".join(facts) if facts else ""
|
||||
|
||||
overlay_hint = "Include an on-image short title or fact if it improves communication; ensure clean, high-contrast safe area for text." if (req.include_overlay is None or req.include_overlay) else "Do not include on-image text."
|
||||
|
||||
prompt = f"""
|
||||
Provider: {provider}
|
||||
Title: {title}
|
||||
Subheadings: {', '.join(subheads[:5])}
|
||||
Key Points: {', '.join(key_points[:5])}
|
||||
Keywords: {', '.join([str(k) for k in keywords[:8]])}
|
||||
Research Facts: {facts_line}
|
||||
Audience: {audience}
|
||||
Industry: {industry}
|
||||
Tone: {tone}
|
||||
|
||||
Craft prompts that visually reflect this exact section (not generic blog topic). {provider_guidance}
|
||||
{best_practices}
|
||||
{overlay_hint}
|
||||
Include a suitable negative_prompt where helpful. Suggest width/height when relevant (e.g., 1024x1024 or 1920x1080).
|
||||
If including on-image text, return it in overlay_text (short: <= 8 words).
|
||||
"""
|
||||
|
||||
raw = llm_text_gen(prompt=prompt, system_prompt=system, json_struct=schema)
|
||||
data = raw if isinstance(raw, dict) else {}
|
||||
suggestions = data.get("suggestions") or []
|
||||
# basic fallback if provider returns string
|
||||
if not suggestions and isinstance(raw, str):
|
||||
suggestions = [{"prompt": raw}]
|
||||
|
||||
return ImagePromptSuggestResponse(suggestions=[PromptSuggestion(**s) for s in suggestions])
|
||||
except Exception as e:
|
||||
logger.error(f"Prompt suggestion failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
Reference in New Issue
Block a user