ALwrity Version 0.5.0 (Fastapi + React )
This commit is contained in:
22
backend/services/llm_providers/__init__.py
Normal file
22
backend/services/llm_providers/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""LLM Providers Service for ALwrity Backend.
|
||||
|
||||
This service handles all LLM (Language Model) provider integrations,
|
||||
migrated from the legacy lib/gpt_providers functionality.
|
||||
"""
|
||||
|
||||
from .main_text_generation import llm_text_gen
|
||||
from .openai_provider import openai_chatgpt, test_openai_api_key
|
||||
from .gemini_provider import gemini_text_response, gemini_structured_json_response, test_gemini_api_key
|
||||
from .anthropic_provider import anthropic_text_response
|
||||
from .deepseek_provider import deepseek_text_response
|
||||
|
||||
__all__ = [
|
||||
"llm_text_gen",
|
||||
"openai_chatgpt",
|
||||
"test_openai_api_key",
|
||||
"gemini_text_response",
|
||||
"gemini_structured_json_response",
|
||||
"test_gemini_api_key",
|
||||
"anthropic_text_response",
|
||||
"deepseek_text_response"
|
||||
]
|
||||
98
backend/services/llm_providers/anthropic_provider.py
Normal file
98
backend/services/llm_providers/anthropic_provider.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Anthropic Provider Service for ALwrity Backend.
|
||||
|
||||
This service handles Anthropic API integrations,
|
||||
migrated from the legacy lib/gpt_providers/text_generation/anthropic_text_gen.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, Any, Tuple
|
||||
from loguru import logger
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
# Import APIKeyManager
|
||||
from ..api_key_manager import APIKeyManager
|
||||
|
||||
try:
|
||||
import anthropic
|
||||
except ImportError:
|
||||
anthropic = None
|
||||
logger.warning("Anthropic library not available. Install with: pip install anthropic")
|
||||
|
||||
async def test_anthropic_api_key(api_key: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Test if the provided Anthropic API key is valid.
|
||||
|
||||
Args:
|
||||
api_key (str): The Anthropic API key to test
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: A tuple containing (is_valid, message)
|
||||
"""
|
||||
if not anthropic:
|
||||
return False, "Anthropic library not available"
|
||||
|
||||
try:
|
||||
# Create Anthropic client with the provided key
|
||||
client = anthropic.Anthropic(api_key=api_key)
|
||||
|
||||
# Try to generate a simple response as a test
|
||||
response = client.messages.create(
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
max_tokens=10,
|
||||
messages=[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
|
||||
# If we get here, the key is valid
|
||||
return True, "Anthropic API key is valid"
|
||||
|
||||
except anthropic.AuthenticationError:
|
||||
return False, "Invalid Anthropic API key"
|
||||
except anthropic.RateLimitError:
|
||||
return False, "Rate limit exceeded. Please try again later."
|
||||
except Exception as e:
|
||||
return False, f"Error testing Anthropic API key: {str(e)}"
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
def anthropic_text_response(prompt: str, model: str = "claude-3-5-sonnet-20241022",
|
||||
temperature: float = 0.7, max_tokens: int = 4000,
|
||||
system_prompt: str = None) -> str:
|
||||
"""Get response from Anthropic Claude."""
|
||||
if not anthropic:
|
||||
logger.error("Anthropic library not available")
|
||||
return "Anthropic library not available. Please install anthropic package."
|
||||
|
||||
try:
|
||||
# Use APIKeyManager instead of direct environment variable access
|
||||
api_key_manager = APIKeyManager()
|
||||
api_key = api_key_manager.get_api_key("anthropic")
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("Anthropic API key not found. Please configure it in the onboarding process.")
|
||||
|
||||
client = anthropic.Anthropic(api_key=api_key)
|
||||
|
||||
# Prepare messages
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
response = client.messages.create(
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
messages=messages
|
||||
)
|
||||
|
||||
logger.info(f"[anthropic_text_response] Generated response with {len(response.content[0].text)} characters")
|
||||
return response.content[0].text
|
||||
|
||||
except Exception as err:
|
||||
logger.error(f"Failed to get response from Anthropic: {err}. Retrying.")
|
||||
raise
|
||||
@@ -0,0 +1,311 @@
|
||||
"""
|
||||
Gemini Audio Text Generation Module
|
||||
|
||||
This module provides a comprehensive interface for working with audio files using Google's Gemini API.
|
||||
It supports various audio processing capabilities including transcription, summarization, and analysis.
|
||||
|
||||
Key Features:
|
||||
------------
|
||||
1. Audio Transcription: Convert speech in audio files to text
|
||||
2. Audio Summarization: Generate concise summaries of audio content
|
||||
3. Segment Analysis: Analyze specific time segments of audio files
|
||||
4. Timestamped Transcription: Generate transcriptions with timestamps
|
||||
5. Token Counting: Count tokens in audio files
|
||||
6. Format Support: Information about supported audio formats
|
||||
|
||||
Supported Audio Formats:
|
||||
----------------------
|
||||
- WAV (audio/wav)
|
||||
- MP3 (audio/mp3)
|
||||
- AIFF (audio/aiff)
|
||||
- AAC (audio/aac)
|
||||
- OGG Vorbis (audio/ogg)
|
||||
- FLAC (audio/flac)
|
||||
|
||||
Technical Details:
|
||||
----------------
|
||||
- Each second of audio is represented as 32 tokens
|
||||
- Maximum supported length of audio data in a single prompt is 9.5 hours
|
||||
- Audio files are downsampled to 16 Kbps data resolution
|
||||
- Multi-channel audio is combined into a single channel
|
||||
|
||||
Usage:
|
||||
------
|
||||
```python
|
||||
from lib.gpt_providers.audio_to_text_generation.gemini_audio_text import transcribe_audio, summarize_audio
|
||||
|
||||
# Basic transcription
|
||||
transcript = transcribe_audio("path/to/audio.mp3")
|
||||
print(transcript)
|
||||
|
||||
# Summarization
|
||||
summary = summarize_audio("path/to/audio.mp3")
|
||||
print(summary)
|
||||
|
||||
# Analyze specific segment
|
||||
segment_analysis = analyze_audio_segment("path/to/audio.mp3", "02:30", "03:29")
|
||||
print(segment_analysis)
|
||||
```
|
||||
|
||||
Requirements:
|
||||
------------
|
||||
- GEMINI_API_KEY environment variable must be set
|
||||
- google-generativeai Python package
|
||||
- python-dotenv for environment variable management
|
||||
- loguru for logging
|
||||
|
||||
Dependencies:
|
||||
------------
|
||||
- google.genai
|
||||
- dotenv
|
||||
- loguru
|
||||
- os, sys, base64, typing
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import google.genai as genai
|
||||
from google.genai import types
|
||||
|
||||
|
||||
from loguru import logger
|
||||
logger.remove()
|
||||
logger.add(sys.stdout,
|
||||
colorize=True,
|
||||
format="<level>{level}</level>|<green>{file}:{line}:{function}</green>| {message}"
|
||||
)
|
||||
|
||||
|
||||
def load_environment():
|
||||
"""Loads environment variables from a .env file."""
|
||||
load_dotenv()
|
||||
logger.info("Environment variables loaded successfully.")
|
||||
|
||||
|
||||
def configure_google_api():
|
||||
"""
|
||||
Configures the Google Gemini API with the API key from environment variables.
|
||||
|
||||
Raises:
|
||||
ValueError: If the GEMINI_API_KEY environment variable is not set.
|
||||
"""
|
||||
# Use APIKeyManager instead of direct environment variable access
|
||||
api_key_manager = APIKeyManager()
|
||||
api_key = api_key_manager.get_api_key("gemini")
|
||||
|
||||
if not api_key:
|
||||
error_message = "Gemini API key not found. Please configure it in the onboarding process."
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
genai.configure(api_key=api_key)
|
||||
logger.info("Google Gemini API configured successfully.")
|
||||
|
||||
|
||||
def transcribe_audio(audio_file_path: str, prompt: str = "Transcribe the following audio:") -> Optional[str]:
|
||||
"""
|
||||
Transcribes audio using Google's Gemini model.
|
||||
|
||||
Args:
|
||||
audio_file_path (str): The path to the audio file to be transcribed.
|
||||
prompt (str, optional): The prompt to guide the transcription. Defaults to "Transcribe the following audio:".
|
||||
|
||||
Returns:
|
||||
str: The transcribed text from the audio.
|
||||
Returns None if transcription fails.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the audio file is not found.
|
||||
"""
|
||||
try:
|
||||
# Load environment variables and configure the Google API
|
||||
load_environment()
|
||||
configure_google_api()
|
||||
|
||||
logger.info(f"Attempting to transcribe audio file: {audio_file_path}")
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.exists(audio_file_path):
|
||||
error_message = f"FileNotFoundError: The audio file at {audio_file_path} does not exist."
|
||||
logger.error(error_message)
|
||||
raise FileNotFoundError(error_message)
|
||||
|
||||
# Initialize a Gemini model appropriate for audio understanding
|
||||
model = genai.GenerativeModel(model_name="gemini-1.5-flash")
|
||||
|
||||
# Upload the audio file
|
||||
try:
|
||||
audio_file = genai.upload_file(audio_file_path)
|
||||
logger.info(f"Audio file uploaded successfully: {audio_file=}")
|
||||
except FileNotFoundError:
|
||||
error_message = f"FileNotFoundError: The audio file at {audio_file_path} does not exist."
|
||||
logger.error(error_message)
|
||||
raise FileNotFoundError(error_message)
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading audio file: {e}")
|
||||
return None
|
||||
|
||||
# Generate the transcription
|
||||
try:
|
||||
response = model.generate_content([
|
||||
prompt,
|
||||
audio_file
|
||||
])
|
||||
|
||||
# Check for valid response and extract text
|
||||
if response and hasattr(response, 'text'):
|
||||
transcript = response.text
|
||||
logger.info(f"Transcription successful:\n{transcript}")
|
||||
return transcript
|
||||
else:
|
||||
logger.warning("Transcription failed: Invalid or empty response from API.")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during transcription: {e}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"An unexpected error occurred: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def summarize_audio(audio_file_path: str) -> Optional[str]:
|
||||
"""
|
||||
Summarizes the content of an audio file using Google's Gemini model.
|
||||
|
||||
Args:
|
||||
audio_file_path (str): The path to the audio file to be summarized.
|
||||
|
||||
Returns:
|
||||
str: A summary of the audio content.
|
||||
Returns None if summarization fails.
|
||||
"""
|
||||
return transcribe_audio(audio_file_path, prompt="Please summarize the audio content:")
|
||||
|
||||
|
||||
def analyze_audio_segment(audio_file_path: str, start_time: str, end_time: str) -> Optional[str]:
|
||||
"""
|
||||
Analyzes a specific segment of an audio file using timestamps.
|
||||
|
||||
Args:
|
||||
audio_file_path (str): The path to the audio file.
|
||||
start_time (str): Start time in MM:SS format.
|
||||
end_time (str): End time in MM:SS format.
|
||||
|
||||
Returns:
|
||||
str: Analysis of the specified audio segment.
|
||||
Returns None if analysis fails.
|
||||
"""
|
||||
prompt = f"Analyze the audio content from {start_time} to {end_time}."
|
||||
return transcribe_audio(audio_file_path, prompt=prompt)
|
||||
|
||||
|
||||
def transcribe_with_timestamps(audio_file_path: str) -> Optional[str]:
|
||||
"""
|
||||
Transcribes audio with timestamps for each segment.
|
||||
|
||||
Args:
|
||||
audio_file_path (str): The path to the audio file.
|
||||
|
||||
Returns:
|
||||
str: Transcription with timestamps.
|
||||
Returns None if transcription fails.
|
||||
"""
|
||||
return transcribe_audio(audio_file_path, prompt="Transcribe the audio with timestamps for each segment:")
|
||||
|
||||
|
||||
def count_tokens(audio_file_path: str) -> Optional[int]:
|
||||
"""
|
||||
Counts the number of tokens in an audio file.
|
||||
|
||||
Args:
|
||||
audio_file_path (str): The path to the audio file.
|
||||
|
||||
Returns:
|
||||
int: Number of tokens in the audio file.
|
||||
Returns None if counting fails.
|
||||
"""
|
||||
try:
|
||||
# Load environment variables and configure the Google API
|
||||
load_environment()
|
||||
configure_google_api()
|
||||
|
||||
logger.info(f"Attempting to count tokens in audio file: {audio_file_path}")
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.exists(audio_file_path):
|
||||
error_message = f"FileNotFoundError: The audio file at {audio_file_path} does not exist."
|
||||
logger.error(error_message)
|
||||
raise FileNotFoundError(error_message)
|
||||
|
||||
# Initialize a Gemini model
|
||||
model = genai.GenerativeModel(model_name="gemini-1.5-flash")
|
||||
|
||||
# Upload the audio file
|
||||
try:
|
||||
audio_file = genai.upload_file(audio_file_path)
|
||||
logger.info(f"Audio file uploaded successfully: {audio_file=}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading audio file: {e}")
|
||||
return None
|
||||
|
||||
# Count tokens
|
||||
try:
|
||||
response = model.count_tokens([audio_file])
|
||||
token_count = response.total_tokens
|
||||
logger.info(f"Token count: {token_count}")
|
||||
return token_count
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting tokens: {e}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"An unexpected error occurred: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_supported_formats() -> List[str]:
|
||||
"""
|
||||
Returns a list of supported audio formats.
|
||||
|
||||
Returns:
|
||||
List[str]: List of supported MIME types.
|
||||
"""
|
||||
return [
|
||||
"audio/wav",
|
||||
"audio/mp3",
|
||||
"audio/aiff",
|
||||
"audio/aac",
|
||||
"audio/ogg",
|
||||
"audio/flac"
|
||||
]
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Example 1: Basic transcription
|
||||
audio_path = "path/to/your/audio.mp3"
|
||||
transcript = transcribe_audio(audio_path)
|
||||
print(f"Transcript: {transcript}")
|
||||
|
||||
# Example 2: Summarization
|
||||
summary = summarize_audio(audio_path)
|
||||
print(f"Summary: {summary}")
|
||||
|
||||
# Example 3: Analyze specific segment
|
||||
segment_analysis = analyze_audio_segment(audio_path, "02:30", "03:29")
|
||||
print(f"Segment Analysis: {segment_analysis}")
|
||||
|
||||
# Example 4: Transcription with timestamps
|
||||
timestamped_transcript = transcribe_with_timestamps(audio_path)
|
||||
print(f"Timestamped Transcript: {timestamped_transcript}")
|
||||
|
||||
# Example 5: Count tokens
|
||||
token_count = count_tokens(audio_path)
|
||||
print(f"Token Count: {token_count}")
|
||||
|
||||
# Example 6: Get supported formats
|
||||
formats = get_supported_formats()
|
||||
print(f"Supported Formats: {formats}")
|
||||
@@ -0,0 +1,218 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from pytubefix import YouTube
|
||||
from loguru import logger
|
||||
from openai import OpenAI
|
||||
from tqdm import tqdm
|
||||
import streamlit as st
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
) # for exponential backoff
|
||||
|
||||
from .gemini_audio_text import transcribe_audio
|
||||
|
||||
# Import APIKeyManager
|
||||
from ...api_key_manager import APIKeyManager
|
||||
|
||||
|
||||
def progress_function(stream, chunk, bytes_remaining):
|
||||
# Calculate the percentage completion
|
||||
current = ((stream.filesize - bytes_remaining) / stream.filesize)
|
||||
progress_bar.update(current - progress_bar.n) # Update the progress bar
|
||||
|
||||
|
||||
def rename_file_with_underscores(file_path):
|
||||
"""Rename a file by replacing spaces and special characters with underscores.
|
||||
|
||||
Args:
|
||||
file_path (str): The original file path.
|
||||
|
||||
Returns:
|
||||
str: The new file path with underscores.
|
||||
"""
|
||||
# Extract the directory and the filename
|
||||
dir_name, original_filename = os.path.split(file_path)
|
||||
|
||||
# Replace spaces and special characters with underscores in the filename
|
||||
new_filename = re.sub(r'[^\w\-_\.]', '_', original_filename)
|
||||
|
||||
# Create the new file path
|
||||
new_file_path = os.path.join(dir_name, new_filename)
|
||||
|
||||
# Rename the file
|
||||
os.rename(file_path, new_file_path)
|
||||
|
||||
return new_file_path
|
||||
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
def speech_to_text(video_url):
|
||||
"""
|
||||
Transcribes speech to text from a YouTube video URL using OpenAI's Whisper model.
|
||||
|
||||
Args:
|
||||
video_url (str): URL of the YouTube video to transcribe.
|
||||
output_path (str, optional): Directory where the audio file will be saved. Defaults to '.'.
|
||||
|
||||
Returns:
|
||||
str: The transcribed text from the video.
|
||||
|
||||
Raises:
|
||||
SystemExit: If a critical error occurs that prevents successful execution.
|
||||
"""
|
||||
output_path = os.getenv("CONTENT_SAVE_DIR")
|
||||
yt = None
|
||||
audio_file = None
|
||||
with st.status("Started Writing..", expanded=False) as status:
|
||||
try:
|
||||
if video_url.startswith("https://www.youtube.com/") or video_url.startswith("http://www.youtube.com/"):
|
||||
logger.info(f"Accessing YouTube URL: {video_url}")
|
||||
status.update(label=f"Accessing YouTube URL: {video_url}")
|
||||
try:
|
||||
vid_id = video_url.split("=")[1]
|
||||
yt = YouTube(video_url, on_progress_callback=progress_function)
|
||||
except Exception as err:
|
||||
logger.error(f"Failed to get pytube stream object: {err}")
|
||||
st.stop()
|
||||
|
||||
logger.info(f"Fetching the highest quality audio stream:{yt.title}")
|
||||
status.update(label=f"Fetching the highest quality audio stream: {yt.title}")
|
||||
try:
|
||||
audio_stream = yt.streams.filter(only_audio=True).first()
|
||||
except Exception as err:
|
||||
logger.error(f"Failed to Download Youtube Audio: {err}")
|
||||
st.stop()
|
||||
|
||||
if audio_stream is None:
|
||||
logger.warning("No audio stream found for this video.")
|
||||
st.warning("No audio stream found for this video.")
|
||||
st.stop()
|
||||
|
||||
logger.info(f"Downloading audio for: {yt.title}")
|
||||
status.update(label=f"Downloading audio for: {yt.title}")
|
||||
global progress_bar
|
||||
progress_bar = tqdm(total=1.0, unit='iB', unit_scale=True, desc=yt.title)
|
||||
try:
|
||||
audio_filename = re.sub(r'[^\w\-_\.]', '_', yt.title) + '.mp4'
|
||||
audio_file = audio_stream.download(
|
||||
output_path=os.getenv("CONTENT_SAVE_DIR"),
|
||||
filename=audio_filename)
|
||||
#audio_file = rename_file_with_underscores(audio_file)
|
||||
except Exception as err:
|
||||
logger.error(f"Failed to download audio file: {audio_file}")
|
||||
|
||||
progress_bar.close()
|
||||
logger.info(f"Audio downloaded: {yt.title} to {audio_file}")
|
||||
status.update(label=f"Audio downloaded: {yt.title} to {output_path}")
|
||||
# Audio filepath from local directory.
|
||||
elif os.path.exists(audio_input):
|
||||
audio_file = video_url
|
||||
|
||||
# Checking file size
|
||||
max_file_size = 24 * 1024 * 1024 # 24MB
|
||||
file_size = os.path.getsize(audio_file)
|
||||
# Convert file size to MB for logging
|
||||
file_size_MB = file_size / (1024 * 1024) # Convert bytes to MB
|
||||
|
||||
logger.info(f"Downloaded Audio Size is: {file_size_MB:.2f} MB")
|
||||
status.update(label=f"Downloaded Audio Size is: {file_size_MB:.2f} MB")
|
||||
|
||||
if file_size > max_file_size:
|
||||
logger.error("File size exceeds 24MB limit.")
|
||||
# FIXME: We can chunk hour long videos, the code is not tested.
|
||||
#long_video(audio_file)
|
||||
sys.exit("File size limit exceeded.")
|
||||
st.error("Audio File size limit exceeded. File a fixme/issues at ALwrity github.")
|
||||
|
||||
try:
|
||||
print(f"Audio File: {audio_file}")
|
||||
transcript = transcribe_audio(audio_file)
|
||||
print(f"\n\n\n--- Tracribe: {transcript} ----\n\n\n")
|
||||
exit(1)
|
||||
status.update(label=f"Initializing OpenAI client for transcription: {audio_file}")
|
||||
logger.info(f"Initializing OpenAI client for transcription: {audio_file}")
|
||||
|
||||
# Use APIKeyManager instead of direct environment variable access
|
||||
api_key_manager = APIKeyManager()
|
||||
api_key = api_key_manager.get_api_key("openai")
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("OpenAI API key not found. Please configure it in the onboarding process.")
|
||||
|
||||
client = OpenAI(api_key=api_key)
|
||||
|
||||
logger.info("Transcribing using OpenAI's Whisper model.")
|
||||
transcript = client.audio.transcriptions.create(
|
||||
model="whisper-1",
|
||||
file=open(audio_file, "rb"),
|
||||
response_format="text"
|
||||
)
|
||||
logger.info(f"\nYouTube video transcription:\n{yt.title}\n{transcript}\n")
|
||||
status.update(label=f"\nYouTube video transcription:\n{yt.title}\n{transcript}\n")
|
||||
return transcript, yt.title
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed in Whisper transcription: {e}")
|
||||
st.warning(f"Failed in Openai Whisper transcription: {e}")
|
||||
transcript = transcribe_audio(audio_file)
|
||||
print(f"\n\n\n--- Tracribe: {transcript} ----\n\n\n")
|
||||
return transcript, yt.title
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"An error occurred during YouTube video processing: {e}")
|
||||
|
||||
finally:
|
||||
try:
|
||||
if os.path.exists(audio_file):
|
||||
os.remove(audio_file)
|
||||
logger.info("Temporary audio file removed.")
|
||||
except PermissionError:
|
||||
st.error(f"Permission error: Cannot remove '{audio_file}'. Please make sure of necessary permissions.")
|
||||
except Exception as e:
|
||||
st.error(f"An error occurred removing audio file: {e}")
|
||||
|
||||
|
||||
def long_video(temp_file_name):
|
||||
"""
|
||||
Transcribes a YouTube video using OpenAI's Whisper API by processing the video in chunks.
|
||||
|
||||
This function handles videos longer than the context limit of the Whisper API by dividing the video into
|
||||
10-minute segments, transcribing each segment individually, and then combining the results.
|
||||
|
||||
Key Changes and Notes:
|
||||
1. Video Splitting: Splits the audio into 10-minute chunks using the moviepy library.
|
||||
2. Chunk Transcription: Each audio chunk is transcribed separately and the results are concatenated.
|
||||
3. Temporary Files for Chunks: Uses temporary files for each audio chunk for transcription.
|
||||
4. Error Handling: Exception handling is included to capture and return any errors during the process.
|
||||
5. Logging: Process steps are logged for debugging and monitoring.
|
||||
6. Cleaning Up: Removes temporary files for both the entire video and individual audio chunks after processing.
|
||||
|
||||
Args:
|
||||
video_url (str): URL of the YouTube video to be transcribed.
|
||||
"""
|
||||
# Extract audio and split into chunks
|
||||
logger.info(f"Processing the YT video: {temp_file_name}")
|
||||
full_audio = mp.AudioFileClip(temp_file_name)
|
||||
duration = full_audio.duration
|
||||
chunk_length = 600 # 10 minutes in seconds
|
||||
chunks = [full_audio.subclip(start, min(start + chunk_length, duration)) for start in range(0, int(duration), chunk_length)]
|
||||
|
||||
combined_transcript = ""
|
||||
for i, chunk in enumerate(chunks):
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as audio_chunk_file:
|
||||
chunk.write_audiofile(audio_chunk_file.name, codec="mp3")
|
||||
with open(audio_chunk_file.name, "rb", encoding="utf-8") as audio_file:
|
||||
# Transcribe each chunk using OpenAI's Whisper API
|
||||
app.logger.info(f"Transcribing chunk {i+1}/{len(chunks)}")
|
||||
transcript = openai.Audio.transcribe("whisper-1", audio_file)
|
||||
combined_transcript += transcript['text'] + "\n\n"
|
||||
|
||||
# Remove the chunk audio file
|
||||
os.remove(audio_chunk_file.name)
|
||||
|
||||
105
backend/services/llm_providers/deepseek_provider.py
Normal file
105
backend/services/llm_providers/deepseek_provider.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""DeepSeek Provider Service for ALwrity Backend.
|
||||
|
||||
This service handles DeepSeek API integrations,
|
||||
migrated from the legacy lib/gpt_providers/text_generation/deepseek_text_gen.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, Any, Tuple
|
||||
from loguru import logger
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
# Import APIKeyManager
|
||||
from ..api_key_manager import APIKeyManager
|
||||
|
||||
try:
|
||||
import openai
|
||||
except ImportError:
|
||||
openai = None
|
||||
logger.warning("OpenAI library not available. Install with: pip install openai")
|
||||
|
||||
async def test_deepseek_api_key(api_key: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Test if the provided DeepSeek API key is valid.
|
||||
|
||||
Args:
|
||||
api_key (str): The DeepSeek API key to test
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: A tuple containing (is_valid, message)
|
||||
"""
|
||||
if not openai:
|
||||
return False, "OpenAI library not available"
|
||||
|
||||
try:
|
||||
# Create DeepSeek client with the provided key
|
||||
client = openai.OpenAI(
|
||||
api_key=api_key,
|
||||
base_url="https://api.deepseek.com/v1"
|
||||
)
|
||||
|
||||
# Try to generate a simple response as a test
|
||||
response = client.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
max_tokens=10,
|
||||
temperature=0.1
|
||||
)
|
||||
|
||||
# If we get here, the key is valid
|
||||
return True, "DeepSeek API key is valid"
|
||||
|
||||
except openai.AuthenticationError:
|
||||
return False, "Invalid DeepSeek API key"
|
||||
except openai.RateLimitError:
|
||||
return False, "Rate limit exceeded. Please try again later."
|
||||
except Exception as e:
|
||||
return False, f"Error testing DeepSeek API key: {str(e)}"
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
def deepseek_text_response(prompt: str, model: str = "deepseek-chat",
|
||||
temperature: float = 0.7, max_tokens: int = 4000,
|
||||
system_prompt: str = None) -> str:
|
||||
"""Get response from DeepSeek."""
|
||||
if not openai:
|
||||
logger.error("OpenAI library not available")
|
||||
return "OpenAI library not available. Please install openai package."
|
||||
|
||||
try:
|
||||
# Use APIKeyManager instead of direct environment variable access
|
||||
api_key_manager = APIKeyManager()
|
||||
api_key = api_key_manager.get_api_key("deepseek")
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("DeepSeek API key not found. Please configure it in the onboarding process.")
|
||||
|
||||
client = openai.OpenAI(
|
||||
api_key=api_key,
|
||||
base_url="https://api.deepseek.com/v1"
|
||||
)
|
||||
|
||||
# Prepare messages
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature
|
||||
)
|
||||
|
||||
logger.info(f"[deepseek_text_response] Generated response with {len(response.choices[0].message.content)} characters")
|
||||
return response.choices[0].message.content
|
||||
|
||||
except Exception as err:
|
||||
logger.error(f"Failed to get response from DeepSeek: {err}. Retrying.")
|
||||
raise
|
||||
232
backend/services/llm_providers/gemini_provider.py
Normal file
232
backend/services/llm_providers/gemini_provider.py
Normal file
@@ -0,0 +1,232 @@
|
||||
# Using Gemini Pro LLM model
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import google.genai as genai
|
||||
from google.genai import types
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(Path('../../../.env'))
|
||||
from loguru import logger
|
||||
logger.remove()
|
||||
logger.add(sys.stdout,
|
||||
colorize=True,
|
||||
format="<level>{level}</level>|<green>{file}:{line}:{function}</green>| {message}"
|
||||
)
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
|
||||
# Configure standard logging
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO, format='[%(asctime)s-%(levelname)s-%(module)s-%(lineno)d]- %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
def gemini_text_response(prompt, temperature, top_p, n, max_tokens, system_prompt):
|
||||
""" Common functiont to get response from gemini pro Text. """
|
||||
#FIXME: Include : https://github.com/google-gemini/cookbook/blob/main/quickstarts/rest/System_instructions_REST.ipynb
|
||||
try:
|
||||
client = genai.Client(api_key=os.getenv('GEMINI_API_KEY'))
|
||||
except Exception as err:
|
||||
logger.error(f"Failed to configure Gemini: {err}")
|
||||
logger.info(f"Temp: {temperature}, MaxTokens: {max_tokens}, TopP: {top_p}, N: {n}")
|
||||
# Set up AI model config
|
||||
generation_config = {
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"top_k": n,
|
||||
"max_output_tokens": max_tokens,
|
||||
}
|
||||
# FIXME: Expose model_name in main_config
|
||||
try:
|
||||
response = client.models.generate_content(
|
||||
model='gemini-2.5-pro',
|
||||
contents=prompt,
|
||||
config=types.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
max_output_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=n,
|
||||
),
|
||||
)
|
||||
|
||||
#logger.info(f"Number of Token in Prompt Sent: {model.count_tokens(prompt)}")
|
||||
return response.text
|
||||
except Exception as err:
|
||||
logger.error(f"Failed to get response from Gemini: {err}. Retrying.")
|
||||
|
||||
|
||||
#@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
#def gemini_blog_metadata_json(blog_content):
|
||||
# """ Common functiont to get response from gemini pro Text. """
|
||||
# prompt = f"I will provide you with the content of a blog post. Based on this content, you need to generate the following elements in JSON format:\n\n1. **Blog Title**: A compelling and relevant title that summarizes the blog content.\n2. **Meta Description**: A concise meta description (up to 160 characters) that captures the essence of the blog post and encourages clicks.\n3. **Tags**: A list of 5-10 relevant tags that represent the key topics covered in the blog post.\n4. **Categories**: A list of 1-3 appropriate categories that best describe the blog post's main themes.\n\nOutput your response in the following JSON format:\n\n```json\n{\n \"type\": \"object\",\n \"properties\": {\n \"blog_title\": {\n \"type\": \"string\"\n },\n \"meta_description\": {\n \"type\": \"string\"\n },\n \"tags\": {\n \"type\": \"array\",\n \"items\": {\n \"type\": \"string\"\n }\n },\n \"categories\": {\n \"type\": \"array\",\n \"items\": {\n \"type\": \"string\"\n }\n }\n }\n}\n\n. The Blog Content is given below: \n\n{blog_content}\n\n"
|
||||
#
|
||||
# try:
|
||||
# genai.configure(api_key=os.getenv('GEMINI_API_KEY'))
|
||||
# except Exception as err:
|
||||
# logger.error(f"Failed to configure Gemini: {err}")
|
||||
#
|
||||
# # Create the model
|
||||
# generation_config = {
|
||||
# "temperature": 1,
|
||||
# "top_p": 0.95,
|
||||
# "top_k": 64,
|
||||
# "max_output_tokens": 8192,
|
||||
# "response_schema": content.Schema(
|
||||
# type = content.Type.OBJECT,
|
||||
# properties = {
|
||||
# "response": content.Schema(
|
||||
# type = content.Type.STRING,
|
||||
# ),
|
||||
# },
|
||||
# ),
|
||||
# "response_mime_type": "application/json",
|
||||
# }
|
||||
#
|
||||
# model = genai.GenerativeModel(
|
||||
# model_name="gemini-1.5-flash",
|
||||
# generation_config=generation_config,
|
||||
# # safety_settings = Adjust safety settings
|
||||
# # See https://ai.google.dev/gemini-api/docs/safety-settings
|
||||
# )
|
||||
#
|
||||
# try:
|
||||
# # text_response = []
|
||||
# response = model.generate_content(prompt)
|
||||
# if response:
|
||||
# logger.info(f"Number of Token in Prompt Sent: {model.count_tokens(prompt)}")
|
||||
# return response.text
|
||||
# except Exception as err:
|
||||
# logger.error(f"Failed to get SEO METADATA from Gemini: {err}. Retrying.")
|
||||
|
||||
async def test_gemini_api_key(api_key: str) -> tuple[bool, str]:
|
||||
"""
|
||||
Test if the provided Gemini API key is valid.
|
||||
|
||||
Args:
|
||||
api_key (str): The Gemini API key to test
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: A tuple containing (is_valid, message)
|
||||
"""
|
||||
try:
|
||||
# Configure Gemini with the provided key
|
||||
genai.configure(api_key=api_key)
|
||||
|
||||
# Try to list models as a simple API test
|
||||
models = genai.list_models()
|
||||
|
||||
# Check if Gemini Pro is available
|
||||
if any(model.name == "gemini-pro" for model in models):
|
||||
return True, "Gemini API key is valid"
|
||||
else:
|
||||
return False, "Gemini Pro model not available with this API key"
|
||||
|
||||
except Exception as e:
|
||||
return False, f"Error testing Gemini API key: {str(e)}"
|
||||
|
||||
def gemini_pro_text_gen(prompt, temperature=0.7, top_p=0.9, top_k=40, max_tokens=2048):
|
||||
"""
|
||||
Generate text using Google's Gemini Pro model.
|
||||
|
||||
Args:
|
||||
prompt (str): The input text to generate completion for
|
||||
temperature (float, optional): Controls randomness. Defaults to 0.7
|
||||
top_p (float, optional): Controls diversity. Defaults to 0.9
|
||||
top_k (int, optional): Controls vocabulary size. Defaults to 40
|
||||
max_tokens (int, optional): Maximum number of tokens to generate. Defaults to 2048
|
||||
|
||||
Returns:
|
||||
str: The generated text completion
|
||||
"""
|
||||
try:
|
||||
# Configure the model
|
||||
model = genai.GenerativeModel('gemini-pro')
|
||||
|
||||
# Generate content
|
||||
response = model.generate_content(
|
||||
prompt,
|
||||
generation_config=genai.types.GenerationConfig(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
max_output_tokens=max_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
# Return the generated text
|
||||
return response.text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Gemini Pro text generation: {e}")
|
||||
return str(e)
|
||||
|
||||
def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9, top_k=40, max_tokens=2048, system_prompt=None):
|
||||
"""
|
||||
Generate structured JSON response using Google's Gemini Pro model.
|
||||
|
||||
Args:
|
||||
prompt (str): The input text to generate completion for
|
||||
schema (dict): The JSON schema to follow for the response
|
||||
temperature (float, optional): Controls randomness. Defaults to 0.7
|
||||
top_p (float, optional): Controls diversity. Defaults to 0.9
|
||||
top_k (int, optional): Controls vocabulary size. Defaults to 40
|
||||
max_tokens (int, optional): Maximum number of tokens to generate. Defaults to 2048
|
||||
system_prompt (str, optional): System instructions for the model
|
||||
|
||||
Returns:
|
||||
dict: The generated structured JSON response
|
||||
"""
|
||||
try:
|
||||
# Configure the model
|
||||
client = genai.Client(api_key=os.getenv('GEMINI_API_KEY'))
|
||||
|
||||
# Set up generation config
|
||||
generation_config = {
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
"max_output_tokens": max_tokens,
|
||||
}
|
||||
|
||||
# Generate content with structured response
|
||||
response = client.models.generate_content(
|
||||
model='gemini-2.5-pro',
|
||||
contents=prompt,
|
||||
config=types.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
max_output_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
response_mime_type='application/json',
|
||||
response_schema=schema
|
||||
),
|
||||
)
|
||||
|
||||
# Parse the response
|
||||
try:
|
||||
# First try to get the parsed response
|
||||
if hasattr(response, 'parsed'):
|
||||
return response.parsed
|
||||
|
||||
# If parsed is not available, try to parse the text
|
||||
response_text = response.text
|
||||
return json.loads(response_text)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Error parsing JSON response: {e}")
|
||||
return {"error": f"Failed to parse JSON response: {e}", "raw_response": response_text}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Gemini Pro structured JSON generation: {e}")
|
||||
return {"error": str(e)}
|
||||
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
Gemini Image Description Module
|
||||
|
||||
This module provides functionality to generate text descriptions of images using Google's Gemini API.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import base64
|
||||
from typing import Optional, Dict, Any, List, Union
|
||||
from dotenv import load_dotenv
|
||||
import google.genai as genai
|
||||
from google.genai import types
|
||||
|
||||
from PIL import Image
|
||||
from loguru import logger
|
||||
logger.remove()
|
||||
logger.add(sys.stdout,
|
||||
colorize=True,
|
||||
format="<level>{level}</level>|<green>{file}:{line}:{function}</green>| {message}"
|
||||
)
|
||||
|
||||
# Import APIKeyManager
|
||||
from ...api_key_manager import APIKeyManager
|
||||
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
except ImportError:
|
||||
genai = None
|
||||
logger.warning("Google genai library not available. Install with: pip install google-generativeai")
|
||||
|
||||
|
||||
def describe_image(image_path: str, prompt: str = "Describe this image in detail:") -> Optional[str]:
|
||||
"""
|
||||
Describe an image using Google's Gemini API.
|
||||
|
||||
Parameters:
|
||||
image_path (str): Path to the image file.
|
||||
prompt (str): Prompt for describing the image.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The generated description of the image, or None if an error occurs.
|
||||
"""
|
||||
try:
|
||||
if not genai:
|
||||
logger.error("Google genai library not available")
|
||||
return None
|
||||
|
||||
# Use APIKeyManager instead of direct environment variable access
|
||||
api_key_manager = APIKeyManager()
|
||||
api_key = api_key_manager.get_api_key("gemini")
|
||||
|
||||
if not api_key:
|
||||
error_message = "Gemini API key not found. Please configure it in the onboarding process."
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
# Check if image file exists
|
||||
if not os.path.exists(image_path):
|
||||
error_message = f"Image file not found: {image_path}"
|
||||
logger.error(error_message)
|
||||
raise FileNotFoundError(error_message)
|
||||
|
||||
# Initialize the Gemini client
|
||||
client = genai.Client(api_key=api_key)
|
||||
|
||||
# Open and process the image
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
logger.info(f"Successfully opened image: {image_path}")
|
||||
except Exception as e:
|
||||
error_message = f"Failed to open image: {e}"
|
||||
logger.error(error_message)
|
||||
return None
|
||||
|
||||
# Generate content description
|
||||
try:
|
||||
response = client.models.generate_content(
|
||||
model='gemini-2.0-flash',
|
||||
contents=[
|
||||
prompt,
|
||||
image
|
||||
]
|
||||
)
|
||||
|
||||
# Extract and return the text
|
||||
description = response.text
|
||||
logger.info(f"Successfully generated description for image: {image_path}")
|
||||
return description
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"Failed to generate content: {e}"
|
||||
logger.error(error_message)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"An unexpected error occurred: {e}"
|
||||
logger.error(error_message)
|
||||
return None
|
||||
|
||||
|
||||
def analyze_image_with_prompt(image_path: str, prompt: str) -> Optional[str]:
|
||||
"""
|
||||
Analyze an image with a custom prompt using Google's Gemini API.
|
||||
|
||||
Parameters:
|
||||
image_path (str): Path to the image file.
|
||||
prompt (str): Custom prompt for analyzing the image.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The generated analysis of the image, or None if an error occurs.
|
||||
"""
|
||||
return describe_image(image_path, prompt)
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Example usage of the function
|
||||
image_path = "path/to/your/image.jpg"
|
||||
description = describe_image(image_path)
|
||||
if description:
|
||||
print(f"Image description: {description}")
|
||||
else:
|
||||
print("Failed to generate image description")
|
||||
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
This module provides functionality to analyze images using OpenAI's Vision API.
|
||||
It encodes an image to a base64 string and sends a request to the OpenAI API
|
||||
to interpret the contents of the image, returning a textual description.
|
||||
"""
|
||||
|
||||
import requests
|
||||
import sys
|
||||
import re
|
||||
import base64
|
||||
|
||||
def analyze_and_extract_details_from_image(image_path, api_key):
|
||||
"""
|
||||
Analyzes an image using OpenAI's Vision API and extracts Alt Text, Description, Title, and Caption.
|
||||
|
||||
Args:
|
||||
image_path (str): Path to the image file.
|
||||
api_key (str): Your OpenAI API key.
|
||||
|
||||
Returns:
|
||||
dict: Extracted details including Alt Text, Description, Title, and Caption.
|
||||
"""
|
||||
def encode_image(path):
|
||||
""" Encodes an image to a base64 string. """
|
||||
with open(path, "rb", encoding="utf-8") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode('utf-8')
|
||||
|
||||
base64_image = encode_image(image_path)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4-vision-preview",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "The given image is used in blog content. Analyze the given image and suggest alternative(alt) test, description, title, caption."
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"max_tokens": 300
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
assistant_message = response.json()['choices'][0]['message']['content']
|
||||
|
||||
# Extracting details using regular expressions
|
||||
alt_text_match = re.search(r'Alt Text: "(.*?)"', assistant_message)
|
||||
description_match = re.search(r'Description: (.*?)\n\n', assistant_message)
|
||||
title_match = re.search(r'Title: "(.*?)"', assistant_message)
|
||||
caption_match = re.search(r'Caption: "(.*?)"', assistant_message)
|
||||
|
||||
return {
|
||||
'alt_text': alt_text_match.group(1) if alt_text_match else None,
|
||||
'description': description_match.group(1) if description_match else None,
|
||||
'title': title_match.group(1) if title_match else None,
|
||||
'caption': caption_match.group(1) if caption_match else None
|
||||
}
|
||||
|
||||
except requests.RequestException as e:
|
||||
sys.exit(f"Error: Failed to communicate with OpenAI API. Error: {e}")
|
||||
except Exception as e:
|
||||
sys.exit(f"Error occurred: {e}")
|
||||
306
backend/services/llm_providers/main_text_generation.py
Normal file
306
backend/services/llm_providers/main_text_generation.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""Main Text Generation Service for ALwrity Backend.
|
||||
|
||||
This service provides the main LLM text generation functionality,
|
||||
migrated from the legacy lib/gpt_providers/text_generation/main_text_generation.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from typing import Optional, Dict, Any
|
||||
from loguru import logger
|
||||
from ..api_key_manager import APIKeyManager
|
||||
|
||||
from .openai_provider import openai_chatgpt
|
||||
from .gemini_provider import gemini_text_response, gemini_structured_json_response
|
||||
from .anthropic_provider import anthropic_text_response
|
||||
from .deepseek_provider import deepseek_text_response
|
||||
|
||||
def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: Optional[Dict[str, Any]] = None) -> str:
|
||||
"""
|
||||
Generate text using Language Model (LLM) based on the provided prompt.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to generate text from.
|
||||
system_prompt (str, optional): Custom system prompt to use instead of the default one.
|
||||
json_struct (dict, optional): JSON schema structure for structured responses.
|
||||
|
||||
Returns:
|
||||
str: Generated text based on the prompt.
|
||||
"""
|
||||
try:
|
||||
logger.info("[llm_text_gen] Starting text generation")
|
||||
logger.debug(f"[llm_text_gen] Prompt length: {len(prompt)} characters")
|
||||
|
||||
# Initialize API key manager
|
||||
api_key_manager = APIKeyManager()
|
||||
|
||||
# Set default values for LLM parameters
|
||||
gpt_provider = "google" # Default to Google Gemini
|
||||
model = "gemini-2.0-flash-001"
|
||||
temperature = 0.7
|
||||
max_tokens = 4000
|
||||
top_p = 0.9
|
||||
n = 1
|
||||
fp = 16
|
||||
frequency_penalty = 0.0
|
||||
presence_penalty = 0.0
|
||||
|
||||
# Default blog characteristics
|
||||
blog_tone = "Professional"
|
||||
blog_demographic = "Professional"
|
||||
blog_type = "Informational"
|
||||
blog_language = "English"
|
||||
blog_output_format = "markdown"
|
||||
blog_length = 2000
|
||||
|
||||
# Try to get provider from environment or config
|
||||
try:
|
||||
# Check which providers have API keys available
|
||||
available_providers = []
|
||||
if api_key_manager.get_api_key("openai"):
|
||||
available_providers.append("openai")
|
||||
if api_key_manager.get_api_key("gemini"):
|
||||
available_providers.append("google")
|
||||
if api_key_manager.get_api_key("anthropic"):
|
||||
available_providers.append("anthropic")
|
||||
if api_key_manager.get_api_key("deepseek"):
|
||||
available_providers.append("deepseek")
|
||||
|
||||
# Prefer Google Gemini if available, otherwise use first available
|
||||
if "google" in available_providers:
|
||||
gpt_provider = "google"
|
||||
model = "gemini-2.0-flash-001"
|
||||
elif available_providers:
|
||||
gpt_provider = available_providers[0]
|
||||
if gpt_provider == "openai":
|
||||
model = "gpt-4o"
|
||||
elif gpt_provider == "anthropic":
|
||||
model = "claude-3-5-sonnet-20241022"
|
||||
elif gpt_provider == "deepseek":
|
||||
model = "deepseek-chat"
|
||||
else:
|
||||
logger.warning("[llm_text_gen] No API keys found, using mock response")
|
||||
return _get_mock_response(prompt)
|
||||
|
||||
logger.debug(f"[llm_text_gen] Using provider: {gpt_provider}, model: {model}")
|
||||
|
||||
except Exception as err:
|
||||
logger.warning(f"[llm_text_gen] Error determining provider, using defaults: {err}")
|
||||
gpt_provider = "google"
|
||||
model = "gemini-2.0-flash-001"
|
||||
|
||||
# Construct the system prompt if not provided
|
||||
if system_prompt is None:
|
||||
system_instructions = f"""You are a highly skilled content writer with a knack for creating engaging and informative content.
|
||||
Your expertise spans various writing styles and formats.
|
||||
|
||||
Writing Style Guidelines:
|
||||
- Tone: {blog_tone}
|
||||
- Target Audience: {blog_demographic}
|
||||
- Content Type: {blog_type}
|
||||
- Language: {blog_language}
|
||||
- Output Format: {blog_output_format}
|
||||
- Target Length: {blog_length} words
|
||||
|
||||
Please provide responses that are:
|
||||
- Well-structured and easy to read
|
||||
- Engaging and informative
|
||||
- Tailored to the specified tone and audience
|
||||
- Professional yet accessible
|
||||
- Optimized for the target content type
|
||||
"""
|
||||
else:
|
||||
system_instructions = system_prompt
|
||||
|
||||
# Generate response based on provider
|
||||
try:
|
||||
if gpt_provider == "openai":
|
||||
return openai_chatgpt(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
fp=fp,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
elif gpt_provider == "google":
|
||||
if json_struct:
|
||||
return gemini_structured_json_response(
|
||||
prompt=prompt,
|
||||
schema=json_struct,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=n,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
else:
|
||||
return gemini_text_response(
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
elif gpt_provider == "anthropic":
|
||||
return anthropic_text_response(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
elif gpt_provider == "deepseek":
|
||||
return deepseek_text_response(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
else:
|
||||
logger.error(f"[llm_text_gen] Unknown provider: {gpt_provider}")
|
||||
return _get_mock_response(prompt)
|
||||
except Exception as provider_error:
|
||||
logger.error(f"[llm_text_gen] Provider {gpt_provider} failed: {str(provider_error)}")
|
||||
# Try to fallback to another provider
|
||||
fallback_providers = ["openai", "anthropic", "deepseek"]
|
||||
for fallback_provider in fallback_providers:
|
||||
if fallback_provider in available_providers and fallback_provider != gpt_provider:
|
||||
try:
|
||||
logger.info(f"[llm_text_gen] Trying fallback provider: {fallback_provider}")
|
||||
if fallback_provider == "openai":
|
||||
return openai_chatgpt(
|
||||
prompt=prompt,
|
||||
model="gpt-4o",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
fp=fp,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
elif fallback_provider == "anthropic":
|
||||
return anthropic_text_response(
|
||||
prompt=prompt,
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
elif fallback_provider == "deepseek":
|
||||
return deepseek_text_response(
|
||||
prompt=prompt,
|
||||
model="deepseek-chat",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
except Exception as fallback_error:
|
||||
logger.error(f"[llm_text_gen] Fallback provider {fallback_provider} also failed: {str(fallback_error)}")
|
||||
continue
|
||||
|
||||
# If all providers fail, return mock response
|
||||
logger.warning("[llm_text_gen] All providers failed, using mock response")
|
||||
return _get_mock_response(prompt)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[llm_text_gen] Error during text generation: {str(e)}")
|
||||
return _get_mock_response(prompt)
|
||||
|
||||
def _get_mock_response(prompt: str) -> str:
|
||||
"""Get a mock response when no API keys are available."""
|
||||
logger.warning("[llm_text_gen] Using mock response - no API keys configured")
|
||||
|
||||
# Return a structured mock response for style detection
|
||||
if "style analysis" in prompt.lower() or "writing style" in prompt.lower():
|
||||
return json.dumps({
|
||||
"writing_style": {
|
||||
"tone": "professional",
|
||||
"voice": "active",
|
||||
"complexity": "moderate",
|
||||
"engagement_level": "high"
|
||||
},
|
||||
"content_characteristics": {
|
||||
"sentence_structure": "well-structured",
|
||||
"vocabulary_level": "intermediate",
|
||||
"paragraph_organization": "logical flow",
|
||||
"content_flow": "smooth transitions"
|
||||
},
|
||||
"target_audience": {
|
||||
"demographics": ["professionals", "business users"],
|
||||
"expertise_level": "intermediate",
|
||||
"industry_focus": "technology",
|
||||
"geographic_focus": "global"
|
||||
},
|
||||
"content_type": {
|
||||
"primary_type": "blog",
|
||||
"secondary_types": ["article", "guide"],
|
||||
"purpose": "inform",
|
||||
"call_to_action": "moderate"
|
||||
},
|
||||
"recommended_settings": {
|
||||
"writing_tone": "professional",
|
||||
"target_audience": "business professionals",
|
||||
"content_type": "blog",
|
||||
"creativity_level": "medium",
|
||||
"geographic_location": "global"
|
||||
}
|
||||
})
|
||||
|
||||
# Handle pattern analysis requests
|
||||
if "pattern" in prompt.lower() or "recurring" in prompt.lower():
|
||||
return json.dumps({
|
||||
"patterns": {
|
||||
"sentence_length": "medium",
|
||||
"vocabulary_patterns": ["technical terms", "professional language"],
|
||||
"rhetorical_devices": ["examples", "analogies"],
|
||||
"paragraph_structure": "topic sentence followed by supporting details",
|
||||
"transition_phrases": ["furthermore", "additionally", "however"]
|
||||
},
|
||||
"style_consistency": "high",
|
||||
"unique_elements": ["clear structure", "professional tone", "evidence-based content"]
|
||||
})
|
||||
|
||||
# Handle guidelines generation requests
|
||||
if "guidelines" in prompt.lower() or "recommendations" in prompt.lower():
|
||||
return json.dumps({
|
||||
"guidelines": {
|
||||
"tone_recommendations": ["maintain professional tone", "use clear language"],
|
||||
"structure_guidelines": ["start with introduction", "use headings", "conclude with summary"],
|
||||
"vocabulary_suggestions": ["avoid jargon", "use industry-specific terms appropriately"],
|
||||
"engagement_tips": ["include examples", "use active voice", "ask questions"],
|
||||
"audience_considerations": ["consider technical level", "provide context"]
|
||||
},
|
||||
"best_practices": ["research thoroughly", "cite sources", "update regularly"],
|
||||
"avoid_elements": ["overly technical language", "long paragraphs", "passive voice"],
|
||||
"content_strategy": "focus on providing value while maintaining professional credibility"
|
||||
})
|
||||
|
||||
# Generic mock response for other content generation
|
||||
return "This is a mock response. Please configure API keys for real content generation. To get started, visit the onboarding process and configure your AI provider API keys."
|
||||
|
||||
def check_gpt_provider(gpt_provider: str) -> bool:
|
||||
"""Check if the specified GPT provider is supported."""
|
||||
supported_providers = ["openai", "google", "anthropic", "deepseek"]
|
||||
return gpt_provider in supported_providers
|
||||
|
||||
def get_api_key(gpt_provider: str) -> Optional[str]:
|
||||
"""Get API key for the specified provider."""
|
||||
try:
|
||||
api_key_manager = APIKeyManager()
|
||||
provider_mapping = {
|
||||
"openai": "openai",
|
||||
"google": "gemini",
|
||||
"anthropic": "anthropic",
|
||||
"deepseek": "deepseek"
|
||||
}
|
||||
|
||||
mapped_provider = provider_mapping.get(gpt_provider, gpt_provider)
|
||||
return api_key_manager.get_api_key(mapped_provider)
|
||||
except Exception as e:
|
||||
logger.error(f"[get_api_key] Error getting API key for {gpt_provider}: {str(e)}")
|
||||
return None
|
||||
133
backend/services/llm_providers/openai_provider.py
Normal file
133
backend/services/llm_providers/openai_provider.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""OpenAI Provider Service for ALwrity Backend.
|
||||
|
||||
This service handles OpenAI API integrations,
|
||||
migrated from the legacy lib/gpt_providers/text_generation/openai_text_gen.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import openai
|
||||
import asyncio
|
||||
from typing import Tuple
|
||||
from loguru import logger
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
# Import APIKeyManager
|
||||
from ..api_key_manager import APIKeyManager
|
||||
|
||||
async def test_openai_api_key(api_key: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Test if the provided OpenAI API key is valid.
|
||||
|
||||
Args:
|
||||
api_key (str): The OpenAI API key to test
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: A tuple containing (is_valid, message)
|
||||
"""
|
||||
try:
|
||||
# Create OpenAI client with the provided key
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
|
||||
# Try to list models as a simple API test
|
||||
models = client.models.list()
|
||||
|
||||
# If we get here, the key is valid
|
||||
return True, "OpenAI API key is valid"
|
||||
|
||||
except openai.AuthenticationError:
|
||||
return False, "Invalid OpenAI API key"
|
||||
except openai.RateLimitError:
|
||||
return False, "Rate limit exceeded. Please try again later."
|
||||
except Exception as e:
|
||||
return False, f"Error testing OpenAI API key: {str(e)}"
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
def openai_chatgpt(prompt: str, model: str = "gpt-4o", temperature: float = 0.7,
|
||||
max_tokens: int = 4000, top_p: float = 0.9, n: int = 1,
|
||||
fp: int = 16, system_prompt: str = None) -> str:
|
||||
"""
|
||||
Wrapper function for OpenAI's ChatGPT completion.
|
||||
|
||||
Args:
|
||||
prompt (str): The input text to generate completion for.
|
||||
model (str, optional): Model to be used for the completion. Defaults to "gpt-4o".
|
||||
temperature (float, optional): Controls randomness. Lower values make responses more deterministic. Defaults to 0.7.
|
||||
max_tokens (int, optional): Maximum number of tokens to generate. Defaults to 4000.
|
||||
top_p (float, optional): Controls diversity. Defaults to 0.9.
|
||||
n (int, optional): Number of completions to generate. Defaults to 1.
|
||||
fp (int, optional): Frequency penalty. Defaults to 16.
|
||||
system_prompt (str, optional): System prompt for the conversation. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The generated text completion.
|
||||
|
||||
Raises:
|
||||
SystemExit: If an API error, connection error, or rate limit error occurs.
|
||||
"""
|
||||
# Wait for 5 seconds to comply with rate limits
|
||||
for _ in range(5):
|
||||
time.sleep(1)
|
||||
|
||||
try:
|
||||
# Create variables to collect the stream of chunks
|
||||
collected_chunks = []
|
||||
collected_messages = []
|
||||
full_reply_content = None
|
||||
|
||||
# Use APIKeyManager instead of direct environment variable access
|
||||
api_key_manager = APIKeyManager()
|
||||
api_key = api_key_manager.get_api_key("openai")
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("OpenAI API key not found. Please configure it in the onboarding process.")
|
||||
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
|
||||
# Prepare messages
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
top_p=top_p,
|
||||
stream=True,
|
||||
frequency_penalty=fp,
|
||||
temperature=temperature
|
||||
)
|
||||
|
||||
# Iterate through the stream of events
|
||||
for chunk in response:
|
||||
collected_chunks.append(chunk) # save the event response
|
||||
chunk_message = chunk.choices[0].delta.content # extract the message
|
||||
collected_messages.append(chunk_message) # save the message
|
||||
print(chunk.choices[0].delta.content, end="", flush=True)
|
||||
|
||||
# Clean None in collected_messages
|
||||
collected_messages = [m for m in collected_messages if m is not None]
|
||||
full_reply_content = ''.join([m for m in collected_messages])
|
||||
|
||||
logger.info(f"[openai_chatgpt] Generated response with {len(full_reply_content)} characters")
|
||||
return full_reply_content
|
||||
|
||||
except openai.APIError as e:
|
||||
logger.error(f"OpenAI API Error: {e}")
|
||||
raise SystemExit from e
|
||||
except openai.RateLimitError as e:
|
||||
logger.error(f"OpenAI Rate Limit Error: {e}")
|
||||
raise SystemExit from e
|
||||
except openai.APIConnectionError as e:
|
||||
logger.error(f"OpenAI API Connection Error: {e}")
|
||||
raise SystemExit from e
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in OpenAI API call: {e}")
|
||||
raise SystemExit from e
|
||||
@@ -0,0 +1,56 @@
|
||||
from openai import OpenAI
|
||||
from loguru import logger
|
||||
import sys
|
||||
|
||||
from .save_image import save_generated_image
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
) # for exponential backoff
|
||||
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=120), stop=stop_after_attempt(6))
|
||||
def generate_dalle3_images(img_prompt, image_dir, size="1024x1024", quality="hd", n=1):
|
||||
"""
|
||||
Generates images using the DALL-E 3 model based on a given text prompt.
|
||||
|
||||
Args:
|
||||
img_prompt (str): Text prompt to generate the image.
|
||||
image_dir (str): Directory where the generated image will be saved.
|
||||
size (str, optional): Size of the generated images. Defaults to "1024x1024".
|
||||
quality (str, optional): Quality of the generated images. Defaults to "hd".
|
||||
n (int, optional): Number of images to generate. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
str: Path to the saved image.
|
||||
|
||||
Raises:
|
||||
SystemExit: If an error occurs in image generation or saving.
|
||||
"""
|
||||
try:
|
||||
logger.info("Generating Dall-e-3 image for the blog.")
|
||||
client = OpenAI()
|
||||
|
||||
img_generation_response = client.images.generate(
|
||||
model="dall-e-3",
|
||||
prompt=img_prompt,
|
||||
size=size,
|
||||
quality=quality,
|
||||
n=n
|
||||
)
|
||||
# Save the generated image locally.
|
||||
try:
|
||||
img_path = save_generated_image(img_generation_response, image_dir)
|
||||
return img_path
|
||||
except Exception as err:
|
||||
logger.error(f"Failed to Save generated image: {err}")
|
||||
|
||||
except openai.OpenAIError as e:
|
||||
logger.error(f"Dalle-3 image generation error: HTTP Status {e.http_status}, Error: {e.error}")
|
||||
sys.exit("Exiting due to Dalle-3 image generation error.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate images with Dalle3: {e}")
|
||||
sys.exit("Exiting due to a general error in image generation.")
|
||||
@@ -0,0 +1,53 @@
|
||||
from openai import OpenAI
|
||||
from loguru import logger
|
||||
import sys
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
) # for exponential backoff
|
||||
|
||||
from .save_image import save_generated_image
|
||||
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=120), stop=stop_after_attempt(6))
|
||||
def generate_dalle3_images(img_prompt, image_dir, size="1024x1024", quality="hd", n=1):
|
||||
"""
|
||||
Generates images using the DALL-E 3 model based on a given text prompt.
|
||||
|
||||
Args:
|
||||
img_prompt (str): Text prompt to generate the image.
|
||||
image_dir (str): Directory where the generated image will be saved.
|
||||
size (str, optional): Size of the generated images. Defaults to "1024x1024".
|
||||
quality (str, optional): Quality of the generated images. Defaults to "hd".
|
||||
n (int, optional): Number of images to generate. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
str: Path to the saved image.
|
||||
|
||||
Raises:
|
||||
SystemExit: If an error occurs in image generation or saving.
|
||||
"""
|
||||
try:
|
||||
logger.info("Generating Dall-e-3 image for the blog.")
|
||||
client = OpenAI()
|
||||
|
||||
img_generation_response = client.images.generate(
|
||||
model="dall-e-3",
|
||||
prompt=img_prompt,
|
||||
size=size,
|
||||
quality=quality,
|
||||
n=n
|
||||
)
|
||||
|
||||
img_path = save_generated_image(img_generation_response, image_dir)
|
||||
return img_path
|
||||
|
||||
except openai.OpenAIError as e:
|
||||
logger.error(f"Dalle-3 image generation error: HTTP Status {e.http_status}, Error: {e.error}")
|
||||
sys.exit("Exiting due to Dalle-3 image generation error.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate images with Dalle3: {e}")
|
||||
sys.exit("Exiting due to a general error in image generation.")
|
||||
@@ -0,0 +1,421 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import datetime
|
||||
import streamlit as st
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from loguru import logger
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
# Import APIKeyManager
|
||||
from ...api_key_manager import APIKeyManager
|
||||
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
from google.generativeai import types
|
||||
except ImportError:
|
||||
genai = None
|
||||
logger.warning("Google genai library not available. Install with: pip install google-generativeai")
|
||||
|
||||
|
||||
from .save_image import save_generated_image
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger('gemini_image_generator')
|
||||
|
||||
# With image generation in Gemini, your imagination is the limit.
|
||||
# If what you see doesn't quite match what you had in mind, try adding more details to the prompt.
|
||||
# The more specific you are, the better Gemini can create images that reflect your vision.
|
||||
|
||||
# Generate images using Gemini
|
||||
# Gemini 2.0 Flash Experimental supports the ability to output text and inline images.
|
||||
# This lets you use Gemini to conversationally edit images or generate outputs with interwoven text (for example, generating a blog post with text and images in a single turn).
|
||||
# Note: Make sure to include responseModalities: ["Text", "Image"] in your generation configuration for text and image output with gemini-2.0-flash-exp-image-generation. Image only is not allowed.
|
||||
|
||||
|
||||
class AIPromptGenerator:
|
||||
"""
|
||||
Generates enhanced AI image prompts based on user keywords,
|
||||
following the guidelines of the Imagen documentation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.photography_styles = ["photo", "photograph"]
|
||||
self.art_styles = ["painting", "sketch", "drawing", "illustration", "digital art", "render"]
|
||||
self.art_techniques = ["technical pencil drawing", "charcoal drawing", "color pencil drawing", "pastel painting", "digital art", "art deco (poster)", "impressionist painting", "renaissance painting", "pop art"]
|
||||
self.camera_proximity = ["close-up", "zoomed out", "taken from far away"]
|
||||
self.camera_position = ["aerial", "from below"]
|
||||
self.lighting = ["natural lighting", "dramatic lighting", "warm lighting", "cold lighting", "studio lighting", "golden hour lighting"]
|
||||
self.camera_settings = ["motion blur", "soft focus", "bokeh", "portrait"]
|
||||
self.lens_types = ["35mm lens", "50mm lens", "fisheye lens", "wide angle lens", "macro lens", "telephoto lens"]
|
||||
self.film_types = ["black and white film", "polaroid"]
|
||||
self.materials = ["made of cheese", "made of paper", "made of neon tubes", "metallic", "glass", "wooden", "stone"]
|
||||
self.shapes = ["in the shape of a bird", "angular", "curved", "geometric"]
|
||||
self.quality_modifiers_general = ["high-quality", "beautiful", "stylized", "detailed", "epic", "grand"]
|
||||
self.quality_modifiers_photo = ["4K", "HDR", "studio photo", "professional photo", "photorealistic"]
|
||||
self.quality_modifiers_art = ["by a professional artist", "intricate details", "masterpiece"]
|
||||
self.aspect_ratios = ["1:1 aspect ratio", "4:3 aspect ratio", "3:4 aspect ratio", "16:9 aspect ratio", "9:16 aspect ratio"]
|
||||
self.photorealistic_modifiers = {
|
||||
"portraits": ["prime lens", "zoom lens", "24-35mm", "black and white film", "film noir", "shallow depth of field", "duotone (mention two colors)"],
|
||||
"objects": ["macro lens", "60-105mm", "high detail", "precise focusing", "controlled lighting"],
|
||||
"motion": ["telephoto zoom lens", "100-400mm", "fast shutter speed", "action shot", "movement tracking"],
|
||||
"wide-angle": ["wide-angle lens", "10-24mm", "long exposure", "sharp focus", "smooth water or clouds", "astro photography"]
|
||||
}
|
||||
|
||||
def generate_prompt(self, keywords):
|
||||
"""
|
||||
Generates an enhanced AI image prompt based on user-provided keywords.
|
||||
|
||||
Args:
|
||||
keywords (list): A list of keywords describing the desired image.
|
||||
|
||||
Returns:
|
||||
str: An enhanced AI image prompt.
|
||||
"""
|
||||
if not keywords:
|
||||
return "A beautiful image."
|
||||
|
||||
prompt_parts = []
|
||||
subject = " ".join(keywords)
|
||||
prompt_parts.append(subject)
|
||||
|
||||
# Add context and background (optional)
|
||||
context_options = ["in a detailed background", "outdoors", "indoors", "in a studio", "with a blurred background"]
|
||||
if random.random() < 0.6: # Add context with a probability
|
||||
prompt_parts.append(random.choice(context_options))
|
||||
|
||||
# Add style (optional)
|
||||
style_options = self.photography_styles + [f"{art} of" for art in self.art_styles]
|
||||
if random.random() < 0.7:
|
||||
prompt_parts.insert(0, random.choice(style_options))
|
||||
if prompt_parts[0].startswith("painting of") or prompt_parts[0].startswith("sketch of") or prompt_parts[0].startswith("drawing of"):
|
||||
if random.random() < 0.5:
|
||||
prompt_parts.append(f"in the style of {random.choice(self.art_techniques)}")
|
||||
|
||||
# Add photography modifiers (if photography style is chosen)
|
||||
if any(style in prompt_parts[0] for style in self.photography_styles):
|
||||
if random.random() < 0.4:
|
||||
prompt_parts.append(random.choice(self.camera_proximity))
|
||||
if random.random() < 0.3:
|
||||
prompt_parts.append(random.choice(self.camera_position))
|
||||
if random.random() < 0.5:
|
||||
prompt_parts.append(random.choice(self.lighting))
|
||||
if random.random() < 0.3:
|
||||
prompt_parts.append(random.choice(self.camera_settings))
|
||||
if random.random() < 0.2:
|
||||
prompt_parts.append(random.choice(self.lens_types))
|
||||
if random.random() < 0.1:
|
||||
prompt_parts.append(random.choice(self.film_types))
|
||||
|
||||
# Add shapes and materials (optional)
|
||||
if random.random() < 0.3:
|
||||
prompt_parts.append(random.choice(self.materials))
|
||||
if random.random() < 0.2:
|
||||
prompt_parts.append(random.choice(self.shapes))
|
||||
|
||||
# Add quality modifiers (optional)
|
||||
if random.random() < 0.6:
|
||||
quality_options = self.quality_modifiers_general
|
||||
if any(style in prompt_parts[0] for style in self.photography_styles):
|
||||
quality_options += self.quality_modifiers_photo
|
||||
else:
|
||||
quality_options += self.quality_modifiers_art
|
||||
prompt_parts.append(random.choice(list(set(quality_options)))) # Avoid duplicates
|
||||
|
||||
# Add aspect ratio (optional)
|
||||
if random.random() < 0.2:
|
||||
prompt_parts.append(random.choice(self.aspect_ratios))
|
||||
|
||||
return ", ".join(prompt_parts)
|
||||
|
||||
def generate_photorealistic_prompt(self, keywords, focus=""):
|
||||
"""
|
||||
Generates an enhanced AI image prompt specifically for photorealistic images.
|
||||
|
||||
Args:
|
||||
keywords (list): A list of keywords describing the desired image.
|
||||
focus (str, optional): The focus of the photorealistic image (e.g., "portraits", "objects", "motion", "wide-angle"). Defaults to "".
|
||||
|
||||
Returns:
|
||||
str: An enhanced photorealistic AI image prompt.
|
||||
"""
|
||||
if not keywords:
|
||||
return "A photorealistic image."
|
||||
|
||||
prompt_parts = ["A photo of", "photorealistic"]
|
||||
prompt_parts.append(" ".join(keywords))
|
||||
|
||||
if focus and focus in self.photorealistic_modifiers:
|
||||
modifiers = self.photorealistic_modifiers[focus]
|
||||
if modifiers:
|
||||
num_modifiers = random.randint(1, min(3, len(modifiers)))
|
||||
selected_modifiers = random.sample(modifiers, num_modifiers)
|
||||
prompt_parts.extend(selected_modifiers)
|
||||
|
||||
# Add general quality modifiers
|
||||
if random.random() < 0.5:
|
||||
prompt_parts.append(random.choice(self.quality_modifiers_photo))
|
||||
|
||||
# Add lighting
|
||||
if random.random() < 0.4:
|
||||
prompt_parts.append(random.choice(self.lighting))
|
||||
|
||||
return ", ".join(prompt_parts)
|
||||
|
||||
|
||||
def generate_gemini_image(prompt, keywords=None, style=None, focus=None, enhance_prompt=True, max_retries=3, initial_retry_delay=2, aspect_ratio="16:9"):
|
||||
"""
|
||||
Generate an image using Gemini's image generation capabilities.
|
||||
|
||||
Args:
|
||||
prompt (str): The text prompt for image generation
|
||||
keywords (list, optional): Keywords to enhance the prompt
|
||||
style (str, optional): Style of the image (photorealistic, artistic, etc.)
|
||||
focus (str, optional): Focus area for photorealistic images
|
||||
enhance_prompt (bool, optional): Whether to enhance the prompt with AI
|
||||
max_retries (int, optional): Maximum number of retry attempts
|
||||
initial_retry_delay (int, optional): Initial delay between retries
|
||||
aspect_ratio (str, optional): Aspect ratio for the generated image
|
||||
|
||||
Returns:
|
||||
str: The path to the generated image.
|
||||
"""
|
||||
logger.info(f"Generating image with prompt: '{prompt[:100]}...'")
|
||||
|
||||
# Use APIKeyManager instead of direct environment variable access
|
||||
api_key_manager = APIKeyManager()
|
||||
api_key = api_key_manager.get_api_key("gemini")
|
||||
|
||||
if not api_key:
|
||||
error_msg = "Gemini API key not found. Please configure it in the onboarding process."
|
||||
logger.error(error_msg)
|
||||
st.error(f"🔑 {error_msg}")
|
||||
return None
|
||||
|
||||
# Enhance the prompt if requested
|
||||
if enhance_prompt and keywords:
|
||||
prompt_generator = AIPromptGenerator()
|
||||
if style == "photorealistic" and focus:
|
||||
logger.info(f"Generating photorealistic prompt with focus: {focus}")
|
||||
enhanced_prompt = prompt_generator.generate_photorealistic_prompt(keywords, focus)
|
||||
else:
|
||||
logger.info("Generating enhanced prompt")
|
||||
enhanced_prompt = prompt_generator.generate_prompt(keywords)
|
||||
|
||||
# Combine the enhanced prompt with the original prompt
|
||||
prompt = f"{prompt}\n\nEnhanced prompt: {enhanced_prompt}"
|
||||
logger.info(f"Final prompt: '{prompt[:100]}...'")
|
||||
|
||||
# Add aspect ratio to the prompt
|
||||
if aspect_ratio:
|
||||
prompt += f"\n\nPlease generate the image with {aspect_ratio} aspect ratio."
|
||||
|
||||
retry_count = 0
|
||||
retry_delay = initial_retry_delay
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
client = genai.Client(api_key=api_key)
|
||||
contents = (prompt)
|
||||
|
||||
logger.info("Sending request to Gemini API")
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.0-flash-exp-image-generation",
|
||||
contents=contents,
|
||||
config=types.GenerateContentConfig(
|
||||
response_modalities=['Text', 'Image']
|
||||
)
|
||||
)
|
||||
logger.info("Received response from Gemini API")
|
||||
|
||||
img_name = None
|
||||
for part in response.candidates[0].content.parts:
|
||||
if part.text is not None:
|
||||
logger.info(f"Received text response: '{part.text[:100]}...'")
|
||||
print(part.text)
|
||||
elif part.inline_data is not None:
|
||||
logger.info("Received image data from Gemini")
|
||||
image = Image.open(BytesIO((part.inline_data.data)))
|
||||
|
||||
# Resize image to match aspect ratio if needed
|
||||
if aspect_ratio:
|
||||
current_width, current_height = image.size
|
||||
target_width = current_width
|
||||
target_height = current_height
|
||||
|
||||
# Calculate target dimensions based on aspect ratio
|
||||
if aspect_ratio == "16:9":
|
||||
target_height = int(current_width * 9/16)
|
||||
elif aspect_ratio == "9:16":
|
||||
target_width = int(current_height * 9/16)
|
||||
elif aspect_ratio == "4:3":
|
||||
target_height = int(current_width * 3/4)
|
||||
elif aspect_ratio == "3:4":
|
||||
target_width = int(current_height * 3/4)
|
||||
elif aspect_ratio == "1:1":
|
||||
target_size = min(current_width, current_height)
|
||||
target_width = target_size
|
||||
target_height = target_size
|
||||
|
||||
logger.info(f"Resizing image from {current_width}x{current_height} to {target_width}x{target_height}")
|
||||
|
||||
# Create a new image with the target dimensions
|
||||
resized_image = Image.new('RGB', (target_width, target_height), (255, 255, 255))
|
||||
|
||||
# Calculate position to paste the original image
|
||||
paste_x = (target_width - current_width) // 2
|
||||
paste_y = (target_height - current_height) // 2
|
||||
|
||||
# Paste the original image onto the new canvas
|
||||
resized_image.paste(image, (paste_x, paste_y))
|
||||
image = resized_image
|
||||
|
||||
if part.text is not None:
|
||||
img_name = f'{part.text}-gemini-native-image.png'
|
||||
else:
|
||||
img_name = f'gemini-native-image-{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}.png'
|
||||
try:
|
||||
logger.info(f"Saving image to: {img_name}")
|
||||
image.save(img_name)
|
||||
|
||||
# Create a dictionary with the expected format for save_generated_image
|
||||
img_response = {
|
||||
"artifacts": [
|
||||
{
|
||||
"base64": base64.b64encode(open(img_name, "rb").read()).decode('utf-8')
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Call save_generated_image with the correct format
|
||||
save_generated_image(img_response)
|
||||
except Exception as err:
|
||||
logger.error(f"Failed to save image: {err}")
|
||||
st.error(f"Failed to save image: {err}")
|
||||
|
||||
logger.info(f"Image generation completed. Image name: {img_name}")
|
||||
return img_name
|
||||
except Exception as err:
|
||||
error_message = str(err)
|
||||
logger.error(f"Error in generate_gemini_image: {err}")
|
||||
|
||||
# Check if this is a 503 UNAVAILABLE error
|
||||
if "503 UNAVAILABLE" in error_message and retry_count < max_retries:
|
||||
retry_count += 1
|
||||
logger.info(f"Model is overloaded. Retrying in {retry_delay} seconds (attempt {retry_count}/{max_retries})")
|
||||
st.warning(f"The image generation service is currently busy. Retrying in {retry_delay} seconds...")
|
||||
time.sleep(retry_delay)
|
||||
# Exponential backoff
|
||||
retry_delay *= 2
|
||||
else:
|
||||
st.error(f"Error generating image: {err}")
|
||||
return None
|
||||
|
||||
# If we've exhausted all retries
|
||||
st.error("The image generation service is currently unavailable. Please try again later.")
|
||||
return None
|
||||
|
||||
|
||||
def edit_image(image_path, prompt, max_retries=3, initial_retry_delay=2):
|
||||
"""
|
||||
- Image editing (text and image to image)
|
||||
Example prompt: "Edit this image to make it look like a cartoon"
|
||||
Example prompt: [image of a cat] + [image of a pillow] + "Create a cross stitch of my cat on this pillow."
|
||||
|
||||
- Multi-turn image editing (chat)
|
||||
Example prompts: [upload an image of a blue car.] "Turn this car into a convertible." "Now change the color to yellow."
|
||||
|
||||
Image editing with Gemini
|
||||
To perform image editing, add an image as input.
|
||||
The following example demonstrats uploading base64 encoded images.
|
||||
For multiple images and larger payloads, check the image input section.
|
||||
|
||||
Args:
|
||||
image_path (str): The path to the image to edit.
|
||||
prompt (str): The prompt to edit the image with.
|
||||
max_retries (int, optional): Maximum number of retry attempts for handling 503 errors. Defaults to 3.
|
||||
initial_retry_delay (int, optional): Initial delay in seconds before retrying. Defaults to 2.
|
||||
|
||||
Returns:
|
||||
str: The path to the edited image.
|
||||
"""
|
||||
import PIL.Image
|
||||
image = PIL.Image.open(image_path)
|
||||
|
||||
retry_count = 0
|
||||
retry_delay = initial_retry_delay
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
client = genai.Client()
|
||||
text_input = (prompt)
|
||||
|
||||
logger.info("Sending request to Gemini API for image editing")
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.0-flash-exp-image-generation",
|
||||
contents=[text_input, image],
|
||||
config=types.GenerateContentConfig(
|
||||
response_modalities=['Text', 'Image']
|
||||
)
|
||||
)
|
||||
logger.info("Received response from Gemini API for image editing")
|
||||
|
||||
edited_img_name = None
|
||||
for part in response.candidates[0].content.parts:
|
||||
if part.text is not None:
|
||||
logger.info(f"Received text response: '{part.text[:100]}...'")
|
||||
st.write(part.text)
|
||||
elif part.inline_data is not None:
|
||||
logger.info("Received edited image data from Gemini")
|
||||
edited_image = Image.open(BytesIO(part.inline_data.data))
|
||||
edited_image.show()
|
||||
|
||||
# Save the edited image
|
||||
edited_img_name = f'edited-{os.path.basename(image_path)}'
|
||||
try:
|
||||
logger.info(f"Saving edited image to: {edited_img_name}")
|
||||
edited_image.save(edited_img_name)
|
||||
|
||||
# Create a dictionary with the expected format for save_generated_image
|
||||
img_response = {
|
||||
"artifacts": [
|
||||
{
|
||||
"base64": base64.b64encode(open(edited_img_name, "rb").read()).decode('utf-8')
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Call save_generated_image with the correct format
|
||||
save_generated_image(img_response)
|
||||
except Exception as err:
|
||||
logger.error(f"Failed to save edited image: {err}")
|
||||
st.error(f"Failed to save edited image: {err}")
|
||||
|
||||
logger.info(f"Image editing completed. Edited image name: {edited_img_name}")
|
||||
return edited_img_name
|
||||
except Exception as err:
|
||||
error_message = str(err)
|
||||
logger.error(f"Error in edit_image: {err}")
|
||||
|
||||
# Check if this is a 503 UNAVAILABLE error
|
||||
if "503 UNAVAILABLE" in error_message and retry_count < max_retries:
|
||||
retry_count += 1
|
||||
logger.info(f"Model is overloaded. Retrying in {retry_delay} seconds (attempt {retry_count}/{max_retries})")
|
||||
st.warning(f"The image editing service is currently busy. Retrying in {retry_delay} seconds...")
|
||||
time.sleep(retry_delay)
|
||||
# Exponential backoff
|
||||
retry_delay *= 2
|
||||
else:
|
||||
st.error(f"Error editing image: {err}")
|
||||
return None
|
||||
|
||||
# If we've exhausted all retries
|
||||
st.error("The image editing service is currently unavailable. Please try again later.")
|
||||
return None
|
||||
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
# Ensure you sign up for an account to obtain an API key:
|
||||
# https://platform.stability.ai/
|
||||
# Your API key can be found here after account creation:
|
||||
# https://platform.stability.ai/account/keys
|
||||
|
||||
import os
|
||||
import requests
|
||||
import base64
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import streamlit as st
|
||||
from loguru import logger
|
||||
|
||||
# Import APIKeyManager
|
||||
from ...api_key_manager import APIKeyManager
|
||||
|
||||
def save_generated_image(data):
|
||||
"""Save the generated image to a file."""
|
||||
# Implementation for saving image
|
||||
pass
|
||||
|
||||
def generate_stable_diffusion_image(prompt):
|
||||
engine_id = "stable-diffusion-xl-1024-v1-0"
|
||||
api_host = os.getenv('API_HOST', 'https://api.stability.ai')
|
||||
|
||||
# Use APIKeyManager instead of direct environment variable access
|
||||
api_key_manager = APIKeyManager()
|
||||
api_key = api_key_manager.get_api_key("stability")
|
||||
|
||||
if api_key is None:
|
||||
st.warning("Missing Stability API key. Please configure it in the onboarding process.")
|
||||
return None
|
||||
|
||||
response = requests.post(
|
||||
f"{api_host}/v1/generation/{engine_id}/text-to-image",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
},
|
||||
json={
|
||||
"text_prompts": [
|
||||
{
|
||||
"text": prompt
|
||||
}
|
||||
],
|
||||
"cfg_scale": 7,
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
"samples": 1,
|
||||
"steps": 30,
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception("Non-200 response: " + str(response.text))
|
||||
|
||||
data = response.json()
|
||||
img_path = save_generated_image(data)
|
||||
|
||||
for i, image in enumerate(data["artifacts"]):
|
||||
# Decode base64 image data
|
||||
img_data = base64.b64decode(image["base64"])
|
||||
# Open image using PIL
|
||||
img = Image.open(BytesIO(img_data))
|
||||
# Display the image
|
||||
img.show()
|
||||
|
||||
return img_path
|
||||
@@ -0,0 +1,51 @@
|
||||
from loguru import logger
|
||||
import sys
|
||||
from PIL import Image
|
||||
from openai import OpenAI
|
||||
|
||||
def gen_new_from_given_img(img_path, image_dir, num_img=1, img_size="1024x1024", response_format="url"):
|
||||
"""
|
||||
Generates variations of a given image using OpenAI's image variation API.
|
||||
|
||||
This function takes an existing image, processes it, and generates a specified number of new images based on it.
|
||||
These generated images are variations of the original, providing creative flexibility.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the original image file.
|
||||
image_dir (str): Directory where the generated images will be saved.
|
||||
num_img (int, optional): Number of image variations to generate. Defaults to 1.
|
||||
img_size (str, optional): Size of the generated images. Defaults to "1024x1024".
|
||||
response_format (str, optional): Format in which the generated images are returned. Defaults to "url".
|
||||
|
||||
Returns:
|
||||
str: Path to the saved image variation.
|
||||
|
||||
Raises:
|
||||
SystemExit: If a critical error occurs that prevents successful execution.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Starting image variation generation for: {img_path}")
|
||||
|
||||
# Convert and prepare the image
|
||||
png = Image.open(img_path).convert('RGBA')
|
||||
background = Image.new('RGBA', png.size, (255, 255, 255))
|
||||
alpha_composite = Image.alpha_composite(background, png)
|
||||
alpha_composite.save(img_path, 'PNG', quality=80)
|
||||
logger.info("Image prepared for variation generation.")
|
||||
|
||||
client = OpenAI()
|
||||
variation_response = client.images.create_variation(
|
||||
image=open(img_path, "rb", encoding="utf-8"),
|
||||
n=num_img,
|
||||
size=img_size,
|
||||
response_format=response_format
|
||||
)
|
||||
|
||||
# Saving the generated image
|
||||
generated_image_path = save_generated_image(variation_response, image_dir)
|
||||
logger.info(f"Image variation generated and saved to: {generated_image_path}")
|
||||
return generated_image_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred during image variation generation: {e}")
|
||||
sys.exit(f"Exiting due to critical error: {e}")
|
||||
@@ -0,0 +1,163 @@
|
||||
#########################################################
|
||||
#
|
||||
# This module will generate images for the blogs using APIs
|
||||
# from Dall-E and other free resources. Given a prompt, the
|
||||
# images will be stored in local directory.
|
||||
# Required: openai API key.
|
||||
#
|
||||
#########################################################
|
||||
|
||||
# imports
|
||||
import os
|
||||
import sys
|
||||
import datetime
|
||||
import streamlit as st
|
||||
|
||||
import openai # OpenAI Python library to make API calls
|
||||
from loguru import logger
|
||||
logger.remove()
|
||||
logger.add(sys.stdout,
|
||||
colorize=True,
|
||||
format="<level>{level}</level>|<green>{file}:{line}:{function}</green>| {message}"
|
||||
)
|
||||
|
||||
#from .gen_dali2_images
|
||||
from .gen_dali3_images import generate_dalle3_images
|
||||
from .gen_stabl_diff_img import generate_stable_diffusion_image
|
||||
from ..text_generation.main_text_generation import llm_text_gen
|
||||
from .gen_gemini_images import generate_gemini_image
|
||||
|
||||
def generate_image(user_prompt, title=None, description=None, tags=None, content=None, aspect_ratio="16:9"):
|
||||
"""
|
||||
The generation API endpoint creates an image based on a text prompt.
|
||||
|
||||
Required inputs:
|
||||
prompt (str): A text description of the desired image(s). The maximum length is 1000 characters.
|
||||
|
||||
Optional inputs:
|
||||
--> image_engine: dalle2, dalle3, stable diffusion are supported.
|
||||
--> num_images (int): The number of images to generate. Must be between 1 and 10. Defaults to 1.
|
||||
--> size (str): The size of the generated images. Must be one of "256x256", "512x512", or "1024x1024".
|
||||
Smaller images are faster. Defaults to "1024x1024".
|
||||
-->response_format (str): The format in which the generated images are returned.
|
||||
Must be one of "url" or "b64_json". Defaults to "url".
|
||||
--> user (str): A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
|
||||
--> aspect_ratio (str): The aspect ratio for the generated image. Must be one of "16:9", "4:3", or "1:1". Defaults to "16:9".
|
||||
"""
|
||||
# FIXME: Need to remove default value to match sidebar input.
|
||||
image_engine = 'Gemini-AI'
|
||||
image_stored_at = None
|
||||
|
||||
if user_prompt:
|
||||
try:
|
||||
# Use enhanced prompt generator with all available parameters
|
||||
img_prompt = generate_enhanced_img_prompt(user_prompt, title, description, tags, content)
|
||||
|
||||
# Add aspect ratio to the prompt
|
||||
if aspect_ratio:
|
||||
img_prompt += f"\n\nAspect ratio: {aspect_ratio}"
|
||||
|
||||
if 'Dalle3' in image_engine:
|
||||
logger.info(f"Calling Dalle3 text-to-image with prompt: {img_prompt}")
|
||||
image_stored_at = generate_dalle3_images(img_prompt)
|
||||
elif 'Stability-AI' in image_engine:
|
||||
logger.info(f"Calling Stable diffusion text-to-image with prompt: \n{img_prompt}")
|
||||
image_stored_at = generate_stable_diffusion_image(img_prompt)
|
||||
elif 'Gemini-AI' in image_engine:
|
||||
logger.info(f"Calling Gemini text-to-image with prompt: \n{img_prompt}")
|
||||
image_stored_at = generate_gemini_image(img_prompt, aspect_ratio=aspect_ratio)
|
||||
return image_stored_at
|
||||
except Exception as err:
|
||||
logger.error(f"Failed to generate Image: {err}")
|
||||
st.warning(f"Failed to generate Image: {err}")
|
||||
else:
|
||||
logger.error("Skipping Image creation, No prompt provided.")
|
||||
|
||||
|
||||
def generate_img_prompt(user_prompt):
|
||||
"""
|
||||
Given prompt, this functions generated a prompt for image generation.
|
||||
"""
|
||||
prompt = f"""
|
||||
As an expert prompt generator for AI text to image models and artist, I will provide you with 'user text' for creating images.
|
||||
Your task is to create a prompt for a highly relevant image from given 'user text'.
|
||||
\n
|
||||
Choose from various art styles, utilize light & shadow effects etc.
|
||||
Make sure to avoid common image generation mistakes.
|
||||
Reply with only one answer, no descrition and in plaintext.
|
||||
Make sure your prompt is detailed and creative descriptions that will inspire unique and interesting images from the AI.
|
||||
|
||||
\n\nuser text:
|
||||
'''{user_prompt}'''"""
|
||||
|
||||
response = llm_text_gen(prompt)
|
||||
return response
|
||||
|
||||
|
||||
def generate_enhanced_img_prompt(user_prompt, title=None, description=None, tags=None, content=None):
|
||||
"""
|
||||
Given user prompt and additional context (title, description, tags, content),
|
||||
this function generates an enhanced prompt for better image generation.
|
||||
|
||||
Args:
|
||||
user_prompt (str): Base prompt from the user
|
||||
title (str, optional): Blog title or content title
|
||||
description (str, optional): Blog or content description/summary
|
||||
tags (list, optional): List of tags related to the content
|
||||
content (str, optional): Actual content or excerpt
|
||||
|
||||
Returns:
|
||||
str: Enhanced prompt for image generation
|
||||
"""
|
||||
# Start with the base prompt
|
||||
context_parts = [user_prompt]
|
||||
|
||||
# Add relevant context if available
|
||||
if title:
|
||||
context_parts.append(f"Title: {title}")
|
||||
|
||||
if description:
|
||||
context_parts.append(f"Description: {description}")
|
||||
|
||||
if tags and len(tags) > 0:
|
||||
tag_text = ", ".join(tags[:5]) # Limit to 5 tags to avoid too much noise
|
||||
context_parts.append(f"Tags: {tag_text}")
|
||||
|
||||
# Create a combined context
|
||||
combined_context = "\n".join(context_parts)
|
||||
|
||||
# Add some content excerpt if available (limited to avoid token limits)
|
||||
content_excerpt = ""
|
||||
if content:
|
||||
# Just use the first few hundred characters as excerpt
|
||||
content_excerpt = content[:300] + "..." if len(content) > 300 else content
|
||||
|
||||
# Create the prompt for LLM
|
||||
prompt = f"""
|
||||
As an expert prompt engineer for AI image generation models, create a detailed, creative prompt
|
||||
for generating a high-quality, relevant image based on the following context:
|
||||
|
||||
{combined_context}
|
||||
|
||||
Additional content excerpt:
|
||||
{content_excerpt}
|
||||
|
||||
Your task is to:
|
||||
1. Analyze the context and content to understand the main theme and subject
|
||||
2. Create a rich, detailed prompt for image generation (50-75 words)
|
||||
3. Include specific visual details, art style, mood, lighting, composition
|
||||
4. Make sure the prompt is highly relevant to the original context
|
||||
5. Avoid prohibited content or anything that violates image generation guidelines
|
||||
|
||||
Reply with ONLY the final prompt. No explanations or other text.
|
||||
"""
|
||||
|
||||
# Generate the enhanced prompt
|
||||
try:
|
||||
enhanced_prompt = llm_text_gen(prompt)
|
||||
logger.info(f"Generated enhanced image prompt: {enhanced_prompt[:100]}...")
|
||||
return enhanced_prompt
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating enhanced prompt: {e}")
|
||||
# Fall back to the simple prompt generation if enhanced fails
|
||||
return generate_img_prompt(user_prompt)
|
||||
@@ -0,0 +1,39 @@
|
||||
import base64
|
||||
import datetime
|
||||
import os
|
||||
import requests
|
||||
from PIL import Image
|
||||
import logging
|
||||
|
||||
def save_generated_image(img_generation_response):
|
||||
"""
|
||||
Save generated images for blog, ensuring unique names for SEO.
|
||||
"""
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Get image save directory with fallback to a local directory
|
||||
image_save_dir = os.getenv('IMG_SAVE_DIR', 'generated_images')
|
||||
|
||||
# Create the directory if it doesn't exist
|
||||
if not os.path.exists(image_save_dir):
|
||||
logger.info(f"Creating image save directory: {image_save_dir}")
|
||||
os.makedirs(image_save_dir, exist_ok=True)
|
||||
|
||||
generated_image_name = f"generated_image_{datetime.datetime.now():%Y-%m-%d-%H-%M-%S}.webp"
|
||||
generated_image_filepath = os.path.join(image_save_dir, generated_image_name)
|
||||
|
||||
try:
|
||||
for i, image in enumerate(img_generation_response["artifacts"]):
|
||||
with open(generated_image_filepath, "wb") as f:
|
||||
f.write(base64.b64decode(image["base64"]))
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Failed to get generated image content: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving image: {e}")
|
||||
return None
|
||||
|
||||
logger.info(f"Saved image at path: {generated_image_filepath}")
|
||||
|
||||
return generated_image_filepath
|
||||
Reference in New Issue
Block a user