Added image generation to blog writer
This commit is contained in:
@@ -7,6 +7,7 @@ content creation, SEO analysis, and publishing.
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from typing import Any, Dict, List
|
||||
from pydantic import BaseModel, Field
|
||||
from loguru import logger
|
||||
|
||||
from models.blog_models import (
|
||||
@@ -29,6 +30,7 @@ from models.blog_models import (
|
||||
HallucinationCheckResponse,
|
||||
)
|
||||
from services.blog_writer.blog_service import BlogWriterService
|
||||
from services.blog_writer.seo.blog_seo_recommendation_applier import BlogSEORecommendationApplier
|
||||
from .task_manager import task_manager
|
||||
from .cache_manager import cache_manager
|
||||
from models.blog_models import MediumBlogGenerateRequest
|
||||
@@ -37,6 +39,44 @@ from models.blog_models import MediumBlogGenerateRequest
|
||||
router = APIRouter(prefix="/api/blog", tags=["AI Blog Writer"])
|
||||
|
||||
service = BlogWriterService()
|
||||
recommendation_applier = BlogSEORecommendationApplier()
|
||||
# ---------------------------
|
||||
# SEO Recommendation Endpoints
|
||||
# ---------------------------
|
||||
|
||||
|
||||
class RecommendationItem(BaseModel):
|
||||
category: str = Field(..., description="Recommendation category, e.g. Structure")
|
||||
priority: str = Field(..., description="Priority level: High | Medium | Low")
|
||||
recommendation: str = Field(..., description="Action to perform")
|
||||
impact: str = Field(..., description="Expected impact or rationale")
|
||||
|
||||
|
||||
class SEOApplyRecommendationsRequest(BaseModel):
|
||||
title: str = Field(..., description="Current blog title")
|
||||
sections: List[Dict[str, Any]] = Field(..., description="Array of sections with id, heading, content")
|
||||
outline: List[Dict[str, Any]] = Field(default_factory=list, description="Outline structure for context")
|
||||
research: Dict[str, Any] = Field(default_factory=dict, description="Research data used for the blog")
|
||||
recommendations: List[RecommendationItem] = Field(..., description="Actionable recommendations to apply")
|
||||
persona: Dict[str, Any] = Field(default_factory=dict, description="Persona settings if available")
|
||||
tone: str | None = Field(default=None, description="Desired tone override")
|
||||
audience: str | None = Field(default=None, description="Target audience override")
|
||||
|
||||
|
||||
@router.post("/seo/apply-recommendations")
|
||||
async def apply_seo_recommendations(request: SEOApplyRecommendationsRequest) -> Dict[str, Any]:
|
||||
"""Apply actionable SEO recommendations and return updated content."""
|
||||
try:
|
||||
result = await recommendation_applier.apply_recommendations(request.dict())
|
||||
if not result.get("success"):
|
||||
raise HTTPException(status_code=500, detail=result.get("error", "Failed to apply recommendations"))
|
||||
return result
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to apply SEO recommendations: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
@@ -92,7 +132,7 @@ async def start_outline_generation(request: BlogOutlineRequest) -> Dict[str, Any
|
||||
async def get_outline_status(task_id: str) -> Dict[str, Any]:
|
||||
"""Get the status of an outline generation operation."""
|
||||
try:
|
||||
status = task_manager.get_task_status(task_id)
|
||||
status = await task_manager.get_task_status(task_id)
|
||||
if status is None:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
@@ -164,6 +204,50 @@ async def generate_section(request: BlogSectionRequest) -> BlogSectionResponse:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/content/start")
|
||||
async def start_content_generation(request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Start full content generation and return a task id for polling.
|
||||
|
||||
Accepts a payload compatible with MediumBlogGenerateRequest to minimize duplication.
|
||||
"""
|
||||
try:
|
||||
# Map dict to MediumBlogGenerateRequest for reuse
|
||||
from models.blog_models import MediumBlogGenerateRequest, MediumSectionOutline, PersonaInfo
|
||||
sections = [MediumSectionOutline(**s) for s in request.get("sections", [])]
|
||||
persona = None
|
||||
if request.get("persona"):
|
||||
persona = PersonaInfo(**request.get("persona"))
|
||||
req = MediumBlogGenerateRequest(
|
||||
title=request.get("title", "Untitled Blog"),
|
||||
sections=sections,
|
||||
persona=persona,
|
||||
tone=request.get("tone"),
|
||||
audience=request.get("audience"),
|
||||
globalTargetWords=request.get("globalTargetWords", 1000),
|
||||
researchKeywords=request.get("researchKeywords") or request.get("keywords"),
|
||||
)
|
||||
task_id = task_manager.start_content_generation_task(req)
|
||||
return {"task_id": task_id, "status": "started"}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start content generation: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/content/status/{task_id}")
|
||||
async def content_generation_status(task_id: str) -> Dict[str, Any]:
|
||||
"""Poll status for content generation task."""
|
||||
try:
|
||||
status = await task_manager.get_task_status(task_id)
|
||||
if status is None:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return status
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get content generation status for {task_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/section/{section_id}/continuity")
|
||||
async def get_section_continuity(section_id: str) -> Dict[str, Any]:
|
||||
"""Fetch last computed continuity metrics for a section (if available)."""
|
||||
@@ -342,7 +426,7 @@ async def start_medium_generation(request: MediumBlogGenerateRequest):
|
||||
async def medium_generation_status(task_id: str):
|
||||
"""Poll status for medium blog generation task."""
|
||||
try:
|
||||
status = task_manager.get_task_status(task_id)
|
||||
status = await task_manager.get_task_status(task_id)
|
||||
if status is None:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return status
|
||||
@@ -366,7 +450,7 @@ async def start_blog_rewrite(request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def rewrite_status(task_id: str):
|
||||
"""Poll status for blog rewrite task."""
|
||||
try:
|
||||
status = service.task_manager.get_task_status(task_id)
|
||||
status = await service.task_manager.get_task_status(task_id)
|
||||
if status is None:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return status
|
||||
|
||||
@@ -133,6 +133,16 @@ class TaskManager:
|
||||
task_id = self.create_task("medium_generation")
|
||||
asyncio.create_task(self._run_medium_generation_task(task_id, request))
|
||||
return task_id
|
||||
|
||||
def start_content_generation_task(self, request: MediumBlogGenerateRequest) -> str:
|
||||
"""Start content generation (full blog via sections) with provider parity.
|
||||
|
||||
Internally reuses medium generator pipeline for now but tracked under
|
||||
distinct task_type 'content_generation' and same polling contract.
|
||||
"""
|
||||
task_id = self.create_task("content_generation")
|
||||
asyncio.create_task(self._run_medium_generation_task(task_id, request))
|
||||
return task_id
|
||||
|
||||
async def _run_research_task(self, task_id: str, request: BlogResearchRequest):
|
||||
"""Background task to run research and update status with progress messages."""
|
||||
|
||||
@@ -4,11 +4,11 @@ from typing import Dict, Any, List
|
||||
from ..models.story_models import FacebookStoryRequest, FacebookStoryResponse
|
||||
from .base_service import FacebookWriterBaseService
|
||||
try:
|
||||
from ...services.llm_providers.text_to_image_generation.gen_gemini_images import (
|
||||
generate_gemini_images_base64,
|
||||
)
|
||||
from ...services.llm_providers.main_image_generation import generate_image
|
||||
from base64 import b64encode
|
||||
except Exception:
|
||||
generate_gemini_images_base64 = None # type: ignore
|
||||
generate_image = None # type: ignore
|
||||
b64encode = None # type: ignore
|
||||
|
||||
|
||||
class FacebookStoryService(FacebookWriterBaseService):
|
||||
@@ -50,22 +50,29 @@ class FacebookStoryService(FacebookWriterBaseService):
|
||||
# Generate visual suggestions and engagement tips
|
||||
visual_suggestions = self._generate_visual_suggestions(actual_story_type, request.visual_options)
|
||||
engagement_tips = self._generate_engagement_tips("story")
|
||||
# Optional: generate one story image (9:16) using Gemini
|
||||
# Optional: generate one story image (9:16) using unified image generation
|
||||
images_base64: List[str] = []
|
||||
try:
|
||||
if generate_gemini_images_base64 is not None:
|
||||
if generate_image is not None and b64encode is not None:
|
||||
img_prompt = request.visual_options.background_image_prompt or (
|
||||
f"Facebook story background for {request.business_type}. "
|
||||
f"Style: {actual_tone}. Type: {actual_story_type}. Vertical mobile 9:16, high contrast, legible overlay space."
|
||||
)
|
||||
images_base64 = generate_gemini_images_base64(
|
||||
img_prompt,
|
||||
enhance_prompt=False,
|
||||
aspect_ratio="9:16",
|
||||
max_retries=2,
|
||||
initial_retry_delay=1.0,
|
||||
) or []
|
||||
except Exception:
|
||||
# Generate image using unified system (9:16 aspect ratio = 1080x1920)
|
||||
result = generate_image(
|
||||
prompt=img_prompt,
|
||||
options={
|
||||
"provider": "gemini", # Facebook stories use Gemini
|
||||
"width": 1080,
|
||||
"height": 1920,
|
||||
}
|
||||
)
|
||||
if result and result.image_bytes:
|
||||
# Convert bytes to base64
|
||||
image_b64 = b64encode(result.image_bytes).decode('utf-8')
|
||||
images_base64 = [image_b64]
|
||||
except Exception as e:
|
||||
# Log error but continue without images
|
||||
images_base64 = []
|
||||
|
||||
return FacebookStoryResponse(
|
||||
|
||||
217
backend/api/images.py
Normal file
217
backend/api/images.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from services.llm_providers.main_image_generation import generate_image
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/images", tags=["images"])
|
||||
logger = get_service_logger("api.images")
|
||||
|
||||
|
||||
class ImageGenerateRequest(BaseModel):
|
||||
prompt: str
|
||||
negative_prompt: Optional[str] = None
|
||||
provider: Optional[str] = Field(None, pattern="^(gemini|huggingface|stability)$")
|
||||
model: Optional[str] = None
|
||||
width: Optional[int] = Field(default=1024, ge=64, le=2048)
|
||||
height: Optional[int] = Field(default=1024, ge=64, le=2048)
|
||||
guidance_scale: Optional[float] = None
|
||||
steps: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
|
||||
class ImageGenerateResponse(BaseModel):
|
||||
success: bool = True
|
||||
image_base64: str
|
||||
width: int
|
||||
height: int
|
||||
provider: str
|
||||
model: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
|
||||
@router.post("/generate", response_model=ImageGenerateResponse)
|
||||
def generate(req: ImageGenerateRequest) -> ImageGenerateResponse:
|
||||
try:
|
||||
last_error: Optional[Exception] = None
|
||||
for attempt in range(2): # simple single retry
|
||||
try:
|
||||
result = generate_image(
|
||||
prompt=req.prompt,
|
||||
options={
|
||||
"negative_prompt": req.negative_prompt,
|
||||
"provider": req.provider,
|
||||
"model": req.model,
|
||||
"width": req.width,
|
||||
"height": req.height,
|
||||
"guidance_scale": req.guidance_scale,
|
||||
"steps": req.steps,
|
||||
"seed": req.seed,
|
||||
},
|
||||
)
|
||||
image_b64 = base64.b64encode(result.image_bytes).decode("utf-8")
|
||||
return ImageGenerateResponse(
|
||||
image_base64=image_b64,
|
||||
width=result.width,
|
||||
height=result.height,
|
||||
provider=result.provider,
|
||||
model=result.model,
|
||||
seed=result.seed,
|
||||
)
|
||||
except Exception as inner:
|
||||
last_error = inner
|
||||
logger.error(f"Image generation attempt {attempt+1} failed: {inner}")
|
||||
# On first failure, try provider auto-remap by clearing provider to let facade decide
|
||||
if attempt == 0 and req.provider:
|
||||
req.provider = None
|
||||
continue
|
||||
break
|
||||
raise last_error or RuntimeError("Unknown image generation error")
|
||||
except Exception as e:
|
||||
logger.error(f"Image generation failed: {e}")
|
||||
# Provide a clean, actionable message to the client
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Image generation service is temporarily unavailable or the connection was reset. Please try again."
|
||||
)
|
||||
|
||||
|
||||
class PromptSuggestion(BaseModel):
|
||||
prompt: str
|
||||
negative_prompt: Optional[str] = None
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
overlay_text: Optional[str] = None
|
||||
|
||||
|
||||
class ImagePromptSuggestRequest(BaseModel):
|
||||
provider: Optional[str] = Field(None, pattern="^(gemini|huggingface|stability)$")
|
||||
title: Optional[str] = None
|
||||
section: Optional[Dict[str, Any]] = None
|
||||
research: Optional[Dict[str, Any]] = None
|
||||
persona: Optional[Dict[str, Any]] = None
|
||||
include_overlay: Optional[bool] = True
|
||||
|
||||
|
||||
class ImagePromptSuggestResponse(BaseModel):
|
||||
suggestions: list[PromptSuggestion]
|
||||
|
||||
|
||||
@router.post("/suggest-prompts", response_model=ImagePromptSuggestResponse)
|
||||
def suggest_prompts(req: ImagePromptSuggestRequest) -> ImagePromptSuggestResponse:
|
||||
try:
|
||||
provider = (req.provider or ("gemini" if (os.getenv("GPT_PROVIDER") or "").lower().startswith("gemini") else "huggingface")).lower()
|
||||
section = req.section or {}
|
||||
title = (req.title or section.get("heading") or "").strip()
|
||||
subheads = section.get("subheadings", []) or []
|
||||
key_points = section.get("key_points", []) or []
|
||||
keywords = section.get("keywords", []) or []
|
||||
if not keywords and req.research:
|
||||
keywords = (
|
||||
req.research.get("keywords", {}).get("primary_keywords")
|
||||
or req.research.get("keywords", {}).get("primary")
|
||||
or []
|
||||
)
|
||||
|
||||
persona = req.persona or {}
|
||||
audience = persona.get("audience", "content creators and digital marketers")
|
||||
industry = persona.get("industry", req.research.get("domain") if req.research else "your industry")
|
||||
tone = persona.get("tone", "professional, trustworthy")
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"suggestions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {"type": "string"},
|
||||
"negative_prompt": {"type": "string"},
|
||||
"width": {"type": "number"},
|
||||
"height": {"type": "number"},
|
||||
"overlay_text": {"type": "string"},
|
||||
},
|
||||
"required": ["prompt"]
|
||||
},
|
||||
"minItems": 3,
|
||||
"maxItems": 5
|
||||
}
|
||||
},
|
||||
"required": ["suggestions"]
|
||||
}
|
||||
|
||||
system = (
|
||||
"You are an expert image prompt engineer for text-to-image models. "
|
||||
"Given blog section context, craft 3-5 hyper-personalized prompts optimized for the specified provider. "
|
||||
"Return STRICT JSON matching the provided schema, no extra text."
|
||||
)
|
||||
|
||||
provider_guidance = {
|
||||
"huggingface": "Photorealistic Flux 1 Krea Dev; include camera/lighting cues (e.g., 50mm, f/2.8, rim light).",
|
||||
"gemini": "Editorial, brand-safe, crisp edges, balanced lighting; avoid artifacts.",
|
||||
"stability": "SDXL coherent details, sharp focus, cinematic contrast; readable text if present."
|
||||
}.get(provider, "")
|
||||
|
||||
best_practices = (
|
||||
"Best Practices: one clear focal subject; clean, uncluttered background; rule-of-thirds or center-weighted composition; "
|
||||
"text-safe margins if overlay text is included; neutral lighting if unsure; realistic skin tones; avoid busy patterns; "
|
||||
"no brand logos or watermarks; no copyrighted characters; avoid low-res, blur, noise, banding, oversaturation, over-sharpening; "
|
||||
"ensure hands and text are coherent if present; prefer 1024px+ on shortest side for quality."
|
||||
)
|
||||
|
||||
# Harvest a few concise facts from research if available
|
||||
facts: list[str] = []
|
||||
try:
|
||||
if req.research:
|
||||
# try common shapes used in research service
|
||||
top_stats = req.research.get("key_facts") or req.research.get("highlights") or []
|
||||
if isinstance(top_stats, list):
|
||||
facts = [str(x) for x in top_stats[:3]]
|
||||
elif isinstance(top_stats, dict):
|
||||
facts = [f"{k}: {v}" for k, v in list(top_stats.items())[:3]]
|
||||
except Exception:
|
||||
facts = []
|
||||
|
||||
facts_line = ", ".join(facts) if facts else ""
|
||||
|
||||
overlay_hint = "Include an on-image short title or fact if it improves communication; ensure clean, high-contrast safe area for text." if (req.include_overlay is None or req.include_overlay) else "Do not include on-image text."
|
||||
|
||||
prompt = f"""
|
||||
Provider: {provider}
|
||||
Title: {title}
|
||||
Subheadings: {', '.join(subheads[:5])}
|
||||
Key Points: {', '.join(key_points[:5])}
|
||||
Keywords: {', '.join([str(k) for k in keywords[:8]])}
|
||||
Research Facts: {facts_line}
|
||||
Audience: {audience}
|
||||
Industry: {industry}
|
||||
Tone: {tone}
|
||||
|
||||
Craft prompts that visually reflect this exact section (not generic blog topic). {provider_guidance}
|
||||
{best_practices}
|
||||
{overlay_hint}
|
||||
Include a suitable negative_prompt where helpful. Suggest width/height when relevant (e.g., 1024x1024 or 1920x1080).
|
||||
If including on-image text, return it in overlay_text (short: <= 8 words).
|
||||
"""
|
||||
|
||||
raw = llm_text_gen(prompt=prompt, system_prompt=system, json_struct=schema)
|
||||
data = raw if isinstance(raw, dict) else {}
|
||||
suggestions = data.get("suggestions") or []
|
||||
# basic fallback if provider returns string
|
||||
if not suggestions and isinstance(raw, str):
|
||||
suggestions = [{"prompt": raw}]
|
||||
|
||||
return ImagePromptSuggestResponse(suggestions=[PromptSuggestion(**s) for s in suggestions])
|
||||
except Exception as e:
|
||||
logger.error(f"Prompt suggestion failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -42,6 +42,7 @@ from routers.linkedin import router as linkedin_router
|
||||
# Import LinkedIn image generation router
|
||||
from api.linkedin_image_generation import router as linkedin_image_router
|
||||
from api.brainstorm import router as brainstorm_router
|
||||
from api.images import router as images_router
|
||||
|
||||
# Import hallucination detector router
|
||||
from api.hallucination_detector import router as hallucination_detector_router
|
||||
@@ -279,6 +280,7 @@ async def batch_analyze_urls_endpoint(urls: list[str]):
|
||||
# Include platform analytics router
|
||||
from routers.platform_analytics import router as platform_analytics_router
|
||||
app.include_router(platform_analytics_router)
|
||||
app.include_router(images_router)
|
||||
|
||||
# Setup frontend serving using modular utilities
|
||||
frontend_serving.setup_frontend_serving()
|
||||
|
||||
@@ -186,6 +186,8 @@ class BlogSEOMetadataRequest(BaseModel):
|
||||
title: Optional[str] = None
|
||||
keywords: List[str] = []
|
||||
research_data: Optional[Dict[str, Any]] = None
|
||||
outline: Optional[List[Dict[str, Any]]] = None # Add outline structure
|
||||
seo_analysis: Optional[Dict[str, Any]] = None # Add SEO analysis results
|
||||
|
||||
|
||||
class BlogSEOMetadataResponse(BaseModel):
|
||||
|
||||
@@ -21,10 +21,9 @@ httpx>=0.27.2,<0.28.0
|
||||
|
||||
# AI/ML dependencies
|
||||
openai>=1.3.0
|
||||
anthropic>=0.7.0
|
||||
mistralai>=0.0.12
|
||||
google-genai>=1.0.0
|
||||
google-ai-generativelanguage>=0.6.18,<0.7.0
|
||||
|
||||
|
||||
google-api-python-client>=2.100.0
|
||||
google-auth>=2.23.0
|
||||
google-auth-oauthlib>=1.0.0
|
||||
@@ -53,6 +52,7 @@ nltk>=3.8.0
|
||||
|
||||
# Image and audio processing for Stability AI
|
||||
Pillow>=10.0.0
|
||||
huggingface_hub>=0.24.0
|
||||
scikit-learn>=1.3.0
|
||||
|
||||
# Testing dependencies
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
"""
|
||||
EnhancedContentGenerator - thin orchestrator combining URL selection and Gemini provider.
|
||||
EnhancedContentGenerator - thin orchestrator for section generation.
|
||||
|
||||
Provides Draft vs Polished modes and optional URL Context usage.
|
||||
Provider parity:
|
||||
- Uses main_text_generation.llm_text_gen to respect GPT_PROVIDER (Gemini/HF)
|
||||
- No direct provider coupling here; Google grounding remains in research only
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from services.llm_providers.gemini_grounded_provider import GeminiGroundedProvider
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from .source_url_manager import SourceURLManager
|
||||
from .context_memory import ContextMemory
|
||||
from .transition_generator import TransitionGenerator
|
||||
@@ -15,24 +17,37 @@ from .flow_analyzer import FlowAnalyzer
|
||||
|
||||
class EnhancedContentGenerator:
|
||||
def __init__(self):
|
||||
self.provider = GeminiGroundedProvider()
|
||||
self.url_manager = SourceURLManager()
|
||||
self.memory = ContextMemory(max_entries=12)
|
||||
self.transitioner = TransitionGenerator()
|
||||
self.flow = FlowAnalyzer()
|
||||
|
||||
async def generate_section(self, section: Any, research: Any, mode: str = "polished") -> Dict[str, Any]:
|
||||
urls = self.url_manager.pick_relevant_urls(section, research)
|
||||
prev_summary = self.memory.build_previous_sections_summary(limit=2)
|
||||
prompt = self._build_prompt(section, research, prev_summary)
|
||||
result = await self.provider.generate_grounded_content(
|
||||
prompt=prompt,
|
||||
content_type="linkedin_article",
|
||||
temperature=0.6 if mode == "polished" else 0.8,
|
||||
max_tokens=2048,
|
||||
urls=urls,
|
||||
mode=mode,
|
||||
)
|
||||
urls = self.url_manager.pick_relevant_urls(section, research)
|
||||
prompt = self._build_prompt(section, research, prev_summary, urls)
|
||||
# Provider-agnostic text generation (respect GPT_PROVIDER & circuit-breaker)
|
||||
content_text: str = ""
|
||||
try:
|
||||
ai_resp = llm_text_gen(
|
||||
prompt=prompt,
|
||||
json_struct=None,
|
||||
system_prompt=None,
|
||||
)
|
||||
if isinstance(ai_resp, dict) and ai_resp.get("text"):
|
||||
content_text = ai_resp.get("text", "")
|
||||
elif isinstance(ai_resp, str):
|
||||
content_text = ai_resp
|
||||
else:
|
||||
# Fallback best-effort extraction
|
||||
content_text = str(ai_resp or "")
|
||||
except Exception as e:
|
||||
content_text = ""
|
||||
|
||||
result = {
|
||||
"content": content_text,
|
||||
"sources": [{"title": u.get("title", ""), "url": u.get("url", "")} for u in urls] if urls else [],
|
||||
}
|
||||
# Generate transition and compute intelligent flow metrics
|
||||
previous_text = prev_summary
|
||||
current_text = result.get("content", "")
|
||||
@@ -56,19 +71,22 @@ class EnhancedContentGenerator:
|
||||
pass
|
||||
return result
|
||||
|
||||
def _build_prompt(self, section: Any, research: Any, prev_summary: str) -> str:
|
||||
def _build_prompt(self, section: Any, research: Any, prev_summary: str, urls: list) -> str:
|
||||
heading = getattr(section, 'heading', 'Section')
|
||||
key_points = getattr(section, 'key_points', [])
|
||||
keywords = getattr(section, 'keywords', [])
|
||||
target_words = getattr(section, 'target_words', 300)
|
||||
url_block = "\n".join([f"- {u.get('title','')} ({u.get('url','')})" for u in urls]) if urls else "(no specific URLs provided)"
|
||||
|
||||
return (
|
||||
f"You are writing the blog section '{heading}'.\n\n"
|
||||
f"Context summary: {prev_summary}\n"
|
||||
f"Key points: {', '.join(key_points)}\n"
|
||||
f"Keywords: {', '.join(keywords)}\n"
|
||||
f"Target word count: {target_words}.\n"
|
||||
"Use only factual info from provided sources; add short transition, then body."
|
||||
f"Context summary (previous sections): {prev_summary}\n\n"
|
||||
f"Authoring requirements:\n"
|
||||
f"- Target word count: ~{target_words}\n"
|
||||
f"- Use the following key points: {', '.join(key_points)}\n"
|
||||
f"- Include these keywords naturally: {', '.join(keywords)}\n"
|
||||
f"- Cite insights from these sources when relevant (do not output raw URLs):\n{url_block}\n\n"
|
||||
"Write engaging, well-structured markdown with clear paragraphs (2-4 sentences each) separated by double line breaks."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from models.blog_models import (
|
||||
MediumGeneratedSection,
|
||||
ResearchSource,
|
||||
)
|
||||
from services.llm_providers.gemini_provider import gemini_structured_json_response
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from services.cache.persistent_content_cache import persistent_content_cache
|
||||
|
||||
|
||||
@@ -176,11 +176,9 @@ class MediumBlogGenerator:
|
||||
f"Sections to write:\n{json.dumps(payload, ensure_ascii=False, indent=2)}"
|
||||
)
|
||||
|
||||
ai_resp = gemini_structured_json_response(
|
||||
ai_resp = llm_text_gen(
|
||||
prompt=prompt,
|
||||
schema=schema,
|
||||
temperature=0.2,
|
||||
max_tokens=8192,
|
||||
json_struct=schema,
|
||||
system_prompt=system,
|
||||
)
|
||||
|
||||
|
||||
@@ -275,11 +275,17 @@ class BlogWriterService:
|
||||
# Initialize metadata generator
|
||||
metadata_generator = BlogSEOMetadataGenerator()
|
||||
|
||||
# Generate comprehensive metadata
|
||||
# Extract outline and seo_analysis from request
|
||||
outline = request.outline if hasattr(request, 'outline') else None
|
||||
seo_analysis = request.seo_analysis if hasattr(request, 'seo_analysis') else None
|
||||
|
||||
# Generate comprehensive metadata with full context
|
||||
metadata_results = await metadata_generator.generate_comprehensive_metadata(
|
||||
blog_content=request.content,
|
||||
blog_title=request.title or "Untitled Blog Post",
|
||||
research_data=request.research_data or {}
|
||||
research_data=request.research_data or {},
|
||||
outline=outline,
|
||||
seo_analysis=seo_analysis
|
||||
)
|
||||
|
||||
# Convert to BlogSEOMetadataResponse format
|
||||
|
||||
@@ -40,7 +40,7 @@ Return JSON format:
|
||||
}}"""
|
||||
|
||||
try:
|
||||
from services.llm_providers.gemini_provider import gemini_structured_json_response
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
|
||||
optimization_schema = {
|
||||
"type": "object",
|
||||
@@ -64,11 +64,10 @@ Return JSON format:
|
||||
"propertyOrdering": ["outline"]
|
||||
}
|
||||
|
||||
optimized_data = gemini_structured_json_response(
|
||||
optimized_data = llm_text_gen(
|
||||
prompt=optimization_prompt,
|
||||
schema=optimization_schema,
|
||||
temperature=0.3,
|
||||
max_tokens=6000 # Match main outline generator
|
||||
json_struct=optimization_schema,
|
||||
system_prompt=None
|
||||
)
|
||||
|
||||
# Handle the new schema format with "outline" wrapper
|
||||
|
||||
@@ -20,7 +20,7 @@ class ResponseProcessor:
|
||||
|
||||
async def generate_with_retry(self, prompt: str, schema: Dict[str, Any], task_id: str = None) -> Dict[str, Any]:
|
||||
"""Generate outline with retry logic for API failures."""
|
||||
from services.llm_providers.gemini_provider import gemini_structured_json_response
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from api.blog_writer.task_manager import task_manager
|
||||
|
||||
max_retries = 2 # Conservative retry for expensive API calls
|
||||
@@ -29,17 +29,16 @@ class ResponseProcessor:
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
if task_id:
|
||||
await task_manager.update_progress(task_id, f"🤖 Calling Gemini API for outline generation (attempt {attempt + 1}/{max_retries + 1})...")
|
||||
await task_manager.update_progress(task_id, f"🤖 Calling AI API for outline generation (attempt {attempt + 1}/{max_retries + 1})...")
|
||||
|
||||
outline_data = gemini_structured_json_response(
|
||||
outline_data = llm_text_gen(
|
||||
prompt=prompt,
|
||||
schema=schema,
|
||||
temperature=0.3,
|
||||
max_tokens=6000 # Increased further to avoid truncation
|
||||
json_struct=schema,
|
||||
system_prompt=None
|
||||
)
|
||||
|
||||
# Log response for debugging
|
||||
logger.info(f"Gemini response received: {type(outline_data)}")
|
||||
logger.info(f"AI response received: {type(outline_data)}")
|
||||
|
||||
# Check for errors in the response
|
||||
if isinstance(outline_data, dict) and 'error' in outline_data:
|
||||
@@ -47,17 +46,17 @@ class ResponseProcessor:
|
||||
if "503" in error_msg and "overloaded" in error_msg and attempt < max_retries:
|
||||
if task_id:
|
||||
await task_manager.update_progress(task_id, f"⚠️ AI service overloaded, retrying in {retry_delay} seconds...")
|
||||
logger.warning(f"Gemini API overloaded, retrying in {retry_delay} seconds (attempt {attempt + 1}/{max_retries + 1})")
|
||||
logger.warning(f"AI API overloaded, retrying in {retry_delay} seconds (attempt {attempt + 1}/{max_retries + 1})")
|
||||
await asyncio.sleep(retry_delay)
|
||||
continue
|
||||
elif "No valid structured response content found" in error_msg and attempt < max_retries:
|
||||
if task_id:
|
||||
await task_manager.update_progress(task_id, f"⚠️ Invalid response format, retrying in {retry_delay} seconds...")
|
||||
logger.warning(f"Gemini response parsing failed, retrying in {retry_delay} seconds (attempt {attempt + 1}/{max_retries + 1})")
|
||||
logger.warning(f"AI response parsing failed, retrying in {retry_delay} seconds (attempt {attempt + 1}/{max_retries + 1})")
|
||||
await asyncio.sleep(retry_delay)
|
||||
continue
|
||||
else:
|
||||
logger.error(f"Gemini structured response error: {outline_data['error']}")
|
||||
logger.error(f"AI structured response error: {outline_data['error']}")
|
||||
raise ValueError(f"AI outline generation failed: {outline_data['error']}")
|
||||
|
||||
# Validate required fields
|
||||
@@ -69,7 +68,7 @@ class ResponseProcessor:
|
||||
await asyncio.sleep(retry_delay)
|
||||
continue
|
||||
else:
|
||||
raise ValueError("Invalid outline structure in Gemini response")
|
||||
raise ValueError("Invalid outline structure in AI response")
|
||||
|
||||
# If we get here, the response is valid
|
||||
return outline_data
|
||||
@@ -79,7 +78,7 @@ class ResponseProcessor:
|
||||
if ("503" in error_str or "overloaded" in error_str) and attempt < max_retries:
|
||||
if task_id:
|
||||
await task_manager.update_progress(task_id, f"⚠️ AI service error, retrying in {retry_delay} seconds...")
|
||||
logger.warning(f"Gemini API error, retrying in {retry_delay} seconds (attempt {attempt + 1}/{max_retries + 1}): {error_str}")
|
||||
logger.warning(f"AI API error, retrying in {retry_delay} seconds (attempt {attempt + 1}/{max_retries + 1}): {error_str}")
|
||||
await asyncio.sleep(retry_delay)
|
||||
continue
|
||||
else:
|
||||
|
||||
@@ -44,7 +44,7 @@ class SectionEnhancer:
|
||||
"""
|
||||
|
||||
try:
|
||||
from services.llm_providers.gemini_provider import gemini_structured_json_response
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
|
||||
enhancement_schema = {
|
||||
"type": "object",
|
||||
@@ -58,11 +58,10 @@ class SectionEnhancer:
|
||||
"required": ["heading", "subheadings", "key_points", "target_words", "keywords"]
|
||||
}
|
||||
|
||||
enhanced_data = gemini_structured_json_response(
|
||||
enhanced_data = llm_text_gen(
|
||||
prompt=enhancement_prompt,
|
||||
schema=enhancement_schema,
|
||||
temperature=0.4,
|
||||
max_tokens=1000
|
||||
json_struct=enhancement_schema,
|
||||
system_prompt=None
|
||||
)
|
||||
|
||||
if isinstance(enhanced_data, dict) and 'error' not in enhanced_data:
|
||||
|
||||
@@ -559,14 +559,11 @@ Analyze the mapping and provide your recommendations.
|
||||
AI validation response
|
||||
"""
|
||||
try:
|
||||
from services.llm_providers.gemini_provider import gemini_text_response
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
|
||||
response = gemini_text_response(
|
||||
response = llm_text_gen(
|
||||
prompt=prompt,
|
||||
temperature=0.3,
|
||||
top_p=0.9,
|
||||
n=1,
|
||||
max_tokens=2000,
|
||||
json_struct=None,
|
||||
system_prompt=None
|
||||
)
|
||||
|
||||
|
||||
@@ -10,13 +10,13 @@ import re
|
||||
import textstat
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List, Optional
|
||||
from loguru import logger
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
from services.seo_analyzer import (
|
||||
ContentAnalyzer, KeywordAnalyzer,
|
||||
URLStructureAnalyzer, AIInsightGenerator
|
||||
)
|
||||
from services.llm_providers.gemini_provider import gemini_structured_json_response
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
|
||||
|
||||
class BlogContentSEOAnalyzer:
|
||||
@@ -24,11 +24,13 @@ class BlogContentSEOAnalyzer:
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the blog content SEO analyzer"""
|
||||
# Service-specific logger (no global reconfiguration)
|
||||
global logger
|
||||
logger = get_service_logger("blog_content_seo_analyzer")
|
||||
self.content_analyzer = ContentAnalyzer()
|
||||
self.keyword_analyzer = KeywordAnalyzer()
|
||||
self.url_analyzer = URLStructureAnalyzer()
|
||||
self.ai_insights = AIInsightGenerator()
|
||||
self.gemini_provider = gemini_structured_json_response
|
||||
|
||||
logger.info("BlogContentSEOAnalyzer initialized")
|
||||
|
||||
@@ -598,7 +600,7 @@ class BlogContentSEOAnalyzer:
|
||||
return recommendations
|
||||
|
||||
async def _run_ai_analysis(self, blog_content: str, keywords_data: Dict[str, Any], non_ai_results: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Run single AI analysis for structured insights"""
|
||||
"""Run single AI analysis for structured insights (provider-agnostic)"""
|
||||
try:
|
||||
# Prepare context for AI analysis
|
||||
context = {
|
||||
@@ -610,7 +612,6 @@ class BlogContentSEOAnalyzer:
|
||||
# Create AI prompt for structured analysis
|
||||
prompt = self._create_ai_analysis_prompt(context)
|
||||
|
||||
# Get structured response from Gemini
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -653,18 +654,17 @@ class BlogContentSEOAnalyzer:
|
||||
}
|
||||
}
|
||||
|
||||
ai_response = self.gemini_provider(
|
||||
# Provider-agnostic structured response respecting GPT_PROVIDER
|
||||
ai_response = llm_text_gen(
|
||||
prompt=prompt,
|
||||
schema=schema,
|
||||
temperature=0.2,
|
||||
max_tokens=8192
|
||||
json_struct=schema,
|
||||
system_prompt=None
|
||||
)
|
||||
|
||||
return ai_response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AI analysis failed: {e}")
|
||||
# Fail fast - don't return mock data
|
||||
raise e
|
||||
|
||||
def _create_ai_analysis_prompt(self, context: Dict[str, Any]) -> str:
|
||||
|
||||
@@ -12,7 +12,7 @@ from datetime import datetime
|
||||
from typing import Dict, Any, List, Optional
|
||||
from loguru import logger
|
||||
|
||||
from services.llm_providers.gemini_provider import gemini_structured_json_response
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
|
||||
|
||||
class BlogSEOMetadataGenerator:
|
||||
@@ -20,14 +20,15 @@ class BlogSEOMetadataGenerator:
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the metadata generator"""
|
||||
self.gemini_provider = gemini_structured_json_response
|
||||
logger.info("BlogSEOMetadataGenerator initialized")
|
||||
|
||||
async def generate_comprehensive_metadata(
|
||||
self,
|
||||
blog_content: str,
|
||||
blog_title: str,
|
||||
research_data: Dict[str, Any]
|
||||
research_data: Dict[str, Any],
|
||||
outline: Optional[List[Dict[str, Any]]] = None,
|
||||
seo_analysis: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate comprehensive SEO metadata using maximum 2 AI calls
|
||||
@@ -36,6 +37,8 @@ class BlogSEOMetadataGenerator:
|
||||
blog_content: The blog content to analyze
|
||||
blog_title: The blog title
|
||||
research_data: Research data containing keywords and insights
|
||||
outline: Outline structure with sections and headings
|
||||
seo_analysis: SEO analysis results from previous phase
|
||||
|
||||
Returns:
|
||||
Comprehensive metadata including all SEO elements
|
||||
@@ -49,11 +52,15 @@ class BlogSEOMetadataGenerator:
|
||||
|
||||
# Call 1: Generate core SEO metadata (parallel with Call 2)
|
||||
logger.info("Generating core SEO metadata")
|
||||
core_metadata_task = self._generate_core_metadata(blog_content, blog_title, keywords_data)
|
||||
core_metadata_task = self._generate_core_metadata(
|
||||
blog_content, blog_title, keywords_data, outline, seo_analysis
|
||||
)
|
||||
|
||||
# Call 2: Generate social media and structured data (parallel with Call 1)
|
||||
logger.info("Generating social media and structured data")
|
||||
social_metadata_task = self._generate_social_metadata(blog_content, blog_title, keywords_data)
|
||||
social_metadata_task = self._generate_social_metadata(
|
||||
blog_content, blog_title, keywords_data, outline, seo_analysis
|
||||
)
|
||||
|
||||
# Wait for both calls to complete
|
||||
core_metadata, social_metadata = await asyncio.gather(
|
||||
@@ -105,12 +112,16 @@ class BlogSEOMetadataGenerator:
|
||||
self,
|
||||
blog_content: str,
|
||||
blog_title: str,
|
||||
keywords_data: Dict[str, Any]
|
||||
keywords_data: Dict[str, Any],
|
||||
outline: Optional[List[Dict[str, Any]]] = None,
|
||||
seo_analysis: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate core SEO metadata (Call 1)"""
|
||||
try:
|
||||
# Create comprehensive prompt for core metadata
|
||||
prompt = self._create_core_metadata_prompt(blog_content, blog_title, keywords_data)
|
||||
prompt = self._create_core_metadata_prompt(
|
||||
blog_content, blog_title, keywords_data, outline, seo_analysis
|
||||
)
|
||||
|
||||
# Define simplified structured schema for core metadata
|
||||
schema = {
|
||||
@@ -155,17 +166,26 @@ class BlogSEOMetadataGenerator:
|
||||
"required": ["seo_title", "meta_description", "url_slug", "blog_tags", "blog_categories", "social_hashtags", "reading_time", "focus_keyword"]
|
||||
}
|
||||
|
||||
# Get structured response from Gemini
|
||||
ai_response = self.gemini_provider(
|
||||
prompt,
|
||||
schema,
|
||||
temperature=0.3,
|
||||
max_tokens=2048
|
||||
# Get structured response using provider-agnostic llm_text_gen
|
||||
ai_response_raw = llm_text_gen(
|
||||
prompt=prompt,
|
||||
json_struct=schema,
|
||||
system_prompt=None
|
||||
)
|
||||
|
||||
# Handle response: llm_text_gen may return dict (from structured JSON) or str (needs parsing)
|
||||
ai_response = ai_response_raw
|
||||
if isinstance(ai_response_raw, str):
|
||||
try:
|
||||
import json
|
||||
ai_response = json.loads(ai_response_raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse JSON response: {ai_response_raw[:200]}...")
|
||||
ai_response = None
|
||||
|
||||
# Check if we got a valid response
|
||||
if not ai_response or not isinstance(ai_response, dict):
|
||||
logger.error("Core metadata generation failed: Invalid response from Gemini")
|
||||
logger.error("Core metadata generation failed: Invalid response from LLM")
|
||||
# Return fallback response
|
||||
primary_keywords = ', '.join(keywords_data.get('primary_keywords', ['content']))
|
||||
word_count = len(blog_content.split())
|
||||
@@ -193,12 +213,16 @@ class BlogSEOMetadataGenerator:
|
||||
self,
|
||||
blog_content: str,
|
||||
blog_title: str,
|
||||
keywords_data: Dict[str, Any]
|
||||
keywords_data: Dict[str, Any],
|
||||
outline: Optional[List[Dict[str, Any]]] = None,
|
||||
seo_analysis: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate social media and structured data (Call 2)"""
|
||||
try:
|
||||
# Create comprehensive prompt for social metadata
|
||||
prompt = self._create_social_metadata_prompt(blog_content, blog_title, keywords_data)
|
||||
prompt = self._create_social_metadata_prompt(
|
||||
blog_content, blog_title, keywords_data, outline, seo_analysis
|
||||
)
|
||||
|
||||
# Define simplified structured schema for social metadata
|
||||
schema = {
|
||||
@@ -246,17 +270,26 @@ class BlogSEOMetadataGenerator:
|
||||
"required": ["open_graph", "twitter_card", "json_ld_schema"]
|
||||
}
|
||||
|
||||
# Get structured response from Gemini
|
||||
ai_response = self.gemini_provider(
|
||||
prompt,
|
||||
schema,
|
||||
temperature=0.3,
|
||||
max_tokens=2048
|
||||
# Get structured response using provider-agnostic llm_text_gen
|
||||
ai_response_raw = llm_text_gen(
|
||||
prompt=prompt,
|
||||
json_struct=schema,
|
||||
system_prompt=None
|
||||
)
|
||||
|
||||
# Handle response: llm_text_gen may return dict (from structured JSON) or str (needs parsing)
|
||||
ai_response = ai_response_raw
|
||||
if isinstance(ai_response_raw, str):
|
||||
try:
|
||||
import json
|
||||
ai_response = json.loads(ai_response_raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse JSON response: {ai_response_raw[:200]}...")
|
||||
ai_response = None
|
||||
|
||||
# Check if we got a valid response
|
||||
if not ai_response or not isinstance(ai_response, dict) or not ai_response.get('open_graph') or not ai_response.get('twitter_card') or not ai_response.get('json_ld_schema'):
|
||||
logger.error("Social metadata generation failed: Invalid or empty response from Gemini")
|
||||
logger.error("Social metadata generation failed: Invalid or empty response from LLM")
|
||||
# Return fallback response
|
||||
return {
|
||||
'open_graph': {
|
||||
@@ -301,11 +334,47 @@ class BlogSEOMetadataGenerator:
|
||||
logger.error(f"Social metadata generation failed: {e}")
|
||||
raise e
|
||||
|
||||
def _extract_content_highlights(self, blog_content: str, max_length: int = 2500) -> str:
|
||||
"""Extract key sections from blog content for prompt context"""
|
||||
try:
|
||||
lines = blog_content.split('\n')
|
||||
|
||||
# Get first paragraph (introduction)
|
||||
intro = ""
|
||||
for line in lines[:20]:
|
||||
if line.strip() and not line.strip().startswith('#'):
|
||||
intro += line.strip() + " "
|
||||
if len(intro) > 300:
|
||||
break
|
||||
|
||||
# Get section headings
|
||||
headings = [line.strip() for line in lines if line.strip().startswith('##')][:6]
|
||||
|
||||
# Get conclusion if available
|
||||
conclusion = ""
|
||||
for line in reversed(lines[-20:]):
|
||||
if line.strip() and not line.strip().startswith('#'):
|
||||
conclusion = line.strip() + " " + conclusion
|
||||
if len(conclusion) > 300:
|
||||
break
|
||||
|
||||
highlights = f"INTRODUCTION: {intro[:300]}...\n\n"
|
||||
highlights += f"SECTION HEADINGS: {' | '.join([h.replace('##', '').strip() for h in headings])}\n\n"
|
||||
if conclusion:
|
||||
highlights += f"CONCLUSION: {conclusion[:300]}..."
|
||||
|
||||
return highlights[:max_length]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract content highlights: {e}")
|
||||
return blog_content[:2000] + "..."
|
||||
|
||||
def _create_core_metadata_prompt(
|
||||
self,
|
||||
blog_content: str,
|
||||
blog_title: str,
|
||||
keywords_data: Dict[str, Any]
|
||||
keywords_data: Dict[str, Any],
|
||||
outline: Optional[List[Dict[str, Any]]] = None,
|
||||
seo_analysis: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""Create high-quality prompt for core metadata generation"""
|
||||
|
||||
@@ -314,30 +383,106 @@ class BlogSEOMetadataGenerator:
|
||||
search_intent = keywords_data.get('search_intent', 'informational')
|
||||
target_audience = keywords_data.get('target_audience', 'general')
|
||||
industry = keywords_data.get('industry', 'general')
|
||||
|
||||
# Calculate word count for reading time estimation
|
||||
word_count = len(blog_content.split())
|
||||
|
||||
# Extract outline structure
|
||||
outline_context = ""
|
||||
if outline:
|
||||
headings = [s.get('heading', '') for s in outline if s.get('heading')]
|
||||
outline_context = f"""
|
||||
OUTLINE STRUCTURE:
|
||||
- Total sections: {len(outline)}
|
||||
- Section headings: {', '.join(headings[:8])}
|
||||
- Content hierarchy: Well-structured with {len(outline)} main sections
|
||||
"""
|
||||
|
||||
# Extract SEO analysis insights
|
||||
seo_context = ""
|
||||
if seo_analysis:
|
||||
overall_score = seo_analysis.get('overall_score', seo_analysis.get('seo_score', 0))
|
||||
category_scores = seo_analysis.get('category_scores', {})
|
||||
applied_recs = seo_analysis.get('applied_recommendations', [])
|
||||
|
||||
seo_context = f"""
|
||||
SEO ANALYSIS RESULTS:
|
||||
- Overall SEO Score: {overall_score}/100
|
||||
- Category Scores: Structure {category_scores.get('structure', category_scores.get('Structure', 0))}, Keywords {category_scores.get('keywords', category_scores.get('Keywords', 0))}, Readability {category_scores.get('readability', category_scores.get('Readability', 0))}
|
||||
- Applied Recommendations: {len(applied_recs)} SEO optimizations have been applied
|
||||
- Content Quality: Optimized for search engines with keyword focus
|
||||
"""
|
||||
|
||||
# Get more content context (key sections instead of just first 1000 chars)
|
||||
content_preview = self._extract_content_highlights(blog_content)
|
||||
|
||||
prompt = f"""
|
||||
Generate SEO metadata for this blog post.
|
||||
Generate comprehensive, personalized SEO metadata for this blog post.
|
||||
|
||||
BLOG TITLE: {blog_title}
|
||||
BLOG CONTENT: {blog_content[:1000]}...
|
||||
=== BLOG CONTENT CONTEXT ===
|
||||
TITLE: {blog_title}
|
||||
CONTENT PREVIEW (key sections): {content_preview}
|
||||
WORD COUNT: {word_count} words
|
||||
READING TIME ESTIMATE: {max(1, word_count // 200)} minutes
|
||||
|
||||
{outline_context}
|
||||
|
||||
=== KEYWORD & AUDIENCE DATA ===
|
||||
PRIMARY KEYWORDS: {primary_keywords}
|
||||
SEMANTIC KEYWORDS: {semantic_keywords}
|
||||
WORD COUNT: {word_count}
|
||||
SEARCH INTENT: {search_intent}
|
||||
TARGET AUDIENCE: {target_audience}
|
||||
INDUSTRY: {industry}
|
||||
|
||||
Generate:
|
||||
1. SEO TITLE (50-60 characters) - include primary keyword
|
||||
2. META DESCRIPTION (150-160 characters) - include CTA
|
||||
3. URL SLUG (lowercase, hyphens, 3-5 words)
|
||||
4. BLOG TAGS (5-8 relevant tags)
|
||||
5. BLOG CATEGORIES (2-3 categories)
|
||||
6. SOCIAL HASHTAGS (5-10 hashtags with #)
|
||||
7. READING TIME (calculate from {word_count} words)
|
||||
8. FOCUS KEYWORD (primary keyword for SEO)
|
||||
{seo_context}
|
||||
|
||||
Make it compelling and SEO-optimized.
|
||||
=== METADATA GENERATION REQUIREMENTS ===
|
||||
1. SEO TITLE (50-60 characters, must include primary keyword):
|
||||
- Front-load primary keyword
|
||||
- Make it compelling and click-worthy
|
||||
- Include power words if appropriate for {target_audience} audience
|
||||
- Optimized for {search_intent} search intent
|
||||
|
||||
2. META DESCRIPTION (150-160 characters, must include CTA):
|
||||
- Include primary keyword naturally in first 120 chars
|
||||
- Add compelling call-to-action (e.g., "Learn more", "Discover how", "Get started")
|
||||
- Highlight value proposition for {target_audience} audience
|
||||
- Use {industry} industry-specific terminology where relevant
|
||||
|
||||
3. URL SLUG (lowercase, hyphens, 3-5 words):
|
||||
- Include primary keyword
|
||||
- Remove stop words
|
||||
- Keep it concise and readable
|
||||
|
||||
4. BLOG TAGS (5-8 relevant tags):
|
||||
- Mix of primary, semantic, and long-tail keywords
|
||||
- Industry-specific tags for {industry}
|
||||
- Audience-relevant tags for {target_audience}
|
||||
|
||||
5. BLOG CATEGORIES (2-3 categories):
|
||||
- Based on content structure and {industry} industry standards
|
||||
- Reflect main themes from outline sections
|
||||
|
||||
6. SOCIAL HASHTAGS (5-10 hashtags with #):
|
||||
- Include primary keyword as hashtag
|
||||
- Industry-specific hashtags for {industry}
|
||||
- Trending/relevant hashtags for {target_audience}
|
||||
|
||||
7. READING TIME (calculate from {word_count} words):
|
||||
- Average reading speed: 200 words/minute
|
||||
- Round to nearest minute
|
||||
|
||||
8. FOCUS KEYWORD (primary keyword for SEO):
|
||||
- Select the most important primary keyword
|
||||
- Should match the main topic and search intent
|
||||
|
||||
=== QUALITY REQUIREMENTS ===
|
||||
- All metadata must be unique, not generic
|
||||
- Incorporate insights from SEO analysis if provided
|
||||
- Reflect the actual content structure from outline
|
||||
- Use language appropriate for {target_audience} audience
|
||||
- Optimize for {search_intent} search intent
|
||||
- Make descriptions compelling and action-oriented
|
||||
|
||||
Generate metadata that is personalized, compelling, and SEO-optimized.
|
||||
"""
|
||||
return prompt
|
||||
|
||||
@@ -345,7 +490,9 @@ Make it compelling and SEO-optimized.
|
||||
self,
|
||||
blog_content: str,
|
||||
blog_title: str,
|
||||
keywords_data: Dict[str, Any]
|
||||
keywords_data: Dict[str, Any],
|
||||
outline: Optional[List[Dict[str, Any]]] = None,
|
||||
seo_analysis: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""Create high-quality prompt for social metadata generation"""
|
||||
|
||||
@@ -353,49 +500,68 @@ Make it compelling and SEO-optimized.
|
||||
search_intent = keywords_data.get('search_intent', 'informational')
|
||||
target_audience = keywords_data.get('target_audience', 'general')
|
||||
industry = keywords_data.get('industry', 'general')
|
||||
|
||||
current_date = datetime.now().isoformat()
|
||||
|
||||
# Add outline and SEO context similar to core metadata prompt
|
||||
outline_context = ""
|
||||
if outline:
|
||||
headings = [s.get('heading', '') for s in outline if s.get('heading')]
|
||||
outline_context = f"\nOUTLINE SECTIONS: {', '.join(headings[:6])}\n"
|
||||
|
||||
seo_context = ""
|
||||
if seo_analysis:
|
||||
overall_score = seo_analysis.get('overall_score', seo_analysis.get('seo_score', 0))
|
||||
seo_context = f"\nSEO SCORE: {overall_score}/100 (optimized content)\n"
|
||||
|
||||
content_preview = self._extract_content_highlights(blog_content, 1500)
|
||||
|
||||
prompt = f"""
|
||||
Generate social media metadata for this blog post.
|
||||
Generate engaging social media metadata for this blog post.
|
||||
|
||||
BLOG TITLE: {blog_title}
|
||||
BLOG CONTENT: {blog_content[:800]}...
|
||||
PRIMARY KEYWORDS: {primary_keywords}
|
||||
=== CONTENT ===
|
||||
TITLE: {blog_title}
|
||||
CONTENT: {content_preview}
|
||||
{outline_context}
|
||||
{seo_context}
|
||||
KEYWORDS: {primary_keywords}
|
||||
TARGET AUDIENCE: {target_audience}
|
||||
INDUSTRY: {industry}
|
||||
CURRENT DATE: {current_date}
|
||||
|
||||
Generate:
|
||||
=== GENERATION REQUIREMENTS ===
|
||||
|
||||
1. OPEN GRAPH (Facebook/LinkedIn):
|
||||
- title: 60 chars max
|
||||
- description: 160 chars max
|
||||
- image: image URL
|
||||
- title: 60 chars max, include primary keyword, compelling for {target_audience}
|
||||
- description: 160 chars max, include CTA and value proposition
|
||||
- image: Suggest an appropriate image URL (placeholder if none available)
|
||||
- type: "article"
|
||||
- site_name: site name
|
||||
- url: canonical URL
|
||||
- site_name: Use appropriate site name for {industry} industry
|
||||
- url: Generate canonical URL structure
|
||||
|
||||
2. TWITTER CARD:
|
||||
- card: "summary_large_image"
|
||||
- title: 70 chars max
|
||||
- description: 200 chars max with hashtags
|
||||
- image: image URL
|
||||
- site: @sitename
|
||||
- creator: @author
|
||||
- title: 70 chars max, optimized for Twitter audience
|
||||
- description: 200 chars max with relevant hashtags inline
|
||||
- image: Match Open Graph image
|
||||
- site: @yourwebsite (placeholder, user should update)
|
||||
- creator: @author (placeholder, user should update)
|
||||
|
||||
3. JSON-LD SCHEMA:
|
||||
3. JSON-LD SCHEMA (Article):
|
||||
- @context: "https://schema.org"
|
||||
- @type: "Article"
|
||||
- headline: article title
|
||||
- description: article description
|
||||
- author: {{"@type": "Person", "name": "Author Name"}}
|
||||
- publisher: {{"@type": "Organization", "name": "Site Name"}}
|
||||
- datePublished: ISO date
|
||||
- dateModified: ISO date
|
||||
- mainEntityOfPage: canonical URL
|
||||
- keywords: array of keywords
|
||||
- wordCount: word count
|
||||
- headline: Article title (optimized)
|
||||
- description: Article description (150-200 chars)
|
||||
- author: {{"@type": "Person", "name": "Author Name"}} (placeholder)
|
||||
- publisher: {{"@type": "Organization", "name": "Site Name", "logo": {{"@type": "ImageObject", "url": "logo-url"}}}}
|
||||
- datePublished: {current_date}
|
||||
- dateModified: {current_date}
|
||||
- mainEntityOfPage: {{"@type": "WebPage", "@id": "canonical-url"}}
|
||||
- keywords: Array of primary and semantic keywords
|
||||
- wordCount: {len(blog_content.split())}
|
||||
- articleSection: Primary category based on content
|
||||
- inLanguage: "en-US"
|
||||
|
||||
Make it engaging and SEO-optimized.
|
||||
Make it engaging, personalized for {target_audience}, and optimized for {industry} industry.
|
||||
"""
|
||||
return prompt
|
||||
|
||||
|
||||
@@ -0,0 +1,269 @@
|
||||
"""Blog SEO Recommendation Applier
|
||||
|
||||
Applies actionable SEO recommendations to existing blog content using the
|
||||
provider-agnostic `llm_text_gen` dispatcher. Ensures GPT_PROVIDER parity.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Any, List
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
|
||||
|
||||
logger = get_service_logger("blog_seo_recommendation_applier")
|
||||
|
||||
|
||||
class BlogSEORecommendationApplier:
|
||||
"""Apply actionable SEO recommendations to blog content."""
|
||||
|
||||
def __init__(self):
|
||||
logger.debug("Initialized BlogSEORecommendationApplier")
|
||||
|
||||
async def apply_recommendations(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Apply recommendations and return updated content."""
|
||||
|
||||
title = payload.get("title", "Untitled Blog")
|
||||
sections: List[Dict[str, Any]] = payload.get("sections", [])
|
||||
outline = payload.get("outline", [])
|
||||
research = payload.get("research", {})
|
||||
recommendations = payload.get("recommendations", [])
|
||||
persona = payload.get("persona", {})
|
||||
tone = payload.get("tone")
|
||||
audience = payload.get("audience")
|
||||
|
||||
if not sections:
|
||||
return {"success": False, "error": "No sections provided for recommendation application"}
|
||||
|
||||
if not recommendations:
|
||||
logger.warning("apply_recommendations called without recommendations")
|
||||
return {"success": True, "title": title, "sections": sections, "applied": []}
|
||||
|
||||
prompt = self._build_prompt(
|
||||
title=title,
|
||||
sections=sections,
|
||||
outline=outline,
|
||||
research=research,
|
||||
recommendations=recommendations,
|
||||
persona=persona,
|
||||
tone=tone,
|
||||
audience=audience,
|
||||
)
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"sections": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "string"},
|
||||
"heading": {"type": "string"},
|
||||
"content": {"type": "string"},
|
||||
"notes": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
"required": ["id", "heading", "content"],
|
||||
},
|
||||
},
|
||||
"applied_recommendations": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"category": {"type": "string"},
|
||||
"summary": {"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["sections"],
|
||||
}
|
||||
|
||||
logger.info("Applying SEO recommendations via llm_text_gen")
|
||||
|
||||
result = await asyncio.to_thread(
|
||||
llm_text_gen,
|
||||
prompt,
|
||||
None,
|
||||
schema,
|
||||
)
|
||||
|
||||
if not result or result.get("error"):
|
||||
error_msg = result.get("error", "Unknown error") if result else "No response from text generator"
|
||||
logger.error(f"SEO recommendation application failed: {error_msg}")
|
||||
return {"success": False, "error": error_msg}
|
||||
|
||||
raw_sections = result.get("sections", []) or []
|
||||
normalized_sections: List[Dict[str, Any]] = []
|
||||
|
||||
# Build lookup table from updated sections using their identifiers
|
||||
updated_map: Dict[str, Dict[str, Any]] = {}
|
||||
for updated in raw_sections:
|
||||
section_id = str(
|
||||
updated.get("id")
|
||||
or updated.get("section_id")
|
||||
or updated.get("heading")
|
||||
or ""
|
||||
).strip()
|
||||
|
||||
if not section_id:
|
||||
continue
|
||||
|
||||
heading = (
|
||||
updated.get("heading")
|
||||
or updated.get("title")
|
||||
or section_id
|
||||
)
|
||||
|
||||
content_text = updated.get("content", "")
|
||||
if isinstance(content_text, list):
|
||||
content_text = "\n\n".join(str(p).strip() for p in content_text if p)
|
||||
|
||||
updated_map[section_id] = {
|
||||
"id": section_id,
|
||||
"heading": heading,
|
||||
"content": str(content_text).strip(),
|
||||
"notes": updated.get("notes", []),
|
||||
}
|
||||
|
||||
if not updated_map and raw_sections:
|
||||
logger.warning("Updated sections missing identifiers; falling back to positional mapping")
|
||||
|
||||
for index, original in enumerate(sections):
|
||||
fallback_id = str(
|
||||
original.get("id")
|
||||
or original.get("section_id")
|
||||
or f"section_{index + 1}"
|
||||
).strip()
|
||||
|
||||
mapped = updated_map.get(fallback_id)
|
||||
|
||||
if not mapped and raw_sections:
|
||||
# Fall back to positional match if identifier lookup failed
|
||||
candidate = raw_sections[index] if index < len(raw_sections) else {}
|
||||
heading = (
|
||||
candidate.get("heading")
|
||||
or candidate.get("title")
|
||||
or original.get("heading")
|
||||
or original.get("title")
|
||||
or f"Section {index + 1}"
|
||||
)
|
||||
content_text = candidate.get("content") or original.get("content", "")
|
||||
if isinstance(content_text, list):
|
||||
content_text = "\n\n".join(str(p).strip() for p in content_text if p)
|
||||
mapped = {
|
||||
"id": fallback_id,
|
||||
"heading": heading,
|
||||
"content": str(content_text).strip(),
|
||||
"notes": candidate.get("notes", []),
|
||||
}
|
||||
|
||||
if not mapped:
|
||||
# Fallback to original content if nothing else available
|
||||
mapped = {
|
||||
"id": fallback_id,
|
||||
"heading": original.get("heading") or original.get("title") or f"Section {index + 1}",
|
||||
"content": str(original.get("content", "")).strip(),
|
||||
"notes": original.get("notes", []),
|
||||
}
|
||||
|
||||
normalized_sections.append(mapped)
|
||||
|
||||
applied = result.get("applied_recommendations", [])
|
||||
|
||||
logger.info("SEO recommendations applied successfully")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"title": result.get("title", title),
|
||||
"sections": normalized_sections,
|
||||
"applied": applied,
|
||||
}
|
||||
|
||||
def _build_prompt(
|
||||
self,
|
||||
*,
|
||||
title: str,
|
||||
sections: List[Dict[str, Any]],
|
||||
outline: List[Dict[str, Any]],
|
||||
research: Dict[str, Any],
|
||||
recommendations: List[Dict[str, Any]],
|
||||
persona: Dict[str, Any],
|
||||
tone: str | None,
|
||||
audience: str | None,
|
||||
) -> str:
|
||||
"""Construct prompt for applying recommendations."""
|
||||
|
||||
sections_str = []
|
||||
for section in sections:
|
||||
sections_str.append(
|
||||
f"ID: {section.get('id', 'section')}, Heading: {section.get('heading', 'Untitled')}\n"
|
||||
f"Current Content:\n{section.get('content', '')}\n"
|
||||
)
|
||||
|
||||
outline_str = "\n".join(
|
||||
[
|
||||
f"- {item.get('heading', 'Section')} (Target words: {item.get('target_words', 'N/A')})"
|
||||
for item in outline
|
||||
]
|
||||
)
|
||||
|
||||
research_summary = research.get("keyword_analysis", {}) if research else {}
|
||||
primary_keywords = ", ".join(research_summary.get("primary", [])[:10]) or "None"
|
||||
|
||||
recommendations_str = []
|
||||
for rec in recommendations:
|
||||
recommendations_str.append(
|
||||
f"Category: {rec.get('category', 'General')} | Priority: {rec.get('priority', 'Medium')}\n"
|
||||
f"Recommendation: {rec.get('recommendation', '')}\n"
|
||||
f"Impact: {rec.get('impact', '')}\n"
|
||||
)
|
||||
|
||||
persona_str = (
|
||||
f"Persona: {persona}\n"
|
||||
if persona
|
||||
else "Persona: (not provided)\n"
|
||||
)
|
||||
|
||||
style_guidance = []
|
||||
if tone:
|
||||
style_guidance.append(f"Desired tone: {tone}")
|
||||
if audience:
|
||||
style_guidance.append(f"Target audience: {audience}")
|
||||
style_str = "\n".join(style_guidance) if style_guidance else "Maintain current tone and audience alignment."
|
||||
|
||||
prompt = f"""
|
||||
You are an expert SEO content strategist. Update the blog content to apply the actionable recommendations.
|
||||
|
||||
Current Title: {title}
|
||||
|
||||
Primary Keywords (for context): {primary_keywords}
|
||||
|
||||
Outline Overview:
|
||||
{outline_str or 'No outline supplied'}
|
||||
|
||||
Existing Sections:
|
||||
{''.join(sections_str)}
|
||||
|
||||
Actionable Recommendations to Apply:
|
||||
{''.join(recommendations_str)}
|
||||
|
||||
{persona_str}
|
||||
{style_str}
|
||||
|
||||
Instructions:
|
||||
1. Carefully apply the recommendations while preserving factual accuracy and research alignment.
|
||||
2. Keep section identifiers (IDs) unchanged so the frontend can map updates correctly.
|
||||
3. Improve clarity, flow, and SEO optimization per the guidance.
|
||||
4. Return updated sections in the requested JSON format.
|
||||
5. Provide a short summary of which recommendations were addressed.
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
__all__ = ["BlogSEORecommendationApplier"]
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from io import BytesIO
|
||||
|
||||
# Import existing infrastructure
|
||||
from ...onboarding.api_key_manager import APIKeyManager
|
||||
from ...llm_providers.text_to_image_generation.gen_gemini_images import generate_gemini_image
|
||||
from ...llm_providers.main_image_generation import generate_image
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -270,41 +270,57 @@ class LinkedInImageGenerator:
|
||||
|
||||
async def _generate_with_gemini(self, prompt: str, aspect_ratio: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate image using existing Gemini infrastructure.
|
||||
Generate image using unified image generation infrastructure.
|
||||
|
||||
Args:
|
||||
prompt: Enhanced prompt for image generation
|
||||
aspect_ratio: Desired aspect ratio
|
||||
|
||||
Returns:
|
||||
Generation result from Gemini
|
||||
Generation result from image generation provider
|
||||
"""
|
||||
try:
|
||||
# Use existing Gemini image generation function
|
||||
# This integrates with the current infrastructure
|
||||
result = generate_gemini_image(prompt, aspect_ratio=aspect_ratio)
|
||||
# Map aspect ratio to dimensions (LinkedIn-optimized)
|
||||
aspect_map = {
|
||||
"1:1": (1024, 1024),
|
||||
"16:9": (1920, 1080),
|
||||
"4:3": (1366, 1024),
|
||||
"9:16": (1080, 1920), # Portrait for stories
|
||||
}
|
||||
width, height = aspect_map.get(aspect_ratio, (1024, 1024))
|
||||
|
||||
if result and os.path.exists(result):
|
||||
# Read the generated image
|
||||
with open(result, 'rb') as f:
|
||||
image_data = f.read()
|
||||
|
||||
# Use unified image generation system (defaults to provider based on GPT_PROVIDER)
|
||||
result = generate_image(
|
||||
prompt=prompt,
|
||||
options={
|
||||
"provider": "gemini", # LinkedIn uses Gemini by default
|
||||
"model": self.model if hasattr(self, 'model') else None,
|
||||
"width": width,
|
||||
"height": height,
|
||||
}
|
||||
)
|
||||
|
||||
if result and result.image_bytes:
|
||||
return {
|
||||
'success': True,
|
||||
'image_data': image_data,
|
||||
'image_path': result
|
||||
'image_data': result.image_bytes,
|
||||
'image_path': None, # No file path, using bytes directly
|
||||
'width': result.width,
|
||||
'height': result.height,
|
||||
'provider': result.provider,
|
||||
'model': result.model,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'success': False,
|
||||
'error': 'Gemini image generation returned no result'
|
||||
'error': 'Image generation returned no result'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Gemini image generation: {str(e)}")
|
||||
logger.error(f"Error in image generation: {str(e)}")
|
||||
return {
|
||||
'success': False,
|
||||
'error': f"Gemini generation failed: {str(e)}"
|
||||
'error': f"Image generation failed: {str(e)}"
|
||||
}
|
||||
|
||||
async def _process_generated_image(
|
||||
|
||||
15
backend/services/llm_providers/image_generation/__init__.py
Normal file
15
backend/services/llm_providers/image_generation/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from .base import ImageGenerationOptions, ImageGenerationResult, ImageGenerationProvider
|
||||
from .hf_provider import HuggingFaceImageProvider
|
||||
from .gemini_provider import GeminiImageProvider
|
||||
from .stability_provider import StabilityImageProvider
|
||||
|
||||
__all__ = [
|
||||
"ImageGenerationOptions",
|
||||
"ImageGenerationResult",
|
||||
"ImageGenerationProvider",
|
||||
"HuggingFaceImageProvider",
|
||||
"GeminiImageProvider",
|
||||
"StabilityImageProvider",
|
||||
]
|
||||
|
||||
|
||||
37
backend/services/llm_providers/image_generation/base.py
Normal file
37
backend/services/llm_providers/image_generation/base.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Dict, Any, Protocol
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageGenerationOptions:
|
||||
prompt: str
|
||||
negative_prompt: Optional[str] = None
|
||||
width: int = 1024
|
||||
height: int = 1024
|
||||
guidance_scale: Optional[float] = None
|
||||
steps: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
model: Optional[str] = None
|
||||
extra: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageGenerationResult:
|
||||
image_bytes: bytes
|
||||
width: int
|
||||
height: int
|
||||
provider: str
|
||||
model: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ImageGenerationProvider(Protocol):
|
||||
"""Protocol for image generation providers."""
|
||||
|
||||
def generate(self, options: ImageGenerationOptions) -> ImageGenerationResult:
|
||||
...
|
||||
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from .base import ImageGenerationOptions, ImageGenerationResult, ImageGenerationProvider
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
logger = get_service_logger("image_generation.gemini")
|
||||
|
||||
|
||||
class GeminiImageProvider(ImageGenerationProvider):
|
||||
"""Google Gemini/Imagen backed image generation.
|
||||
|
||||
NOTE: Implementation should call the actual Gemini Images API used in the codebase.
|
||||
Here we keep a minimal interface and expect the underlying client to be wired
|
||||
similarly to other providers and return a PIL image or raw bytes.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
api_key = os.getenv("GOOGLE_API_KEY")
|
||||
if not api_key:
|
||||
logger.warning("GOOGLE_API_KEY not set. Gemini image generation may fail at runtime.")
|
||||
logger.info("GeminiImageProvider initialized")
|
||||
|
||||
def generate(self, options: ImageGenerationOptions) -> ImageGenerationResult:
|
||||
# Placeholder implementation to be replaced by real Gemini/Imagen call.
|
||||
# For now, generate a 1x1 transparent PNG to maintain interface consistency
|
||||
img = Image.new("RGBA", (max(1, options.width), max(1, options.height)), (0, 0, 0, 0))
|
||||
with io.BytesIO() as buf:
|
||||
img.save(buf, format="PNG")
|
||||
png = buf.getvalue()
|
||||
|
||||
return ImageGenerationResult(
|
||||
image_bytes=png,
|
||||
width=img.width,
|
||||
height=img.height,
|
||||
provider="gemini",
|
||||
model=os.getenv("GEMINI_IMAGE_MODEL"),
|
||||
seed=options.seed,
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from PIL import Image
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
from .base import ImageGenerationOptions, ImageGenerationResult, ImageGenerationProvider
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
logger = get_service_logger("image_generation.huggingface")
|
||||
|
||||
|
||||
DEFAULT_HF_MODEL = os.getenv(
|
||||
"HF_IMAGE_MODEL",
|
||||
"black-forest-labs/FLUX.1-Krea-dev",
|
||||
)
|
||||
|
||||
|
||||
class HuggingFaceImageProvider(ImageGenerationProvider):
|
||||
"""Hugging Face Inference Providers (fal-ai) backed image generation.
|
||||
|
||||
API doc: https://huggingface.co/docs/inference-providers/en/tasks/text-to-image
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, provider: str = "fal-ai") -> None:
|
||||
self.api_key = api_key or os.getenv("HF_TOKEN")
|
||||
if not self.api_key:
|
||||
raise RuntimeError("HF_TOKEN is required for Hugging Face image generation")
|
||||
self.provider = provider
|
||||
self.client = InferenceClient(provider=self.provider, api_key=self.api_key)
|
||||
logger.info("HuggingFaceImageProvider initialized (provider=%s)", self.provider)
|
||||
|
||||
def generate(self, options: ImageGenerationOptions) -> ImageGenerationResult:
|
||||
model = options.model or DEFAULT_HF_MODEL
|
||||
params: Dict[str, Any] = {}
|
||||
if options.guidance_scale is not None:
|
||||
params["guidance_scale"] = options.guidance_scale
|
||||
if options.steps is not None:
|
||||
params["num_inference_steps"] = options.steps
|
||||
if options.negative_prompt:
|
||||
params["negative_prompt"] = options.negative_prompt
|
||||
if options.seed is not None:
|
||||
params["seed"] = options.seed
|
||||
|
||||
# The HF InferenceClient returns a PIL Image
|
||||
logger.debug("HF generate: model=%s width=%s height=%s params=%s", model, options.width, options.height, params)
|
||||
img: Image.Image = self.client.text_to_image(
|
||||
options.prompt,
|
||||
model=model,
|
||||
width=options.width,
|
||||
height=options.height,
|
||||
**params,
|
||||
)
|
||||
|
||||
with io.BytesIO() as buf:
|
||||
img.save(buf, format="PNG")
|
||||
image_bytes = buf.getvalue()
|
||||
|
||||
return ImageGenerationResult(
|
||||
image_bytes=image_bytes,
|
||||
width=img.width,
|
||||
height=img.height,
|
||||
provider="huggingface",
|
||||
model=model,
|
||||
seed=options.seed,
|
||||
metadata={"provider": self.provider},
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from .base import ImageGenerationOptions, ImageGenerationResult, ImageGenerationProvider
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
logger = get_service_logger("image_generation.stability")
|
||||
|
||||
|
||||
DEFAULT_STABILITY_MODEL = os.getenv("STABILITY_MODEL", "stable-diffusion-xl-1024-v1-0")
|
||||
|
||||
|
||||
class StabilityImageProvider(ImageGenerationProvider):
|
||||
"""Stability AI Images API provider (simple text-to-image).
|
||||
|
||||
This uses the v1 text-to-image endpoint format. Adjust to match your existing
|
||||
Stability integration if different.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None) -> None:
|
||||
self.api_key = api_key or os.getenv("STABILITY_API_KEY")
|
||||
if not self.api_key:
|
||||
logger.warning("STABILITY_API_KEY not set. Stability generation may fail at runtime.")
|
||||
logger.info("StabilityImageProvider initialized")
|
||||
|
||||
def generate(self, options: ImageGenerationOptions) -> ImageGenerationResult:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload: Dict[str, Any] = {
|
||||
"text_prompts": [
|
||||
{"text": options.prompt, "weight": 1.0},
|
||||
],
|
||||
"cfg_scale": options.guidance_scale or 7.0,
|
||||
"steps": options.steps or 30,
|
||||
"width": options.width,
|
||||
"height": options.height,
|
||||
"seed": options.seed,
|
||||
}
|
||||
if options.negative_prompt:
|
||||
payload["text_prompts"].append({"text": options.negative_prompt, "weight": -1.0})
|
||||
|
||||
model = options.model or DEFAULT_STABILITY_MODEL
|
||||
url = f"https://api.stability.ai/v1/generation/{model}/text-to-image"
|
||||
|
||||
logger.debug("Stability generate: model=%s payload_keys=%s", model, list(payload.keys()))
|
||||
resp = requests.post(url, headers=headers, json=payload, timeout=60)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
# Expecting data["artifacts"][0]["base64"]
|
||||
import base64
|
||||
|
||||
artifact = (data.get("artifacts") or [{}])[0]
|
||||
b64 = artifact.get("base64", "")
|
||||
image_bytes = base64.b64decode(b64)
|
||||
|
||||
# Confirm dimensions by loading once (optional)
|
||||
img = Image.open(io.BytesIO(image_bytes))
|
||||
|
||||
return ImageGenerationResult(
|
||||
image_bytes=image_bytes,
|
||||
width=img.width,
|
||||
height=img.height,
|
||||
provider="stability",
|
||||
model=model,
|
||||
seed=options.seed,
|
||||
)
|
||||
|
||||
|
||||
73
backend/services/llm_providers/main_image_generation.py
Normal file
73
backend/services/llm_providers/main_image_generation.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from .image_generation import (
|
||||
ImageGenerationOptions,
|
||||
ImageGenerationResult,
|
||||
HuggingFaceImageProvider,
|
||||
GeminiImageProvider,
|
||||
StabilityImageProvider,
|
||||
)
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
logger = get_service_logger("image_generation.facade")
|
||||
|
||||
|
||||
def _select_provider(explicit: Optional[str]) -> str:
|
||||
if explicit:
|
||||
return explicit
|
||||
gpt_provider = (os.getenv("GPT_PROVIDER") or "").lower()
|
||||
if gpt_provider.startswith("gemini"):
|
||||
return "gemini"
|
||||
if gpt_provider.startswith("hf"):
|
||||
return "huggingface"
|
||||
if os.getenv("STABILITY_API_KEY"):
|
||||
return "stability"
|
||||
# Fallback to huggingface to enable a path if configured
|
||||
return "huggingface"
|
||||
|
||||
|
||||
def _get_provider(provider_name: str):
|
||||
if provider_name == "huggingface":
|
||||
return HuggingFaceImageProvider()
|
||||
if provider_name == "gemini":
|
||||
return GeminiImageProvider()
|
||||
if provider_name == "stability":
|
||||
return StabilityImageProvider()
|
||||
raise ValueError(f"Unknown image provider: {provider_name}")
|
||||
|
||||
|
||||
def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None) -> ImageGenerationResult:
|
||||
opts = options or {}
|
||||
provider_name = _select_provider(opts.get("provider"))
|
||||
|
||||
image_options = ImageGenerationOptions(
|
||||
prompt=prompt,
|
||||
negative_prompt=opts.get("negative_prompt"),
|
||||
width=int(opts.get("width", 1024)),
|
||||
height=int(opts.get("height", 1024)),
|
||||
guidance_scale=opts.get("guidance_scale"),
|
||||
steps=opts.get("steps"),
|
||||
seed=opts.get("seed"),
|
||||
model=opts.get("model"),
|
||||
extra=opts,
|
||||
)
|
||||
|
||||
# Normalize obvious model/provider mismatches
|
||||
model_lower = (image_options.model or "").lower()
|
||||
if provider_name == "stability" and (model_lower.startswith("black-forest-labs/") or model_lower.startswith("runwayml/") or model_lower.startswith("stabilityai/flux")):
|
||||
logger.info("Remapping provider to huggingface for model=%s", image_options.model)
|
||||
provider_name = "huggingface"
|
||||
|
||||
if provider_name == "huggingface" and not image_options.model:
|
||||
# Provide a sensible default HF model if none specified
|
||||
image_options.model = "black-forest-labs/FLUX.1-Krea-dev"
|
||||
|
||||
logger.info("Generating image via provider=%s model=%s", provider_name, image_options.model)
|
||||
provider = _get_provider(provider_name)
|
||||
return provider.generate(image_options)
|
||||
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
from openai import OpenAI
|
||||
from loguru import logger
|
||||
import sys
|
||||
|
||||
from .save_image import save_generated_image
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
) # for exponential backoff
|
||||
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=120), stop=stop_after_attempt(6))
|
||||
def generate_dalle3_images(img_prompt, image_dir, size="1024x1024", quality="hd", n=1):
|
||||
"""
|
||||
Generates images using the DALL-E 3 model based on a given text prompt.
|
||||
|
||||
Args:
|
||||
img_prompt (str): Text prompt to generate the image.
|
||||
image_dir (str): Directory where the generated image will be saved.
|
||||
size (str, optional): Size of the generated images. Defaults to "1024x1024".
|
||||
quality (str, optional): Quality of the generated images. Defaults to "hd".
|
||||
n (int, optional): Number of images to generate. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
str: Path to the saved image.
|
||||
|
||||
Raises:
|
||||
SystemExit: If an error occurs in image generation or saving.
|
||||
"""
|
||||
try:
|
||||
logger.info("Generating Dall-e-3 image for the blog.")
|
||||
client = OpenAI()
|
||||
|
||||
img_generation_response = client.images.generate(
|
||||
model="dall-e-3",
|
||||
prompt=img_prompt,
|
||||
size=size,
|
||||
quality=quality,
|
||||
n=n
|
||||
)
|
||||
# Save the generated image locally.
|
||||
try:
|
||||
img_path = save_generated_image(img_generation_response, image_dir)
|
||||
return img_path
|
||||
except Exception as err:
|
||||
logger.error(f"Failed to Save generated image: {err}")
|
||||
|
||||
except openai.OpenAIError as e:
|
||||
logger.error(f"Dalle-3 image generation error: HTTP Status {e.http_status}, Error: {e.error}")
|
||||
sys.exit("Exiting due to Dalle-3 image generation error.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate images with Dalle3: {e}")
|
||||
sys.exit("Exiting due to a general error in image generation.")
|
||||
@@ -1,53 +0,0 @@
|
||||
from openai import OpenAI
|
||||
from loguru import logger
|
||||
import sys
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
) # for exponential backoff
|
||||
|
||||
from .save_image import save_generated_image
|
||||
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=120), stop=stop_after_attempt(6))
|
||||
def generate_dalle3_images(img_prompt, image_dir, size="1024x1024", quality="hd", n=1):
|
||||
"""
|
||||
Generates images using the DALL-E 3 model based on a given text prompt.
|
||||
|
||||
Args:
|
||||
img_prompt (str): Text prompt to generate the image.
|
||||
image_dir (str): Directory where the generated image will be saved.
|
||||
size (str, optional): Size of the generated images. Defaults to "1024x1024".
|
||||
quality (str, optional): Quality of the generated images. Defaults to "hd".
|
||||
n (int, optional): Number of images to generate. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
str: Path to the saved image.
|
||||
|
||||
Raises:
|
||||
SystemExit: If an error occurs in image generation or saving.
|
||||
"""
|
||||
try:
|
||||
logger.info("Generating Dall-e-3 image for the blog.")
|
||||
client = OpenAI()
|
||||
|
||||
img_generation_response = client.images.generate(
|
||||
model="dall-e-3",
|
||||
prompt=img_prompt,
|
||||
size=size,
|
||||
quality=quality,
|
||||
n=n
|
||||
)
|
||||
|
||||
img_path = save_generated_image(img_generation_response, image_dir)
|
||||
return img_path
|
||||
|
||||
except openai.OpenAIError as e:
|
||||
logger.error(f"Dalle-3 image generation error: HTTP Status {e.http_status}, Error: {e.error}")
|
||||
sys.exit("Exiting due to Dalle-3 image generation error.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate images with Dalle3: {e}")
|
||||
sys.exit("Exiting due to a general error in image generation.")
|
||||
@@ -1,583 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import datetime
|
||||
import base64
|
||||
import random
|
||||
from typing import List, Optional, Tuple
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import logging
|
||||
|
||||
# Import APIKeyManager
|
||||
from ...onboarding.api_key_manager import APIKeyManager
|
||||
|
||||
try:
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
except ImportError:
|
||||
genai = None
|
||||
logging.getLogger('gemini_image_generator').warning(
|
||||
"Google genai library not available. Install with: pip install google-generativeai"
|
||||
)
|
||||
|
||||
|
||||
from .save_image import save_generated_image
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger('gemini_image_generator')
|
||||
|
||||
# Imagen fallback configuration
|
||||
IMAGEN_FALLBACK_CONFIG = {
|
||||
'enabled': os.getenv('IMAGEN_FALLBACK_ENABLED', 'true').lower() == 'true', # Master switch for Imagen fallback
|
||||
'auto_fallback': os.getenv('IMAGEN_AUTO_FALLBACK', 'true').lower() == 'true', # Automatically fall back on Gemini failures
|
||||
'preferred_model': os.getenv('IMAGEN_MODEL', 'imagen-4.0-generate-001'), # Fast model for quick generation
|
||||
'fallback_aspect_ratios': {
|
||||
'1:1': '1:1',
|
||||
'3:4': '3:4',
|
||||
'4:3': '4:3',
|
||||
'9:16': '9:16',
|
||||
'16:9': '16:9'
|
||||
},
|
||||
'max_images': int(os.getenv('IMAGEN_MAX_IMAGES', '1')), # Generate 1 image for LinkedIn posts
|
||||
}
|
||||
|
||||
# Log configuration on startup
|
||||
logger.info(f"🔄 Imagen fallback configuration: {IMAGEN_FALLBACK_CONFIG}")
|
||||
|
||||
# With image generation in Gemini, your imagination is the limit.
|
||||
# Follow Google AI best practices for detailed prompts and iterative refinement.
|
||||
|
||||
# Generate images using Gemini
|
||||
# Gemini 2.0 Flash Experimental supports the ability to output text and inline images.
|
||||
# This lets you use Gemini to conversationally edit images or generate outputs with interwoven text (for example, generating a blog post with text and images in a single turn).
|
||||
# Note: Make sure to include responseModalities: ["Text", "Image"] in your generation configuration for text and image output with gemini-2.0-flash-exp-image-generation. Image only is not allowed.
|
||||
|
||||
|
||||
class AIPromptGenerator:
|
||||
"""
|
||||
Generates enhanced AI image prompts based on user keywords,
|
||||
following the guidelines of the Imagen documentation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.photography_styles = ["photo", "photograph"]
|
||||
self.art_styles = ["painting", "sketch", "drawing", "illustration", "digital art", "render"]
|
||||
self.art_techniques = ["technical pencil drawing", "charcoal drawing", "color pencil drawing", "pastel painting", "digital art", "art deco (poster)", "impressionist painting", "renaissance painting", "pop art"]
|
||||
self.camera_proximity = ["close-up", "zoomed out", "taken from far away"]
|
||||
self.camera_position = ["aerial", "from below"]
|
||||
self.lighting = ["natural lighting", "dramatic lighting", "warm lighting", "cold lighting", "studio lighting", "golden hour lighting"]
|
||||
self.camera_settings = ["motion blur", "soft focus", "bokeh", "portrait"]
|
||||
self.lens_types = ["35mm lens", "50mm lens", "fisheye lens", "wide angle lens", "macro lens", "telephoto lens"]
|
||||
self.film_types = ["black and white film", "polaroid"]
|
||||
self.materials = ["made of cheese", "made of paper", "made of neon tubes", "metallic", "glass", "wooden", "stone"]
|
||||
self.shapes = ["in the shape of a bird", "angular", "curved", "geometric"]
|
||||
self.quality_modifiers_general = ["high-quality", "beautiful", "stylized", "detailed", "epic", "grand"]
|
||||
self.quality_modifiers_photo = ["4K", "HDR", "studio photo", "professional photo", "photorealistic"]
|
||||
self.quality_modifiers_art = ["by a professional artist", "intricate details", "masterpiece"]
|
||||
self.aspect_ratios = ["1:1 aspect ratio", "4:3 aspect ratio", "3:4 aspect ratio", "16:9 aspect ratio", "9:16 aspect ratio"]
|
||||
self.photorealistic_modifiers = {
|
||||
"portraits": ["prime lens", "zoom lens", "24-35mm", "black and white film", "film noir", "shallow depth of field", "duotone (mention two colors)"],
|
||||
"objects": ["macro lens", "60-105mm", "high detail", "precise focusing", "controlled lighting"],
|
||||
"motion": ["telephoto zoom lens", "100-400mm", "fast shutter speed", "action shot", "movement tracking"],
|
||||
"wide-angle": ["wide-angle lens", "10-24mm", "long exposure", "sharp focus", "smooth water or clouds", "astro photography"]
|
||||
}
|
||||
|
||||
def generate_prompt(self, keywords):
|
||||
"""
|
||||
Generates an enhanced AI image prompt based on user-provided keywords.
|
||||
|
||||
Args:
|
||||
keywords (list): A list of keywords describing the desired image.
|
||||
|
||||
Returns:
|
||||
str: An enhanced AI image prompt.
|
||||
"""
|
||||
if not keywords:
|
||||
return "A beautiful image."
|
||||
|
||||
prompt_parts = []
|
||||
subject = " ".join(keywords)
|
||||
prompt_parts.append(subject)
|
||||
|
||||
# Add context and background (optional)
|
||||
context_options = ["in a detailed background", "outdoors", "indoors", "in a studio", "with a blurred background"]
|
||||
if random.random() < 0.6: # Add context with a probability
|
||||
prompt_parts.append(random.choice(context_options))
|
||||
|
||||
# Add style (optional)
|
||||
style_options = self.photography_styles + [f"{art} of" for art in self.art_styles]
|
||||
if random.random() < 0.7:
|
||||
prompt_parts.insert(0, random.choice(style_options))
|
||||
if prompt_parts[0].startswith("painting of") or prompt_parts[0].startswith("sketch of") or prompt_parts[0].startswith("drawing of"):
|
||||
if random.random() < 0.5:
|
||||
prompt_parts.append(f"in the style of {random.choice(self.art_techniques)}")
|
||||
|
||||
# Add photography modifiers (if photography style is chosen)
|
||||
if any(style in prompt_parts[0] for style in self.photography_styles):
|
||||
if random.random() < 0.4:
|
||||
prompt_parts.append(random.choice(self.camera_proximity))
|
||||
if random.random() < 0.3:
|
||||
prompt_parts.append(random.choice(self.camera_position))
|
||||
if random.random() < 0.5:
|
||||
prompt_parts.append(random.choice(self.lighting))
|
||||
if random.random() < 0.3:
|
||||
prompt_parts.append(random.choice(self.camera_settings))
|
||||
if random.random() < 0.2:
|
||||
prompt_parts.append(random.choice(self.lens_types))
|
||||
if random.random() < 0.1:
|
||||
prompt_parts.append(random.choice(self.film_types))
|
||||
|
||||
# Add shapes and materials (optional)
|
||||
if random.random() < 0.3:
|
||||
prompt_parts.append(random.choice(self.materials))
|
||||
if random.random() < 0.2:
|
||||
prompt_parts.append(random.choice(self.shapes))
|
||||
|
||||
# Add quality modifiers (optional)
|
||||
if random.random() < 0.6:
|
||||
quality_options = self.quality_modifiers_general
|
||||
if any(style in prompt_parts[0] for style in self.photography_styles):
|
||||
quality_options += self.quality_modifiers_photo
|
||||
else:
|
||||
quality_options += self.quality_modifiers_art
|
||||
prompt_parts.append(random.choice(list(set(quality_options)))) # Avoid duplicates
|
||||
|
||||
# Add aspect ratio (optional)
|
||||
if random.random() < 0.2:
|
||||
prompt_parts.append(random.choice(self.aspect_ratios))
|
||||
|
||||
return ", ".join(prompt_parts)
|
||||
|
||||
def generate_photorealistic_prompt(self, keywords, focus=""):
|
||||
"""
|
||||
Generates an enhanced AI image prompt specifically for photorealistic images.
|
||||
|
||||
Args:
|
||||
keywords (list): A list of keywords describing the desired image.
|
||||
focus (str, optional): The focus of the photorealistic image (e.g., "portraits", "objects", "motion", "wide-angle"). Defaults to "".
|
||||
|
||||
Returns:
|
||||
str: An enhanced photorealistic AI image prompt.
|
||||
"""
|
||||
if not keywords:
|
||||
return "A photorealistic image."
|
||||
|
||||
prompt_parts = ["A photo of", "photorealistic"]
|
||||
prompt_parts.append(" ".join(keywords))
|
||||
|
||||
if focus and focus in self.photorealistic_modifiers:
|
||||
modifiers = self.photorealistic_modifiers[focus]
|
||||
if modifiers:
|
||||
num_modifiers = random.randint(1, min(3, len(modifiers)))
|
||||
selected_modifiers = random.sample(modifiers, num_modifiers)
|
||||
prompt_parts.extend(selected_modifiers)
|
||||
|
||||
# Add general quality modifiers
|
||||
if random.random() < 0.5:
|
||||
prompt_parts.append(random.choice(self.quality_modifiers_photo))
|
||||
|
||||
# Add lighting
|
||||
if random.random() < 0.4:
|
||||
prompt_parts.append(random.choice(self.lighting))
|
||||
|
||||
return ", ".join(prompt_parts)
|
||||
|
||||
def _ensure_client() -> Optional[object]:
|
||||
"""Create a Gemini client if available and API key is configured."""
|
||||
api_key_manager = APIKeyManager()
|
||||
api_key = api_key_manager.get_api_key("gemini")
|
||||
if not api_key or genai is None:
|
||||
if not api_key:
|
||||
logger.warning("No Gemini API key found")
|
||||
if genai is None:
|
||||
logger.warning("Google Generative AI library not available")
|
||||
return None
|
||||
try:
|
||||
logger.info("Creating Gemini client...")
|
||||
# Create a client using the correct API pattern
|
||||
# The API key is passed directly to the Client constructor
|
||||
client = genai.Client(api_key=api_key)
|
||||
logger.info("Gemini client created successfully")
|
||||
return client
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create Gemini client: {e}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
|
||||
def _generate_imagen_images_base64(prompt: str, aspect_ratio: str = "1:1") -> List[str]:
|
||||
"""
|
||||
Generate images using Imagen API as a fallback method.
|
||||
|
||||
This function implements the Imagen API following the official documentation:
|
||||
https://ai.google.dev/gemini-api/docs/imagen
|
||||
|
||||
Args:
|
||||
prompt: Text prompt for image generation
|
||||
aspect_ratio: Desired aspect ratio (1:1, 3:4, 4:3, 9:16, 16:9)
|
||||
|
||||
Returns:
|
||||
List of base64-encoded PNG images
|
||||
"""
|
||||
logger = logging.getLogger('gemini_image_generator')
|
||||
logger.info("🔄 Falling back to Imagen API for image generation")
|
||||
|
||||
try:
|
||||
# Get API key for Imagen (can use same Gemini API key)
|
||||
api_key_manager = APIKeyManager()
|
||||
api_key = api_key_manager.get_api_key("gemini") # Imagen uses same API key
|
||||
|
||||
if not api_key:
|
||||
logger.error("No API key available for Imagen fallback")
|
||||
return []
|
||||
|
||||
# Create Imagen client
|
||||
client = genai.Client(api_key=api_key)
|
||||
|
||||
# Map aspect ratio to Imagen format using configuration
|
||||
imagen_aspect_ratio = IMAGEN_FALLBACK_CONFIG['fallback_aspect_ratios'].get(aspect_ratio, "1:1")
|
||||
|
||||
# Optimize prompt for Imagen (remove Gemini-specific formatting)
|
||||
imagen_prompt = _optimize_prompt_for_imagen(prompt)
|
||||
|
||||
logger.info(f"Generating Imagen images with prompt: {imagen_prompt[:100]}...")
|
||||
logger.info(f"Using aspect ratio: {imagen_aspect_ratio}")
|
||||
logger.info(f"Using model: {IMAGEN_FALLBACK_CONFIG['preferred_model']}")
|
||||
|
||||
# Generate images using configured Imagen model
|
||||
# Note: sample_image_size is not supported in current library version
|
||||
config_params = {
|
||||
'number_of_images': IMAGEN_FALLBACK_CONFIG['max_images'],
|
||||
'aspect_ratio': imagen_aspect_ratio,
|
||||
}
|
||||
|
||||
# Add additional configuration options if needed
|
||||
# config_params['guidance_scale'] = 7.5 # Optional: control image generation quality
|
||||
# config_params['person_generation'] = 'allow_adult' # Optional: control person generation
|
||||
|
||||
response = client.models.generate_images(
|
||||
model=IMAGEN_FALLBACK_CONFIG['preferred_model'],
|
||||
prompt=imagen_prompt,
|
||||
config=types.GenerateImagesConfig(**config_params)
|
||||
)
|
||||
|
||||
# Extract base64 images from response
|
||||
images_b64: List[str] = []
|
||||
for generated_image in response.generated_images:
|
||||
if hasattr(generated_image, 'image') and hasattr(generated_image.image, 'image_bytes'):
|
||||
# Convert image bytes to base64
|
||||
image_bytes = generated_image.image.image_bytes
|
||||
if isinstance(image_bytes, bytes):
|
||||
images_b64.append(base64.b64encode(image_bytes).decode('utf-8'))
|
||||
else:
|
||||
# If already base64 string
|
||||
images_b64.append(str(image_bytes))
|
||||
|
||||
if images_b64:
|
||||
logger.info(f"✅ Imagen fallback successful! Generated {len(images_b64)} images")
|
||||
return images_b64
|
||||
else:
|
||||
logger.warning("Imagen fallback returned no images")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Imagen fallback failed: {e}")
|
||||
import traceback
|
||||
logger.error(f"Imagen error traceback: {traceback.format_exc()}")
|
||||
return []
|
||||
|
||||
|
||||
def _optimize_prompt_for_imagen(prompt: str) -> str:
|
||||
"""
|
||||
Optimize prompt for Imagen API by removing Gemini-specific formatting
|
||||
and enhancing it with Imagen best practices.
|
||||
|
||||
Based on Imagen prompt guide: https://ai.google.dev/gemini-api/docs/imagen
|
||||
"""
|
||||
# Remove Gemini-specific formatting
|
||||
prompt = prompt.replace('\n\nEnhanced prompt:', '')
|
||||
prompt = prompt.replace('\n\nAspect ratio:', '')
|
||||
|
||||
# Clean up extra whitespace
|
||||
prompt = ' '.join(prompt.split())
|
||||
|
||||
# Add Imagen-specific enhancements if not present
|
||||
if 'professional' in prompt.lower() and 'linkedin' in prompt.lower():
|
||||
# Enhance for LinkedIn professional content
|
||||
prompt += ", high quality, professional photography, business appropriate"
|
||||
|
||||
if 'digital transformation' in prompt.lower() or 'technology' in prompt.lower():
|
||||
# Enhance for tech content
|
||||
prompt += ", modern, innovative, clean design, corporate aesthetic"
|
||||
|
||||
# Ensure prompt doesn't exceed Imagen's 480 token limit
|
||||
if len(prompt) > 400: # Leave some buffer
|
||||
prompt = prompt[:400] + "..."
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def generate_gemini_images_base64(
|
||||
prompt: str,
|
||||
*,
|
||||
keywords: Optional[list] = None,
|
||||
style: Optional[str] = None,
|
||||
focus: Optional[str] = None,
|
||||
enhance_prompt: bool = True,
|
||||
aspect_ratio: str = "9:16",
|
||||
max_retries: int = 2,
|
||||
initial_retry_delay: float = 1.0,
|
||||
enable_imagen_fallback: bool = True,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Return list of base64 PNG images generated from a prompt.
|
||||
|
||||
Primary method: Gemini API for image generation
|
||||
Fallback method: Imagen API when Gemini fails (quota limits, API errors, etc.)
|
||||
|
||||
Implements best practices per Gemini docs: send text prompt, parse inline image parts,
|
||||
and return base64 data suitable for API responses. No Streamlit, no printing.
|
||||
|
||||
Docs:
|
||||
- Gemini: https://ai.google.dev/gemini-api/docs/image-generation
|
||||
- Imagen: https://ai.google.dev/gemini-api/docs/imagen
|
||||
"""
|
||||
logger = logging.getLogger('gemini_image_generator')
|
||||
logger.info("Generating image (base64) with Gemini (with Imagen fallback)")
|
||||
|
||||
if enhance_prompt and keywords:
|
||||
pg = AIPromptGenerator()
|
||||
enhanced = (
|
||||
pg.generate_photorealistic_prompt(keywords, focus)
|
||||
if style == "photorealistic" and focus
|
||||
else pg.generate_prompt(keywords)
|
||||
)
|
||||
prompt = f"{prompt}\n\nEnhanced prompt: {enhanced}"
|
||||
|
||||
# Optional hint in-text for aspect ratio; API doesn't take ratio param directly
|
||||
if aspect_ratio:
|
||||
prompt = f"{prompt}\n\nAspect ratio: {aspect_ratio}"
|
||||
|
||||
# Try Gemini first
|
||||
client = _ensure_client()
|
||||
if client is None:
|
||||
logger.warning("Gemini client not available or API key missing")
|
||||
if enable_imagen_fallback and IMAGEN_FALLBACK_CONFIG['enabled']:
|
||||
logger.info("Falling back to Imagen API")
|
||||
return _generate_imagen_images_base64(prompt, aspect_ratio)
|
||||
return []
|
||||
|
||||
retry = 0
|
||||
delay = initial_retry_delay
|
||||
while retry <= max_retries:
|
||||
try:
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.0-flash-exp-image-generation",
|
||||
contents=[prompt],
|
||||
)
|
||||
|
||||
images_b64: List[str] = []
|
||||
for part in response.candidates[0].content.parts:
|
||||
if getattr(part, 'inline_data', None) is not None:
|
||||
# part.inline_data.data is bytes (base64 decoded by SDK?)
|
||||
# Standardize to base64 string for API consumers
|
||||
raw = part.inline_data.data
|
||||
if isinstance(raw, bytes):
|
||||
images_b64.append(base64.b64encode(raw).decode('utf-8'))
|
||||
else:
|
||||
# Some SDKs may already present base64 str
|
||||
images_b64.append(str(raw))
|
||||
|
||||
if images_b64:
|
||||
logger.info(f"✅ Gemini generated {len(images_b64)} images successfully")
|
||||
return images_b64
|
||||
else:
|
||||
logger.warning("Gemini returned no images, falling back to Imagen")
|
||||
if enable_imagen_fallback and IMAGEN_FALLBACK_CONFIG['enabled']:
|
||||
return _generate_imagen_images_base64(prompt, aspect_ratio)
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
msg = str(e)
|
||||
logger.warning(f"Gemini image gen error: {msg}")
|
||||
|
||||
# Check if this is a quota/API error that warrants fallback
|
||||
if any(error_type in msg.lower() for error_type in [
|
||||
'quota', 'resource_exhausted', 'rate_limit', 'billing', 'api_key', '403', '429'
|
||||
]):
|
||||
logger.info("Gemini quota/API error detected, falling back to Imagen")
|
||||
if enable_imagen_fallback and IMAGEN_FALLBACK_CONFIG['enabled']:
|
||||
return _generate_imagen_images_base64(prompt, aspect_ratio)
|
||||
return []
|
||||
|
||||
# For other errors, retry if possible
|
||||
if "503" in msg and retry < max_retries:
|
||||
time.sleep(delay)
|
||||
delay *= 2
|
||||
retry += 1
|
||||
continue
|
||||
|
||||
# Final fallback for any other errors
|
||||
if enable_imagen_fallback and IMAGEN_FALLBACK_CONFIG['enabled']:
|
||||
logger.info("Final fallback to Imagen due to Gemini error")
|
||||
return _generate_imagen_images_base64(prompt, aspect_ratio)
|
||||
return []
|
||||
|
||||
# If all retries exhausted, fall back to Imagen
|
||||
if enable_imagen_fallback and IMAGEN_FALLBACK_CONFIG['enabled']:
|
||||
logger.info("All Gemini retries exhausted, falling back to Imagen")
|
||||
return _generate_imagen_images_base64(prompt, aspect_ratio)
|
||||
return []
|
||||
|
||||
|
||||
def generate_gemini_image(
|
||||
prompt,
|
||||
keywords=None,
|
||||
style=None,
|
||||
focus=None,
|
||||
enhance_prompt=True,
|
||||
max_retries=2,
|
||||
initial_retry_delay=1.0,
|
||||
aspect_ratio="9:16",
|
||||
enable_imagen_fallback=True,
|
||||
):
|
||||
"""
|
||||
Backward-compatible wrapper that generates a single image file on disk and returns path.
|
||||
Now includes Imagen fallback for improved reliability.
|
||||
|
||||
Prefer generate_gemini_images_base64 in new code paths.
|
||||
"""
|
||||
logger = logging.getLogger('gemini_image_generator')
|
||||
images = generate_gemini_images_base64(
|
||||
prompt,
|
||||
keywords=keywords,
|
||||
style=style,
|
||||
focus=focus,
|
||||
enhance_prompt=enhance_prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
max_retries=max_retries,
|
||||
initial_retry_delay=initial_retry_delay,
|
||||
enable_imagen_fallback=enable_imagen_fallback,
|
||||
)
|
||||
if not images:
|
||||
return None
|
||||
|
||||
# Persist first image to file for legacy callers
|
||||
img_b64 = images[0]
|
||||
img_bytes = base64.b64decode(img_b64)
|
||||
img = Image.open(BytesIO(img_bytes))
|
||||
|
||||
# Update filename to indicate which API was used
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
if 'imagen' in prompt.lower() or 'fallback' in prompt.lower():
|
||||
out_name = f'imagen-fallback-image-{timestamp}.png'
|
||||
else:
|
||||
out_name = f'gemini-native-image-{timestamp}.png'
|
||||
|
||||
try:
|
||||
img.save(out_name)
|
||||
# Also call save_generated_image to reuse existing pipeline
|
||||
save_generated_image({"artifacts": [{"base64": img_b64}]})
|
||||
logger.info(f"✅ Image saved successfully: {out_name}")
|
||||
return out_name
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to save image: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def edit_image(image_path, prompt, max_retries=2, initial_retry_delay=1.0):
|
||||
"""
|
||||
- Image editing (text and image to image)
|
||||
Example prompt: "Edit this image to make it look like a cartoon"
|
||||
Example prompt: [image of a cat] + [image of a pillow] + "Create a cross stitch of my cat on this pillow."
|
||||
|
||||
- Multi-turn image editing (chat)
|
||||
Example prompts: [upload an image of a blue car.] "Turn this car into a convertible." "Now change the color to yellow."
|
||||
|
||||
Image editing with Gemini
|
||||
To perform image editing, add an image as input.
|
||||
The following example demonstrats uploading base64 encoded images.
|
||||
For multiple images and larger payloads, check the image input section.
|
||||
|
||||
Args:
|
||||
image_path (str): The path to the image to edit.
|
||||
prompt (str): The prompt to edit the image with.
|
||||
max_retries (int, optional): Maximum number of retry attempts for handling 503 errors. Defaults to 3.
|
||||
initial_retry_delay (int, optional): Initial delay in seconds before retrying. Defaults to 2.
|
||||
|
||||
Returns:
|
||||
str: The path to the edited image.
|
||||
"""
|
||||
import PIL.Image
|
||||
image = PIL.Image.open(image_path)
|
||||
|
||||
retry_count = 0
|
||||
retry_delay = initial_retry_delay
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
client = _ensure_client()
|
||||
if client is None:
|
||||
return None
|
||||
text_input = (prompt)
|
||||
|
||||
logger.info("Sending request to Gemini API for image editing")
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.0-flash-exp-image-generation",
|
||||
contents=[text_input, image],
|
||||
config=types.GenerateContentConfig(
|
||||
response_modalities=['Text', 'Image']
|
||||
)
|
||||
)
|
||||
logger.info("Received response from Gemini API for image editing")
|
||||
|
||||
edited_img_name = None
|
||||
for part in response.candidates[0].content.parts:
|
||||
if getattr(part, 'inline_data', None) is not None:
|
||||
logger.info("Received edited image data from Gemini")
|
||||
edited_image = Image.open(BytesIO(part.inline_data.data))
|
||||
|
||||
# Save the edited image
|
||||
edited_img_name = f'edited-{os.path.basename(image_path)}'
|
||||
try:
|
||||
logger.info(f"Saving edited image to: {edited_img_name}")
|
||||
edited_image.save(edited_img_name)
|
||||
|
||||
# Create a dictionary with the expected format for save_generated_image
|
||||
img_response = {
|
||||
"artifacts": [
|
||||
{
|
||||
"base64": base64.b64encode(open(edited_img_name, "rb").read()).decode('utf-8')
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Call save_generated_image with the correct format
|
||||
save_generated_image(img_response)
|
||||
except Exception as err:
|
||||
logger.error(f"Failed to save edited image: {err}")
|
||||
|
||||
logger.info(f"Image editing completed. Edited image name: {edited_img_name}")
|
||||
return edited_img_name
|
||||
except Exception as err:
|
||||
error_message = str(err)
|
||||
logger.error(f"Error in edit_image: {err}")
|
||||
# Retry on transient 503
|
||||
if "503" in error_message and retry_count < max_retries:
|
||||
retry_count += 1
|
||||
logger.info(f"Retrying in {retry_delay} seconds (attempt {retry_count}/{max_retries})")
|
||||
time.sleep(retry_delay)
|
||||
# Exponential backoff
|
||||
retry_delay *= 2
|
||||
else:
|
||||
return None
|
||||
# If we've exhausted all retries
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
# Ensure you sign up for an account to obtain an API key:
|
||||
# https://platform.stability.ai/
|
||||
# Your API key can be found here after account creation:
|
||||
# https://platform.stability.ai/account/keys
|
||||
|
||||
import os
|
||||
import requests
|
||||
import base64
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import streamlit as st
|
||||
from loguru import logger
|
||||
|
||||
# Import APIKeyManager
|
||||
from ...onboarding.api_key_manager import APIKeyManager
|
||||
|
||||
def save_generated_image(data):
|
||||
"""Save the generated image to a file."""
|
||||
# Implementation for saving image
|
||||
pass
|
||||
|
||||
def generate_stable_diffusion_image(prompt):
|
||||
engine_id = "stable-diffusion-xl-1024-v1-0"
|
||||
api_host = os.getenv('API_HOST', 'https://api.stability.ai')
|
||||
|
||||
# Use APIKeyManager instead of direct environment variable access
|
||||
api_key_manager = APIKeyManager()
|
||||
api_key = api_key_manager.get_api_key("stability")
|
||||
|
||||
if api_key is None:
|
||||
st.warning("Missing Stability API key. Please configure it in the onboarding process.")
|
||||
return None
|
||||
|
||||
response = requests.post(
|
||||
f"{api_host}/v1/generation/{engine_id}/text-to-image",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
},
|
||||
json={
|
||||
"text_prompts": [
|
||||
{
|
||||
"text": prompt
|
||||
}
|
||||
],
|
||||
"cfg_scale": 7,
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
"samples": 1,
|
||||
"steps": 30,
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception("Non-200 response: " + str(response.text))
|
||||
|
||||
data = response.json()
|
||||
img_path = save_generated_image(data)
|
||||
|
||||
for i, image in enumerate(data["artifacts"]):
|
||||
# Decode base64 image data
|
||||
img_data = base64.b64decode(image["base64"])
|
||||
# Open image using PIL
|
||||
img = Image.open(BytesIO(img_data))
|
||||
# Display the image
|
||||
img.show()
|
||||
|
||||
return img_path
|
||||
@@ -1,51 +0,0 @@
|
||||
from loguru import logger
|
||||
import sys
|
||||
from PIL import Image
|
||||
from openai import OpenAI
|
||||
|
||||
def gen_new_from_given_img(img_path, image_dir, num_img=1, img_size="1024x1024", response_format="url"):
|
||||
"""
|
||||
Generates variations of a given image using OpenAI's image variation API.
|
||||
|
||||
This function takes an existing image, processes it, and generates a specified number of new images based on it.
|
||||
These generated images are variations of the original, providing creative flexibility.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the original image file.
|
||||
image_dir (str): Directory where the generated images will be saved.
|
||||
num_img (int, optional): Number of image variations to generate. Defaults to 1.
|
||||
img_size (str, optional): Size of the generated images. Defaults to "1024x1024".
|
||||
response_format (str, optional): Format in which the generated images are returned. Defaults to "url".
|
||||
|
||||
Returns:
|
||||
str: Path to the saved image variation.
|
||||
|
||||
Raises:
|
||||
SystemExit: If a critical error occurs that prevents successful execution.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Starting image variation generation for: {img_path}")
|
||||
|
||||
# Convert and prepare the image
|
||||
png = Image.open(img_path).convert('RGBA')
|
||||
background = Image.new('RGBA', png.size, (255, 255, 255))
|
||||
alpha_composite = Image.alpha_composite(background, png)
|
||||
alpha_composite.save(img_path, 'PNG', quality=80)
|
||||
logger.info("Image prepared for variation generation.")
|
||||
|
||||
client = OpenAI()
|
||||
variation_response = client.images.create_variation(
|
||||
image=open(img_path, "rb", encoding="utf-8"),
|
||||
n=num_img,
|
||||
size=img_size,
|
||||
response_format=response_format
|
||||
)
|
||||
|
||||
# Saving the generated image
|
||||
generated_image_path = save_generated_image(variation_response, image_dir)
|
||||
logger.info(f"Image variation generated and saved to: {generated_image_path}")
|
||||
return generated_image_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred during image variation generation: {e}")
|
||||
sys.exit(f"Exiting due to critical error: {e}")
|
||||
@@ -1,162 +0,0 @@
|
||||
#########################################################
|
||||
#
|
||||
# This module will generate images for the blogs using APIs
|
||||
# from Dall-E and other free resources. Given a prompt, the
|
||||
# images will be stored in local directory.
|
||||
# Required: openai API key.
|
||||
#
|
||||
#########################################################
|
||||
|
||||
# imports
|
||||
import os
|
||||
import sys
|
||||
import datetime
|
||||
import streamlit as st
|
||||
|
||||
import openai # OpenAI Python library to make API calls
|
||||
from loguru import logger
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
# Use service-specific logger to avoid conflicts
|
||||
logger = get_service_logger("text_to_image_generation")
|
||||
|
||||
#from .gen_dali2_images
|
||||
from .gen_dali3_images import generate_dalle3_images
|
||||
from .gen_stabl_diff_img import generate_stable_diffusion_image
|
||||
from ..text_generation.main_text_generation import llm_text_gen
|
||||
from .gen_gemini_images import generate_gemini_image
|
||||
|
||||
def generate_image(user_prompt, title=None, description=None, tags=None, content=None, aspect_ratio="16:9"):
|
||||
"""
|
||||
The generation API endpoint creates an image based on a text prompt.
|
||||
|
||||
Required inputs:
|
||||
prompt (str): A text description of the desired image(s). The maximum length is 1000 characters.
|
||||
|
||||
Optional inputs:
|
||||
--> image_engine: dalle2, dalle3, stable diffusion are supported.
|
||||
--> num_images (int): The number of images to generate. Must be between 1 and 10. Defaults to 1.
|
||||
--> size (str): The size of the generated images. Must be one of "256x256", "512x512", or "1024x1024".
|
||||
Smaller images are faster. Defaults to "1024x1024".
|
||||
-->response_format (str): The format in which the generated images are returned.
|
||||
Must be one of "url" or "b64_json". Defaults to "url".
|
||||
--> user (str): A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
|
||||
--> aspect_ratio (str): The aspect ratio for the generated image. Must be one of "16:9", "4:3", or "1:1". Defaults to "16:9".
|
||||
"""
|
||||
# FIXME: Need to remove default value to match sidebar input.
|
||||
image_engine = 'Gemini-AI'
|
||||
image_stored_at = None
|
||||
|
||||
if user_prompt:
|
||||
try:
|
||||
# Use enhanced prompt generator with all available parameters
|
||||
img_prompt = generate_enhanced_img_prompt(user_prompt, title, description, tags, content)
|
||||
|
||||
# Add aspect ratio to the prompt
|
||||
if aspect_ratio:
|
||||
img_prompt += f"\n\nAspect ratio: {aspect_ratio}"
|
||||
|
||||
if 'Dalle3' in image_engine:
|
||||
logger.info(f"Calling Dalle3 text-to-image with prompt: {img_prompt}")
|
||||
image_stored_at = generate_dalle3_images(img_prompt)
|
||||
elif 'Stability-AI' in image_engine:
|
||||
logger.info(f"Calling Stable diffusion text-to-image with prompt: \n{img_prompt}")
|
||||
image_stored_at = generate_stable_diffusion_image(img_prompt)
|
||||
elif 'Gemini-AI' in image_engine:
|
||||
logger.info(f"Calling Gemini text-to-image with prompt: \n{img_prompt}")
|
||||
image_stored_at = generate_gemini_image(img_prompt, aspect_ratio=aspect_ratio)
|
||||
return image_stored_at
|
||||
except Exception as err:
|
||||
logger.error(f"Failed to generate Image: {err}")
|
||||
st.warning(f"Failed to generate Image: {err}")
|
||||
else:
|
||||
logger.error("Skipping Image creation, No prompt provided.")
|
||||
|
||||
|
||||
def generate_img_prompt(user_prompt):
|
||||
"""
|
||||
Given prompt, this functions generated a prompt for image generation.
|
||||
"""
|
||||
prompt = f"""
|
||||
As an expert prompt generator for AI text to image models and artist, I will provide you with 'user text' for creating images.
|
||||
Your task is to create a prompt for a highly relevant image from given 'user text'.
|
||||
\n
|
||||
Choose from various art styles, utilize light & shadow effects etc.
|
||||
Make sure to avoid common image generation mistakes.
|
||||
Reply with only one answer, no descrition and in plaintext.
|
||||
Make sure your prompt is detailed and creative descriptions that will inspire unique and interesting images from the AI.
|
||||
|
||||
\n\nuser text:
|
||||
'''{user_prompt}'''"""
|
||||
|
||||
response = llm_text_gen(prompt)
|
||||
return response
|
||||
|
||||
|
||||
def generate_enhanced_img_prompt(user_prompt, title=None, description=None, tags=None, content=None):
|
||||
"""
|
||||
Given user prompt and additional context (title, description, tags, content),
|
||||
this function generates an enhanced prompt for better image generation.
|
||||
|
||||
Args:
|
||||
user_prompt (str): Base prompt from the user
|
||||
title (str, optional): Blog title or content title
|
||||
description (str, optional): Blog or content description/summary
|
||||
tags (list, optional): List of tags related to the content
|
||||
content (str, optional): Actual content or excerpt
|
||||
|
||||
Returns:
|
||||
str: Enhanced prompt for image generation
|
||||
"""
|
||||
# Start with the base prompt
|
||||
context_parts = [user_prompt]
|
||||
|
||||
# Add relevant context if available
|
||||
if title:
|
||||
context_parts.append(f"Title: {title}")
|
||||
|
||||
if description:
|
||||
context_parts.append(f"Description: {description}")
|
||||
|
||||
if tags and len(tags) > 0:
|
||||
tag_text = ", ".join(tags[:5]) # Limit to 5 tags to avoid too much noise
|
||||
context_parts.append(f"Tags: {tag_text}")
|
||||
|
||||
# Create a combined context
|
||||
combined_context = "\n".join(context_parts)
|
||||
|
||||
# Add some content excerpt if available (limited to avoid token limits)
|
||||
content_excerpt = ""
|
||||
if content:
|
||||
# Just use the first few hundred characters as excerpt
|
||||
content_excerpt = content[:300] + "..." if len(content) > 300 else content
|
||||
|
||||
# Create the prompt for LLM
|
||||
prompt = f"""
|
||||
As an expert prompt engineer for AI image generation models, create a detailed, creative prompt
|
||||
for generating a high-quality, relevant image based on the following context:
|
||||
|
||||
{combined_context}
|
||||
|
||||
Additional content excerpt:
|
||||
{content_excerpt}
|
||||
|
||||
Your task is to:
|
||||
1. Analyze the context and content to understand the main theme and subject
|
||||
2. Create a rich, detailed prompt for image generation (50-75 words)
|
||||
3. Include specific visual details, art style, mood, lighting, composition
|
||||
4. Make sure the prompt is highly relevant to the original context
|
||||
5. Avoid prohibited content or anything that violates image generation guidelines
|
||||
|
||||
Reply with ONLY the final prompt. No explanations or other text.
|
||||
"""
|
||||
|
||||
# Generate the enhanced prompt
|
||||
try:
|
||||
enhanced_prompt = llm_text_gen(prompt)
|
||||
logger.info(f"Generated enhanced image prompt: {enhanced_prompt[:100]}...")
|
||||
return enhanced_prompt
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating enhanced prompt: {e}")
|
||||
# Fall back to the simple prompt generation if enhanced fails
|
||||
return generate_img_prompt(user_prompt)
|
||||
@@ -1,39 +0,0 @@
|
||||
import base64
|
||||
import datetime
|
||||
import os
|
||||
import requests
|
||||
from PIL import Image
|
||||
import logging
|
||||
|
||||
def save_generated_image(img_generation_response):
|
||||
"""
|
||||
Save generated images for blog, ensuring unique names for SEO.
|
||||
"""
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Get image save directory with fallback to a local directory
|
||||
image_save_dir = os.getenv('IMG_SAVE_DIR', 'generated_images')
|
||||
|
||||
# Create the directory if it doesn't exist
|
||||
if not os.path.exists(image_save_dir):
|
||||
logger.info(f"Creating image save directory: {image_save_dir}")
|
||||
os.makedirs(image_save_dir, exist_ok=True)
|
||||
|
||||
generated_image_name = f"generated_image_{datetime.datetime.now():%Y-%m-%d-%H-%M-%S}.webp"
|
||||
generated_image_filepath = os.path.join(image_save_dir, generated_image_name)
|
||||
|
||||
try:
|
||||
for i, image in enumerate(img_generation_response["artifacts"]):
|
||||
with open(generated_image_filepath, "wb") as f:
|
||||
f.write(base64.b64decode(image["base64"]))
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Failed to get generated image content: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving image: {e}")
|
||||
return None
|
||||
|
||||
logger.info(f"Saved image at path: {generated_image_filepath}")
|
||||
|
||||
return generated_image_filepath
|
||||
@@ -1,16 +1,13 @@
|
||||
import os
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
from typing import Any, Dict, List
|
||||
from dataclasses import dataclass
|
||||
import requests
|
||||
from loguru import logger
|
||||
import time
|
||||
import random
|
||||
|
||||
try:
|
||||
from google import genai
|
||||
GOOGLE_GENAI_AVAILABLE = True
|
||||
except Exception:
|
||||
GOOGLE_GENAI_AVAILABLE = False
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -29,17 +26,10 @@ class WritingAssistantService:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.exa_api_key = os.getenv("EXA_API_KEY")
|
||||
self.gemini_api_key = os.getenv("GEMINI_API_KEY")
|
||||
|
||||
if not self.exa_api_key:
|
||||
logger.warning("EXA_API_KEY not configured; writing assistant will fail")
|
||||
|
||||
if not (GOOGLE_GENAI_AVAILABLE and self.gemini_api_key):
|
||||
logger.warning("Gemini not available; writing assistant will fail")
|
||||
self.gemini_client = None
|
||||
else:
|
||||
self.gemini_client = genai.Client(api_key=self.gemini_api_key)
|
||||
|
||||
self.http_timeout_seconds = 15
|
||||
|
||||
# COST CONTROL: Daily usage limits
|
||||
@@ -151,9 +141,6 @@ class WritingAssistantService:
|
||||
raise
|
||||
|
||||
async def _generate_continuation(self, text: str, sources: List[Dict[str, Any]]) -> tuple[str, float]:
|
||||
if not self.gemini_client:
|
||||
raise Exception("Gemini client not available")
|
||||
|
||||
# Build compact sources context block
|
||||
source_blocks: List[str] = []
|
||||
for i, s in enumerate(sources[:5]):
|
||||
@@ -164,12 +151,12 @@ class WritingAssistantService:
|
||||
)
|
||||
sources_text = "\n\n".join(source_blocks) if source_blocks else "(No sources)"
|
||||
|
||||
# Based on Exa demo guidance for completion-only behavior and inline citations
|
||||
# Provider-agnostic behavior: short continuation with one inline citation hint
|
||||
system_prompt = (
|
||||
"You are an essay-completion bot that completes a sentence or continues prose. "
|
||||
"You are an assistive writing continuation bot. "
|
||||
"Only produce 1-2 SHORT sentences. Do not repeat or paraphrase the user's stub. "
|
||||
"Continue in the same tone and topic as the stub. Prefer concrete, current facts from the provided sources. "
|
||||
"Include exactly one brief, verifiable citation hint in parentheses with an author (or 'Source') and URL in square brackets, e.g., ((Doe, 2021)[https://example.com])."
|
||||
"Match tone and topic. Prefer concrete, current facts from the provided sources. "
|
||||
"Include exactly one brief citation hint in parentheses with an author (or 'Source') and URL in square brackets, e.g., ((Doe, 2021)[https://example.com])."
|
||||
)
|
||||
|
||||
user_prompt = (
|
||||
@@ -179,17 +166,20 @@ class WritingAssistantService:
|
||||
)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
resp = await loop.run_in_executor(
|
||||
executor,
|
||||
lambda: self.gemini_client.models.generate_content(
|
||||
model="gemini-1.5-flash", contents=f"{system_prompt}\n\n{user_prompt}"
|
||||
),
|
||||
)
|
||||
suggestion = (resp.text or "").strip()
|
||||
# Inter-call jitter to reduce burst rate limits
|
||||
time.sleep(random.uniform(0.05, 0.15))
|
||||
|
||||
ai_resp = llm_text_gen(
|
||||
prompt=user_prompt,
|
||||
json_struct=None,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
if isinstance(ai_resp, dict) and ai_resp.get("text"):
|
||||
suggestion = (ai_resp.get("text", "") or "").strip()
|
||||
else:
|
||||
suggestion = (str(ai_resp or "")).strip()
|
||||
if not suggestion:
|
||||
raise Exception("Gemini returned empty suggestion")
|
||||
raise Exception("Assistive writer returned empty suggestion")
|
||||
# naive confidence from number of sources present
|
||||
confidence = 0.7 if sources else 0.5
|
||||
return suggestion, confidence
|
||||
|
||||
Reference in New Issue
Block a user