Base code

This commit is contained in:
Kunthawat Greethong
2026-01-08 22:39:53 +07:00
parent 697115c61a
commit c35fa52117
2169 changed files with 626670 additions and 0 deletions

View File

@@ -0,0 +1,343 @@
# LLM Providers Module
This module provides functions for interacting with multiple LLM providers, specifically Google's Gemini API and Hugging Face Inference Providers. It follows official API documentation and implements best practices for reliable AI interactions.
## Supported Providers
- **Google Gemini**: High-quality text generation with structured JSON output
- **Hugging Face**: Multiple models via Inference Providers with unified interface
## Quick Start
```python
from services.llm_providers.main_text_generation import llm_text_gen
# Generate text (auto-detects available provider)
response = llm_text_gen("Write a blog post about AI trends")
print(response)
```
## Configuration
Set your preferred provider using the `GPT_PROVIDER` environment variable:
```bash
# Use Google Gemini (default)
export GPT_PROVIDER=gemini
# Use Hugging Face
export GPT_PROVIDER=hf_response_api
```
Configure API keys:
```bash
# For Google Gemini
export GEMINI_API_KEY=your_gemini_api_key_here
# For Hugging Face
export HF_TOKEN=your_huggingface_token_here
```
## Key Features
- **Structured JSON Response Generation**: Generate structured outputs with schema validation
- **Text Response Generation**: Simple text generation with retry logic
- **Comprehensive Error Handling**: Robust error handling and logging
- **Automatic API Key Management**: Secure API key handling
- **Support for Multiple Models**: gemini-2.5-flash and gemini-2.5-pro
## Best Practices
### 1. Use Structured Output for Complex Responses
```python
# ✅ Good: Use structured output for multi-field responses
schema = {
"type": "object",
"properties": {
"tasks": {
"type": "array",
"items": {
"type": "object",
"properties": {
"title": {"type": "string"},
"description": {"type": "string"}
}
}
}
}
}
result = gemini_structured_json_response(prompt, schema, temperature=0.2, max_tokens=8192)
```
### 2. Keep Schemas Simple and Flat
```python
# ✅ Good: Simple, flat schema
schema = {
"type": "object",
"properties": {
"monitoringTasks": {
"type": "array",
"items": {"type": "object", "properties": {...}}
}
}
}
# ❌ Avoid: Complex nested schemas with many required fields
schema = {
"type": "object",
"required": ["field1", "field2", "field3"],
"properties": {
"field1": {"type": "object", "required": [...], "properties": {...}},
"field2": {"type": "array", "items": {"type": "object", "required": [...], "properties": {...}}}
}
}
```
### 3. Set Appropriate Token Limits
```python
# ✅ Good: Use 8192 tokens for complex outputs
result = gemini_structured_json_response(prompt, schema, max_tokens=8192)
# ✅ Good: Use 2048 tokens for simple text responses
result = gemini_text_response(prompt, max_tokens=2048)
```
### 4. Use Low Temperature for Structured Output
```python
# ✅ Good: Low temperature for consistent structured output
result = gemini_structured_json_response(prompt, schema, temperature=0.1, max_tokens=8192)
# ✅ Good: Higher temperature for creative text
result = gemini_text_response(prompt, temperature=0.8, max_tokens=2048)
```
### 5. Implement Proper Error Handling
```python
# ✅ Good: Handle errors in calling functions
try:
response = gemini_structured_json_response(prompt, schema)
if isinstance(response, dict) and "error" in response:
raise Exception(f"Gemini error: {response.get('error')}")
# Process successful response
except Exception as e:
logger.error(f"AI service error: {e}")
# Handle error appropriately
```
### 6. Avoid Fallback to Text Parsing
```python
# ✅ Good: Use structured output only, no fallback
response = gemini_structured_json_response(prompt, schema)
if "error" in response:
raise Exception(f"Gemini error: {response.get('error')}")
# ❌ Avoid: Fallback to text parsing for structured responses
# This can lead to inconsistent results and parsing errors
```
## Usage Examples
### Structured JSON Response
```python
from services.llm_providers.gemini_provider import gemini_structured_json_response
# Define schema
monitoring_schema = {
"type": "object",
"properties": {
"monitoringTasks": {
"type": "array",
"items": {
"type": "object",
"properties": {
"component": {"type": "string"},
"title": {"type": "string"},
"description": {"type": "string"},
"assignee": {"type": "string"},
"frequency": {"type": "string"},
"metric": {"type": "string"},
"measurementMethod": {"type": "string"},
"successCriteria": {"type": "string"},
"alertThreshold": {"type": "string"},
"actionableInsights": {"type": "string"}
}
}
}
}
}
# Generate structured response
prompt = "Generate a monitoring plan for content strategy..."
result = gemini_structured_json_response(
prompt=prompt,
schema=monitoring_schema,
temperature=0.1,
max_tokens=8192
)
# Handle response
if isinstance(result, dict) and "error" in result:
raise Exception(f"Gemini error: {result.get('error')}")
# Process successful response
monitoring_tasks = result.get("monitoringTasks", [])
```
### Text Response
```python
from services.llm_providers.gemini_provider import gemini_text_response
# Generate text response
prompt = "Write a blog post about AI in content marketing..."
result = gemini_text_response(
prompt=prompt,
temperature=0.8,
max_tokens=2048
)
# Process response
if result:
print(f"Generated text: {result}")
else:
print("No response generated")
```
## Troubleshooting
### Common Issues and Solutions
#### 1. Response.parsed is None
**Symptoms**: `response.parsed` returns `None` even with successful HTTP 200
**Causes**:
- Schema too complex for the model
- Token limit too low
- Temperature too high for structured output
**Solutions**:
- Simplify schema structure
- Increase `max_tokens` to 8192
- Lower temperature to 0.1-0.3
- Test with smaller outputs first
#### 2. JSON Parsing Fails
**Symptoms**: `JSONDecodeError` or "Unterminated string" errors
**Causes**:
- Response truncated due to token limits
- Schema doesn't match expected output
- Model generates malformed JSON
**Solutions**:
- Reduce output size requested
- Verify schema matches expected structure
- Use structured output instead of text parsing
- Increase token limits
#### 3. Truncation Issues
**Symptoms**: Response cuts off mid-sentence or mid-array
**Causes**:
- Output too large for single response
- Token limits exceeded
**Solutions**:
- Reduce number of items requested
- Increase `max_tokens` to 8192
- Break large requests into smaller chunks
- Use `gemini-2.5-pro` for larger outputs
#### 4. Rate Limiting
**Symptoms**: `RetryError` or connection timeouts
**Causes**:
- Too many requests in short time
- Network connectivity issues
**Solutions**:
- Exponential backoff already implemented
- Check network connectivity
- Reduce request frequency
- Verify API key validity
### Debug Logging
The module includes comprehensive debug logging. Enable debug mode to see:
```python
import logging
logging.getLogger('services.llm_providers.gemini_provider').setLevel(logging.DEBUG)
```
Key log messages to monitor:
- `Gemini structured call | prompt_len=X | schema_kind=Y | temp=Z`
- `Gemini response | type=X | has_text=Y | has_parsed=Z`
- `Using response.parsed for structured output`
- `Falling back to response.text parsing`
## API Reference
### gemini_structured_json_response()
Generate structured JSON response using Google's Gemini Pro model.
**Parameters**:
- `prompt` (str): Input prompt for the AI model
- `schema` (dict): JSON schema defining expected output structure
- `temperature` (float): Controls randomness (0.0-1.0). Use 0.1-0.3 for structured output
- `top_p` (float): Nucleus sampling parameter (0.0-1.0)
- `top_k` (int): Top-k sampling parameter
- `max_tokens` (int): Maximum tokens in response. Use 8192 for complex outputs
- `system_prompt` (str, optional): System instruction for the model
**Returns**:
- `dict`: Parsed JSON response matching the provided schema
**Raises**:
- `Exception`: If API key is missing or API call fails
### gemini_text_response()
Generate text response using Google's Gemini Pro model.
**Parameters**:
- `prompt` (str): Input prompt for the AI model
- `temperature` (float): Controls randomness (0.0-1.0). Higher = more creative
- `top_p` (float): Nucleus sampling parameter (0.0-1.0)
- `n` (int): Number of responses to generate
- `max_tokens` (int): Maximum tokens in response
- `system_prompt` (str, optional): System instruction for the model
**Returns**:
- `str`: Generated text response
**Raises**:
- `Exception`: If API key is missing or API call fails
## Dependencies
- `google.generativeai` (genai): Official Gemini API client
- `tenacity`: Retry logic with exponential backoff
- `logging`: Debug and error logging
- `json`: Fallback JSON parsing
- `re`: Text extraction utilities
## Version History
- **v2.0** (January 2025): Enhanced structured output support, improved error handling, comprehensive documentation
- **v1.0**: Initial implementation with basic text and structured response support
## Contributing
When contributing to this module:
1. Follow the established patterns for error handling
2. Add comprehensive logging for debugging
3. Test with both simple and complex schemas
4. Update documentation for any new features
5. Ensure backward compatibility
## Support
For issues or questions:
1. Check the troubleshooting section above
2. Review debug logs for specific error messages
3. Test with simplified schemas to isolate issues
4. Verify API key configuration and network connectivity

View File

@@ -0,0 +1,237 @@
# Hugging Face Integration for AI Blog Writer
## Overview
The AI Blog Writer now supports both Google Gemini and Hugging Face as LLM providers, with a clean environment variable-based configuration system. This integration uses the [Hugging Face Responses API](https://huggingface.co/docs/inference-providers/guides/responses-api) which provides a unified interface for model interactions.
## Supported Providers
### 1. Google Gemini (Default)
- **Provider ID**: `google`
- **Environment Variable**: `GEMINI_API_KEY`
- **Models**: `gemini-2.0-flash-001`
- **Features**: Text generation, structured JSON output
### 2. Hugging Face
- **Provider ID**: `huggingface`
- **Environment Variable**: `HF_TOKEN`
- **Models**: Multiple models via Inference Providers
- **Features**: Text generation, structured JSON output, multi-model support
## Configuration
### Environment Variables
Set the `GPT_PROVIDER` environment variable to choose your preferred provider:
```bash
# Use Google Gemini (default)
export GPT_PROVIDER=gemini
# or
export GPT_PROVIDER=google
# Use Hugging Face
export GPT_PROVIDER=hf_response_api
# or
export GPT_PROVIDER=huggingface
# or
export GPT_PROVIDER=hf
```
### API Keys
Configure the appropriate API key for your chosen provider:
```bash
# For Google Gemini
export GEMINI_API_KEY=your_gemini_api_key_here
# For Hugging Face
export HF_TOKEN=your_huggingface_token_here
```
## Usage
### Basic Text Generation
```python
from services.llm_providers.main_text_generation import llm_text_gen
# Generate text (uses configured provider)
response = llm_text_gen("Write a blog post about AI trends")
print(response)
```
### Structured JSON Generation
```python
from services.llm_providers.main_text_generation import llm_text_gen
# Define JSON schema
schema = {
"type": "object",
"properties": {
"title": {"type": "string"},
"sections": {
"type": "array",
"items": {
"type": "object",
"properties": {
"heading": {"type": "string"},
"content": {"type": "string"}
}
}
}
}
}
# Generate structured response
response = llm_text_gen(
"Create a blog outline about machine learning",
json_struct=schema
)
print(response)
```
### Direct Provider Usage
```python
# Google Gemini
from services.llm_providers.gemini_provider import gemini_text_response
response = gemini_text_response(
prompt="Write about AI",
temperature=0.7,
max_tokens=1000
)
# Hugging Face
from services.llm_providers.huggingface_provider import huggingface_text_response
response = huggingface_text_response(
prompt="Write about AI",
model="openai/gpt-oss-120b:groq",
temperature=0.7,
max_tokens=1000
)
```
## Available Hugging Face Models
The Hugging Face provider supports multiple models via Inference Providers:
- `openai/gpt-oss-120b:groq` (default)
- `moonshotai/Kimi-K2-Instruct-0905:groq`
- `Qwen/Qwen2.5-VL-7B-Instruct`
- `meta-llama/Llama-3.1-8B-Instruct:groq`
- `microsoft/Phi-3-medium-4k-instruct:groq`
- `mistralai/Mistral-7B-Instruct-v0.3:groq`
## Provider Selection Logic
1. **Environment Variable**: If `GPT_PROVIDER` is set, use the specified provider
2. **Auto-detection**: If no environment variable, check available API keys:
- Prefer Google Gemini if `GEMINI_API_KEY` is available
- Fall back to Hugging Face if `HF_TOKEN` is available
3. **Fallback**: If the specified provider fails, automatically try the other provider
## Error Handling
The system includes comprehensive error handling:
- **Missing API Keys**: Clear error messages with setup instructions
- **Provider Failures**: Automatic fallback to the other provider
- **Invalid Models**: Validation with helpful error messages
- **Network Issues**: Retry logic with exponential backoff
## Migration from Previous Version
### Removed Providers
The following providers have been removed to simplify the system:
- OpenAI
- Anthropic
- DeepSeek
### Updated Imports
```python
# Old imports (no longer work)
from services.llm_providers.openai_provider import openai_chatgpt
from services.llm_providers.anthropic_provider import anthropic_text_response
from services.llm_providers.deepseek_provider import deepseek_text_response
# New imports
from services.llm_providers.gemini_provider import gemini_text_response, gemini_structured_json_response
from services.llm_providers.huggingface_provider import huggingface_text_response, huggingface_structured_json_response
```
## Testing
Run the integration tests to verify everything works:
```bash
cd backend
python -c "
import sys
sys.path.insert(0, '.')
from services.llm_providers.main_text_generation import check_gpt_provider
print('Google provider supported:', check_gpt_provider('google'))
print('Hugging Face provider supported:', check_gpt_provider('huggingface'))
print('OpenAI provider supported:', check_gpt_provider('openai'))
"
```
## Performance Considerations
### Google Gemini
- Fast response times
- High-quality outputs
- Good for structured content
### Hugging Face
- Multiple model options
- Cost-effective for high-volume usage
- Good for experimentation with different models
## Troubleshooting
### Common Issues
1. **"No LLM API keys configured"**
- Ensure either `GEMINI_API_KEY` or `HF_TOKEN` is set
- Check that the API key is valid
2. **"Unknown LLM provider"**
- Use only `google` or `huggingface` as provider values
- Check the `GPT_PROVIDER` environment variable
3. **"HF_TOKEN appears to be invalid"**
- Ensure your Hugging Face token starts with `hf_`
- Get a new token from [Hugging Face Settings](https://huggingface.co/settings/tokens)
4. **"OpenAI library not available"**
- Install the OpenAI library: `pip install openai`
- This is required for Hugging Face Responses API
### Debug Mode
Enable debug logging to see provider selection:
```python
import logging
logging.basicConfig(level=logging.DEBUG)
```
## Future Enhancements
- Support for additional Hugging Face models
- Model-specific parameter optimization
- Advanced caching strategies
- Performance monitoring and metrics
- A/B testing between providers
## Support
For issues or questions:
1. Check the troubleshooting section above
2. Review the [Hugging Face Responses API documentation](https://huggingface.co/docs/inference-providers/guides/responses-api)
3. Check the Google Gemini API documentation for Gemini-specific issues

View File

@@ -0,0 +1,18 @@
"""LLM Providers Service for ALwrity Backend.
This service handles all LLM (Language Model) provider integrations,
migrated from the legacy lib/gpt_providers functionality.
"""
from services.llm_providers.main_text_generation import llm_text_gen
from services.llm_providers.gemini_provider import gemini_text_response, gemini_structured_json_response
from services.llm_providers.huggingface_provider import huggingface_text_response, huggingface_structured_json_response
__all__ = [
"llm_text_gen",
"gemini_text_response",
"gemini_structured_json_response",
"huggingface_text_response",
"huggingface_structured_json_response"
]

View File

@@ -0,0 +1,310 @@
"""
Gemini Audio Text Generation Module
This module provides a comprehensive interface for working with audio files using Google's Gemini API.
It supports various audio processing capabilities including transcription, summarization, and analysis.
Key Features:
------------
1. Audio Transcription: Convert speech in audio files to text
2. Audio Summarization: Generate concise summaries of audio content
3. Segment Analysis: Analyze specific time segments of audio files
4. Timestamped Transcription: Generate transcriptions with timestamps
5. Token Counting: Count tokens in audio files
6. Format Support: Information about supported audio formats
Supported Audio Formats:
----------------------
- WAV (audio/wav)
- MP3 (audio/mp3)
- AIFF (audio/aiff)
- AAC (audio/aac)
- OGG Vorbis (audio/ogg)
- FLAC (audio/flac)
Technical Details:
----------------
- Each second of audio is represented as 32 tokens
- Maximum supported length of audio data in a single prompt is 9.5 hours
- Audio files are downsampled to 16 Kbps data resolution
- Multi-channel audio is combined into a single channel
Usage:
------
```python
from lib.gpt_providers.audio_to_text_generation.gemini_audio_text import transcribe_audio, summarize_audio
# Basic transcription
transcript = transcribe_audio("path/to/audio.mp3")
print(transcript)
# Summarization
summary = summarize_audio("path/to/audio.mp3")
print(summary)
# Analyze specific segment
segment_analysis = analyze_audio_segment("path/to/audio.mp3", "02:30", "03:29")
print(segment_analysis)
```
Requirements:
------------
- GEMINI_API_KEY environment variable must be set
- google-generativeai Python package
- python-dotenv for environment variable management
- loguru for logging
Dependencies:
------------
- google.genai
- dotenv
- loguru
- os, sys, base64, typing
"""
import os
import sys
from pathlib import Path
import google.genai as genai
from google.genai import types
from loguru import logger
from utils.logger_utils import get_service_logger
# Use service-specific logger to avoid conflicts
logger = get_service_logger("gemini_audio_text")
def load_environment():
"""Loads environment variables from a .env file."""
load_dotenv()
logger.info("Environment variables loaded successfully.")
def configure_google_api():
"""
Configures the Google Gemini API with the API key from environment variables.
Raises:
ValueError: If the GEMINI_API_KEY environment variable is not set.
"""
# Use APIKeyManager instead of direct environment variable access
api_key_manager = APIKeyManager()
api_key = api_key_manager.get_api_key("gemini")
if not api_key:
error_message = "Gemini API key not found. Please configure it in the onboarding process."
logger.error(error_message)
raise ValueError(error_message)
genai.configure(api_key=api_key)
logger.info("Google Gemini API configured successfully.")
def transcribe_audio(audio_file_path: str, prompt: str = "Transcribe the following audio:") -> Optional[str]:
"""
Transcribes audio using Google's Gemini model.
Args:
audio_file_path (str): The path to the audio file to be transcribed.
prompt (str, optional): The prompt to guide the transcription. Defaults to "Transcribe the following audio:".
Returns:
str: The transcribed text from the audio.
Returns None if transcription fails.
Raises:
FileNotFoundError: If the audio file is not found.
"""
try:
# Load environment variables and configure the Google API
load_environment()
configure_google_api()
logger.info(f"Attempting to transcribe audio file: {audio_file_path}")
# Check if file exists
if not os.path.exists(audio_file_path):
error_message = f"FileNotFoundError: The audio file at {audio_file_path} does not exist."
logger.error(error_message)
raise FileNotFoundError(error_message)
# Initialize a Gemini model appropriate for audio understanding
model = genai.GenerativeModel(model_name="gemini-1.5-flash")
# Upload the audio file
try:
audio_file = genai.upload_file(audio_file_path)
logger.info(f"Audio file uploaded successfully: {audio_file=}")
except FileNotFoundError:
error_message = f"FileNotFoundError: The audio file at {audio_file_path} does not exist."
logger.error(error_message)
raise FileNotFoundError(error_message)
except Exception as e:
logger.error(f"Error uploading audio file: {e}")
return None
# Generate the transcription
try:
response = model.generate_content([
prompt,
audio_file
])
# Check for valid response and extract text
if response and hasattr(response, 'text'):
transcript = response.text
logger.info(f"Transcription successful:\n{transcript}")
return transcript
else:
logger.warning("Transcription failed: Invalid or empty response from API.")
return None
except Exception as e:
logger.error(f"Error during transcription: {e}")
return None
except Exception as e:
logger.error(f"An unexpected error occurred: {e}")
return None
def summarize_audio(audio_file_path: str) -> Optional[str]:
"""
Summarizes the content of an audio file using Google's Gemini model.
Args:
audio_file_path (str): The path to the audio file to be summarized.
Returns:
str: A summary of the audio content.
Returns None if summarization fails.
"""
return transcribe_audio(audio_file_path, prompt="Please summarize the audio content:")
def analyze_audio_segment(audio_file_path: str, start_time: str, end_time: str) -> Optional[str]:
"""
Analyzes a specific segment of an audio file using timestamps.
Args:
audio_file_path (str): The path to the audio file.
start_time (str): Start time in MM:SS format.
end_time (str): End time in MM:SS format.
Returns:
str: Analysis of the specified audio segment.
Returns None if analysis fails.
"""
prompt = f"Analyze the audio content from {start_time} to {end_time}."
return transcribe_audio(audio_file_path, prompt=prompt)
def transcribe_with_timestamps(audio_file_path: str) -> Optional[str]:
"""
Transcribes audio with timestamps for each segment.
Args:
audio_file_path (str): The path to the audio file.
Returns:
str: Transcription with timestamps.
Returns None if transcription fails.
"""
return transcribe_audio(audio_file_path, prompt="Transcribe the audio with timestamps for each segment:")
def count_tokens(audio_file_path: str) -> Optional[int]:
"""
Counts the number of tokens in an audio file.
Args:
audio_file_path (str): The path to the audio file.
Returns:
int: Number of tokens in the audio file.
Returns None if counting fails.
"""
try:
# Load environment variables and configure the Google API
load_environment()
configure_google_api()
logger.info(f"Attempting to count tokens in audio file: {audio_file_path}")
# Check if file exists
if not os.path.exists(audio_file_path):
error_message = f"FileNotFoundError: The audio file at {audio_file_path} does not exist."
logger.error(error_message)
raise FileNotFoundError(error_message)
# Initialize a Gemini model
model = genai.GenerativeModel(model_name="gemini-1.5-flash")
# Upload the audio file
try:
audio_file = genai.upload_file(audio_file_path)
logger.info(f"Audio file uploaded successfully: {audio_file=}")
except Exception as e:
logger.error(f"Error uploading audio file: {e}")
return None
# Count tokens
try:
response = model.count_tokens([audio_file])
token_count = response.total_tokens
logger.info(f"Token count: {token_count}")
return token_count
except Exception as e:
logger.error(f"Error counting tokens: {e}")
return None
except Exception as e:
logger.error(f"An unexpected error occurred: {e}")
return None
def get_supported_formats() -> List[str]:
"""
Returns a list of supported audio formats.
Returns:
List[str]: List of supported MIME types.
"""
return [
"audio/wav",
"audio/mp3",
"audio/aiff",
"audio/aac",
"audio/ogg",
"audio/flac"
]
# Example usage
if __name__ == "__main__":
# Example 1: Basic transcription
audio_path = "path/to/your/audio.mp3"
transcript = transcribe_audio(audio_path)
print(f"Transcript: {transcript}")
# Example 2: Summarization
summary = summarize_audio(audio_path)
print(f"Summary: {summary}")
# Example 3: Analyze specific segment
segment_analysis = analyze_audio_segment(audio_path, "02:30", "03:29")
print(f"Segment Analysis: {segment_analysis}")
# Example 4: Transcription with timestamps
timestamped_transcript = transcribe_with_timestamps(audio_path)
print(f"Timestamped Transcript: {timestamped_transcript}")
# Example 5: Count tokens
token_count = count_tokens(audio_path)
print(f"Token Count: {token_count}")
# Example 6: Get supported formats
formats = get_supported_formats()
print(f"Supported Formats: {formats}")

View File

@@ -0,0 +1,218 @@
import os
import re
import sys
import tempfile
from pytubefix import YouTube
from loguru import logger
from openai import OpenAI
from tqdm import tqdm
import streamlit as st
from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential,
) # for exponential backoff
from .gemini_audio_text import transcribe_audio
# Import APIKeyManager
from ...onboarding.api_key_manager import APIKeyManager
def progress_function(stream, chunk, bytes_remaining):
# Calculate the percentage completion
current = ((stream.filesize - bytes_remaining) / stream.filesize)
progress_bar.update(current - progress_bar.n) # Update the progress bar
def rename_file_with_underscores(file_path):
"""Rename a file by replacing spaces and special characters with underscores.
Args:
file_path (str): The original file path.
Returns:
str: The new file path with underscores.
"""
# Extract the directory and the filename
dir_name, original_filename = os.path.split(file_path)
# Replace spaces and special characters with underscores in the filename
new_filename = re.sub(r'[^\w\-_\.]', '_', original_filename)
# Create the new file path
new_file_path = os.path.join(dir_name, new_filename)
# Rename the file
os.rename(file_path, new_file_path)
return new_file_path
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def speech_to_text(video_url):
"""
Transcribes speech to text from a YouTube video URL using OpenAI's Whisper model.
Args:
video_url (str): URL of the YouTube video to transcribe.
output_path (str, optional): Directory where the audio file will be saved. Defaults to '.'.
Returns:
str: The transcribed text from the video.
Raises:
SystemExit: If a critical error occurs that prevents successful execution.
"""
output_path = os.getenv("CONTENT_SAVE_DIR")
yt = None
audio_file = None
with st.status("Started Writing..", expanded=False) as status:
try:
if video_url.startswith("https://www.youtube.com/") or video_url.startswith("http://www.youtube.com/"):
logger.info(f"Accessing YouTube URL: {video_url}")
status.update(label=f"Accessing YouTube URL: {video_url}")
try:
vid_id = video_url.split("=")[1]
yt = YouTube(video_url, on_progress_callback=progress_function)
except Exception as err:
logger.error(f"Failed to get pytube stream object: {err}")
st.stop()
logger.info(f"Fetching the highest quality audio stream:{yt.title}")
status.update(label=f"Fetching the highest quality audio stream: {yt.title}")
try:
audio_stream = yt.streams.filter(only_audio=True).first()
except Exception as err:
logger.error(f"Failed to Download Youtube Audio: {err}")
st.stop()
if audio_stream is None:
logger.warning("No audio stream found for this video.")
st.warning("No audio stream found for this video.")
st.stop()
logger.info(f"Downloading audio for: {yt.title}")
status.update(label=f"Downloading audio for: {yt.title}")
global progress_bar
progress_bar = tqdm(total=1.0, unit='iB', unit_scale=True, desc=yt.title)
try:
audio_filename = re.sub(r'[^\w\-_\.]', '_', yt.title) + '.mp4'
audio_file = audio_stream.download(
output_path=os.getenv("CONTENT_SAVE_DIR"),
filename=audio_filename)
#audio_file = rename_file_with_underscores(audio_file)
except Exception as err:
logger.error(f"Failed to download audio file: {audio_file}")
progress_bar.close()
logger.info(f"Audio downloaded: {yt.title} to {audio_file}")
status.update(label=f"Audio downloaded: {yt.title} to {output_path}")
# Audio filepath from local directory.
elif os.path.exists(audio_input):
audio_file = video_url
# Checking file size
max_file_size = 24 * 1024 * 1024 # 24MB
file_size = os.path.getsize(audio_file)
# Convert file size to MB for logging
file_size_MB = file_size / (1024 * 1024) # Convert bytes to MB
logger.info(f"Downloaded Audio Size is: {file_size_MB:.2f} MB")
status.update(label=f"Downloaded Audio Size is: {file_size_MB:.2f} MB")
if file_size > max_file_size:
logger.error("File size exceeds 24MB limit.")
# FIXME: We can chunk hour long videos, the code is not tested.
#long_video(audio_file)
sys.exit("File size limit exceeded.")
st.error("Audio File size limit exceeded. File a fixme/issues at ALwrity github.")
try:
print(f"Audio File: {audio_file}")
transcript = transcribe_audio(audio_file)
print(f"\n\n\n--- Tracribe: {transcript} ----\n\n\n")
exit(1)
status.update(label=f"Initializing OpenAI client for transcription: {audio_file}")
logger.info(f"Initializing OpenAI client for transcription: {audio_file}")
# Use APIKeyManager instead of direct environment variable access
api_key_manager = APIKeyManager()
api_key = api_key_manager.get_api_key("openai")
if not api_key:
raise ValueError("OpenAI API key not found. Please configure it in the onboarding process.")
client = OpenAI(api_key=api_key)
logger.info("Transcribing using OpenAI's Whisper model.")
transcript = client.audio.transcriptions.create(
model="whisper-1",
file=open(audio_file, "rb"),
response_format="text"
)
logger.info(f"\nYouTube video transcription:\n{yt.title}\n{transcript}\n")
status.update(label=f"\nYouTube video transcription:\n{yt.title}\n{transcript}\n")
return transcript, yt.title
except Exception as e:
logger.error(f"Failed in Whisper transcription: {e}")
st.warning(f"Failed in Openai Whisper transcription: {e}")
transcript = transcribe_audio(audio_file)
print(f"\n\n\n--- Tracribe: {transcript} ----\n\n\n")
return transcript, yt.title
except Exception as e:
st.error(f"An error occurred during YouTube video processing: {e}")
finally:
try:
if os.path.exists(audio_file):
os.remove(audio_file)
logger.info("Temporary audio file removed.")
except PermissionError:
st.error(f"Permission error: Cannot remove '{audio_file}'. Please make sure of necessary permissions.")
except Exception as e:
st.error(f"An error occurred removing audio file: {e}")
def long_video(temp_file_name):
"""
Transcribes a YouTube video using OpenAI's Whisper API by processing the video in chunks.
This function handles videos longer than the context limit of the Whisper API by dividing the video into
10-minute segments, transcribing each segment individually, and then combining the results.
Key Changes and Notes:
1. Video Splitting: Splits the audio into 10-minute chunks using the moviepy library.
2. Chunk Transcription: Each audio chunk is transcribed separately and the results are concatenated.
3. Temporary Files for Chunks: Uses temporary files for each audio chunk for transcription.
4. Error Handling: Exception handling is included to capture and return any errors during the process.
5. Logging: Process steps are logged for debugging and monitoring.
6. Cleaning Up: Removes temporary files for both the entire video and individual audio chunks after processing.
Args:
video_url (str): URL of the YouTube video to be transcribed.
"""
# Extract audio and split into chunks
logger.info(f"Processing the YT video: {temp_file_name}")
full_audio = mp.AudioFileClip(temp_file_name)
duration = full_audio.duration
chunk_length = 600 # 10 minutes in seconds
chunks = [full_audio.subclip(start, min(start + chunk_length, duration)) for start in range(0, int(duration), chunk_length)]
combined_transcript = ""
for i, chunk in enumerate(chunks):
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as audio_chunk_file:
chunk.write_audiofile(audio_chunk_file.name, codec="mp3")
with open(audio_chunk_file.name, "rb", encoding="utf-8") as audio_file:
# Transcribe each chunk using OpenAI's Whisper API
app.logger.info(f"Transcribing chunk {i+1}/{len(chunks)}")
transcript = openai.Audio.transcribe("whisper-1", audio_file)
combined_transcript += transcript['text'] + "\n\n"
# Remove the chunk audio file
os.remove(audio_chunk_file.name)

View File

@@ -0,0 +1,886 @@
"""
Enhanced Gemini Provider for Grounded Content Generation
This provider uses native Google Search grounding to generate content that is
factually grounded in current web sources, with automatic citation generation.
Based on Google AI's official grounding documentation.
"""
import os
import json
import re
import time
import asyncio
from typing import List, Dict, Any, Optional
from datetime import datetime
from loguru import logger
try:
from google import genai
from google.genai import types
GOOGLE_GENAI_AVAILABLE = True
except ImportError:
GOOGLE_GENAI_AVAILABLE = False
logger.warn("Google GenAI not available. Install with: pip install google-genai")
class GeminiGroundedProvider:
"""
Enhanced Gemini provider for grounded content generation with native Google Search.
This provider uses the official Google Search grounding tool to generate content
that is factually grounded in current web sources, with automatic citation generation.
Based on: https://ai.google.dev/gemini-api/docs/google-search
"""
def __init__(self):
"""Initialize the Gemini Grounded Provider."""
if not GOOGLE_GENAI_AVAILABLE:
raise ImportError("Google GenAI library not available. Install with: pip install google-genai")
self.api_key = os.getenv('GEMINI_API_KEY')
if not self.api_key:
raise ValueError("GEMINI_API_KEY environment variable is required")
# Initialize the Gemini client with timeout configuration
self.client = genai.Client(api_key=self.api_key)
self.timeout = 60 # 60 second timeout for API calls (increased for research)
self._cache: Dict[str, Any] = {}
logger.info("✅ Gemini Grounded Provider initialized with native Google Search grounding")
async def generate_grounded_content(
self,
prompt: str,
content_type: str = "linkedin_post",
temperature: float = 0.7,
max_tokens: int = 2048,
urls: Optional[List[str]] = None,
mode: str = "polished",
user_id: Optional[str] = None,
validate_subsequent_operations: bool = False
) -> Dict[str, Any]:
"""
Generate grounded content using native Google Search grounding.
Args:
prompt: The content generation prompt
content_type: Type of content to generate
temperature: Creativity level (0.0-1.0)
max_tokens: Maximum tokens in response
urls: Optional list of URLs for URL Context tool
mode: Content mode ("draft" or "polished")
user_id: User ID for subscription checking (required if validate_subsequent_operations=True)
validate_subsequent_operations: If True, validates Google Grounding + 3 LLM calls for research workflow
Returns:
Dictionary containing generated content and grounding metadata
"""
try:
# PRE-FLIGHT VALIDATION: If this is part of a research workflow, validate ALL operations
# MUST happen BEFORE any API calls - return immediately if validation fails
if validate_subsequent_operations:
if not user_id:
raise ValueError("user_id is required when validate_subsequent_operations=True")
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_research_operations
from fastapi import HTTPException
import os
db = next(get_db())
try:
pricing_service = PricingService(db)
gpt_provider = os.getenv("GPT_PROVIDER", "google")
# Validate ALL research operations before making ANY API calls
# This prevents wasteful external API calls if subsequent LLM calls would fail
# Raises HTTPException immediately if validation fails - frontend gets immediate response
validate_research_operations(
pricing_service=pricing_service,
user_id=user_id,
gpt_provider=gpt_provider
)
except HTTPException as http_ex:
# Re-raise immediately - don't proceed with API call
logger.error(f"[Gemini Grounded] ❌ Pre-flight validation failed - blocking API call")
raise
finally:
db.close()
logger.info(f"[Gemini Grounded] ✅ Pre-flight validation passed - proceeding with API call")
logger.info(f"[Gemini Grounded] Generating grounded content for {content_type} using native Google Search")
# Build the grounded prompt
grounded_prompt = self._build_grounded_prompt(prompt, content_type)
# Configure tools: Google Search and optional URL Context
tools: List[Any] = [
types.Tool(google_search=types.GoogleSearch())
]
if urls:
try:
# URL Context tool (ai.google.dev URL Context)
tools.append(types.Tool(url_context=types.UrlContext()))
logger.info(f"Enabled URL Context tool for {len(urls)} URLs")
except Exception as tool_err:
logger.warning(f"URL Context tool not available in SDK version: {tool_err}")
# Apply mode presets (Draft vs Polished)
# Use Gemini 2.0 Flash for better content generation with grounding
model_id = "gemini-2.0-flash"
if mode == "draft":
model_id = "gemini-2.0-flash"
temperature = min(1.0, max(0.0, temperature))
else:
model_id = "gemini-2.0-flash"
# Configure generation settings
config = types.GenerateContentConfig(
tools=tools,
max_output_tokens=max_tokens,
temperature=temperature
)
# Make the request with native grounding and timeout
import asyncio
import concurrent.futures
try:
# Cache first
cache_key = self._make_cache_key(model_id, grounded_prompt, urls)
if cache_key in self._cache:
logger.info("Cache hit for grounded content request")
response = self._cache[cache_key]
else:
# Run the synchronous generate_content in a thread pool to make it awaitable
loop = asyncio.get_event_loop()
with concurrent.futures.ThreadPoolExecutor() as executor:
response = await asyncio.wait_for(
loop.run_in_executor(
executor,
lambda: self.client.models.generate_content(
model=model_id,
contents=self._inject_urls_into_prompt(grounded_prompt, urls) if urls else grounded_prompt,
config=config,
)
),
timeout=self.timeout
)
self._cache[cache_key] = response
except asyncio.TimeoutError:
from services.blog_writer.exceptions import APITimeoutException
raise APITimeoutException(
f"Gemini API request timed out after {self.timeout} seconds",
timeout_seconds=self.timeout,
context={"content_type": content_type, "model_id": model_id}
)
except Exception as api_error:
# Handle specific Google API errors with enhanced retry logic
error_str = str(api_error)
# Non-retryable errors
if "401" in error_str or "403" in error_str:
from services.blog_writer.exceptions import ValidationException
raise ValidationException(
"Authentication failed. Please check your API credentials.",
field="api_key",
context={"error": error_str, "content_type": content_type}
)
elif "400" in error_str:
from services.blog_writer.exceptions import ValidationException
raise ValidationException(
"Invalid request. Please check your input parameters.",
field="request",
context={"error": error_str, "content_type": content_type}
)
# Retryable errors - use enhanced retry logic
from services.blog_writer.retry_utils import retry_with_backoff, RESEARCH_RETRY_CONFIG
try:
response = await retry_with_backoff(
lambda: self._make_api_request_with_model(grounded_prompt, config, model_id, urls),
config=RESEARCH_RETRY_CONFIG,
operation_name=f"gemini_grounded_{content_type}",
context={"content_type": content_type, "model_id": model_id}
)
except Exception as retry_error:
# If retry also failed, raise the original error with context
from services.blog_writer.exceptions import ResearchFailedException
raise ResearchFailedException(
f"Google AI service error after retries: {error_str}",
context={"original_error": error_str, "retry_error": str(retry_error), "content_type": content_type}
)
# Process the grounded response
result = self._process_grounded_response(response, content_type)
# Attach URL Context metadata if present
try:
if hasattr(response, 'candidates') and response.candidates:
candidate0 = response.candidates[0]
if hasattr(candidate0, 'url_context_metadata') and candidate0.url_context_metadata:
result['url_context_metadata'] = candidate0.url_context_metadata
logger.info("Attached url_context_metadata to result")
except Exception as meta_err:
logger.warning(f"Unable to attach url_context_metadata: {meta_err}")
logger.info(f"✅ Grounded content generated successfully with {len(result.get('sources', []))} sources")
return result
except Exception as e:
# Log error without causing secondary exceptions
try:
logger.error(f"❌ Error generating grounded content: {str(e)}")
except:
# Fallback to print if logging fails
print(f"Error generating grounded content: {str(e)}")
raise
async def _make_api_request(self, grounded_prompt: str, config: Any):
"""Make the actual API request to Gemini."""
import concurrent.futures
loop = asyncio.get_event_loop()
with concurrent.futures.ThreadPoolExecutor() as executor:
return await asyncio.wait_for(
loop.run_in_executor(
executor,
lambda: self.client.models.generate_content(
model="gemini-2.0-flash",
contents=grounded_prompt,
config=config,
)
),
timeout=self.timeout
)
async def _make_api_request_with_model(self, grounded_prompt: str, config: Any, model_id: str, urls: Optional[List[str]] = None):
"""Make the API request with explicit model id and optional URL injection."""
logger.info(f"🔍 DEBUG: Making API request with model: {model_id}")
logger.info(f"🔍 DEBUG: Prompt length: {len(grounded_prompt)} characters")
logger.info(f"🔍 DEBUG: Prompt preview (first 300 chars): {grounded_prompt[:300]}...")
import concurrent.futures
loop = asyncio.get_event_loop()
with concurrent.futures.ThreadPoolExecutor() as executor:
resp = await asyncio.wait_for(
loop.run_in_executor(
executor,
lambda: self.client.models.generate_content(
model=model_id,
contents=self._inject_urls_into_prompt(grounded_prompt, urls) if urls else grounded_prompt,
config=config,
)
),
timeout=self.timeout
)
self._cache[self._make_cache_key(model_id, grounded_prompt, urls)] = resp
return resp
def _inject_urls_into_prompt(self, prompt: str, urls: Optional[List[str]]) -> str:
"""Append URLs to the prompt for URL Context tool to pick up (as per docs)."""
if not urls:
return prompt
safe_urls = [u for u in urls if isinstance(u, str) and u.startswith("http")]
if not safe_urls:
return prompt
urls_block = "\n".join(safe_urls[:20])
return f"{prompt}\n\nSOURCE URLS (use url_context to retrieve content):\n{urls_block}"
def _make_cache_key(self, model_id: str, prompt: str, urls: Optional[List[str]]) -> str:
import hashlib
u = "|".join((urls or [])[:20])
base = f"{model_id}|{prompt}|{u}"
return hashlib.sha256(base.encode("utf-8")).hexdigest()
async def _retry_with_backoff(self, func, max_retries: int = 3, base_delay: float = 1.0):
"""Retry a function with exponential backoff."""
for attempt in range(max_retries + 1):
try:
return await func()
except Exception as e:
if attempt == max_retries:
# Last attempt failed, raise the error
raise e
# Calculate delay with exponential backoff
delay = base_delay * (2 ** attempt)
logger.warning(f"Attempt {attempt + 1} failed, retrying in {delay} seconds: {str(e)}")
await asyncio.sleep(delay)
def _build_grounded_prompt(self, prompt: str, content_type: str) -> str:
"""
Build a prompt optimized for grounded content generation.
Args:
prompt: Base prompt
content_type: Type of content being generated
Returns:
Enhanced prompt for grounded generation
"""
content_type_instructions = {
"linkedin_post": "You are an expert LinkedIn content strategist. Generate a highly engaging, professional LinkedIn post that drives meaningful engagement, establishes thought leadership, and includes compelling hooks, actionable insights, and strategic hashtags. Every element should be optimized for maximum engagement and shareability.",
"linkedin_article": "You are a senior content strategist and industry thought leader. Generate a comprehensive, SEO-optimized LinkedIn article with compelling headlines, structured content, data-driven insights, and practical takeaways. Include proper source citations and engagement elements throughout.",
"linkedin_carousel": "You are a visual content strategist specializing in LinkedIn carousels. Generate compelling, story-driven carousel content with clear visual hierarchy, actionable insights per slide, and strategic engagement elements. Each slide should provide immediate value while building anticipation for the next.",
"linkedin_video_script": "You are a video content strategist and LinkedIn engagement expert. Generate a compelling video script optimized for LinkedIn's algorithm with attention-grabbing hooks, strategic timing, and engagement-driven content. Include specific visual and audio recommendations for maximum impact.",
"linkedin_comment_response": "You are a LinkedIn engagement specialist and industry expert. Generate thoughtful, value-adding comment responses that encourage further discussion, demonstrate expertise, and build meaningful professional relationships. Focus on genuine engagement over generic responses."
}
instruction = content_type_instructions.get(content_type, "Generate professional content with factual accuracy.")
grounded_prompt = f"""
{instruction}
CRITICAL REQUIREMENTS FOR LINKEDIN CONTENT:
- Use ONLY current, factual information from reliable sources (2024-2025)
- Cite specific sources for ALL claims, statistics, and recent developments
- Ensure content is optimized for LinkedIn's algorithm and engagement patterns
- Include strategic hashtags and engagement elements throughout
User Request: {prompt}
CONTENT QUALITY STANDARDS:
- All factual claims must be backed by current, authoritative sources
- Use professional yet conversational language that encourages engagement
- Include relevant industry insights, trends, and data points
- Make content highly shareable with clear value proposition
- Optimize for LinkedIn's professional audience and engagement metrics
ENGAGEMENT OPTIMIZATION:
- Include thought-provoking questions and calls-to-action
- Use storytelling elements and real-world examples
- Ensure content provides immediate, actionable value
- Optimize for comments, shares, and professional networking
- Include industry-specific terminology and insights
REMEMBER: This content will be displayed on LinkedIn with full source attribution and grounding data. Every claim must be verifiable, and the content should position the author as a thought leader in their industry.
"""
return grounded_prompt.strip()
def _process_grounded_response(self, response, content_type: str) -> Dict[str, Any]:
"""
Process the Gemini response with grounding metadata.
Args:
response: Gemini API response
content_type: Type of content generated
Returns:
Processed content with sources and citations
"""
try:
# Debug: Log response structure
logger.info(f"🔍 DEBUG: Response type: {type(response)}")
logger.info(f"🔍 DEBUG: Response has 'text': {hasattr(response, 'text')}")
logger.info(f"🔍 DEBUG: Response has 'candidates': {hasattr(response, 'candidates')}")
logger.info(f"🔍 DEBUG: Response has 'grounding_metadata': {hasattr(response, 'grounding_metadata')}")
if hasattr(response, 'grounding_metadata'):
logger.info(f"🔍 DEBUG: Grounding metadata: {response.grounding_metadata}")
if hasattr(response, 'candidates') and response.candidates:
logger.info(f"🔍 DEBUG: Number of candidates: {len(response.candidates)}")
candidate = response.candidates[0]
logger.info(f"🔍 DEBUG: Candidate type: {type(candidate)}")
logger.info(f"🔍 DEBUG: Candidate has 'content': {hasattr(candidate, 'content')}")
if hasattr(candidate, 'content') and candidate.content:
logger.info(f"🔍 DEBUG: Content type: {type(candidate.content)}")
# Check if content is a list or single object
if hasattr(candidate.content, '__iter__') and not isinstance(candidate.content, str):
try:
content_length = len(candidate.content) if candidate.content else 0
logger.info(f"🔍 DEBUG: Content is iterable, length: {content_length}")
except TypeError:
logger.info(f"🔍 DEBUG: Content is iterable but has no len() - treating as single object")
for i, part in enumerate(candidate.content):
logger.info(f"🔍 DEBUG: Part {i} type: {type(part)}")
logger.info(f"🔍 DEBUG: Part {i} has 'text': {hasattr(part, 'text')}")
if hasattr(part, 'text'):
logger.info(f"🔍 DEBUG: Part {i} text length: {len(part.text) if part.text else 0}")
else:
logger.info(f"🔍 DEBUG: Content is single object, has 'text': {hasattr(candidate.content, 'text')}")
if hasattr(candidate.content, 'text'):
logger.info(f"🔍 DEBUG: Content text length: {len(candidate.content.text) if candidate.content.text else 0}")
# Extract the main content - prioritize response.text as it's more reliable
content = ""
if hasattr(response, 'text'):
logger.info(f"🔍 DEBUG: response.text exists, value: '{response.text}', type: {type(response.text)}")
if response.text:
content = response.text
logger.info(f"🔍 DEBUG: Using response.text, length: {len(content)}")
else:
logger.info(f"🔍 DEBUG: response.text is empty or None")
elif hasattr(response, 'candidates') and response.candidates:
candidate = response.candidates[0]
if hasattr(candidate, 'content') and candidate.content:
# Handle both single Content object and list of parts
if hasattr(candidate.content, '__iter__') and not isinstance(candidate.content, str):
# Content is a list of parts
text_parts = []
for part in candidate.content:
if hasattr(part, 'text'):
text_parts.append(part.text)
content = " ".join(text_parts)
logger.info(f"🔍 DEBUG: Using candidate.content (list), extracted {len(text_parts)} parts, total length: {len(content)}")
else:
# Content is a single object
if hasattr(candidate.content, 'text'):
content = candidate.content.text
logger.info(f"🔍 DEBUG: Using candidate.content (single), text length: {len(content)}")
else:
logger.warning("🔍 DEBUG: candidate.content has no 'text' attribute")
logger.info(f"Extracted content length: {len(content) if content else 0}")
if not content:
logger.warning("⚠️ No content extracted from Gemini response - using fallback content")
logger.warning("⚠️ This indicates Google Search grounding is not working properly")
content = "Generated content about the requested topic."
# Initialize result structure
result = {
'content': content,
'sources': [],
'citations': [],
'search_queries': [],
'grounding_metadata': {},
'content_type': content_type,
'generation_timestamp': datetime.now().isoformat()
}
# Debug: Log response structure
logger.info(f"Response type: {type(response)}")
logger.info(f"Response attributes: {dir(response)}")
# Extract grounding metadata if available
if hasattr(response, 'candidates') and response.candidates:
candidate = response.candidates[0]
logger.info(f"Candidate attributes: {dir(candidate)}")
if hasattr(candidate, 'grounding_metadata') and candidate.grounding_metadata:
grounding_metadata = candidate.grounding_metadata
result['grounding_metadata'] = grounding_metadata
logger.info(f"Grounding metadata attributes: {dir(grounding_metadata)}")
logger.info(f"Grounding metadata type: {type(grounding_metadata)}")
logger.info(f"Grounding metadata value: {grounding_metadata}")
# Log all available attributes and their values
for attr in dir(grounding_metadata):
if not attr.startswith('_'):
try:
value = getattr(grounding_metadata, attr)
logger.info(f" {attr}: {type(value)} = {value}")
except Exception as e:
logger.warning(f" {attr}: Error accessing - {e}")
# Extract search queries
if hasattr(grounding_metadata, 'web_search_queries'):
result['search_queries'] = grounding_metadata.web_search_queries
logger.info(f"Search queries: {grounding_metadata.web_search_queries}")
# Extract sources from grounding chunks
sources = [] # Initialize sources list
if hasattr(grounding_metadata, 'grounding_chunks') and grounding_metadata.grounding_chunks:
for i, chunk in enumerate(grounding_metadata.grounding_chunks):
logger.info(f"Chunk {i} attributes: {dir(chunk)}")
if hasattr(chunk, 'web'):
source = {
'index': i,
'title': getattr(chunk.web, 'title', f'Source {i+1}'),
'url': getattr(chunk.web, 'uri', ''),
'type': 'web'
}
sources.append(source)
logger.info(f"Extracted {len(sources)} sources from grounding chunks")
else:
logger.warning("⚠️ No grounding chunks found - this is normal for some queries")
logger.info(f"Grounding metadata available fields: {[attr for attr in dir(grounding_metadata) if not attr.startswith('_')]}")
# Check if we have search queries - this means Google Search was triggered
if hasattr(grounding_metadata, 'web_search_queries') and grounding_metadata.web_search_queries:
logger.info(f"✅ Google Search was triggered with {len(grounding_metadata.web_search_queries)} queries")
# Create sources based on search queries
for i, query in enumerate(grounding_metadata.web_search_queries[:5]): # Limit to 5 sources
source = {
'index': i,
'title': f"Search: {query}",
'url': f"https://www.google.com/search?q={query.replace(' ', '+')}",
'type': 'search_query',
'query': query
}
sources.append(source)
logger.info(f"Created {len(sources)} sources from search queries")
else:
logger.warning("⚠️ No search queries found either - grounding may not have been triggered")
result['sources'] = sources
# Extract citations from grounding supports
if hasattr(grounding_metadata, 'grounding_supports') and grounding_metadata.grounding_supports:
citations = []
for support in grounding_metadata.grounding_supports:
if hasattr(support, 'segment') and hasattr(support, 'grounding_chunk_indices'):
citation = {
'type': 'inline',
'start_index': getattr(support.segment, 'start_index', 0),
'end_index': getattr(support.segment, 'end_index', 0),
'text': getattr(support.segment, 'text', ''),
'source_indices': support.grounding_chunk_indices,
'reference': f"Source {support.grounding_chunk_indices[0] + 1}" if support.grounding_chunk_indices else "Unknown"
}
citations.append(citation)
result['citations'] = citations
logger.info(f"Extracted {len(citations)} citations")
else:
logger.warning("⚠️ No grounding supports found - this is normal when no web sources are retrieved")
# Create basic citations from the content if we have sources
if sources:
citations = []
for i, source in enumerate(sources[:3]): # Limit to 3 citations
citation = {
'type': 'reference',
'start_index': 0,
'end_index': 0,
'text': f"Source {i+1}",
'source_indices': [i],
'reference': f"Source {i+1}",
'source': source
}
citations.append(citation)
result['citations'] = citations
logger.info(f"Created {len(citations)} basic citations from sources")
else:
result['citations'] = []
logger.info("No citations created - no sources available")
# Extract search entry point for UI display
if hasattr(grounding_metadata, 'search_entry_point') and grounding_metadata.search_entry_point:
if hasattr(grounding_metadata.search_entry_point, 'rendered_content'):
result['search_widget'] = grounding_metadata.search_entry_point.rendered_content
logger.info("✅ Extracted search widget HTML for UI display")
# Extract search queries for reference
if hasattr(grounding_metadata, 'web_search_queries') and grounding_metadata.web_search_queries:
result['search_queries'] = grounding_metadata.web_search_queries
logger.info(f"✅ Extracted {len(grounding_metadata.web_search_queries)} search queries")
logger.info(f"✅ Successfully extracted {len(result['sources'])} sources and {len(result['citations'])} citations from grounding metadata")
logger.info(f"Sources: {result['sources']}")
logger.info(f"Citations: {result['citations']}")
else:
logger.error("❌ CRITICAL: No grounding metadata found in response")
logger.error(f"Response structure: {dir(response)}")
logger.error(f"First candidate structure: {dir(candidates[0]) if candidates else 'No candidates'}")
raise ValueError("No grounding metadata found - grounding is not working properly")
else:
logger.warning("⚠️ No candidates found in response. Returning content without sources.")
# Add content-specific processing
if content_type == "linkedin_post":
result.update(self._process_post_content(content))
elif content_type == "linkedin_article":
result.update(self._process_article_content(content))
elif content_type == "linkedin_carousel":
result.update(self._process_carousel_content(content))
elif content_type == "linkedin_video_script":
result.update(self._process_video_script_content(content))
return result
except Exception as e:
logger.error(f"❌ CRITICAL: Error processing grounded response: {str(e)}")
logger.error(f"Exception type: {type(e)}")
logger.error(f"Exception details: {e}")
raise ValueError(f"Failed to process grounded response: {str(e)}")
def _process_post_content(self, content: str) -> Dict[str, Any]:
"""Process LinkedIn post content for hashtags and engagement elements."""
try:
# Handle None content
if content is None:
content = ""
logger.warning("Content is None, using empty string")
# Extract hashtags
hashtags = re.findall(r'#\w+', content)
# Generate call-to-action if not present
cta_patterns = [
r'What do you think\?',
r'Share your thoughts',
r'Comment below',
r'What\'s your experience\?',
r'Let me know in the comments'
]
has_cta = any(re.search(pattern, content, re.IGNORECASE) for pattern in cta_patterns)
call_to_action = None
if not has_cta:
call_to_action = "What are your thoughts on this? Share in the comments!"
return {
'hashtags': [{'hashtag': tag, 'category': 'general', 'popularity_score': 0.8} for tag in hashtags],
'call_to_action': call_to_action,
'engagement_prediction': {
'estimated_likes': max(50, len(content) // 10),
'estimated_comments': max(5, len(content) // 100)
}
}
except Exception as e:
logger.error(f"Error processing post content: {str(e)}")
return {}
def _process_article_content(self, content: str) -> Dict[str, Any]:
"""Process LinkedIn article content for structure and SEO."""
try:
# Extract title (first line or first sentence)
lines = content.split('\n')
title = lines[0].strip() if lines else "Article Title"
# Estimate word count
word_count = len(content.split())
# Generate sections based on content structure
sections = []
current_section = ""
for line in lines:
if line.strip().startswith('#') or line.strip().startswith('##'):
if current_section:
sections.append({'title': 'Section', 'content': current_section.strip()})
current_section = ""
else:
current_section += line + "\n"
if current_section:
sections.append({'title': 'Content', 'content': current_section.strip()})
return {
'title': title,
'word_count': word_count,
'sections': sections,
'reading_time': max(1, word_count // 200), # 200 words per minute
'seo_metadata': {
'meta_description': content[:160] + "..." if len(content) > 160 else content,
'keywords': self._extract_keywords(content)
}
}
except Exception as e:
logger.error(f"Error processing article content: {str(e)}")
return {}
def _process_carousel_content(self, content: str) -> Dict[str, Any]:
"""Process LinkedIn carousel content for slide structure."""
try:
# Split content into slides (basic implementation)
slides = []
content_parts = content.split('\n\n')
for i, part in enumerate(content_parts[:10]): # Max 10 slides
if part.strip():
slides.append({
'slide_number': i + 1,
'title': f"Slide {i + 1}",
'content': part.strip(),
'visual_elements': [],
'design_notes': None
})
return {
'title': f"Carousel on {content[:50]}...",
'slides': slides,
'design_guidelines': {
'color_scheme': 'professional',
'typography': 'clean',
'layout': 'minimal'
}
}
except Exception as e:
logger.error(f"Error processing carousel content: {str(e)}")
return {}
def _process_video_script_content(self, content: str) -> Dict[str, Any]:
"""Process LinkedIn video script content for structure."""
try:
# Basic video script processing
lines = content.split('\n')
hook = ""
main_content = []
conclusion = ""
# Extract hook (first few lines)
hook_lines = []
for line in lines[:3]:
if line.strip() and not line.strip().startswith('#'):
hook_lines.append(line.strip())
if len(' '.join(hook_lines)) > 100:
break
hook = ' '.join(hook_lines)
# Extract conclusion (last few lines)
conclusion_lines = []
for line in lines[-3:]:
if line.strip() and not line.strip().startswith('#'):
conclusion_lines.insert(0, line.strip())
if len(' '.join(conclusion_lines)) > 100:
break
conclusion = ' '.join(conclusion_lines)
# Main content (everything in between)
main_content_text = content[len(hook):len(content)-len(conclusion)].strip()
return {
'hook': hook,
'main_content': [{
'scene_number': 1,
'content': main_content_text,
'duration': 60,
'visual_notes': 'Professional presentation style'
}],
'conclusion': conclusion,
'thumbnail_suggestions': ['Professional thumbnail', 'Industry-focused image'],
'video_description': f"Professional insights on {content[:100]}..."
}
except Exception as e:
logger.error(f"Error processing video script content: {str(e)}")
return {}
def _extract_keywords(self, content: str) -> List[str]:
"""Extract relevant keywords from content."""
try:
# Simple keyword extraction (can be enhanced with NLP)
words = re.findall(r'\b\w+\b', content.lower())
word_freq = {}
# Filter out common words
stop_words = {'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by', 'is', 'are', 'was', 'were', 'be', 'been', 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could', 'should', 'may', 'might', 'can', 'this', 'that', 'these', 'those', 'a', 'an'}
for word in words:
if word not in stop_words and len(word) > 3:
word_freq[word] = word_freq.get(word, 0) + 1
# Return top keywords
sorted_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)
return [word for word, freq in sorted_words[:10]]
except Exception as e:
logger.error(f"Error extracting keywords: {str(e)}")
return []
def add_citations(self, content: str, sources: List[Dict[str, Any]]) -> str:
"""
Add inline citations to content based on grounding metadata.
Args:
content: The content to add citations to
sources: List of sources from grounding metadata
Returns:
Content with inline citations
"""
try:
if not sources:
return content
# Create citation mapping
citation_map = {}
for source in sources:
index = source.get('index', 0)
citation_map[index] = f"[Source {index + 1}]({source.get('url', '')})"
# Add citations at the end of sentences or paragraphs
# This is a simplified approach - in practice, you'd use the groundingSupports data
citation_text = "\n\n**Sources:**\n"
for i, source in enumerate(sources):
citation_text += f"{i+1}. **{source.get('title', f'Source {i+1}')}**\n - URL: [{source.get('url', '')}]({source.get('url', '')})\n\n"
return content + citation_text
except Exception as e:
logger.error(f"Error adding citations: {str(e)}")
return content
def extract_citations(self, content: str) -> List[Dict[str, Any]]:
"""
Extract citations from content.
Args:
content: Content to extract citations from
Returns:
List of citation objects
"""
try:
citations = []
# Look for citation patterns
citation_patterns = [
r'\[Source (\d+)\]',
r'\[(\d+)\]',
r'\(Source (\d+)\)'
]
for pattern in citation_patterns:
matches = re.finditer(pattern, content)
for match in matches:
citations.append({
'type': 'inline',
'reference': match.group(0),
'position': match.start(),
'source_index': int(match.group(1)) - 1
})
return citations
except Exception as e:
logger.error(f"Error extracting citations: {str(e)}")
return []
def assess_content_quality(self, content: str, sources: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Assess the quality of generated content.
Args:
content: The generated content
sources: List of sources used
Returns:
Quality metrics dictionary
"""
try:
# Basic quality metrics
word_count = len(content.split())
char_count = len(content)
# Source coverage
source_coverage = min(1.0, len(sources) / max(1, word_count / 100))
# Professional tone indicators
professional_indicators = ['research', 'analysis', 'insights', 'trends', 'industry', 'professional']
unprofessional_indicators = ['awesome', 'amazing', 'incredible', 'mind-blowing']
professional_score = sum(1 for indicator in professional_indicators if indicator.lower() in content.lower()) / len(professional_indicators)
unprofessional_score = sum(1 for indicator in unprofessional_indicators if indicator.lower() in content.lower()) / len(unprofessional_indicators)
tone_score = max(0, professional_score - unprofessional_score)
# Overall quality score
overall_score = (source_coverage * 0.4 + tone_score * 0.3 + min(1.0, word_count / 500) * 0.3)
return {
'overall_score': round(overall_score, 2),
'source_coverage': round(source_coverage, 2),
'tone_score': round(tone_score, 2),
'word_count': word_count,
'char_count': char_count,
'sources_count': len(sources),
'quality_level': 'high' if overall_score > 0.8 else 'medium' if overall_score > 0.6 else 'low'
}
except Exception as e:
logger.error(f"Error assessing content quality: {str(e)}")
return {
'overall_score': 0.0,
'error': str(e)
}

View File

@@ -0,0 +1,842 @@
"""
Gemini Provider Module for ALwrity
This module provides functions for interacting with Google's Gemini API, specifically designed
for structured JSON output and text generation. It follows the official Gemini API documentation
and implements best practices for reliable AI interactions.
Key Features:
- Structured JSON response generation with schema validation
- Text response generation with retry logic
- Comprehensive error handling and logging
- Automatic API key management
- Support for both gemini-2.5-flash and gemini-2.5-pro models
Best Practices:
1. Use structured output for complex, multi-field responses
2. Keep schemas simple and flat to avoid truncation
3. Set appropriate token limits (8192 for complex outputs)
4. Use low temperature (0.1-0.3) for consistent structured output
5. Implement proper error handling in calling functions
6. Avoid fallback to text parsing for structured responses
Usage Examples:
# Structured JSON response
schema = {
"type": "object",
"properties": {
"tasks": {
"type": "array",
"items": {"type": "object", "properties": {...}}
}
}
}
result = gemini_structured_json_response(prompt, schema, temperature=0.2, max_tokens=8192)
# Text response
result = gemini_text_response(prompt, temperature=0.7, max_tokens=2048)
Troubleshooting:
- If response.parsed is None: Check schema complexity and token limits
- If JSON parsing fails: Verify schema matches expected output structure
- If truncation occurs: Reduce output size or increase max_tokens
- If rate limiting: Implement exponential backoff (already included)
Dependencies:
- google.generativeai (genai)
- tenacity (for retry logic)
- logging (for debugging)
- json (for fallback parsing)
- re (for text extraction)
Author: ALwrity Team
Version: 2.0
Last Updated: January 2025
"""
import os
import sys
from pathlib import Path
import google.genai as genai
from google.genai import types
from dotenv import load_dotenv
# Fix the environment loading path - load from backend directory
current_dir = Path(__file__).parent.parent # services directory
backend_dir = current_dir.parent # backend directory
env_path = backend_dir / '.env'
if env_path.exists():
load_dotenv(env_path)
print(f"Loaded .env from: {env_path}")
else:
# Fallback to current directory
load_dotenv()
print(f"No .env found at {env_path}, using current directory")
from loguru import logger
from utils.logger_utils import get_service_logger
# Use service-specific logger to avoid conflicts
logger = get_service_logger("gemini_provider")
from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential,
)
import asyncio
import json
import re
from typing import Optional, Dict, Any
# Configure standard logging
import logging
logging.basicConfig(level=logging.INFO, format='[%(asctime)s-%(levelname)s-%(module)s-%(lineno)d]- %(message)s')
logger = logging.getLogger(__name__)
def get_gemini_api_key() -> str:
"""Get Gemini API key with proper error handling."""
api_key = os.getenv('GEMINI_API_KEY')
if not api_key:
error_msg = "GEMINI_API_KEY environment variable is not set. Please set it in your .env file."
logger.error(error_msg)
raise ValueError(error_msg)
# Validate API key format (basic check)
if not api_key.startswith('AIza'):
error_msg = "GEMINI_API_KEY appears to be invalid. It should start with 'AIza'."
logger.error(error_msg)
raise ValueError(error_msg)
return api_key
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def gemini_text_response(prompt, temperature, top_p, n, max_tokens, system_prompt):
"""
Generate text response using Google's Gemini Pro model.
This function provides simple text generation with retry logic and error handling.
For structured output, use gemini_structured_json_response instead.
Args:
prompt (str): The input prompt for the AI model
temperature (float): Controls randomness (0.0-1.0). Higher = more creative
top_p (float): Nucleus sampling parameter (0.0-1.0)
n (int): Number of responses to generate
max_tokens (int): Maximum tokens in response
system_prompt (str, optional): System instruction for the model
Returns:
str: Generated text response
Raises:
Exception: If API key is missing or API call fails
Best Practices:
- Use temperature 0.7-0.9 for creative content
- Use temperature 0.1-0.3 for factual/consistent content
- Set appropriate max_tokens based on expected response length
- Implement proper error handling in calling functions
Example:
result = gemini_text_response(
"Write a blog post about AI",
temperature=0.8,
max_tokens=1024
)
"""
#FIXME: Include : https://github.com/google-gemini/cookbook/blob/main/quickstarts/rest/System_instructions_REST.ipynb
try:
api_key = get_gemini_api_key()
client = genai.Client(api_key=api_key)
logger.info("✅ Gemini client initialized successfully")
except Exception as err:
logger.error(f"Failed to configure Gemini: {err}")
raise
logger.info(f"Temp: {temperature}, MaxTokens: {max_tokens}, TopP: {top_p}, N: {n}")
# Set up AI model config
generation_config = {
"temperature": temperature,
"top_p": top_p,
"top_k": n,
"max_output_tokens": max_tokens,
}
# FIXME: Expose model_name in main_config
try:
response = client.models.generate_content(
model='gemini-2.0-flash-lite',
contents=prompt,
config=types.GenerateContentConfig(
system_instruction=system_prompt,
max_output_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=n,
),
)
#logger.info(f"Number of Token in Prompt Sent: {model.count_tokens(prompt)}")
return response.text
except Exception as err:
logger.error(f"Failed to get response from Gemini: {err}. Retrying.")
raise
async def test_gemini_api_key(api_key: str) -> tuple[bool, str]:
"""
Test if the provided Gemini API key is valid.
Args:
api_key (str): The Gemini API key to test
Returns:
tuple[bool, str]: A tuple containing (is_valid, message)
"""
try:
# Validate API key format first
if not api_key:
return False, "API key is empty"
if not api_key.startswith('AIza'):
return False, "API key format appears invalid (should start with 'AIza')"
# Configure Gemini with the provided key
client = genai.Client(api_key=api_key)
# Try to list models as a simple API test
models = client.models.list()
# Check if Gemini Pro is available
model_names = [model.name for model in models]
logger.info(f"Available models: {model_names}")
if any("gemini" in model_name.lower() for model_name in model_names):
return True, "Gemini API key is valid"
else:
return False, "No Gemini models available with this API key"
except Exception as e:
error_msg = f"Error testing Gemini API key: {str(e)}"
logger.error(error_msg)
return False, error_msg
def gemini_pro_text_gen(prompt, temperature=0.7, top_p=0.9, top_k=40, max_tokens=2048):
"""
Generate text using Google's Gemini Pro model.
Args:
prompt (str): The input text to generate completion for
temperature (float, optional): Controls randomness. Defaults to 0.7
top_p (float, optional): Controls diversity. Defaults to 0.9
top_k (int, optional): Controls vocabulary size. Defaults to 40
max_tokens (int, optional): Maximum number of tokens to generate. Defaults to 2048
Returns:
str: The generated text completion
"""
try:
# Get API key with proper error handling
api_key = get_gemini_api_key()
client = genai.Client(api_key=api_key)
# Generate content using the new client
response = client.models.generate_content(
model='gemini-2.5-flash',
contents=prompt,
config=types.GenerateContentConfig(
max_output_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
),
)
# Return the generated text
return response.text
except Exception as e:
logger.error(f"Error in Gemini Pro text generation: {e}")
return str(e)
def _dict_to_types_schema(schema: Dict[str, Any]) -> types.Schema:
"""Convert a lightweight dict schema to google.genai.types.Schema."""
if not isinstance(schema, dict):
raise ValueError("response_schema must be a dict compatible with types.Schema")
def _convert(node: Dict[str, Any]) -> types.Schema:
node_type = (node.get("type") or "OBJECT").upper()
if node_type == "OBJECT":
props = node.get("properties") or {}
props_types: Dict[str, types.Schema] = {}
for key, prop in props.items():
if isinstance(prop, dict):
props_types[key] = _convert(prop)
else:
props_types[key] = types.Schema(type=types.Type.STRING)
return types.Schema(type=types.Type.OBJECT, properties=props_types if props_types else None)
elif node_type == "ARRAY":
items_node = node.get("items")
if isinstance(items_node, dict):
item_schema = _convert(items_node)
else:
item_schema = types.Schema(type=types.Type.STRING)
return types.Schema(type=types.Type.ARRAY, items=item_schema)
elif node_type == "NUMBER":
return types.Schema(type=types.Type.NUMBER)
elif node_type == "INTEGER":
return types.Schema(type=types.Type.NUMBER)
elif node_type == "BOOLEAN":
return types.Schema(type=types.Type.BOOLEAN)
else:
return types.Schema(type=types.Type.STRING)
return _convert(schema)
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9, top_k=40, max_tokens=8192, system_prompt=None):
"""
Generate structured JSON response using Google's Gemini Pro model.
This function follows the official Gemini API documentation for structured output:
https://ai.google.dev/gemini-api/docs/structured-output#python
Args:
prompt (str): The input prompt for the AI model
schema (dict): JSON schema defining the expected output structure
temperature (float): Controls randomness (0.0-1.0). Use 0.1-0.3 for structured output
top_p (float): Nucleus sampling parameter (0.0-1.0)
top_k (int): Top-k sampling parameter
max_tokens (int): Maximum tokens in response. Use 8192 for complex outputs
system_prompt (str, optional): System instruction for the model
Returns:
dict: Parsed JSON response matching the provided schema
Raises:
Exception: If API key is missing or API call fails
Best Practices:
- Keep schemas simple and flat to avoid truncation
- Use low temperature (0.1-0.3) for consistent structured output
- Set max_tokens to 8192 for complex multi-field responses
- Avoid deeply nested schemas with many required fields
- Test with smaller outputs first, then scale up
Example:
schema = {
"type": "object",
"properties": {
"tasks": {
"type": "array",
"items": {
"type": "object",
"properties": {
"title": {"type": "string"},
"description": {"type": "string"}
}
}
}
}
}
result = gemini_structured_json_response(prompt, schema, temperature=0.2, max_tokens=8192)
"""
try:
# Get API key with proper error handling
api_key = get_gemini_api_key()
logger.info(f"🔑 Gemini API key loaded: {bool(api_key)} (length: {len(api_key) if api_key else 0})")
if not api_key:
raise Exception("GEMINI_API_KEY not found in environment variables")
client = genai.Client(api_key=api_key)
logger.info("✅ Gemini client initialized for structured JSON response")
# Prepare schema for SDK (dict -> types.Schema). If schema is already a types.Schema or Pydantic type, use as-is
try:
if isinstance(schema, dict):
types_schema = _dict_to_types_schema(schema)
else:
types_schema = schema
except Exception as conv_err:
logger.info(f"Schema conversion warning, defaulting to OBJECT: {conv_err}")
types_schema = types.Schema(type=types.Type.OBJECT)
# Add debugging for API call
logger.info(
"Gemini structured call | prompt_len=%s | schema_kind=%s | temp=%s | top_p=%s | top_k=%s | max_tokens=%s",
len(prompt) if isinstance(prompt, str) else '<non-str>',
type(types_schema).__name__,
temperature,
top_p,
top_k,
max_tokens,
)
# Use the official SDK GenerateContentConfig with response_schema
generation_config = types.GenerateContentConfig(
response_mime_type='application/json',
response_schema=types_schema,
max_output_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
system_instruction=system_prompt,
)
logger.info("🚀 Making Gemini API call...")
# Use enhanced retry logic for structured JSON calls
from services.blog_writer.retry_utils import retry_with_backoff, CONTENT_RETRY_CONFIG
async def make_api_call():
return client.models.generate_content(
model="gemini-2.5-flash",
contents=prompt,
config=generation_config,
)
try:
# Convert sync call to async for retry logic
import asyncio
# Check if there's already an event loop running
try:
loop = asyncio.get_running_loop()
# If we're already in an async context, we need to run this differently
logger.warning("⚠️ Already in async context, using direct sync call")
# For now, let's use a simpler approach without retry logic
response = client.models.generate_content(
model="gemini-2.5-flash",
contents=prompt,
config=generation_config,
)
logger.info("✅ Gemini API call completed successfully (sync mode)")
except RuntimeError:
# No event loop running, we can create one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
response = loop.run_until_complete(
retry_with_backoff(
make_api_call,
config=CONTENT_RETRY_CONFIG,
operation_name="gemini_structured_json",
context={"schema_type": type(types_schema).__name__, "max_tokens": max_tokens}
)
)
logger.info("✅ Gemini API call completed successfully")
except Exception as api_error:
logger.error(f"❌ Gemini API call failed: {api_error}")
logger.error(f"❌ API Error type: {type(api_error).__name__}")
# Enhance error with specific exception types
error_str = str(api_error)
if "429" in error_str or "rate limit" in error_str.lower():
from services.blog_writer.exceptions import APIRateLimitException
raise APIRateLimitException(
f"Rate limit exceeded for structured JSON generation: {error_str}",
retry_after=60,
context={"operation": "structured_json", "max_tokens": max_tokens}
)
elif "timeout" in error_str.lower():
from services.blog_writer.exceptions import APITimeoutException
raise APITimeoutException(
f"Structured JSON generation timed out: {error_str}",
timeout_seconds=60,
context={"operation": "structured_json", "max_tokens": max_tokens}
)
elif "401" in error_str or "403" in error_str:
from services.blog_writer.exceptions import ValidationException
raise ValidationException(
"Authentication failed for structured JSON generation. Please check your API credentials.",
field="api_key",
context={"error": error_str, "operation": "structured_json"}
)
else:
from services.blog_writer.exceptions import ContentGenerationException
raise ContentGenerationException(
f"Structured JSON generation failed: {error_str}",
context={"error": error_str, "operation": "structured_json", "max_tokens": max_tokens}
)
# Check for parsed content first (primary method for structured output)
if hasattr(response, 'parsed'):
logger.info(f"Response has parsed attribute: {response.parsed is not None}")
if response.parsed is not None:
logger.info("Using response.parsed for structured output")
return response.parsed
else:
logger.warning("Response.parsed is None, falling back to text parsing")
# Debug: Check if there's any text content
if hasattr(response, 'text') and response.text:
logger.info(f"Text response length: {len(response.text)}")
logger.debug(f"Text response preview: {response.text[:200]}...")
# Check for text content as fallback (only if no parsed content)
if hasattr(response, 'text') and response.text:
logger.info("No parsed content, trying to parse text response")
try:
import json
import re
# Clean the text response to fix common JSON issues
cleaned_text = response.text.strip()
# Remove any markdown code blocks if present
if cleaned_text.startswith('```json'):
cleaned_text = cleaned_text[7:]
if cleaned_text.endswith('```'):
cleaned_text = cleaned_text[:-3]
cleaned_text = cleaned_text.strip()
# Try to find JSON content between curly braces
json_match = re.search(r'\{.*\}', cleaned_text, re.DOTALL)
if json_match:
cleaned_text = json_match.group(0)
parsed_text = json.loads(cleaned_text)
logger.info("Successfully parsed text as JSON")
return parsed_text
except json.JSONDecodeError as e:
logger.error(f"Failed to parse text as JSON: {e}")
logger.debug(f"Problematic text (first 500 chars): {response.text[:500]}")
# Try to extract and fix JSON manually
try:
import re
# Look for the main JSON object
json_pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
matches = re.findall(json_pattern, response.text, re.DOTALL)
if matches:
# Try the largest match (likely the main JSON)
largest_match = max(matches, key=len)
# Basic cleanup of common issues
fixed_json = largest_match.replace('\n', ' ').replace('\r', ' ')
# Remove any trailing commas before closing braces
fixed_json = re.sub(r',\s*}', '}', fixed_json)
fixed_json = re.sub(r',\s*]', ']', fixed_json)
parsed_text = json.loads(fixed_json)
logger.info("Successfully parsed cleaned JSON")
return parsed_text
except Exception as fix_error:
logger.error(f"Failed to fix JSON manually: {fix_error}")
# Check candidates for content (fallback for edge cases)
if hasattr(response, 'candidates') and response.candidates:
candidate = response.candidates[0]
if hasattr(candidate, 'content') and candidate.content:
if hasattr(candidate.content, 'parts') and candidate.content.parts:
for part in candidate.content.parts:
if hasattr(part, 'text') and part.text:
try:
import json
parsed_text = json.loads(part.text)
logger.info("Successfully parsed candidate text as JSON")
return parsed_text
except json.JSONDecodeError as e:
logger.error(f"Failed to parse candidate text as JSON: {e}")
logger.error("No valid structured response content found")
return {"error": "No valid structured response content found"}
except ValueError as e:
# API key related errors should not be retried
logger.error(f"API key error in Gemini Pro structured JSON generation: {e}")
return {"error": str(e)}
except Exception as e:
# Check if this is a quota/rate limit error
msg = str(e)
if "RESOURCE_EXHAUSTED" in msg or "429" in msg or "quota" in msg.lower():
logger.error(f"Rate limit/quota error in Gemini Pro structured JSON generation: {msg}")
# Return error instead of retrying - quota exhausted means we need to wait or upgrade plan
return {"error": msg}
# For other errors, let tenacity handle retries
logger.error(f"Error in Gemini Pro structured JSON generation: {e}")
raise
# Removed JSON repair functions to avoid false positives
def _removed_repair_json_string(text: str) -> Optional[str]:
"""
Attempt to repair common JSON issues in AI responses.
"""
if not text:
return None
# Remove any non-JSON content before first {
start = text.find('{')
if start == -1:
return None
text = text[start:]
# Remove any content after last }
end = text.rfind('}')
if end == -1:
return None
text = text[:end+1]
# Fix common issues
repaired = text
# 1. Fix unterminated arrays (add missing closing brackets)
# Count opening and closing brackets
open_brackets = repaired.count('[')
close_brackets = repaired.count(']')
if open_brackets > close_brackets:
# Add missing closing brackets
missing_brackets = open_brackets - close_brackets
repaired = repaired + ']' * missing_brackets
# 2. Fix unterminated strings in arrays
# Look for patterns like ["item1", "item2" and add missing quote and bracket
lines = repaired.split('\n')
fixed_lines = []
for i, line in enumerate(lines):
stripped = line.strip()
# Check if line ends with an unquoted string in an array
if stripped.endswith('"') and i < len(lines) - 1:
next_line = lines[i + 1].strip()
if next_line.startswith(']'):
# This is fine
pass
elif not next_line.startswith('"') and not next_line.startswith(']'):
# Add missing quote and comma
line = line + '",'
fixed_lines.append(line)
repaired = '\n'.join(fixed_lines)
# 3. Fix unterminated strings (common issue with AI responses)
try:
# Handle unterminated strings by finding the last incomplete string and closing it
lines = repaired.split('\n')
fixed_lines = []
for i, line in enumerate(lines):
stripped = line.strip()
# Check for unterminated strings (line ends with quote but no closing quote)
if stripped.endswith('"') and i < len(lines) - 1:
next_line = lines[i + 1].strip()
# If next line doesn't start with quote or closing bracket, we might have an unterminated string
if not next_line.startswith('"') and not next_line.startswith(']') and not next_line.startswith('}'):
# Check if this looks like an unterminated string value
if ':' in line and not line.strip().endswith('",'):
line = line + '",'
# Count quotes in the line
quote_count = line.count('"')
if quote_count % 2 == 1: # Odd number of quotes
# Add a quote at the end if it looks like an incomplete string
if ':' in line and line.strip().endswith('"'):
line = line + '"'
elif ':' in line and not line.strip().endswith('"') and not line.strip().endswith(','):
line = line + '",'
fixed_lines.append(line)
repaired = '\n'.join(fixed_lines)
except Exception:
pass
# 4. Remove trailing commas before closing braces/brackets
repaired = re.sub(r',(\s*[}\]])', r'\1', repaired)
# 5. Fix missing commas between object properties
repaired = re.sub(r'"(\s*)"', r'",\1"', repaired)
return repaired
# Removed partial JSON extraction to avoid false positives
def _removed_extract_partial_json(text: str) -> Optional[Dict[str, Any]]:
"""
Extract partial JSON from truncated responses.
Attempts to salvage as much data as possible from incomplete JSON.
"""
if not text:
return None
try:
# Find the start of JSON
start = text.find('{')
if start == -1:
return None
# Extract from start to end, handling common truncation patterns
json_text = text[start:]
# Common truncation patterns and their fixes
truncation_patterns = [
(r'(["\w\s,{}\[\]\-\.:]+)\.\.\.$', r'\1'), # Remove trailing ...
(r'(["\w\s,{}\[\]\-\.:]+)"$', r'\1"'), # Add missing closing quote
(r'(["\w\s,{}\[\]\-\.:]+),$', r'\1'), # Remove trailing comma
(r'(["\w\s,{}\[\]\-\.:]+)\[(["\w\s,{}\[\]\-\.:]*)$', r'\1\2]'), # Close unclosed arrays
(r'(["\w\s,{}\[\]\-\.:]+)\{(["\w\s,{}\[\]\-\.:]*)$', r'\1\2}'), # Close unclosed objects
]
# Apply truncation fixes
import re
for pattern, replacement in truncation_patterns:
json_text = re.sub(pattern, replacement, json_text)
# Try to balance brackets and braces
open_braces = json_text.count('{')
close_braces = json_text.count('}')
open_brackets = json_text.count('[')
close_brackets = json_text.count(']')
# Add missing closing braces/brackets
if open_braces > close_braces:
json_text += '}' * (open_braces - close_braces)
if open_brackets > close_brackets:
json_text += ']' * (open_brackets - close_brackets)
# Try to parse the repaired JSON
try:
result = json.loads(json_text)
logger.info(f"Successfully extracted partial JSON with {len(str(result))} characters")
return result
except json.JSONDecodeError as e:
logger.debug(f"Partial JSON parsing failed: {e}")
# Try to extract individual fields as a last resort
fields = {}
# Extract key-value pairs using regex (more comprehensive patterns)
kv_patterns = [
r'"([^"]+)"\s*:\s*"([^"]*)"', # "key": "value"
r'"([^"]+)"\s*:\s*(\d+)', # "key": 123
r'"([^"]+)"\s*:\s*(true|false)', # "key": true/false
r'"([^"]+)"\s*:\s*null', # "key": null
]
for pattern in kv_patterns:
matches = re.findall(pattern, json_text)
for key, value in matches:
if value == 'true':
fields[key] = True
elif value == 'false':
fields[key] = False
elif value == 'null':
fields[key] = None
elif value.isdigit():
fields[key] = int(value)
else:
fields[key] = value
# Extract array fields (more robust)
array_pattern = r'"([^"]+)"\s*:\s*\[([^\]]*)\]'
array_matches = re.findall(array_pattern, json_text)
for key, array_content in array_matches:
# Parse array items more comprehensively
items = []
# Look for quoted strings, numbers, booleans, null
item_patterns = [
r'"([^"]*)"', # quoted strings
r'(\d+)', # numbers
r'(true|false)', # booleans
r'(null)', # null
]
for pattern in item_patterns:
item_matches = re.findall(pattern, array_content)
for match in item_matches:
if match == 'true':
items.append(True)
elif match == 'false':
items.append(False)
elif match == 'null':
items.append(None)
elif match.isdigit():
items.append(int(match))
else:
items.append(match)
if items:
fields[key] = items
# Extract nested object fields (basic)
object_pattern = r'"([^"]+)"\s*:\s*\{([^}]*)\}'
object_matches = re.findall(object_pattern, json_text)
for key, object_content in object_matches:
# Simple nested object extraction
nested_fields = {}
nested_kv_matches = re.findall(r'"([^"]+)"\s*:\s*"([^"]*)"', object_content)
for nested_key, nested_value in nested_kv_matches:
nested_fields[nested_key] = nested_value
if nested_fields:
fields[key] = nested_fields
if fields:
logger.info(f"Extracted {len(fields)} fields from truncated JSON: {list(fields.keys())}")
# Only return if we have a valid outline structure
if 'outline' in fields and isinstance(fields['outline'], list):
return {'outline': fields['outline']}
else:
logger.error("No valid 'outline' field found in partial JSON")
return None
return None
except Exception as e:
logger.debug(f"Error in partial JSON extraction: {e}")
return None
# Removed key-value extraction to avoid false positives
def _removed_extract_key_value_pairs(text: str) -> Optional[Dict[str, Any]]:
"""
Extract key-value pairs from malformed JSON text as a last resort.
"""
if not text:
return None
result = {}
# Look for patterns like "key": "value" or "key": value
# This regex looks for quoted keys followed by colons and values
pattern = r'"([^"]+)"\s*:\s*(?:"([^"]*)"|([^,}\]]+))'
matches = re.findall(pattern, text)
for key, quoted_value, unquoted_value in matches:
value = quoted_value if quoted_value else unquoted_value.strip()
# Clean up the value - remove any trailing content that looks like the next key
# This handles cases where the regex captured too much
if value and '"' in value:
# Split at the first quote that might be the start of the next key
parts = value.split('"')
if len(parts) > 1:
value = parts[0].strip()
# Try to parse the value appropriately
if value.lower() in ['true', 'false']:
result[key] = value.lower() == 'true'
elif value.lower() == 'null':
result[key] = None
elif value.isdigit():
result[key] = int(value)
elif value.replace('.', '').replace('-', '').isdigit():
try:
result[key] = float(value)
except ValueError:
result[key] = value
else:
result[key] = value
# Also try to extract array values
array_pattern = r'"([^"]+)"\s*:\s*\[([^\]]*)\]'
array_matches = re.findall(array_pattern, text)
for key, array_content in array_matches:
# Extract individual array items
items = []
# Look for quoted strings in the array
item_pattern = r'"([^"]*)"'
item_matches = re.findall(item_pattern, array_content)
for item in item_matches:
if item.strip():
items.append(item.strip())
if items:
result[key] = items
return result if result else None

View File

@@ -0,0 +1,441 @@
"""
Hugging Face Provider Module for ALwrity
This module provides functions for interacting with Hugging Face's Inference Providers API
using the Responses API (beta) which provides a unified interface for model interactions.
Key Features:
- Text response generation with retry logic
- Structured JSON response generation with schema validation
- Comprehensive error handling and logging
- Automatic API key management
- Support for various Hugging Face models via Inference Providers
Best Practices:
1. Use structured output for complex, multi-field responses
2. Keep schemas simple and flat to avoid truncation
3. Set appropriate token limits (8192 for complex outputs)
4. Use low temperature (0.1-0.3) for consistent structured output
5. Implement proper error handling in calling functions
6. Use the Responses API for better compatibility
Usage Examples:
# Text response
result = huggingface_text_response(prompt, temperature=0.7, max_tokens=2048)
# Structured JSON response
schema = {
"type": "object",
"properties": {
"tasks": {
"type": "array",
"items": {"type": "object", "properties": {...}}
}
}
}
result = huggingface_structured_json_response(prompt, schema, temperature=0.2, max_tokens=8192)
Dependencies:
- openai (for Hugging Face Responses API)
- tenacity (for retry logic)
- logging (for debugging)
- json (for fallback parsing)
Author: ALwrity Team
Version: 1.0
Last Updated: January 2025
"""
import os
import sys
from pathlib import Path
import json
import re
from typing import Optional, Dict, Any
from dotenv import load_dotenv
# Fix the environment loading path - load from backend directory
current_dir = Path(__file__).parent.parent # services directory
backend_dir = current_dir.parent # backend directory
env_path = backend_dir / '.env'
if env_path.exists():
load_dotenv(env_path)
print(f"Loaded .env from: {env_path}")
else:
# Fallback to current directory
load_dotenv()
print(f"No .env found at {env_path}, using current directory")
from loguru import logger
from utils.logger_utils import get_service_logger
# Use service-specific logger to avoid conflicts
logger = get_service_logger("huggingface_provider")
from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential,
)
try:
from openai import OpenAI
OPENAI_AVAILABLE = True
except ImportError:
OPENAI_AVAILABLE = False
logger.warn("OpenAI library not available. Install with: pip install openai")
def get_huggingface_api_key() -> str:
"""Get Hugging Face API key with proper error handling."""
api_key = os.getenv('HF_TOKEN')
if not api_key:
error_msg = "HF_TOKEN environment variable is not set. Please set it in your .env file."
logger.error(error_msg)
raise ValueError(error_msg)
# Validate API key format (basic check)
if not api_key.startswith('hf_'):
error_msg = "HF_TOKEN appears to be invalid. It should start with 'hf_'."
logger.error(error_msg)
raise ValueError(error_msg)
return api_key
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def huggingface_text_response(
prompt: str,
model: str = "openai/gpt-oss-120b:groq",
temperature: float = 0.7,
max_tokens: int = 2048,
top_p: float = 0.9,
system_prompt: Optional[str] = None
) -> str:
"""
Generate text response using Hugging Face Inference Providers API.
This function uses the Hugging Face Responses API which provides a unified interface
for model interactions with built-in retry logic and error handling.
Args:
prompt (str): The input prompt for the AI model
model (str): Hugging Face model identifier (default: "openai/gpt-oss-120b:groq")
temperature (float): Controls randomness (0.0-1.0)
max_tokens (int): Maximum tokens in response
top_p (float): Nucleus sampling parameter (0.0-1.0)
system_prompt (str, optional): System instruction for the model
Returns:
str: Generated text response
Raises:
Exception: If API key is missing or API call fails
Best Practices:
- Use appropriate temperature for your use case (0.7 for creative, 0.1-0.3 for factual)
- Set max_tokens based on expected response length
- Use system_prompt to guide model behavior
- Handle errors gracefully in calling functions
Example:
result = huggingface_text_response(
prompt="Write a blog post about AI",
model="openai/gpt-oss-120b:groq",
temperature=0.7,
max_tokens=2048,
system_prompt="You are a professional content writer."
)
"""
try:
if not OPENAI_AVAILABLE:
raise ImportError("OpenAI library not available. Install with: pip install openai")
# Get API key with proper error handling
api_key = get_huggingface_api_key()
logger.info(f"🔑 Hugging Face API key loaded: {bool(api_key)} (length: {len(api_key) if api_key else 0})")
if not api_key:
raise Exception("HF_TOKEN not found in environment variables")
# Initialize Hugging Face client using Responses API
client = OpenAI(
base_url="https://router.huggingface.co/v1",
api_key=api_key,
)
logger.info("✅ Hugging Face client initialized for text response")
# Prepare input for the API
input_content = []
# Add system prompt if provided
if system_prompt:
input_content.append({
"role": "system",
"content": system_prompt
})
# Add user prompt
input_content.append({
"role": "user",
"content": prompt
})
# Add debugging for API call
logger.info(
"Hugging Face text call | model=%s | prompt_len=%s | temp=%s | top_p=%s | max_tokens=%s",
model,
len(prompt) if isinstance(prompt, str) else '<non-str>',
temperature,
top_p,
max_tokens,
)
logger.info("🚀 Making Hugging Face API call...")
# Add rate limiting to prevent expensive API calls
import time
time.sleep(1) # 1 second delay between API calls
# Make the API call using Responses API
response = client.responses.parse(
model=model,
input=input_content,
temperature=temperature,
top_p=top_p,
)
# Extract text from response
if hasattr(response, 'output_text') and response.output_text:
generated_text = response.output_text
elif hasattr(response, 'output') and response.output:
# Handle case where output is a list
if isinstance(response.output, list) and len(response.output) > 0:
generated_text = response.output[0].get('content', '')
else:
generated_text = str(response.output)
else:
generated_text = str(response)
# Clean up the response
if generated_text:
# Remove any markdown formatting if present
generated_text = re.sub(r'```[a-zA-Z]*\n?', '', generated_text)
generated_text = re.sub(r'```\n?', '', generated_text)
generated_text = generated_text.strip()
logger.info(f"✅ Hugging Face text response generated successfully (length: {len(generated_text)})")
return generated_text
except Exception as e:
logger.error(f"❌ Hugging Face text generation failed: {str(e)}")
raise Exception(f"Hugging Face text generation failed: {str(e)}")
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def huggingface_structured_json_response(
prompt: str,
schema: Dict[str, Any],
model: str = "openai/gpt-oss-120b:groq",
temperature: float = 0.7,
max_tokens: int = 8192,
system_prompt: Optional[str] = None
) -> Dict[str, Any]:
"""
Generate structured JSON response using Hugging Face Inference Providers API.
This function uses the Hugging Face Responses API with structured output support
to generate JSON responses that match a provided schema.
Args:
prompt (str): The input prompt for the AI model
schema (dict): JSON schema defining the expected output structure
model (str): Hugging Face model identifier (default: "openai/gpt-oss-120b:groq")
temperature (float): Controls randomness (0.0-1.0). Use 0.1-0.3 for structured output
max_tokens (int): Maximum tokens in response. Use 8192 for complex outputs
system_prompt (str, optional): System instruction for the model
Returns:
dict: Parsed JSON response matching the provided schema
Raises:
Exception: If API key is missing or API call fails
Best Practices:
- Keep schemas simple and flat to avoid truncation
- Use low temperature (0.1-0.3) for consistent structured output
- Set max_tokens to 8192 for complex multi-field responses
- Avoid deeply nested schemas with many required fields
- Test with smaller outputs first, then scale up
Example:
schema = {
"type": "object",
"properties": {
"tasks": {
"type": "array",
"items": {
"type": "object",
"properties": {
"title": {"type": "string"},
"description": {"type": "string"}
}
}
}
}
}
result = huggingface_structured_json_response(prompt, schema, temperature=0.2, max_tokens=8192)
"""
try:
if not OPENAI_AVAILABLE:
raise ImportError("OpenAI library not available. Install with: pip install openai")
# Get API key with proper error handling
api_key = get_huggingface_api_key()
logger.info(f"🔑 Hugging Face API key loaded: {bool(api_key)} (length: {len(api_key) if api_key else 0})")
if not api_key:
raise Exception("HF_TOKEN not found in environment variables")
# Initialize Hugging Face client using Responses API
client = OpenAI(
base_url="https://router.huggingface.co/v1",
api_key=api_key,
)
logger.info("✅ Hugging Face client initialized for structured JSON response")
# Prepare input for the API
input_content = []
# Add system prompt if provided
if system_prompt:
input_content.append({
"role": "system",
"content": system_prompt
})
# Add user prompt with JSON instruction
json_instruction = "Please respond with valid JSON that matches the provided schema."
input_content.append({
"role": "user",
"content": f"{prompt}\n\n{json_instruction}"
})
# Add debugging for API call
logger.info(
"Hugging Face structured call | model=%s | prompt_len=%s | schema_kind=%s | temp=%s | max_tokens=%s",
model,
len(prompt) if isinstance(prompt, str) else '<non-str>',
type(schema).__name__,
temperature,
max_tokens,
)
logger.info("🚀 Making Hugging Face structured API call...")
# Make the API call using Responses API with structured output
# Use simple text generation and parse JSON manually to avoid API format issues
logger.info("🚀 Making Hugging Face API call (text mode with JSON parsing)...")
# Add JSON instruction to the prompt
json_instruction = "\n\nPlease respond with valid JSON that matches this exact structure:\n" + json.dumps(schema, indent=2)
input_content[-1]["content"] = input_content[-1]["content"] + json_instruction
# Add rate limiting to prevent expensive API calls
import time
time.sleep(1) # 1 second delay between API calls
response = client.responses.parse(
model=model,
input=input_content,
temperature=temperature
)
# Extract structured data from response
if hasattr(response, 'output_parsed') and response.output_parsed:
# The new API returns parsed data directly (Pydantic model case)
logger.info("✅ Hugging Face structured JSON response parsed successfully")
# Convert Pydantic model to dict if needed
if hasattr(response.output_parsed, 'model_dump'):
return response.output_parsed.model_dump()
elif hasattr(response.output_parsed, 'dict'):
return response.output_parsed.dict()
else:
return response.output_parsed
elif hasattr(response, 'output_text') and response.output_text:
# Fallback to text parsing if output_parsed is not available
response_text = response.output_text
# Clean up the response text
response_text = re.sub(r'```json\n?', '', response_text)
response_text = re.sub(r'```\n?', '', response_text)
response_text = response_text.strip()
try:
parsed_json = json.loads(response_text)
logger.info("✅ Hugging Face structured JSON response parsed from text")
return parsed_json
except json.JSONDecodeError as json_err:
logger.error(f"❌ JSON parsing failed: {json_err}")
logger.error(f"Raw response: {response_text}")
# Try to extract JSON from the response using regex
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
if json_match:
try:
extracted_json = json.loads(json_match.group())
logger.info("✅ JSON extracted using regex fallback")
return extracted_json
except json.JSONDecodeError:
pass
# If all else fails, return a structured error response
logger.error("❌ All JSON parsing attempts failed")
return {
"error": "Failed to parse JSON response",
"raw_response": response_text,
"schema_expected": schema
}
else:
logger.error("❌ No valid response data found")
return {
"error": "No valid response data found",
"raw_response": str(response),
"schema_expected": schema
}
except Exception as e:
error_msg = str(e) if str(e) else repr(e)
error_type = type(e).__name__
logger.error(f"❌ Hugging Face structured JSON generation failed: {error_type}: {error_msg}")
logger.error(f"❌ Full exception details: {repr(e)}")
import traceback
logger.error(f"❌ Traceback: {traceback.format_exc()}")
raise Exception(f"Hugging Face structured JSON generation failed: {error_type}: {error_msg}")
def get_available_models() -> list:
"""
Get list of available Hugging Face models for text generation.
Returns:
list: List of available model identifiers
"""
return [
"openai/gpt-oss-120b:groq",
"moonshotai/Kimi-K2-Instruct-0905:groq",
"Qwen/Qwen2.5-VL-7B-Instruct",
"meta-llama/Llama-3.1-8B-Instruct:groq",
"microsoft/Phi-3-medium-4k-instruct:groq",
"mistralai/Mistral-7B-Instruct-v0.3:groq"
]
def validate_model(model: str) -> bool:
"""
Validate if a model identifier is supported.
Args:
model (str): Model identifier to validate
Returns:
bool: True if model is supported, False otherwise
"""
available_models = get_available_models()
return model in available_models

View File

@@ -0,0 +1,17 @@
from .base import ImageGenerationOptions, ImageGenerationResult, ImageGenerationProvider
from .hf_provider import HuggingFaceImageProvider
from .gemini_provider import GeminiImageProvider
from .stability_provider import StabilityImageProvider
from .wavespeed_provider import WaveSpeedImageProvider
__all__ = [
"ImageGenerationOptions",
"ImageGenerationResult",
"ImageGenerationProvider",
"HuggingFaceImageProvider",
"GeminiImageProvider",
"StabilityImageProvider",
"WaveSpeedImageProvider",
]

View File

@@ -0,0 +1,37 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Dict, Any, Protocol
@dataclass
class ImageGenerationOptions:
prompt: str
negative_prompt: Optional[str] = None
width: int = 1024
height: int = 1024
guidance_scale: Optional[float] = None
steps: Optional[int] = None
seed: Optional[int] = None
model: Optional[str] = None
extra: Optional[Dict[str, Any]] = None
@dataclass
class ImageGenerationResult:
image_bytes: bytes
width: int
height: int
provider: str
model: Optional[str] = None
seed: Optional[int] = None
metadata: Optional[Dict[str, Any]] = None
class ImageGenerationProvider(Protocol):
"""Protocol for image generation providers."""
def generate(self, options: ImageGenerationOptions) -> ImageGenerationResult:
...

View File

@@ -0,0 +1,47 @@
from __future__ import annotations
import io
import os
from typing import Optional
from PIL import Image
from .base import ImageGenerationOptions, ImageGenerationResult, ImageGenerationProvider
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_generation.gemini")
class GeminiImageProvider(ImageGenerationProvider):
"""Google Gemini/Imagen backed image generation.
NOTE: Implementation should call the actual Gemini Images API used in the codebase.
Here we keep a minimal interface and expect the underlying client to be wired
similarly to other providers and return a PIL image or raw bytes.
"""
def __init__(self) -> None:
api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
logger.warning("GOOGLE_API_KEY not set. Gemini image generation may fail at runtime.")
logger.info("GeminiImageProvider initialized")
def generate(self, options: ImageGenerationOptions) -> ImageGenerationResult:
# Placeholder implementation to be replaced by real Gemini/Imagen call.
# For now, generate a 1x1 transparent PNG to maintain interface consistency
img = Image.new("RGBA", (max(1, options.width), max(1, options.height)), (0, 0, 0, 0))
with io.BytesIO() as buf:
img.save(buf, format="PNG")
png = buf.getvalue()
return ImageGenerationResult(
image_bytes=png,
width=img.width,
height=img.height,
provider="gemini",
model=os.getenv("GEMINI_IMAGE_MODEL"),
seed=options.seed,
)

View File

@@ -0,0 +1,73 @@
from __future__ import annotations
import io
import os
from typing import Optional, Dict, Any
from PIL import Image
from huggingface_hub import InferenceClient
from .base import ImageGenerationOptions, ImageGenerationResult, ImageGenerationProvider
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_generation.huggingface")
DEFAULT_HF_MODEL = os.getenv(
"HF_IMAGE_MODEL",
"black-forest-labs/FLUX.1-Krea-dev",
)
class HuggingFaceImageProvider(ImageGenerationProvider):
"""Hugging Face Inference Providers (fal-ai) backed image generation.
API doc: https://huggingface.co/docs/inference-providers/en/tasks/text-to-image
"""
def __init__(self, api_key: Optional[str] = None, provider: str = "fal-ai") -> None:
self.api_key = api_key or os.getenv("HF_TOKEN")
if not self.api_key:
raise RuntimeError("HF_TOKEN is required for Hugging Face image generation")
self.provider = provider
self.client = InferenceClient(provider=self.provider, api_key=self.api_key)
logger.info("HuggingFaceImageProvider initialized (provider=%s)", self.provider)
def generate(self, options: ImageGenerationOptions) -> ImageGenerationResult:
model = options.model or DEFAULT_HF_MODEL
params: Dict[str, Any] = {}
if options.guidance_scale is not None:
params["guidance_scale"] = options.guidance_scale
if options.steps is not None:
params["num_inference_steps"] = options.steps
if options.negative_prompt:
params["negative_prompt"] = options.negative_prompt
if options.seed is not None:
params["seed"] = options.seed
# The HF InferenceClient returns a PIL Image
logger.debug("HF generate: model=%s width=%s height=%s params=%s", model, options.width, options.height, params)
img: Image.Image = self.client.text_to_image(
options.prompt,
model=model,
width=options.width,
height=options.height,
**params,
)
with io.BytesIO() as buf:
img.save(buf, format="PNG")
image_bytes = buf.getvalue()
return ImageGenerationResult(
image_bytes=image_bytes,
width=img.width,
height=img.height,
provider="huggingface",
model=model,
seed=options.seed,
metadata={"provider": self.provider},
)

View File

@@ -0,0 +1,79 @@
from __future__ import annotations
import io
import os
from typing import Optional, Dict, Any
import requests
from PIL import Image
from .base import ImageGenerationOptions, ImageGenerationResult, ImageGenerationProvider
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_generation.stability")
DEFAULT_STABILITY_MODEL = os.getenv("STABILITY_MODEL", "stable-diffusion-xl-1024-v1-0")
class StabilityImageProvider(ImageGenerationProvider):
"""Stability AI Images API provider (simple text-to-image).
This uses the v1 text-to-image endpoint format. Adjust to match your existing
Stability integration if different.
"""
def __init__(self, api_key: Optional[str] = None) -> None:
self.api_key = api_key or os.getenv("STABILITY_API_KEY")
if not self.api_key:
logger.warning("STABILITY_API_KEY not set. Stability generation may fail at runtime.")
logger.info("StabilityImageProvider initialized")
def generate(self, options: ImageGenerationOptions) -> ImageGenerationResult:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "application/json",
"Content-Type": "application/json",
}
payload: Dict[str, Any] = {
"text_prompts": [
{"text": options.prompt, "weight": 1.0},
],
"cfg_scale": options.guidance_scale or 7.0,
"steps": options.steps or 30,
"width": options.width,
"height": options.height,
"seed": options.seed,
}
if options.negative_prompt:
payload["text_prompts"].append({"text": options.negative_prompt, "weight": -1.0})
model = options.model or DEFAULT_STABILITY_MODEL
url = f"https://api.stability.ai/v1/generation/{model}/text-to-image"
logger.debug("Stability generate: model=%s payload_keys=%s", model, list(payload.keys()))
resp = requests.post(url, headers=headers, json=payload, timeout=60)
resp.raise_for_status()
data = resp.json()
# Expecting data["artifacts"][0]["base64"]
import base64
artifact = (data.get("artifacts") or [{}])[0]
b64 = artifact.get("base64", "")
image_bytes = base64.b64decode(b64)
# Confirm dimensions by loading once (optional)
img = Image.open(io.BytesIO(image_bytes))
return ImageGenerationResult(
image_bytes=image_bytes,
width=img.width,
height=img.height,
provider="stability",
model=model,
seed=options.seed,
)

View File

@@ -0,0 +1,243 @@
"""WaveSpeed AI image generation provider (Ideogram V3 Turbo & Qwen Image)."""
import io
import os
from typing import Optional
from PIL import Image
from .base import ImageGenerationProvider, ImageGenerationOptions, ImageGenerationResult
from services.wavespeed.client import WaveSpeedClient
from utils.logger_utils import get_service_logger
logger = get_service_logger("wavespeed.image_provider")
class WaveSpeedImageProvider(ImageGenerationProvider):
"""WaveSpeed AI image generation provider supporting Ideogram V3 and Qwen."""
SUPPORTED_MODELS = {
"ideogram-v3-turbo": {
"name": "Ideogram V3 Turbo",
"description": "Photorealistic generation with superior text rendering",
"cost_per_image": 0.10, # Estimated, adjust based on actual pricing
"max_resolution": (1024, 1024),
"default_steps": 20,
},
"qwen-image": {
"name": "Qwen Image",
"description": "Fast, high-quality text-to-image generation",
"cost_per_image": 0.05, # Estimated, adjust based on actual pricing
"max_resolution": (1024, 1024),
"default_steps": 15,
}
}
def __init__(self, api_key: Optional[str] = None):
"""Initialize WaveSpeed image provider.
Args:
api_key: WaveSpeed API key (falls back to env var if not provided)
"""
self.api_key = api_key or os.getenv("WAVESPEED_API_KEY")
if not self.api_key:
raise ValueError("WaveSpeed API key not found. Set WAVESPEED_API_KEY environment variable.")
self.client = WaveSpeedClient(api_key=self.api_key)
logger.info("[WaveSpeed Image Provider] Initialized with available models: %s",
list(self.SUPPORTED_MODELS.keys()))
def _validate_options(self, options: ImageGenerationOptions) -> None:
"""Validate generation options.
Args:
options: Image generation options
Raises:
ValueError: If options are invalid
"""
model = options.model or "ideogram-v3-turbo"
if model not in self.SUPPORTED_MODELS:
raise ValueError(
f"Unsupported model: {model}. "
f"Supported models: {list(self.SUPPORTED_MODELS.keys())}"
)
model_info = self.SUPPORTED_MODELS[model]
max_width, max_height = model_info["max_resolution"]
if options.width > max_width or options.height > max_height:
raise ValueError(
f"Resolution {options.width}x{options.height} exceeds maximum "
f"{max_width}x{max_height} for model {model}"
)
if not options.prompt or len(options.prompt.strip()) == 0:
raise ValueError("Prompt cannot be empty")
def _generate_ideogram_v3(self, options: ImageGenerationOptions) -> bytes:
"""Generate image using Ideogram V3 Turbo.
Args:
options: Image generation options
Returns:
Image bytes
"""
logger.info("[Ideogram V3] Starting image generation: %s", options.prompt[:100])
try:
# Prepare parameters for WaveSpeed Ideogram V3 API
# Note: Adjust these based on actual WaveSpeed API documentation
params = {
"model": "ideogram-v3-turbo",
"prompt": options.prompt,
"width": options.width,
"height": options.height,
"num_inference_steps": options.steps or self.SUPPORTED_MODELS["ideogram-v3-turbo"]["default_steps"],
}
# Add optional parameters
if options.negative_prompt:
params["negative_prompt"] = options.negative_prompt
if options.guidance_scale:
params["guidance_scale"] = options.guidance_scale
if options.seed:
params["seed"] = options.seed
# Call WaveSpeed API (using generic image generation method)
# This will need to be adjusted based on actual WaveSpeed client implementation
result = self.client.generate_image(**params)
# Extract image bytes from result
# Adjust based on actual WaveSpeed API response format
if isinstance(result, bytes):
image_bytes = result
elif isinstance(result, dict) and "image" in result:
image_bytes = result["image"]
else:
raise ValueError(f"Unexpected response format from WaveSpeed API: {type(result)}")
logger.info("[Ideogram V3] ✅ Successfully generated image: %d bytes", len(image_bytes))
return image_bytes
except Exception as e:
logger.error("[Ideogram V3] ❌ Error generating image: %s", str(e), exc_info=True)
raise RuntimeError(f"Ideogram V3 generation failed: {str(e)}")
def _generate_qwen_image(self, options: ImageGenerationOptions) -> bytes:
"""Generate image using Qwen Image.
Args:
options: Image generation options
Returns:
Image bytes
"""
logger.info("[Qwen Image] Starting image generation: %s", options.prompt[:100])
try:
# Prepare parameters for WaveSpeed Qwen Image API
params = {
"model": "qwen-image",
"prompt": options.prompt,
"width": options.width,
"height": options.height,
"num_inference_steps": options.steps or self.SUPPORTED_MODELS["qwen-image"]["default_steps"],
}
# Add optional parameters
if options.negative_prompt:
params["negative_prompt"] = options.negative_prompt
if options.guidance_scale:
params["guidance_scale"] = options.guidance_scale
if options.seed:
params["seed"] = options.seed
# Call WaveSpeed API
result = self.client.generate_image(**params)
# Extract image bytes from result
if isinstance(result, bytes):
image_bytes = result
elif isinstance(result, dict) and "image" in result:
image_bytes = result["image"]
else:
raise ValueError(f"Unexpected response format from WaveSpeed API: {type(result)}")
logger.info("[Qwen Image] ✅ Successfully generated image: %d bytes", len(image_bytes))
return image_bytes
except Exception as e:
logger.error("[Qwen Image] ❌ Error generating image: %s", str(e), exc_info=True)
raise RuntimeError(f"Qwen Image generation failed: {str(e)}")
def generate(self, options: ImageGenerationOptions) -> ImageGenerationResult:
"""Generate image using WaveSpeed AI models.
Args:
options: Image generation options
Returns:
ImageGenerationResult with generated image
Raises:
ValueError: If options are invalid
RuntimeError: If generation fails
"""
# Validate options
self._validate_options(options)
# Determine model
model = options.model or "ideogram-v3-turbo"
# Generate based on model
if model == "ideogram-v3-turbo":
image_bytes = self._generate_ideogram_v3(options)
elif model == "qwen-image":
image_bytes = self._generate_qwen_image(options)
else:
raise ValueError(f"Unsupported model: {model}")
# Load image to get dimensions
image = Image.open(io.BytesIO(image_bytes))
width, height = image.size
# Calculate estimated cost
model_info = self.SUPPORTED_MODELS[model]
estimated_cost = model_info["cost_per_image"]
# Return result
return ImageGenerationResult(
image_bytes=image_bytes,
width=width,
height=height,
provider="wavespeed",
model=model,
seed=options.seed,
metadata={
"provider": "wavespeed",
"model": model,
"model_name": model_info["name"],
"prompt": options.prompt,
"negative_prompt": options.negative_prompt,
"steps": options.steps or model_info["default_steps"],
"guidance_scale": options.guidance_scale,
"estimated_cost": estimated_cost,
}
)
@classmethod
def get_available_models(cls) -> dict:
"""Get available models and their information.
Returns:
Dictionary of available models
"""
return cls.SUPPORTED_MODELS

View File

@@ -0,0 +1,124 @@
"""
Gemini Image Description Module
This module provides functionality to generate text descriptions of images using Google's Gemini API.
"""
import os
import sys
from pathlib import Path
import base64
from typing import Optional, Dict, Any, List, Union
from dotenv import load_dotenv
import google.genai as genai
from google.genai import types
from PIL import Image
from loguru import logger
from utils.logger_utils import get_service_logger
# Use service-specific logger to avoid conflicts
logger = get_service_logger("gemini_image_describe")
# Import APIKeyManager
from ...onboarding.api_key_manager import APIKeyManager
try:
import google.generativeai as genai
except ImportError:
genai = None
logger.warning("Google genai library not available. Install with: pip install google-generativeai")
def describe_image(image_path: str, prompt: str = "Describe this image in detail:") -> Optional[str]:
"""
Describe an image using Google's Gemini API.
Parameters:
image_path (str): Path to the image file.
prompt (str): Prompt for describing the image.
Returns:
Optional[str]: The generated description of the image, or None if an error occurs.
"""
try:
if not genai:
logger.error("Google genai library not available")
return None
# Use APIKeyManager instead of direct environment variable access
api_key_manager = APIKeyManager()
api_key = api_key_manager.get_api_key("gemini")
if not api_key:
error_message = "Gemini API key not found. Please configure it in the onboarding process."
logger.error(error_message)
raise ValueError(error_message)
# Check if image file exists
if not os.path.exists(image_path):
error_message = f"Image file not found: {image_path}"
logger.error(error_message)
raise FileNotFoundError(error_message)
# Initialize the Gemini client
client = genai.Client(api_key=api_key)
# Open and process the image
try:
image = Image.open(image_path)
logger.info(f"Successfully opened image: {image_path}")
except Exception as e:
error_message = f"Failed to open image: {e}"
logger.error(error_message)
return None
# Generate content description
try:
response = client.models.generate_content(
model='gemini-2.0-flash',
contents=[
prompt,
image
]
)
# Extract and return the text
description = response.text
logger.info(f"Successfully generated description for image: {image_path}")
return description
except Exception as e:
error_message = f"Failed to generate content: {e}"
logger.error(error_message)
return None
except Exception as e:
error_message = f"An unexpected error occurred: {e}"
logger.error(error_message)
return None
def analyze_image_with_prompt(image_path: str, prompt: str) -> Optional[str]:
"""
Analyze an image with a custom prompt using Google's Gemini API.
Parameters:
image_path (str): Path to the image file.
prompt (str): Custom prompt for analyzing the image.
Returns:
Optional[str]: The generated analysis of the image, or None if an error occurs.
"""
return describe_image(image_path, prompt)
# Example usage
if __name__ == "__main__":
# Example usage of the function
image_path = "path/to/your/image.jpg"
description = describe_image(image_path)
if description:
print(f"Image description: {description}")
else:
print("Failed to generate image description")

View File

@@ -0,0 +1,79 @@
"""
This module provides functionality to analyze images using OpenAI's Vision API.
It encodes an image to a base64 string and sends a request to the OpenAI API
to interpret the contents of the image, returning a textual description.
"""
import requests
import sys
import re
import base64
def analyze_and_extract_details_from_image(image_path, api_key):
"""
Analyzes an image using OpenAI's Vision API and extracts Alt Text, Description, Title, and Caption.
Args:
image_path (str): Path to the image file.
api_key (str): Your OpenAI API key.
Returns:
dict: Extracted details including Alt Text, Description, Title, and Caption.
"""
def encode_image(path):
""" Encodes an image to a base64 string. """
with open(path, "rb", encoding="utf-8") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
base64_image = encode_image(image_path)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
payload = {
"model": "gpt-4-vision-preview",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "The given image is used in blog content. Analyze the given image and suggest alternative(alt) test, description, title, caption."
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
}
],
"max_tokens": 300
}
try:
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
response.raise_for_status()
assistant_message = response.json()['choices'][0]['message']['content']
# Extracting details using regular expressions
alt_text_match = re.search(r'Alt Text: "(.*?)"', assistant_message)
description_match = re.search(r'Description: (.*?)\n\n', assistant_message)
title_match = re.search(r'Title: "(.*?)"', assistant_message)
caption_match = re.search(r'Caption: "(.*?)"', assistant_message)
return {
'alt_text': alt_text_match.group(1) if alt_text_match else None,
'description': description_match.group(1) if description_match else None,
'title': title_match.group(1) if title_match else None,
'caption': caption_match.group(1) if caption_match else None
}
except requests.RequestException as e:
sys.exit(f"Error: Failed to communicate with OpenAI API. Error: {e}")
except Exception as e:
sys.exit(f"Error occurred: {e}")

View File

@@ -0,0 +1,319 @@
"""
Main Audio Generation Service for ALwrity Backend.
This service provides AI-powered text-to-speech functionality using WaveSpeed Minimax Speech 02 HD.
"""
from __future__ import annotations
import sys
from typing import Optional, Dict, Any
from datetime import datetime
from loguru import logger
from fastapi import HTTPException
from services.wavespeed.client import WaveSpeedClient
from services.onboarding.api_key_manager import APIKeyManager
from utils.logger_utils import get_service_logger
logger = get_service_logger("audio_generation")
class AudioGenerationResult:
"""Result of audio generation."""
def __init__(
self,
audio_bytes: bytes,
provider: str,
model: str,
voice_id: str,
text_length: int,
file_size: int,
):
self.audio_bytes = audio_bytes
self.provider = provider
self.model = model
self.voice_id = voice_id
self.text_length = text_length
self.file_size = file_size
def generate_audio(
text: str,
voice_id: str = "Wise_Woman",
speed: float = 1.0,
volume: float = 1.0,
pitch: float = 0.0,
emotion: str = "happy",
user_id: Optional[str] = None,
**kwargs
) -> AudioGenerationResult:
"""
Generate audio using AI text-to-speech with subscription tracking.
Args:
text: Text to convert to speech (max 10000 characters)
voice_id: Voice ID (default: "Wise_Woman")
speed: Speech speed (0.5-2.0, default: 1.0)
volume: Speech volume (0.1-10.0, default: 1.0)
pitch: Speech pitch (-12 to 12, default: 0.0)
emotion: Emotion (default: "happy")
user_id: User ID for subscription checking (required)
**kwargs: Additional parameters (sample_rate, bitrate, format, etc.)
Returns:
AudioGenerationResult: Generated audio result
Raises:
RuntimeError: If subscription limits are exceeded or user_id is missing.
"""
try:
# VALIDATION: Check inputs before any processing or API calls
if not text or not isinstance(text, str) or len(text.strip()) == 0:
raise ValueError("Text input is required and cannot be empty")
text = text.strip() # Normalize whitespace
if len(text) > 10000:
raise ValueError(f"Text is too long ({len(text)} characters). Maximum is 10,000 characters.")
if not user_id:
raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.")
logger.info("[audio_gen] Starting audio generation")
logger.debug(f"[audio_gen] Text length: {len(text)} characters, voice: {voice_id}")
# Calculate cost based on character count (every character is 1 token)
# Pricing: $0.05 per 1,000 characters
character_count = len(text)
cost_per_1000_chars = 0.05
estimated_cost = (character_count / 1000.0) * cost_per_1000_chars
try:
from services.database import get_db
from services.subscription import PricingService
from models.subscription_models import UsageSummary, APIProvider
db = next(get_db())
try:
pricing_service = PricingService(db)
# Check limits using sync method from pricing service (strict enforcement)
# Use AUDIO provider for audio generation
can_proceed, message, usage_info = pricing_service.check_usage_limits(
user_id=user_id,
provider=APIProvider.AUDIO,
tokens_requested=character_count, # Use character count as "tokens" for audio
actual_provider_name="wavespeed" # Actual provider is WaveSpeed
)
if not can_proceed:
logger.warning(f"[audio_gen] Subscription limit exceeded for user {user_id}: {message}")
error_detail = {
'error': message,
'message': message,
'provider': 'wavespeed',
'usage_info': usage_info if usage_info else {}
}
raise HTTPException(status_code=429, detail=error_detail)
# Get current usage for limit checking
current_period = pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
usage = db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
finally:
db.close()
except HTTPException:
raise
except RuntimeError:
raise
except Exception as sub_error:
logger.error(f"[audio_gen] Subscription check failed for user {user_id}: {sub_error}")
raise RuntimeError(f"Subscription check failed: {str(sub_error)}")
# Generate audio using WaveSpeed
try:
# Avoid passing duplicate enable_sync_mode; allow override via kwargs
enable_sync_mode = kwargs.pop("enable_sync_mode", True)
# Filter out None values from kwargs to prevent WaveSpeed validation errors
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
logger.info(f"[audio_gen] Filtered kwargs (removed None values): {filtered_kwargs}")
client = WaveSpeedClient()
audio_bytes = client.generate_speech(
text=text,
voice_id=voice_id,
speed=speed,
volume=volume,
pitch=pitch,
emotion=emotion,
enable_sync_mode=enable_sync_mode,
**filtered_kwargs
)
logger.info(f"[audio_gen] ✅ API call successful, generated {len(audio_bytes)} bytes")
except HTTPException:
raise
except Exception as api_error:
logger.error(f"[audio_gen] Audio generation API failed: {api_error}")
raise HTTPException(
status_code=502,
detail={
"error": "Audio generation failed",
"message": str(api_error)
}
)
# TRACK USAGE after successful API call
if audio_bytes:
logger.info(f"[audio_gen] ✅ API call successful, tracking usage for user {user_id}")
try:
db_track = next(get_db())
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Get or create usage summary
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush()
# Get current values before update
current_calls_before = getattr(summary, "audio_calls", 0) or 0
current_cost_before = getattr(summary, "audio_cost", 0.0) or 0.0
# Update audio calls and cost
new_calls = current_calls_before + 1
new_cost = current_cost_before + estimated_cost
# Use direct SQL UPDATE for dynamic attributes
# Import sqlalchemy.text with alias to avoid shadowing the 'text' parameter
from sqlalchemy import text as sql_text
update_query = sql_text("""
UPDATE usage_summaries
SET audio_calls = :new_calls,
audio_cost = :new_cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_query, {
'new_calls': new_calls,
'new_cost': new_cost,
'user_id': user_id,
'period': current_period
})
# Update total cost
summary.total_cost = (summary.total_cost or 0.0) + estimated_cost
summary.total_calls = (summary.total_calls or 0) + 1
summary.updated_at = datetime.utcnow()
# Create usage log
# Store the text parameter in a local variable before any imports to prevent shadowing
text_param = text # Capture function parameter before any potential shadowing
usage_log = APIUsageLog(
user_id=user_id,
provider=APIProvider.AUDIO,
endpoint="/audio-generation/wavespeed",
method="POST",
model_used="minimax/speech-02-hd",
tokens_input=character_count,
tokens_output=0,
tokens_total=character_count,
cost_input=0.0,
cost_output=0.0,
cost_total=estimated_cost,
response_time=0.0,
status_code=200,
request_size=len(text_param.encode("utf-8")), # Use captured parameter
response_size=len(audio_bytes),
billing_period=current_period,
)
db_track.add(usage_log)
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else ''
# Get related stats for unified log
current_image_calls = getattr(summary, "stability_calls", 0) or 0
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
db_track.commit()
logger.info(f"[audio_gen] ✅ Successfully tracked usage: user {user_id} -> audio -> {new_calls} calls, ${estimated_cost:.4f}")
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
print(f"""
[SUBSCRIPTION] Audio Generation
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: wavespeed
├─ Actual Provider: wavespeed
├─ Model: minimax/speech-02-hd
├─ Voice: {voice_id}
├─ Calls: {current_calls_before}{new_calls} / {audio_limit_display}
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
├─ Characters: {character_count}
├─ Images: {current_image_calls} / {image_limit if image_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
└─ Status: ✅ Allowed & Tracked
""", flush=True)
sys.stdout.flush()
except Exception as track_error:
logger.error(f"[audio_gen] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
logger.error(f"[audio_gen] ❌ Failed to track usage: {usage_error}", exc_info=True)
return AudioGenerationResult(
audio_bytes=audio_bytes,
provider="wavespeed",
model="minimax/speech-02-hd",
voice_id=voice_id,
text_length=character_count,
file_size=len(audio_bytes),
)
except HTTPException:
raise
except RuntimeError:
raise
except Exception as e:
logger.error(f"[audio_gen] Error generating audio: {e}")
raise HTTPException(
status_code=500,
detail={
"error": "Audio generation failed",
"message": str(e)
}
)

View File

@@ -0,0 +1,190 @@
from __future__ import annotations
import os
import io
from typing import Optional, Dict, Any
from PIL import Image
from .image_generation import (
ImageGenerationOptions,
ImageGenerationResult,
)
from utils.logger_utils import get_service_logger
try:
from huggingface_hub import InferenceClient
HF_HUB_AVAILABLE = True
except ImportError:
HF_HUB_AVAILABLE = False
logger = get_service_logger("image_editing.facade")
DEFAULT_IMAGE_EDIT_MODEL = os.getenv(
"HF_IMAGE_EDIT_MODEL",
"Qwen/Qwen-Image-Edit",
)
def _select_provider(explicit: Optional[str]) -> str:
"""Select provider for image editing. Defaults to huggingface with fal-ai."""
if explicit:
return explicit
# Default to huggingface for image editing (best support for image-to-image)
return "huggingface"
def _get_provider_client(provider_name: str, api_key: Optional[str] = None):
"""Get InferenceClient for the specified provider."""
if not HF_HUB_AVAILABLE:
raise RuntimeError("huggingface_hub is not installed. Install with: pip install huggingface_hub")
if provider_name == "huggingface":
api_key = api_key or os.getenv("HF_TOKEN")
if not api_key:
raise RuntimeError("HF_TOKEN is required for Hugging Face image editing")
# Use fal-ai provider for fast inference
return InferenceClient(provider="fal-ai", api_key=api_key)
raise ValueError(f"Unknown image editing provider: {provider_name}")
def edit_image(
input_image_bytes: bytes,
prompt: str,
options: Optional[Dict[str, Any]] = None,
user_id: Optional[str] = None,
mask_bytes: Optional[bytes] = None,
) -> ImageGenerationResult:
"""Edit image with pre-flight validation.
Args:
input_image_bytes: Input image as bytes (PNG/JPEG)
prompt: Natural language prompt describing desired edits (e.g., "Turn the cat into a tiger")
options: Image editing options (provider, model, etc.)
user_id: User ID for subscription checking (optional, but required for validation)
mask_bytes: Optional mask image bytes for selective editing (grayscale, white=edit, black=preserve)
Returns:
ImageGenerationResult with edited image bytes and metadata
Best Practices for Prompts:
- Use clear, specific language describing desired changes
- Describe what should change and what should remain
- Examples: "Turn the cat into a tiger", "Change background to forest",
"Make it look like a watercolor painting"
Note: Mask support depends on the specific model. Some models may ignore the mask parameter.
"""
# PRE-FLIGHT VALIDATION: Validate image editing before API call
# MUST happen BEFORE any API calls - return immediately if validation fails
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_editing_operations
from fastapi import HTTPException
logger.info(f"[Image Editing] 🔍 Starting pre-flight validation for user_id={user_id}")
db = next(get_db())
try:
pricing_service = PricingService(db)
# Raises HTTPException immediately if validation fails - frontend gets immediate response
validate_image_editing_operations(
pricing_service=pricing_service,
user_id=user_id
)
logger.info(f"[Image Editing] ✅ Pre-flight validation passed for user_id={user_id} - proceeding with image editing")
except HTTPException as http_ex:
# Re-raise immediately - don't proceed with API call
logger.error(f"[Image Editing] ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
raise
finally:
db.close()
else:
logger.warning(f"[Image Editing] ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
# Validate input
if not input_image_bytes:
raise ValueError("input_image_bytes is required")
if not prompt or not prompt.strip():
raise ValueError("prompt is required for image editing")
opts = options or {}
provider_name = _select_provider(opts.get("provider"))
model = opts.get("model") or DEFAULT_IMAGE_EDIT_MODEL
logger.info(f"[Image Editing] Editing image via provider={provider_name} model={model}")
# Get provider client
client = _get_provider_client(provider_name, opts.get("api_key"))
# Prepare parameters for image-to-image
params: Dict[str, Any] = {}
if opts.get("guidance_scale") is not None:
params["guidance_scale"] = opts.get("guidance_scale")
if opts.get("steps") is not None:
params["num_inference_steps"] = opts.get("steps")
if opts.get("seed") is not None:
params["seed"] = opts.get("seed")
try:
# Convert input image bytes to PIL Image for validation
input_image = Image.open(io.BytesIO(input_image_bytes))
width = input_image.width
height = input_image.height
# Convert mask bytes to PIL Image if provided
mask_image = None
if mask_bytes:
try:
mask_image = Image.open(io.BytesIO(mask_bytes)).convert("L") # Convert to grayscale
# Ensure mask dimensions match input image
if mask_image.size != input_image.size:
logger.warning(f"[Image Editing] Mask size {mask_image.size} doesn't match image size {input_image.size}, resizing mask")
mask_image = mask_image.resize(input_image.size, Image.Resampling.LANCZOS)
except Exception as e:
logger.warning(f"[Image Editing] Failed to process mask image: {e}, continuing without mask")
mask_image = None
# Use image_to_image method from Hugging Face InferenceClient
# This follows the pattern from the Hugging Face documentation
# Docs: https://huggingface.co/docs/inference-providers/en/guides/image-editor
# Note: Mask support depends on the model - some models may ignore it
call_params = params.copy()
if mask_image:
call_params["mask_image"] = mask_image
logger.info("[Image Editing] Using mask for selective editing")
edited_image: Image.Image = client.image_to_image(
image=input_image,
prompt=prompt.strip(),
model=model,
**call_params,
)
# Convert edited image back to bytes
with io.BytesIO() as buf:
edited_image.save(buf, format="PNG")
edited_image_bytes = buf.getvalue()
logger.info(f"[Image Editing] ✅ Successfully edited image: {len(edited_image_bytes)} bytes")
return ImageGenerationResult(
image_bytes=edited_image_bytes,
width=edited_image.width,
height=edited_image.height,
provider="huggingface",
model=model,
seed=opts.get("seed"),
metadata={
"provider": "fal-ai",
"operation": "image_editing",
"original_width": width,
"original_height": height,
},
)
except Exception as e:
logger.error(f"[Image Editing] ❌ Error editing image: {e}", exc_info=True)
raise RuntimeError(f"Image editing failed: {str(e)}")

View File

@@ -0,0 +1,478 @@
from __future__ import annotations
import os
import sys
from datetime import datetime
from typing import Optional, Dict, Any
from .image_generation import (
ImageGenerationOptions,
ImageGenerationResult,
HuggingFaceImageProvider,
GeminiImageProvider,
StabilityImageProvider,
WaveSpeedImageProvider,
)
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_generation.facade")
def _select_provider(explicit: Optional[str]) -> str:
if explicit:
return explicit
gpt_provider = (os.getenv("GPT_PROVIDER") or "").lower()
if gpt_provider.startswith("gemini"):
return "gemini"
if gpt_provider.startswith("hf"):
return "huggingface"
if os.getenv("STABILITY_API_KEY"):
return "stability"
if os.getenv("WAVESPEED_API_KEY"):
return "wavespeed"
# Fallback to huggingface to enable a path if configured
return "huggingface"
def _get_provider(provider_name: str):
if provider_name == "huggingface":
return HuggingFaceImageProvider()
if provider_name == "gemini":
return GeminiImageProvider()
if provider_name == "stability":
return StabilityImageProvider()
if provider_name == "wavespeed":
return WaveSpeedImageProvider()
raise ValueError(f"Unknown image provider: {provider_name}")
def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> ImageGenerationResult:
"""Generate image with pre-flight validation.
Args:
prompt: Image generation prompt
options: Image generation options (provider, model, width, height, etc.)
user_id: User ID for subscription checking (optional, but required for validation)
"""
# PRE-FLIGHT VALIDATION: Validate image generation before API call
# MUST happen BEFORE any API calls - return immediately if validation fails
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_generation_operations
from fastapi import HTTPException
logger.info(f"[Image Generation] 🔍 Starting pre-flight validation for user_id={user_id}")
db = next(get_db())
try:
pricing_service = PricingService(db)
# Raises HTTPException immediately if validation fails - frontend gets immediate response
validate_image_generation_operations(
pricing_service=pricing_service,
user_id=user_id
)
logger.info(f"[Image Generation] ✅ Pre-flight validation passed for user_id={user_id} - proceeding with image generation")
except HTTPException as http_ex:
# Re-raise immediately - don't proceed with API call
logger.error(f"[Image Generation] ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
raise
finally:
db.close()
else:
logger.warning(f"[Image Generation] ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
opts = options or {}
provider_name = _select_provider(opts.get("provider"))
image_options = ImageGenerationOptions(
prompt=prompt,
negative_prompt=opts.get("negative_prompt"),
width=int(opts.get("width", 1024)),
height=int(opts.get("height", 1024)),
guidance_scale=opts.get("guidance_scale"),
steps=opts.get("steps"),
seed=opts.get("seed"),
model=opts.get("model"),
extra=opts,
)
# Normalize obvious model/provider mismatches
model_lower = (image_options.model or "").lower()
if provider_name == "stability" and (model_lower.startswith("black-forest-labs/") or model_lower.startswith("runwayml/") or model_lower.startswith("stabilityai/flux")):
logger.info("Remapping provider to huggingface for model=%s", image_options.model)
provider_name = "huggingface"
if provider_name == "huggingface" and not image_options.model:
# Provide a sensible default HF model if none specified
image_options.model = "black-forest-labs/FLUX.1-Krea-dev"
if provider_name == "wavespeed" and not image_options.model:
# Provide a sensible default WaveSpeed model if none specified
image_options.model = "ideogram-v3-turbo"
logger.info("Generating image via provider=%s model=%s", provider_name, image_options.model)
provider = _get_provider(provider_name)
result = provider.generate(image_options)
# TRACK USAGE after successful API call
has_image_bytes = bool(result.image_bytes) if result else False
image_bytes_len = len(result.image_bytes) if (result and result.image_bytes) else 0
logger.info(f"[Image Generation] Checking tracking conditions: user_id={user_id}, has_result={bool(result)}, has_image_bytes={has_image_bytes}, image_bytes_len={image_bytes_len}")
if user_id and result and result.image_bytes:
logger.info(f"[Image Generation] ✅ API call successful, tracking usage for user {user_id}")
try:
from services.database import get_db as get_db_track
db_track = next(get_db_track())
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Get or create usage summary
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush()
# Get cost from result metadata or calculate
estimated_cost = 0.0
if result.metadata and "estimated_cost" in result.metadata:
estimated_cost = float(result.metadata["estimated_cost"])
else:
# Fallback: estimate based on provider/model
if provider_name == "wavespeed":
if result.model and "qwen" in result.model.lower():
estimated_cost = 0.05
else:
estimated_cost = 0.10 # ideogram-v3-turbo default
elif provider_name == "stability":
estimated_cost = 0.04
else:
estimated_cost = 0.05 # Default estimate
# Get current values before update
current_calls_before = getattr(summary, "stability_calls", 0) or 0
current_cost_before = getattr(summary, "stability_cost", 0.0) or 0.0
# Update image calls and cost
new_calls = current_calls_before + 1
new_cost = current_cost_before + estimated_cost
# Use direct SQL UPDATE for dynamic attributes
from sqlalchemy import text as sql_text
update_query = sql_text("""
UPDATE usage_summaries
SET stability_calls = :new_calls,
stability_cost = :new_cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_query, {
'new_calls': new_calls,
'new_cost': new_cost,
'user_id': user_id,
'period': current_period
})
# Update total cost
summary.total_cost = (summary.total_cost or 0.0) + estimated_cost
summary.total_calls = (summary.total_calls or 0) + 1
summary.updated_at = datetime.utcnow()
# Determine API provider based on actual provider
api_provider = APIProvider.STABILITY # Default for image generation
# Create usage log
usage_log = APIUsageLog(
user_id=user_id,
provider=api_provider,
endpoint="/image-generation",
method="POST",
model_used=result.model or "unknown",
tokens_input=0,
tokens_output=0,
tokens_total=0,
cost_input=0.0,
cost_output=0.0,
cost_total=estimated_cost,
response_time=0.0,
status_code=200,
request_size=len(prompt.encode("utf-8")),
response_size=len(result.image_bytes),
billing_period=current_period,
)
db_track.add(usage_log)
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
image_limit_display = image_limit if (image_limit > 0 or tier != 'enterprise') else ''
# Get related stats for unified log
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
db_track.commit()
logger.info(f"[Image Generation] ✅ Successfully tracked usage: user {user_id} -> image -> {new_calls} calls, ${estimated_cost:.4f}")
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
print(f"""
[SUBSCRIPTION] Image Generation
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: {provider_name}
├─ Actual Provider: {provider_name}
├─ Model: {result.model or 'unknown'}
├─ Calls: {current_calls_before}{new_calls} / {image_limit_display}
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
├─ Audio: {current_audio_calls} / {audio_limit if audio_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
└─ Status: ✅ Allowed & Tracked
""", flush=True)
sys.stdout.flush()
except Exception as track_error:
logger.error(f"[Image Generation] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
import traceback
logger.error(f"[Image Generation] Full traceback: {traceback.format_exc()}")
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
logger.error(f"[Image Generation] ❌ Failed to track usage: {usage_error}", exc_info=True)
import traceback
logger.error(f"[Image Generation] Full traceback: {traceback.format_exc()}")
else:
logger.warning(f"[Image Generation] ⚠️ Skipping usage tracking: user_id={user_id}, image_bytes={len(result.image_bytes) if result.image_bytes else 0} bytes")
return result
def generate_character_image(
prompt: str,
reference_image_bytes: bytes,
user_id: Optional[str] = None,
style: str = "Realistic",
aspect_ratio: str = "16:9",
rendering_speed: str = "Quality",
timeout: Optional[int] = None,
) -> bytes:
"""Generate character-consistent image with pre-flight validation and usage tracking.
Uses Ideogram Character API via WaveSpeed to maintain character consistency.
Args:
prompt: Text prompt describing the scene/context for the character
reference_image_bytes: Reference image bytes (base avatar)
user_id: User ID for subscription checking (required)
style: Character style type ("Auto", "Fiction", or "Realistic")
aspect_ratio: Aspect ratio ("1:1", "16:9", "9:16", "4:3", "3:4")
rendering_speed: Rendering speed ("Default", "Turbo", "Quality")
timeout: Total timeout in seconds for submission + polling (default: 180)
Returns:
bytes: Generated image bytes with consistent character
"""
# PRE-FLIGHT VALIDATION: Validate image generation before API call
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_generation_operations
from fastapi import HTTPException
logger.info(f"[Character Image Generation] 🔍 Starting pre-flight validation for user_id={user_id}")
db = next(get_db())
try:
pricing_service = PricingService(db)
# Raises HTTPException immediately if validation fails
validate_image_generation_operations(
pricing_service=pricing_service,
user_id=user_id,
num_images=1,
)
logger.info(f"[Character Image Generation] ✅ Pre-flight validation passed for user_id={user_id} - proceeding with character image generation")
except HTTPException as http_ex:
# Re-raise immediately - don't proceed with API call
logger.error(f"[Character Image Generation] ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
raise
finally:
db.close()
else:
logger.warning(f"[Character Image Generation] ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
# Generate character image via WaveSpeed
from services.wavespeed.client import WaveSpeedClient
from fastapi import HTTPException
try:
wavespeed_client = WaveSpeedClient()
image_bytes = wavespeed_client.generate_character_image(
prompt=prompt,
reference_image_bytes=reference_image_bytes,
style=style,
aspect_ratio=aspect_ratio,
rendering_speed=rendering_speed,
timeout=timeout,
)
# TRACK USAGE after successful API call
has_image_bytes = bool(image_bytes) if image_bytes else False
image_bytes_len = len(image_bytes) if image_bytes else 0
logger.info(f"[Character Image Generation] Checking tracking conditions: user_id={user_id}, has_image_bytes={has_image_bytes}, image_bytes_len={image_bytes_len}")
if user_id and image_bytes:
logger.info(f"[Character Image Generation] ✅ API call successful, tracking usage for user {user_id}")
try:
from services.database import get_db as get_db_track
db_track = next(get_db_track())
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Get or create usage summary
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush()
# Character image cost (same as ideogram-v3-turbo)
estimated_cost = 0.10
current_calls_before = getattr(summary, "stability_calls", 0) or 0
current_cost_before = getattr(summary, "stability_cost", 0.0) or 0.0
new_calls = current_calls_before + 1
new_cost = current_cost_before + estimated_cost
# Use direct SQL UPDATE for dynamic attributes
from sqlalchemy import text as sql_text
update_query = sql_text("""
UPDATE usage_summaries
SET stability_calls = :new_calls,
stability_cost = :new_cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_query, {
'new_calls': new_calls,
'new_cost': new_cost,
'user_id': user_id,
'period': current_period
})
# Update total cost
summary.total_cost = (summary.total_cost or 0.0) + estimated_cost
summary.total_calls = (summary.total_calls or 0) + 1
summary.updated_at = datetime.utcnow()
# Create usage log
usage_log = APIUsageLog(
user_id=user_id,
provider=APIProvider.STABILITY, # Image generation uses STABILITY provider
endpoint="/image-generation/character",
method="POST",
model_used="ideogram-character",
tokens_input=0,
tokens_output=0,
tokens_total=0,
cost_input=0.0,
cost_output=0.0,
cost_total=estimated_cost,
response_time=0.0,
status_code=200,
request_size=len(prompt.encode("utf-8")),
response_size=len(image_bytes),
billing_period=current_period,
)
db_track.add(usage_log)
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
image_limit_display = image_limit if (image_limit > 0 or tier != 'enterprise') else ''
# Get related stats
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
db_track.commit()
# UNIFIED SUBSCRIPTION LOG
print(f"""
[SUBSCRIPTION] Image Generation (Character)
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: wavespeed
├─ Actual Provider: wavespeed
├─ Model: ideogram-character
├─ Calls: {current_calls_before}{new_calls} / {image_limit_display}
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
├─ Audio: {current_audio_calls} / {audio_limit if audio_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
└─ Status: ✅ Allowed & Tracked
""", flush=True)
sys.stdout.flush()
logger.info(f"[Character Image Generation] ✅ Successfully tracked usage: user {user_id} -> {new_calls} calls, ${estimated_cost:.4f}")
except Exception as track_error:
logger.error(f"[Character Image Generation] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
import traceback
logger.error(f"[Character Image Generation] Full traceback: {traceback.format_exc()}")
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
logger.error(f"[Character Image Generation] ❌ Failed to track usage: {usage_error}", exc_info=True)
import traceback
logger.error(f"[Character Image Generation] Full traceback: {traceback.format_exc()}")
else:
logger.warning(f"[Character Image Generation] ⚠️ Skipping usage tracking: user_id={user_id}, image_bytes={len(image_bytes) if image_bytes else 0} bytes")
return image_bytes
except HTTPException:
raise
except Exception as api_error:
logger.error(f"[Character Image Generation] Character image generation API failed: {api_error}")
raise HTTPException(
status_code=502,
detail={
"error": "Character image generation failed",
"message": str(api_error)
}
)

View File

@@ -0,0 +1,896 @@
"""Main Text Generation Service for ALwrity Backend.
This service provides the main LLM text generation functionality,
migrated from the legacy lib/gpt_providers/text_generation/main_text_generation.py
"""
import os
import json
from typing import Optional, Dict, Any
from datetime import datetime
from loguru import logger
from fastapi import HTTPException
from ..onboarding.api_key_manager import APIKeyManager
from .gemini_provider import gemini_text_response, gemini_structured_json_response
from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response
def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: Optional[Dict[str, Any]] = None, user_id: str = None) -> str:
"""
Generate text using Language Model (LLM) based on the provided prompt.
Args:
prompt (str): The prompt to generate text from.
system_prompt (str, optional): Custom system prompt to use instead of the default one.
json_struct (dict, optional): JSON schema structure for structured responses.
user_id (str): Clerk user ID for subscription checking (required).
Returns:
str: Generated text based on the prompt.
Raises:
RuntimeError: If subscription limits are exceeded or user_id is missing.
"""
try:
logger.info("[llm_text_gen] Starting text generation")
logger.debug(f"[llm_text_gen] Prompt length: {len(prompt)} characters")
# Set default values for LLM parameters
gpt_provider = "google" # Default to Google Gemini
model = "gemini-2.0-flash-001"
temperature = 0.7
max_tokens = 4000
top_p = 0.9
n = 1
fp = 16
frequency_penalty = 0.0
presence_penalty = 0.0
# Check for GPT_PROVIDER environment variable
env_provider = os.getenv('GPT_PROVIDER', '').lower()
if env_provider in ['gemini', 'google']:
gpt_provider = "google"
model = "gemini-2.0-flash-001"
elif env_provider in ['hf_response_api', 'huggingface', 'hf']:
gpt_provider = "huggingface"
model = "openai/gpt-oss-120b:groq"
# Default blog characteristics
blog_tone = "Professional"
blog_demographic = "Professional"
blog_type = "Informational"
blog_language = "English"
blog_output_format = "markdown"
blog_length = 2000
# Check which providers have API keys available using APIKeyManager
api_key_manager = APIKeyManager()
available_providers = []
if api_key_manager.get_api_key("gemini"):
available_providers.append("google")
if api_key_manager.get_api_key("hf_token"):
available_providers.append("huggingface")
# If no environment variable set, auto-detect based on available keys
if not env_provider:
# Prefer Google Gemini if available, otherwise use Hugging Face
if "google" in available_providers:
gpt_provider = "google"
model = "gemini-2.0-flash-001"
elif "huggingface" in available_providers:
gpt_provider = "huggingface"
model = "openai/gpt-oss-120b:groq"
else:
logger.error("[llm_text_gen] No API keys found for supported providers.")
raise RuntimeError("No LLM API keys configured. Configure GEMINI_API_KEY or HF_TOKEN to enable AI responses.")
else:
# Environment variable was set, validate it's supported
if gpt_provider not in available_providers:
logger.warning(f"[llm_text_gen] Provider {gpt_provider} not available, falling back to available providers")
if "google" in available_providers:
gpt_provider = "google"
model = "gemini-2.0-flash-001"
elif "huggingface" in available_providers:
gpt_provider = "huggingface"
model = "openai/gpt-oss-120b:groq"
else:
raise RuntimeError("No supported providers available.")
logger.debug(f"[llm_text_gen] Using provider: {gpt_provider}, model: {model}")
# Map provider name to APIProvider enum (define at function scope for usage tracking)
from models.subscription_models import APIProvider
provider_enum = None
# Store actual provider name for logging (e.g., "huggingface", "gemini")
actual_provider_name = None
if gpt_provider == "google":
provider_enum = APIProvider.GEMINI
actual_provider_name = "gemini" # Use "gemini" for consistency in logs
elif gpt_provider == "huggingface":
provider_enum = APIProvider.MISTRAL # HuggingFace maps to Mistral enum for usage tracking
actual_provider_name = "huggingface" # Keep actual provider name for logs
if not provider_enum:
raise RuntimeError(f"Unknown provider {gpt_provider} for subscription checking")
# SUBSCRIPTION CHECK - Required and strict enforcement
if not user_id:
raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.")
try:
from services.database import get_db
from services.subscription import UsageTrackingService, PricingService
from models.subscription_models import UsageSummary
db = next(get_db())
try:
usage_service = UsageTrackingService(db)
pricing_service = PricingService(db)
# Estimate tokens from prompt (input tokens)
# CRITICAL: Use worst-case scenario (input + max_tokens) for validation to prevent abuse
# This ensures we block requests that would exceed limits even if response is longer than expected
input_tokens = int(len(prompt.split()) * 1.3)
# Worst-case estimate: assume maximum possible output tokens (max_tokens if specified)
# This prevents abuse where actual response tokens exceed the estimate
if max_tokens:
estimated_output_tokens = max_tokens # Use maximum allowed output tokens
else:
# If max_tokens not specified, use conservative estimate (input * 1.5)
estimated_output_tokens = int(input_tokens * 1.5)
estimated_total_tokens = input_tokens + estimated_output_tokens
# Check limits using sync method from pricing service (strict enforcement)
can_proceed, message, usage_info = pricing_service.check_usage_limits(
user_id=user_id,
provider=provider_enum,
tokens_requested=estimated_total_tokens,
actual_provider_name=actual_provider_name # Pass actual provider name for correct error messages
)
if not can_proceed:
logger.warning(f"[llm_text_gen] Subscription limit exceeded for user {user_id}: {message}")
# Raise HTTPException(429) with usage info so frontend can display subscription modal
error_detail = {
'error': message,
'message': message,
'provider': actual_provider_name or provider_enum.value,
'usage_info': usage_info if usage_info else {}
}
raise HTTPException(status_code=429, detail=error_detail)
# Get current usage for limit checking only
current_period = pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
usage = db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
# No separate log here - we'll create unified log after API call and usage tracking
finally:
db.close()
except HTTPException:
# Re-raise HTTPExceptions (e.g., 429 subscription limit) - preserve error details
raise
except RuntimeError:
# Re-raise subscription limit errors
raise
except Exception as sub_error:
# STRICT: Fail on subscription check errors
logger.error(f"[llm_text_gen] Subscription check failed for user {user_id}: {sub_error}")
raise RuntimeError(f"Subscription check failed: {str(sub_error)}")
# Construct the system prompt if not provided
if system_prompt is None:
system_instructions = f"""You are a highly skilled content writer with a knack for creating engaging and informative content.
Your expertise spans various writing styles and formats.
Writing Style Guidelines:
- Tone: {blog_tone}
- Target Audience: {blog_demographic}
- Content Type: {blog_type}
- Language: {blog_language}
- Output Format: {blog_output_format}
- Target Length: {blog_length} words
Please provide responses that are:
- Well-structured and easy to read
- Engaging and informative
- Tailored to the specified tone and audience
- Professional yet accessible
- Optimized for the target content type
"""
else:
system_instructions = system_prompt
# Generate response based on provider
response_text = None
actual_provider_used = gpt_provider
try:
if gpt_provider == "google":
if json_struct:
response_text = gemini_structured_json_response(
prompt=prompt,
schema=json_struct,
temperature=temperature,
top_p=top_p,
top_k=n,
max_tokens=max_tokens,
system_prompt=system_instructions
)
else:
response_text = gemini_text_response(
prompt=prompt,
temperature=temperature,
top_p=top_p,
n=n,
max_tokens=max_tokens,
system_prompt=system_instructions
)
elif gpt_provider == "huggingface":
if json_struct:
response_text = huggingface_structured_json_response(
prompt=prompt,
schema=json_struct,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_instructions
)
else:
response_text = huggingface_text_response(
prompt=prompt,
model=model,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
system_prompt=system_instructions
)
else:
logger.error(f"[llm_text_gen] Unknown provider: {gpt_provider}")
raise RuntimeError("Unknown LLM provider. Supported providers: google, huggingface")
# TRACK USAGE after successful API call
if response_text:
logger.info(f"[llm_text_gen] ✅ API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
try:
db_track = next(get_db())
try:
# Estimate tokens from prompt and response
# Recalculate input tokens from prompt (consistent with pre-flight estimation)
tokens_input = int(len(prompt.split()) * 1.3)
tokens_output = int(len(str(response_text).split()) * 1.3) # Estimate output tokens
tokens_total = tokens_input + tokens_output
logger.debug(f"[llm_text_gen] Token estimates: input={tokens_input}, output={tokens_output}, total={tokens_total}")
# Get or create usage summary
from models.subscription_models import UsageSummary
from services.subscription import PricingService
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
logger.debug(f"[llm_text_gen] Looking for usage summary: user_id={user_id}, period={current_period}")
# Get limits once for safety check (to prevent exceeding limits even if actual usage > estimate)
provider_name = provider_enum.value
limits = pricing.get_user_limits(user_id)
token_limit = 0
if limits and limits.get('limits'):
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
# CRITICAL: Use raw SQL to read current values directly from DB, bypassing SQLAlchemy cache
# This ensures we always get the absolute latest committed values, even across different sessions
from sqlalchemy import text
current_calls_before = 0
current_tokens_before = 0
record_count = 0 # Initialize to ensure it's always defined
# CRITICAL: First check if record exists using COUNT query
try:
check_query = text("SELECT COUNT(*) FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period")
record_count = db_track.execute(check_query, {'user_id': user_id, 'period': current_period}).scalar()
logger.debug(f"[llm_text_gen] 🔍 DEBUG: Record count check - found {record_count} record(s) for user={user_id}, period={current_period}")
except Exception as count_error:
logger.error(f"[llm_text_gen] ❌ COUNT query failed: {count_error}", exc_info=True)
record_count = 0
if record_count and record_count > 0:
# Record exists - read current values with raw SQL
try:
# Validate provider_name to prevent SQL injection (whitelist approach)
valid_providers = ['gemini', 'openai', 'anthropic', 'mistral']
if provider_name not in valid_providers:
raise ValueError(f"Invalid provider_name for SQL query: {provider_name}")
# Read current values directly from database using raw SQL
# CRITICAL: This bypasses SQLAlchemy's session cache and gets absolute latest values
sql_query = text(f"""
SELECT {provider_name}_calls, {provider_name}_tokens
FROM usage_summaries
WHERE user_id = :user_id AND billing_period = :period
LIMIT 1
""")
logger.debug(f"[llm_text_gen] 🔍 Executing raw SQL for EXISTING record: SELECT {provider_name}_calls, {provider_name}_tokens WHERE user_id={user_id}, period={current_period}")
result = db_track.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
if result:
raw_calls = result[0] if result[0] is not None else 0
raw_tokens = result[1] if result[1] is not None else 0
current_calls_before = raw_calls
current_tokens_before = raw_tokens
logger.debug(f"[llm_text_gen] ✅ Raw SQL SUCCESS: Found EXISTING record - calls={current_calls_before}, tokens={current_tokens_before} (provider={provider_name}, column={provider_name}_calls/{provider_name}_tokens)")
logger.debug(f"[llm_text_gen] 🔍 Raw SQL returned row: {result}, extracted calls={raw_calls}, tokens={raw_tokens}")
else:
logger.error(f"[llm_text_gen] ❌ CRITICAL BUG: Record EXISTS (count={record_count}) but SELECT query returned None! Query: {sql_query}")
# Fallback: Use ORM to get values
summary_fallback = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if summary_fallback:
db_track.refresh(summary_fallback)
current_calls_before = getattr(summary_fallback, f"{provider_name}_calls", 0) or 0
current_tokens_before = getattr(summary_fallback, f"{provider_name}_tokens", 0) or 0
logger.warning(f"[llm_text_gen] ⚠️ Using ORM fallback: calls={current_calls_before}, tokens={current_tokens_before}")
except Exception as sql_error:
logger.error(f"[llm_text_gen] ❌ Raw SQL query failed: {sql_error}", exc_info=True)
# Fallback: Use ORM to get values
summary_fallback = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if summary_fallback:
db_track.refresh(summary_fallback)
current_calls_before = getattr(summary_fallback, f"{provider_name}_calls", 0) or 0
current_tokens_before = getattr(summary_fallback, f"{provider_name}_tokens", 0) or 0
else:
logger.debug(f"[llm_text_gen] No record exists yet (will create new) - user={user_id}, period={current_period}")
# Get or create usage summary object (needed for ORM update)
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
logger.debug(f"[llm_text_gen] Creating NEW usage summary for user {user_id}, period {current_period}")
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush() # Ensure summary is persisted before updating
# New record - values are already 0, no need to set
logger.debug(f"[llm_text_gen] ✅ New summary created - starting from 0")
else:
# CRITICAL: Update the ORM object with values from raw SQL query
# This ensures the ORM object reflects the actual database state before we increment
logger.debug(f"[llm_text_gen] 🔄 Existing summary found - syncing with raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}")
setattr(summary, f"{provider_name}_calls", current_calls_before)
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
setattr(summary, f"{provider_name}_tokens", current_tokens_before)
logger.debug(f"[llm_text_gen] ✅ Synchronized ORM object: {provider_name}_calls={current_calls_before}, {provider_name}_tokens={current_tokens_before}")
logger.debug(f"[llm_text_gen] Current {provider_name}_calls from DB (raw SQL): {current_calls_before}")
# Update provider-specific counters (sync operation)
new_calls = current_calls_before + 1
# CRITICAL: Use direct SQL UPDATE instead of ORM setattr for dynamic attributes
# SQLAlchemy doesn't detect changes when using setattr() on dynamic attributes
# Using raw SQL UPDATE ensures the change is persisted
from sqlalchemy import text
update_calls_query = text(f"""
UPDATE usage_summaries
SET {provider_name}_calls = :new_calls
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_calls_query, {
'new_calls': new_calls,
'user_id': user_id,
'period': current_period
})
logger.debug(f"[llm_text_gen] Updated {provider_name}_calls via SQL: {current_calls_before} -> {new_calls}")
# Update token usage for LLM providers with safety check
# CRITICAL: Use current_tokens_before from raw SQL query (NOT from ORM object)
# The ORM object may have stale values, but raw SQL always has the latest committed values
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
logger.debug(f"[llm_text_gen] Current {provider_name}_tokens from DB (raw SQL): {current_tokens_before}")
# SAFETY CHECK: Prevent exceeding token limit even if actual usage exceeds estimate
# This prevents abuse where actual response tokens exceed pre-flight validation estimate
projected_new_tokens = current_tokens_before + tokens_total
# If limit is set (> 0) and would be exceeded, cap at limit
if token_limit > 0 and projected_new_tokens > token_limit:
logger.warning(
f"[llm_text_gen] ⚠️ ACTUAL token usage ({tokens_total}) exceeded estimate. "
f"Would exceed limit: {projected_new_tokens} > {token_limit}. "
f"Capping tracked tokens at limit to prevent abuse."
)
# Cap at limit to prevent abuse
new_tokens = token_limit
# Adjust tokens_total for accurate total tracking
tokens_total = token_limit - current_tokens_before
if tokens_total < 0:
tokens_total = 0
else:
new_tokens = projected_new_tokens
# CRITICAL: Use direct SQL UPDATE instead of ORM setattr for dynamic attributes
update_tokens_query = text(f"""
UPDATE usage_summaries
SET {provider_name}_tokens = :new_tokens
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_tokens_query, {
'new_tokens': new_tokens,
'user_id': user_id,
'period': current_period
})
logger.debug(f"[llm_text_gen] Updated {provider_name}_tokens via SQL: {current_tokens_before} -> {new_tokens}")
else:
current_tokens_before = 0
new_tokens = 0
# Determine tracked tokens (after any safety capping)
tracked_tokens_input = min(tokens_input, tokens_total)
tracked_tokens_output = max(tokens_total - tracked_tokens_input, 0)
# Calculate and persist cost for this call
try:
cost_info = pricing.calculate_api_cost(
provider=provider_enum,
model_name=model,
tokens_input=tracked_tokens_input,
tokens_output=tracked_tokens_output,
request_count=1
)
cost_total = cost_info.get('cost_total', 0.0) or 0.0
except Exception as cost_error:
cost_total = 0.0
logger.error(f"[llm_text_gen] ❌ Failed to calculate API cost: {cost_error}", exc_info=True)
if cost_total > 0:
logger.debug(f"[llm_text_gen] 💰 Calculated cost for {provider_name}: ${cost_total:.6f}")
update_costs_query = text(f"""
UPDATE usage_summaries
SET {provider_name}_cost = COALESCE({provider_name}_cost, 0) + :cost,
total_cost = COALESCE(total_cost, 0) + :cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_costs_query, {
'cost': cost_total,
'user_id': user_id,
'period': current_period
})
# Keep ORM object in sync for logging/debugging
current_provider_cost = getattr(summary, f"{provider_name}_cost", 0.0) or 0.0
setattr(summary, f"{provider_name}_cost", current_provider_cost + cost_total)
summary.total_cost = (summary.total_cost or 0.0) + cost_total
else:
logger.debug(f"[llm_text_gen] 💰 Cost calculation returned $0 for {provider_name} (tokens_input={tracked_tokens_input}, tokens_output={tracked_tokens_output})")
# Update totals using SQL UPDATE
old_total_calls = summary.total_calls or 0
old_total_tokens = summary.total_tokens or 0
new_total_calls = old_total_calls + 1
new_total_tokens = old_total_tokens + tokens_total
update_totals_query = text("""
UPDATE usage_summaries
SET total_calls = :total_calls, total_tokens = :total_tokens
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_totals_query, {
'total_calls': new_total_calls,
'total_tokens': new_total_tokens,
'user_id': user_id,
'period': current_period
})
logger.debug(f"[llm_text_gen] Updated totals via SQL: calls {old_total_calls} -> {new_total_calls}, tokens {old_total_tokens} -> {new_total_tokens}")
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) if limits else 0
# Get image stats for unified log
current_images_before = getattr(summary, "stability_calls", 0) or 0
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
# Get image editing stats for unified log
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
# Get video stats for unified log
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
# Get audio stats for unified log
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else ''
# CRITICAL DEBUG: Print diagnostic info BEFORE commit (always visible, flushed immediately)
import sys
debug_msg = f"[DEBUG] BEFORE COMMIT - Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}, Period: {current_period}, New calls will be: {new_calls}, New tokens will be: {new_tokens}"
print(debug_msg, flush=True)
sys.stdout.flush()
logger.debug(f"[llm_text_gen] {debug_msg}")
# CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions
db_track.flush() # Flush to ensure changes are in DB (not just in transaction)
db_track.commit() # Commit transaction to make changes visible to other sessions
logger.debug(f"[llm_text_gen] ✅ Successfully tracked usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens (COMMITTED to DB)")
logger.debug(f"[llm_text_gen] Database state after commit: {provider_name}_calls={new_calls}, {provider_name}_tokens={new_tokens} (should be visible to next session)")
# CRITICAL: Verify commit worked by reading back from DB immediately after commit
try:
verify_query = text(f"SELECT {provider_name}_calls, {provider_name}_tokens FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1")
verify_result = db_track.execute(verify_query, {'user_id': user_id, 'period': current_period}).first()
if verify_result:
verified_calls = verify_result[0] if verify_result[0] is not None else 0
verified_tokens = verify_result[1] if verify_result[1] is not None else 0
logger.debug(f"[llm_text_gen] ✅ VERIFICATION AFTER COMMIT: Read back calls={verified_calls}, tokens={verified_tokens} (expected: calls={new_calls}, tokens={new_tokens})")
if verified_calls != new_calls or verified_tokens != new_tokens:
logger.error(f"[llm_text_gen] ❌ CRITICAL: COMMIT VERIFICATION FAILED! Expected calls={new_calls}, tokens={new_tokens}, but DB has calls={verified_calls}, tokens={verified_tokens}")
# Force another commit attempt
db_track.commit()
verify_result2 = db_track.execute(verify_query, {'user_id': user_id, 'period': current_period}).first()
if verify_result2:
verified_calls2 = verify_result2[0] if verify_result2[0] is not None else 0
verified_tokens2 = verify_result2[1] if verify_result2[1] is not None else 0
logger.debug(f"[llm_text_gen] 🔄 After second commit attempt: calls={verified_calls2}, tokens={verified_tokens2}")
else:
logger.debug(f"[llm_text_gen] ✅ COMMIT VERIFICATION PASSED: Values match expected values")
else:
logger.error(f"[llm_text_gen] ❌ CRITICAL: COMMIT VERIFICATION FAILED! Record not found after commit!")
except Exception as verify_error:
logger.error(f"[llm_text_gen] ❌ Error verifying commit: {verify_error}", exc_info=True)
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")
# Include image stats in the log
# DEBUG: Log the actual values being used
logger.debug(f"[llm_text_gen] 📊 FINAL VALUES FOR LOG: calls_before={current_calls_before}, calls_after={new_calls}, tokens_before={current_tokens_before}, tokens_after={new_tokens}, provider={provider_name}, enum={provider_enum}")
# CRITICAL DEBUG: Print diagnostic info to stdout (always visible)
print(f"[DEBUG] Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}")
print(f"""
[SUBSCRIPTION] LLM Text Generation
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: {actual_provider_name}
├─ Model: {model}
├─ Calls: {current_calls_before}{new_calls} / {call_limit if call_limit > 0 else ''}
├─ Tokens: {current_tokens_before}{new_tokens} / {token_limit if token_limit > 0 else ''}
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
├─ Audio: {current_audio_calls} / {audio_limit_display}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:
logger.error(f"[llm_text_gen] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
# Non-blocking: log error but don't fail the request
logger.error(f"[llm_text_gen] ❌ Failed to track usage: {usage_error}", exc_info=True)
return response_text
except Exception as provider_error:
logger.error(f"[llm_text_gen] Provider {gpt_provider} failed: {str(provider_error)}")
# CIRCUIT BREAKER: Only try ONE fallback to prevent expensive API calls
fallback_providers = ["google", "huggingface"]
fallback_providers = [p for p in fallback_providers if p in available_providers and p != gpt_provider]
if fallback_providers:
fallback_provider = fallback_providers[0] # Only try the first available
try:
logger.info(f"[llm_text_gen] Trying SINGLE fallback provider: {fallback_provider}")
actual_provider_used = fallback_provider
# Update provider enum for fallback
if fallback_provider == "google":
provider_enum = APIProvider.GEMINI
actual_provider_name = "gemini"
fallback_model = "gemini-2.0-flash-lite"
elif fallback_provider == "huggingface":
provider_enum = APIProvider.MISTRAL
actual_provider_name = "huggingface"
fallback_model = "openai/gpt-oss-120b:groq"
if fallback_provider == "google":
if json_struct:
response_text = gemini_structured_json_response(
prompt=prompt,
schema=json_struct,
temperature=temperature,
top_p=top_p,
top_k=n,
max_tokens=max_tokens,
system_prompt=system_instructions
)
else:
response_text = gemini_text_response(
prompt=prompt,
temperature=temperature,
top_p=top_p,
n=n,
max_tokens=max_tokens,
system_prompt=system_instructions
)
elif fallback_provider == "huggingface":
if json_struct:
response_text = huggingface_structured_json_response(
prompt=prompt,
schema=json_struct,
model="openai/gpt-oss-120b:groq",
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_instructions
)
else:
response_text = huggingface_text_response(
prompt=prompt,
model="openai/gpt-oss-120b:groq",
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
system_prompt=system_instructions
)
# TRACK USAGE after successful fallback call
if response_text:
logger.info(f"[llm_text_gen] ✅ Fallback API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
try:
db_track = next(get_db())
try:
# Estimate tokens from prompt and response
# Recalculate input tokens from prompt (consistent with pre-flight estimation)
tokens_input = int(len(prompt.split()) * 1.3)
tokens_output = int(len(str(response_text).split()) * 1.3)
tokens_total = tokens_input + tokens_output
# Get or create usage summary
from models.subscription_models import UsageSummary
from services.subscription import PricingService
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Get limits once for safety check (to prevent exceeding limits even if actual usage > estimate)
provider_name = provider_enum.value
limits = pricing.get_user_limits(user_id)
token_limit = 0
if limits and limits.get('limits'):
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
# CRITICAL: Use raw SQL to read current values directly from DB, bypassing SQLAlchemy cache
from sqlalchemy import text
current_calls_before = 0
current_tokens_before = 0
try:
# Validate provider_name to prevent SQL injection
valid_providers = ['gemini', 'openai', 'anthropic', 'mistral']
if provider_name not in valid_providers:
raise ValueError(f"Invalid provider_name for SQL query: {provider_name}")
# Read current values directly from database using raw SQL
sql_query = text(f"""
SELECT {provider_name}_calls, {provider_name}_tokens
FROM usage_summaries
WHERE user_id = :user_id AND billing_period = :period
LIMIT 1
""")
result = db_track.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
if result:
current_calls_before = result[0] if result[0] is not None else 0
current_tokens_before = result[1] if result[1] is not None else 0
logger.debug(f"[llm_text_gen] Raw SQL read current values (fallback): calls={current_calls_before}, tokens={current_tokens_before}")
except Exception as sql_error:
logger.warning(f"[llm_text_gen] Raw SQL query failed (fallback), falling back to ORM: {sql_error}")
# Fallback to ORM query if raw SQL fails
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if summary:
db_track.refresh(summary)
current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0
current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0
# Get or create usage summary object (needed for ORM update)
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush() # Ensure summary is persisted before updating
else:
# CRITICAL: Update the ORM object with values from raw SQL query
# This ensures the ORM object reflects the actual database state before we increment
setattr(summary, f"{provider_name}_calls", current_calls_before)
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
setattr(summary, f"{provider_name}_tokens", current_tokens_before)
logger.debug(f"[llm_text_gen] Synchronized summary object with raw SQL values (fallback): calls={current_calls_before}, tokens={current_tokens_before}")
# Get "before" state for unified log (from raw SQL query)
logger.debug(f"[llm_text_gen] Current {provider_name}_calls from DB (fallback, raw SQL): {current_calls_before}")
# Update provider-specific counters (sync operation)
new_calls = current_calls_before + 1
setattr(summary, f"{provider_name}_calls", new_calls)
# Update token usage for LLM providers with safety check
# Use current_tokens_before from raw SQL query (most reliable)
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
logger.debug(f"[llm_text_gen] Current {provider_name}_tokens from DB (fallback, raw SQL): {current_tokens_before}")
# SAFETY CHECK: Prevent exceeding token limit even if actual usage exceeds estimate
# This prevents abuse where actual response tokens exceed pre-flight validation estimate
projected_new_tokens = current_tokens_before + tokens_total
# If limit is set (> 0) and would be exceeded, cap at limit
if token_limit > 0 and projected_new_tokens > token_limit:
logger.warning(
f"[llm_text_gen] ⚠️ ACTUAL token usage ({tokens_total}) exceeded estimate in fallback provider. "
f"Would exceed limit: {projected_new_tokens} > {token_limit}. "
f"Capping tracked tokens at limit to prevent abuse."
)
# Cap at limit to prevent abuse
new_tokens = token_limit
# Adjust tokens_total for accurate total tracking
tokens_total = token_limit - current_tokens_before
if tokens_total < 0:
tokens_total = 0
else:
new_tokens = projected_new_tokens
setattr(summary, f"{provider_name}_tokens", new_tokens)
else:
current_tokens_before = 0
new_tokens = 0
# Determine tracked tokens after any safety capping
tracked_tokens_input = min(tokens_input, tokens_total)
tracked_tokens_output = max(tokens_total - tracked_tokens_input, 0)
# Calculate and persist cost for this fallback call
cost_total = 0.0
try:
cost_info = pricing.calculate_api_cost(
provider=provider_enum,
model_name=fallback_model,
tokens_input=tracked_tokens_input,
tokens_output=tracked_tokens_output,
request_count=1
)
cost_total = cost_info.get('cost_total', 0.0) or 0.0
except Exception as cost_error:
logger.error(f"[llm_text_gen] ❌ Failed to calculate fallback cost: {cost_error}", exc_info=True)
if cost_total > 0:
update_costs_query = text(f"""
UPDATE usage_summaries
SET {provider_name}_cost = COALESCE({provider_name}_cost, 0) + :cost,
total_cost = COALESCE(total_cost, 0) + :cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_costs_query, {
'cost': cost_total,
'user_id': user_id,
'period': current_period
})
setattr(summary, f"{provider_name}_cost", (getattr(summary, f"{provider_name}_cost", 0.0) or 0.0) + cost_total)
summary.total_cost = (summary.total_cost or 0.0) + cost_total
# Update totals (using potentially capped tokens_total from safety check)
summary.total_calls = (summary.total_calls or 0) + 1
summary.total_tokens = (summary.total_tokens or 0) + tokens_total
# Get plan details for unified log (limits already retrieved above)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0
# Get image stats for unified log
current_images_before = getattr(summary, "stability_calls", 0) or 0
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
# Get image editing stats for unified log
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
# Get video stats for unified log
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
# Get audio stats for unified log
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else ''
# CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions
db_track.flush() # Flush to ensure changes are in DB (not just in transaction)
db_track.commit() # Commit transaction to make changes visible to other sessions
logger.info(f"[llm_text_gen] ✅ Successfully tracked fallback usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens (committed)")
# UNIFIED SUBSCRIPTION LOG for fallback
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")
# Include image stats in the log
print(f"""
[SUBSCRIPTION] LLM Text Generation (Fallback)
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: {actual_provider_name}
├─ Model: {fallback_model}
├─ Calls: {current_calls_before}{new_calls} / {call_limit if call_limit > 0 else ''}
├─ Tokens: {current_tokens_before}{new_tokens} / {token_limit if token_limit > 0 else ''}
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
├─ Audio: {current_audio_calls} / {audio_limit_display}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:
logger.error(f"[llm_text_gen] ❌ Error tracking fallback usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
logger.error(f"[llm_text_gen] ❌ Failed to track fallback usage: {usage_error}", exc_info=True)
return response_text
except Exception as fallback_error:
logger.error(f"[llm_text_gen] Fallback provider {fallback_provider} also failed: {str(fallback_error)}")
# CIRCUIT BREAKER: Stop immediately to prevent expensive API calls
logger.error("[llm_text_gen] CIRCUIT BREAKER: Stopping to prevent expensive API calls.")
raise RuntimeError("All LLM providers failed to generate a response.")
except Exception as e:
logger.error(f"[llm_text_gen] Error during text generation: {str(e)}")
raise
def check_gpt_provider(gpt_provider: str) -> bool:
"""Check if the specified GPT provider is supported."""
supported_providers = ["google", "huggingface"]
return gpt_provider in supported_providers
def get_api_key(gpt_provider: str) -> Optional[str]:
"""Get API key for the specified provider."""
try:
api_key_manager = APIKeyManager()
provider_mapping = {
"google": "gemini",
"huggingface": "hf_token"
}
mapped_provider = provider_mapping.get(gpt_provider, gpt_provider)
return api_key_manager.get_api_key(mapped_provider)
except Exception as e:
logger.error(f"[get_api_key] Error getting API key for {gpt_provider}: {str(e)}")
return None

View File

@@ -0,0 +1,792 @@
"""
Main Video Generation Service
Provides a unified interface for AI video generation providers.
Supports:
- Text-to-video: Hugging Face Inference Providers, WaveSpeed models
- Image-to-video: WaveSpeed WAN 2.5, Kandinsky 5 Pro
Stubs included for Gemini (Veo 3) and OpenAI (Sora) for future use.
"""
from __future__ import annotations
import os
import base64
import io
import sys
import asyncio
from typing import Any, Dict, Optional, Union, Callable
from fastapi import HTTPException
try:
from huggingface_hub import InferenceClient
HF_HUB_AVAILABLE = True
except ImportError:
HF_HUB_AVAILABLE = False
InferenceClient = None
from ..onboarding.api_key_manager import APIKeyManager
from services.subscription import PricingService
from utils.logger_utils import get_service_logger
logger = get_service_logger("video_generation_service")
class VideoProviderNotImplemented(Exception):
pass
def _get_api_key(provider: str) -> Optional[str]:
try:
manager = APIKeyManager()
mapping = {
"huggingface": "hf_token",
"wavespeed": "wavespeed", # WaveSpeed API key
"gemini": "gemini", # placeholder for Veo 3
"openai": "openai_api_key", # placeholder for Sora
}
return manager.get_api_key(mapping.get(provider, provider))
except Exception as e:
logger.error(f"[video_gen] Failed to read API key for {provider}: {e}")
return None
def _coerce_video_bytes(output: Any) -> bytes:
"""
Normalizes the different return shapes that huggingface_hub may emit for video tasks.
According to HF docs, text_to_video() should return bytes directly.
"""
logger.debug(f"[video_gen] _coerce_video_bytes received type: {type(output)}")
# Most common case: bytes directly
if isinstance(output, (bytes, bytearray, memoryview)):
logger.debug(f"[video_gen] Output is bytes: {len(output)} bytes")
return bytes(output)
# Handle file-like objects
if hasattr(output, "read"):
logger.debug("[video_gen] Output has read() method, reading...")
data = output.read()
if isinstance(data, (bytes, bytearray, memoryview)):
return bytes(data)
raise TypeError(f"File-like object returned non-bytes: {type(data)}")
# Objects with direct attribute access
if hasattr(output, "video"):
logger.debug("[video_gen] Output has 'video' attribute")
data = getattr(output, "video")
if isinstance(data, (bytes, bytearray, memoryview)):
return bytes(data)
if hasattr(data, "read"):
return bytes(data.read())
if hasattr(output, "bytes"):
logger.debug("[video_gen] Output has 'bytes' attribute")
data = getattr(output, "bytes")
if isinstance(data, (bytes, bytearray, memoryview)):
return bytes(data)
if hasattr(data, "read"):
return bytes(data.read())
# Dict handling - but this shouldn't happen with text_to_video()
if isinstance(output, dict):
logger.warning(f"[video_gen] Received dict output (unexpected): keys={list(output.keys())}")
# Try to get video key safely - use .get() to avoid KeyError
data = output.get("video")
if data is not None:
if isinstance(data, (bytes, bytearray, memoryview)):
return bytes(data)
if hasattr(data, "read"):
return bytes(data.read())
# Try other common keys
for key in ["data", "content", "file", "result", "output"]:
data = output.get(key)
if data is not None:
if isinstance(data, (bytes, bytearray, memoryview)):
return bytes(data)
if hasattr(data, "read"):
return bytes(data.read())
raise TypeError(f"Dict output has no recognized video key. Keys: {list(output.keys())}")
# String handling (base64)
if isinstance(output, str):
logger.debug("[video_gen] Output is string, attempting base64 decode")
if output.startswith("data:"):
_, encoded = output.split(",", 1)
return base64.b64decode(encoded)
try:
return base64.b64decode(output)
except Exception as exc:
raise TypeError(f"Unable to decode string video payload: {exc}") from exc
# Fallback: try to use output directly
logger.warning(f"[video_gen] Unexpected output type: {type(output)}, attempting direct conversion")
try:
if hasattr(output, "__bytes__"):
return bytes(output)
except Exception:
pass
raise TypeError(f"Unsupported video payload type: {type(output)}. Output: {str(output)[:200]}")
def _generate_with_huggingface(
prompt: str,
num_frames: int = 24 * 4,
guidance_scale: float = 7.5,
num_inference_steps: int = 30,
negative_prompt: Optional[str] = None,
seed: Optional[int] = None,
model: str = "tencent/HunyuanVideo",
) -> bytes:
"""
Generates video bytes using Hugging Face's InferenceClient.
"""
if not HF_HUB_AVAILABLE:
raise RuntimeError("huggingface_hub is not installed. Install with: pip install huggingface_hub")
token = _get_api_key("huggingface")
if not token:
raise RuntimeError("HF token not configured. Set an hf_token in APIKeyManager.")
client = InferenceClient(
provider="fal-ai",
token=token,
)
logger.info("[video_gen] Using HuggingFace provider 'fal-ai'")
params: Dict[str, Any] = {
"num_frames": num_frames,
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
}
if negative_prompt:
params["negative_prompt"] = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt]
if seed is not None:
params["seed"] = seed
logger.info(
"[video_gen] HuggingFace request model=%s frames=%s steps=%s mode=text-to-video",
model,
num_frames,
num_inference_steps,
)
try:
logger.info("[video_gen] Calling client.text_to_video()...")
video_output = client.text_to_video(
prompt=prompt,
model=model,
**params,
)
logger.info(f"[video_gen] text_to_video() returned type: {type(video_output)}")
if isinstance(video_output, dict):
logger.info(f"[video_gen] Dict keys: {list(video_output.keys())}")
elif hasattr(video_output, "__dict__"):
logger.info(f"[video_gen] Object attributes: {dir(video_output)}")
video_bytes = _coerce_video_bytes(video_output)
if not isinstance(video_bytes, bytes):
raise TypeError(f"Expected bytes from text_to_video, got {type(video_bytes)}")
if len(video_bytes) == 0:
raise ValueError("Received empty video bytes from Hugging Face API")
logger.info(f"[video_gen] Successfully generated video: {len(video_bytes)} bytes")
return video_bytes
except KeyError as e:
error_msg = str(e)
logger.error(f"[video_gen] HF KeyError: {error_msg}", exc_info=True)
logger.error(f"[video_gen] This suggests the API response format is unexpected. Check logs above for response type.")
raise HTTPException(status_code=502, detail={
"error": f"Hugging Face API returned unexpected response format: {error_msg}",
"error_type": "KeyError",
"hint": "The API response may have changed. Check server logs for details."
})
except Exception as e:
error_msg = str(e)
error_type = type(e).__name__
logger.error(f"[video_gen] HF error ({error_type}): {error_msg}", exc_info=True)
raise HTTPException(status_code=502, detail={
"error": f"Hugging Face video generation failed: {error_msg}",
"error_type": error_type
})
async def _generate_image_to_video_wavespeed(
image_data: Optional[bytes] = None,
image_base64: Optional[str] = None,
prompt: str = "",
duration: int = 5,
resolution: str = "720p",
model: str = "alibaba/wan-2.5/image-to-video",
negative_prompt: Optional[str] = None,
seed: Optional[int] = None,
audio_base64: Optional[str] = None,
enable_prompt_expansion: bool = True,
progress_callback: Optional[Callable[[float, str], None]] = None,
**kwargs
) -> Dict[str, Any]:
"""
Generate video from image using WaveSpeed (WAN 2.5 or Kandinsky 5 Pro).
Args:
image_data: Image bytes (required if image_base64 not provided)
image_base64: Image in base64 or data URI format (required if image_data not provided)
prompt: Text prompt describing the video motion
duration: Video duration in seconds (5 or 10)
resolution: Output resolution (480p, 720p, 1080p)
model: Model to use (alibaba/wan-2.5/image-to-video, wavespeed/kandinsky5-pro/image-to-video)
negative_prompt: Optional negative prompt
seed: Optional random seed
audio_base64: Optional audio file for synchronization
enable_prompt_expansion: Enable prompt optimization
Returns:
Dictionary with video_bytes and metadata (cost, duration, resolution, width, height, etc.)
"""
# Import here to avoid circular dependencies
from services.image_studio.wan25_service import WAN25Service
logger.info(f"[video_gen] WaveSpeed image-to-video: model={model}, resolution={resolution}, duration={duration}s")
# Validate inputs
if not image_data and not image_base64:
raise ValueError("Either image_data or image_base64 must be provided for image-to-video")
# Convert image_data to base64 if needed
if image_data and not image_base64:
image_base64 = base64.b64encode(image_data).decode('utf-8')
# Add data URI prefix if not present
if not image_base64.startswith("data:"):
image_base64 = f"data:image/png;base64,{image_base64}"
# Initialize WAN25Service (handles both WAN 2.5 and Kandinsky 5 Pro)
wan25_service = WAN25Service()
try:
# Generate video using WAN25Service (returns full metadata)
result = await wan25_service.generate_video(
image_base64=image_base64,
prompt=prompt,
audio_base64=audio_base64,
resolution=resolution,
duration=duration,
negative_prompt=negative_prompt,
seed=seed,
enable_prompt_expansion=enable_prompt_expansion,
progress_callback=progress_callback,
)
video_bytes = result.get("video_bytes")
if not video_bytes:
raise ValueError("WAN25Service returned no video bytes")
if not isinstance(video_bytes, bytes):
raise TypeError(f"Expected bytes from WAN25Service, got {type(video_bytes)}")
if len(video_bytes) == 0:
raise ValueError("Received empty video bytes from WaveSpeed API")
logger.info(f"[video_gen] Successfully generated image-to-video: {len(video_bytes)} bytes")
# Return video bytes with metadata
return {
"video_bytes": video_bytes,
"prompt": result.get("prompt", prompt),
"duration": result.get("duration", float(duration)),
"model_name": result.get("model_name", model),
"cost": result.get("cost", 0.0),
"provider": result.get("provider", "wavespeed"),
"resolution": result.get("resolution", resolution),
"width": result.get("width", 1280),
"height": result.get("height", 720),
"metadata": result.get("metadata", {}),
"source_video_url": result.get("source_video_url"),
"prediction_id": result.get("prediction_id"),
}
except HTTPException:
# Re-raise HTTPExceptions from WAN25Service
raise
except Exception as e:
error_msg = str(e)
error_type = type(e).__name__
logger.error(f"[video_gen] WaveSpeed image-to-video error ({error_type}): {error_msg}", exc_info=True)
raise HTTPException(
status_code=502,
detail={
"error": f"WaveSpeed image-to-video generation failed: {error_msg}",
"error_type": error_type
}
)
def _generate_with_gemini(prompt: str, **kwargs) -> bytes:
raise VideoProviderNotImplemented("Gemini Veo 3 integration coming soon.")
def _generate_with_openai(prompt: str, **kwargs) -> bytes:
raise VideoProviderNotImplemented("OpenAI Sora integration coming soon.")
async def _generate_text_to_video_wavespeed(
prompt: str,
duration: int = 5,
resolution: str = "720p",
model: str = "hunyuan-video-1.5",
negative_prompt: Optional[str] = None,
seed: Optional[int] = None,
audio_base64: Optional[str] = None,
enable_prompt_expansion: bool = True,
progress_callback: Optional[Callable[[float, str], None]] = None,
**kwargs
) -> Dict[str, Any]:
"""
Generate text-to-video using WaveSpeed models.
Args:
prompt: Text prompt describing the video
duration: Video duration in seconds
resolution: Output resolution (480p, 720p)
model: Model identifier (e.g., "hunyuan-video-1.5")
negative_prompt: Optional negative prompt
seed: Optional random seed
audio_base64: Optional audio (not supported by all models)
enable_prompt_expansion: Enable prompt optimization (not supported by all models)
progress_callback: Optional progress callback function
**kwargs: Additional model-specific parameters
Returns:
Dictionary with video_bytes, prompt, duration, model_name, cost, etc.
"""
from .video_generation.wavespeed_provider import get_wavespeed_text_to_video_service
logger.info(f"[video_gen] WaveSpeed text-to-video: model={model}, resolution={resolution}, duration={duration}s")
# Get the appropriate service for the model
try:
service = get_wavespeed_text_to_video_service(model)
except ValueError as e:
logger.error(f"[video_gen] Unsupported WaveSpeed text-to-video model: {model}")
raise HTTPException(
status_code=400,
detail=str(e)
)
# Generate video using the service
try:
result = await service.generate_video(
prompt=prompt,
duration=duration,
resolution=resolution,
negative_prompt=negative_prompt,
seed=seed,
audio_base64=audio_base64,
enable_prompt_expansion=enable_prompt_expansion,
progress_callback=progress_callback,
**kwargs
)
logger.info(f"[video_gen] Successfully generated text-to-video: {len(result.get('video_bytes', b''))} bytes")
return result
except HTTPException:
# Re-raise HTTPExceptions from service
raise
except Exception as e:
error_msg = str(e)
error_type = type(e).__name__
logger.error(f"[video_gen] WaveSpeed text-to-video error ({error_type}): {error_msg}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
"error": f"WaveSpeed text-to-video generation failed: {error_msg}",
"type": error_type,
}
)
async def ai_video_generate(
prompt: Optional[str] = None,
image_data: Optional[bytes] = None,
image_base64: Optional[str] = None,
operation_type: str = "text-to-video",
provider: str = "huggingface",
user_id: Optional[str] = None,
progress_callback: Optional[Callable[[float, str], None]] = None,
**kwargs,
) -> Dict[str, Any]:
"""
Unified video generation entry point for ALL video operations.
Supports:
- text-to-video: prompt required, provider: 'huggingface', 'wavespeed', 'gemini' (stub), 'openai' (stub)
- image-to-video: image_data or image_base64 required, provider: 'wavespeed'
Args:
prompt: Text prompt (required for text-to-video)
image_data: Image bytes (required for image-to-video if image_base64 not provided)
image_base64: Image base64 string (required for image-to-video if image_data not provided)
operation_type: "text-to-video" or "image-to-video" (default: "text-to-video")
provider: Provider name (default: "huggingface" for text-to-video, "wavespeed" for image-to-video)
user_id: Required for subscription/usage tracking
progress_callback: Optional function(progress: float, message: str) -> None
Called at key stages: submission (10%), polling (20-80%), completion (100%)
**kwargs: Model-specific parameters:
- For text-to-video: num_frames, guidance_scale, num_inference_steps, negative_prompt, seed, model
- For image-to-video: duration, resolution, negative_prompt, seed, audio_base64, enable_prompt_expansion, model
Returns:
Dictionary with:
- video_bytes: Raw video bytes (mp4/webm depending on provider)
- prompt: The prompt used (may be enhanced)
- duration: Video duration in seconds
- model_name: Model used for generation
- cost: Cost of generation
- provider: Provider name
- resolution: Video resolution (for image-to-video)
- width: Video width in pixels (for image-to-video)
- height: Video height in pixels (for image-to-video)
- metadata: Additional metadata dict
"""
logger.info(f"[video_gen] operation={operation_type}, provider={provider}")
# Enforce authentication usage like text gen does
if not user_id:
raise RuntimeError("user_id is required for subscription/usage tracking.")
# Validate operation type and required inputs
if operation_type == "text-to-video":
if not prompt:
raise ValueError("prompt is required for text-to-video generation")
# Set default provider if not specified
if provider == "huggingface" and "model" not in kwargs:
kwargs.setdefault("model", "tencent/HunyuanVideo")
elif operation_type == "image-to-video":
if not image_data and not image_base64:
raise ValueError("image_data or image_base64 is required for image-to-video generation")
# Set default provider and model for image-to-video
if provider not in ["wavespeed"]:
logger.warning(f"[video_gen] Provider {provider} not supported for image-to-video, defaulting to wavespeed")
provider = "wavespeed"
if "model" not in kwargs:
kwargs.setdefault("model", "alibaba/wan-2.5/image-to-video")
# Set defaults for image-to-video
kwargs.setdefault("duration", 5)
kwargs.setdefault("resolution", "720p")
else:
raise ValueError(f"Invalid operation_type: {operation_type}. Must be 'text-to-video' or 'image-to-video'")
# PRE-FLIGHT VALIDATION: Validate video generation before API call
# MUST happen BEFORE any API calls - return immediately if validation fails
from services.database import get_db
from services.subscription.preflight_validator import validate_video_generation_operations
from fastapi import HTTPException
db = next(get_db())
try:
pricing_service = PricingService(db)
# Raises HTTPException immediately if validation fails - frontend gets immediate response
validate_video_generation_operations(
pricing_service=pricing_service,
user_id=user_id
)
except HTTPException:
# Re-raise immediately - don't proceed with API call
logger.error(f"[Video Generation] ❌ Pre-flight validation failed - blocking API call")
raise
finally:
db.close()
logger.info(f"[Video Generation] ✅ Pre-flight validation passed - proceeding with {operation_type}")
# Progress callback: Initial submission
if progress_callback:
progress_callback(10.0, f"Submitting {operation_type} request to {provider}...")
# Generate video based on operation type
model_name = kwargs.get("model", _get_default_model(operation_type, provider))
try:
if operation_type == "text-to-video":
if provider == "huggingface":
video_bytes = _generate_with_huggingface(
prompt=prompt,
**kwargs,
)
# For text-to-video, create metadata dict (HuggingFace doesn't return metadata)
result_dict = {
"video_bytes": video_bytes,
"prompt": prompt,
"duration": kwargs.get("duration", 5.0),
"model_name": model_name,
"cost": 0.10, # Default cost, will be calculated in track_video_usage
"provider": provider,
"resolution": kwargs.get("resolution", "720p"),
"width": 1280, # Default, actual may vary
"height": 720, # Default, actual may vary
"metadata": {},
}
elif provider == "wavespeed":
# WaveSpeed text-to-video - use unified service
result_dict = await _generate_text_to_video_wavespeed(
prompt=prompt,
progress_callback=progress_callback,
**kwargs,
)
elif provider == "gemini":
video_bytes = _generate_with_gemini(prompt=prompt, **kwargs)
result_dict = {
"video_bytes": video_bytes,
"prompt": prompt,
"duration": kwargs.get("duration", 5.0),
"model_name": model_name,
"cost": 0.10,
"provider": provider,
"resolution": kwargs.get("resolution", "720p"),
"width": 1280,
"height": 720,
"metadata": {},
}
elif provider == "openai":
video_bytes = _generate_with_openai(prompt=prompt, **kwargs)
result_dict = {
"video_bytes": video_bytes,
"prompt": prompt,
"duration": kwargs.get("duration", 5.0),
"model_name": model_name,
"cost": 0.10,
"provider": provider,
"resolution": kwargs.get("resolution", "720p"),
"width": 1280,
"height": 720,
"metadata": {},
}
else:
raise RuntimeError(f"Unknown provider for text-to-video: {provider}")
elif operation_type == "image-to-video":
if provider == "wavespeed":
# Progress callback: Starting generation
if progress_callback:
progress_callback(20.0, "Video generation in progress...")
# Handle async call from sync context
# Since ai_video_generate is sync, we need to run async function
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# We're in an async context - use ThreadPoolExecutor to run in new event loop
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(
asyncio.run,
_generate_image_to_video_wavespeed(
image_data=image_data,
image_base64=image_base64,
prompt=prompt or kwargs.get("prompt", ""),
progress_callback=progress_callback,
**kwargs
)
)
result_dict = future.result()
else:
# Event loop exists but not running - use it
result_dict = loop.run_until_complete(_generate_image_to_video_wavespeed(
image_data=image_data,
image_base64=image_base64,
prompt=prompt or kwargs.get("prompt", ""),
progress_callback=progress_callback,
**kwargs
))
except RuntimeError:
# No event loop exists, create a new one
result_dict = asyncio.run(_generate_image_to_video_wavespeed(
image_data=image_data,
image_base64=image_base64,
prompt=prompt or kwargs.get("prompt", ""),
progress_callback=progress_callback,
**kwargs
))
video_bytes = result_dict["video_bytes"]
model_name = result_dict.get("model_name", model_name)
# Progress callback: Processing result
if progress_callback:
progress_callback(90.0, "Processing video result...")
else:
raise RuntimeError(f"Unknown provider for image-to-video: {provider}. Only 'wavespeed' is supported.")
# Track usage (same pattern as text generation)
# Use cost from result_dict if available, otherwise calculate
cost_override = result_dict.get("cost") if operation_type == "image-to-video" else kwargs.get("cost_override")
track_video_usage(
user_id=user_id,
provider=provider,
model_name=model_name,
prompt=result_dict.get("prompt", prompt or ""),
video_bytes=video_bytes,
cost_override=cost_override,
)
# Progress callback: Complete
if progress_callback:
progress_callback(100.0, "Video generation complete!")
return result_dict
except HTTPException:
# Re-raise HTTPExceptions (e.g., from validation or API errors)
raise
except Exception as e:
logger.error(f"[video_gen] Error during video generation: {e}", exc_info=True)
raise HTTPException(status_code=500, detail={"error": str(e)})
def _get_default_model(operation_type: str, provider: str) -> str:
"""Get default model for operation type and provider."""
defaults = {
("text-to-video", "huggingface"): "tencent/HunyuanVideo",
("text-to-video", "wavespeed"): "hunyuan-video-1.5",
("image-to-video", "wavespeed"): "alibaba/wan-2.5/image-to-video",
}
return defaults.get((operation_type, provider), "hunyuan-video-1.5")
def track_video_usage(
*,
user_id: str,
provider: str,
model_name: str,
prompt: str,
video_bytes: bytes,
cost_override: Optional[float] = None,
) -> Dict[str, Any]:
"""
Track subscription usage for any video generation (text-to-video or image-to-video).
"""
from datetime import datetime
from models.subscription_models import APIProvider, APIUsageLog, UsageSummary
from services.database import get_db
db_track = next(get_db())
try:
logger.info(f"[video_gen] Starting usage tracking for user={user_id}, provider={provider}, model={model_name}")
pricing_service_track = PricingService(db_track)
current_period = pricing_service_track.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
logger.debug(f"[video_gen] Billing period: {current_period}")
usage_summary = (
db_track.query(UsageSummary)
.filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period,
)
.first()
)
if not usage_summary:
logger.debug(f"[video_gen] Creating new UsageSummary for user={user_id}, period={current_period}")
usage_summary = UsageSummary(
user_id=user_id,
billing_period=current_period,
)
db_track.add(usage_summary)
db_track.commit()
db_track.refresh(usage_summary)
else:
logger.debug(f"[video_gen] Found existing UsageSummary: video_calls={getattr(usage_summary, 'video_calls', 0)}")
cost_info = pricing_service_track.get_pricing_for_provider_model(
APIProvider.VIDEO,
model_name,
)
default_cost = 0.10
if cost_info and cost_info.get("cost_per_request") is not None:
default_cost = cost_info["cost_per_request"]
cost_per_video = cost_override if cost_override is not None else default_cost
logger.debug(f"[video_gen] Cost per video: ${cost_per_video} (override={cost_override}, default={default_cost})")
current_video_calls_before = getattr(usage_summary, "video_calls", 0) or 0
current_video_cost = getattr(usage_summary, "video_cost", 0.0) or 0.0
usage_summary.video_calls = current_video_calls_before + 1
usage_summary.video_cost = current_video_cost + cost_per_video
usage_summary.total_calls = (usage_summary.total_calls or 0) + 1
usage_summary.total_cost = (usage_summary.total_cost or 0.0) + cost_per_video
# Ensure the object is in the session
db_track.add(usage_summary)
logger.debug(f"[video_gen] Updated usage_summary: video_calls={current_video_calls_before}{usage_summary.video_calls}")
limits = pricing_service_track.get_user_limits(user_id)
plan_name = limits.get("plan_name", "unknown") if limits else "unknown"
tier = limits.get("tier", "unknown") if limits else "unknown"
video_limit = limits["limits"].get("video_calls", 0) if limits else 0
current_image_calls = getattr(usage_summary, "stability_calls", 0) or 0
image_limit = limits["limits"].get("stability_calls", 0) if limits else 0
current_image_edit_calls = getattr(usage_summary, "image_edit_calls", 0) or 0
image_edit_limit = limits["limits"].get("image_edit_calls", 0) if limits else 0
current_audio_calls = getattr(usage_summary, "audio_calls", 0) or 0
audio_limit = limits["limits"].get("audio_calls", 0) if limits else 0
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else ''
usage_log = APIUsageLog(
user_id=user_id,
provider=APIProvider.VIDEO,
endpoint=f"/video-generation/{provider}",
method="POST",
model_used=model_name,
tokens_input=0,
tokens_output=0,
tokens_total=0,
cost_input=0.0,
cost_output=0.0,
cost_total=cost_per_video,
response_time=0.0,
status_code=200,
request_size=len((prompt or "").encode("utf-8")),
response_size=len(video_bytes),
billing_period=current_period,
)
db_track.add(usage_log)
logger.debug(f"[video_gen] Flushing changes before commit...")
db_track.flush()
logger.debug(f"[video_gen] Committing usage tracking changes...")
db_track.commit()
db_track.refresh(usage_summary)
logger.debug(f"[video_gen] Commit successful. Final video_calls: {usage_summary.video_calls}, video_cost: {usage_summary.video_cost}")
video_limit_display = video_limit if video_limit > 0 else ''
log_message = f"""
[SUBSCRIPTION] Video Generation
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: video
├─ Actual Provider: {provider}
├─ Model: {model_name or 'default'}
├─ Calls: {current_video_calls_before}{usage_summary.video_calls} / {video_limit_display}
├─ Images: {current_image_calls} / {image_limit if image_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Audio: {current_audio_calls} / {audio_limit_display}
└─ Status: ✅ Allowed & Tracked
"""
logger.info(log_message)
return {
"previous_calls": current_video_calls_before,
"current_calls": usage_summary.video_calls,
"video_limit": video_limit,
"video_limit_display": video_limit_display,
"cost_per_video": cost_per_video,
"total_video_cost": usage_summary.video_cost,
}
except Exception as track_error:
logger.error(f"[video_gen] Error tracking usage: {track_error}", exc_info=True)
logger.error(f"[video_gen] Exception type: {type(track_error).__name__}", exc_info=True)
db_track.rollback()
finally:
db_track.close()

View File

@@ -0,0 +1,10 @@
"""
Video Generation Services
Modular services for text-to-video and image-to-video generation.
Each provider/model has its own service class for separation of concerns.
"""
from typing import Optional, Dict, Any
__all__ = []

View File

@@ -0,0 +1,53 @@
"""
Base classes and interfaces for video generation services.
Provides common interfaces and data structures for video generation providers.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Dict, Any, Protocol, Callable
@dataclass
class VideoGenerationOptions:
"""Options for video generation."""
prompt: str
duration: int = 5
resolution: str = "720p"
negative_prompt: Optional[str] = None
seed: Optional[int] = None
audio_base64: Optional[str] = None
enable_prompt_expansion: bool = True
model: Optional[str] = None
extra: Optional[Dict[str, Any]] = None
@dataclass
class VideoGenerationResult:
"""Result from video generation."""
video_bytes: bytes
prompt: str
duration: float
model_name: str
cost: float
provider: str
resolution: str
width: int
height: int
metadata: Dict[str, Any]
source_video_url: Optional[str] = None
prediction_id: Optional[str] = None
class VideoGenerationProvider(Protocol):
"""Protocol for video generation providers."""
async def generate_video(
self,
options: VideoGenerationOptions,
progress_callback: Optional[Callable[[float, str], None]] = None,
) -> VideoGenerationResult:
"""Generate video with given options."""
...

File diff suppressed because it is too large Load Diff