Base code
This commit is contained in:
343
backend/services/llm_providers/README.md
Normal file
343
backend/services/llm_providers/README.md
Normal file
@@ -0,0 +1,343 @@
|
||||
# LLM Providers Module
|
||||
|
||||
This module provides functions for interacting with multiple LLM providers, specifically Google's Gemini API and Hugging Face Inference Providers. It follows official API documentation and implements best practices for reliable AI interactions.
|
||||
|
||||
## Supported Providers
|
||||
|
||||
- **Google Gemini**: High-quality text generation with structured JSON output
|
||||
- **Hugging Face**: Multiple models via Inference Providers with unified interface
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
|
||||
# Generate text (auto-detects available provider)
|
||||
response = llm_text_gen("Write a blog post about AI trends")
|
||||
print(response)
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Set your preferred provider using the `GPT_PROVIDER` environment variable:
|
||||
|
||||
```bash
|
||||
# Use Google Gemini (default)
|
||||
export GPT_PROVIDER=gemini
|
||||
|
||||
# Use Hugging Face
|
||||
export GPT_PROVIDER=hf_response_api
|
||||
```
|
||||
|
||||
Configure API keys:
|
||||
|
||||
```bash
|
||||
# For Google Gemini
|
||||
export GEMINI_API_KEY=your_gemini_api_key_here
|
||||
|
||||
# For Hugging Face
|
||||
export HF_TOKEN=your_huggingface_token_here
|
||||
```
|
||||
|
||||
## Key Features
|
||||
|
||||
- **Structured JSON Response Generation**: Generate structured outputs with schema validation
|
||||
- **Text Response Generation**: Simple text generation with retry logic
|
||||
- **Comprehensive Error Handling**: Robust error handling and logging
|
||||
- **Automatic API Key Management**: Secure API key handling
|
||||
- **Support for Multiple Models**: gemini-2.5-flash and gemini-2.5-pro
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Use Structured Output for Complex Responses
|
||||
```python
|
||||
# ✅ Good: Use structured output for multi-field responses
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tasks": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"description": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
result = gemini_structured_json_response(prompt, schema, temperature=0.2, max_tokens=8192)
|
||||
```
|
||||
|
||||
### 2. Keep Schemas Simple and Flat
|
||||
```python
|
||||
# ✅ Good: Simple, flat schema
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"monitoringTasks": {
|
||||
"type": "array",
|
||||
"items": {"type": "object", "properties": {...}}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# ❌ Avoid: Complex nested schemas with many required fields
|
||||
schema = {
|
||||
"type": "object",
|
||||
"required": ["field1", "field2", "field3"],
|
||||
"properties": {
|
||||
"field1": {"type": "object", "required": [...], "properties": {...}},
|
||||
"field2": {"type": "array", "items": {"type": "object", "required": [...], "properties": {...}}}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Set Appropriate Token Limits
|
||||
```python
|
||||
# ✅ Good: Use 8192 tokens for complex outputs
|
||||
result = gemini_structured_json_response(prompt, schema, max_tokens=8192)
|
||||
|
||||
# ✅ Good: Use 2048 tokens for simple text responses
|
||||
result = gemini_text_response(prompt, max_tokens=2048)
|
||||
```
|
||||
|
||||
### 4. Use Low Temperature for Structured Output
|
||||
```python
|
||||
# ✅ Good: Low temperature for consistent structured output
|
||||
result = gemini_structured_json_response(prompt, schema, temperature=0.1, max_tokens=8192)
|
||||
|
||||
# ✅ Good: Higher temperature for creative text
|
||||
result = gemini_text_response(prompt, temperature=0.8, max_tokens=2048)
|
||||
```
|
||||
|
||||
### 5. Implement Proper Error Handling
|
||||
```python
|
||||
# ✅ Good: Handle errors in calling functions
|
||||
try:
|
||||
response = gemini_structured_json_response(prompt, schema)
|
||||
if isinstance(response, dict) and "error" in response:
|
||||
raise Exception(f"Gemini error: {response.get('error')}")
|
||||
# Process successful response
|
||||
except Exception as e:
|
||||
logger.error(f"AI service error: {e}")
|
||||
# Handle error appropriately
|
||||
```
|
||||
|
||||
### 6. Avoid Fallback to Text Parsing
|
||||
```python
|
||||
# ✅ Good: Use structured output only, no fallback
|
||||
response = gemini_structured_json_response(prompt, schema)
|
||||
if "error" in response:
|
||||
raise Exception(f"Gemini error: {response.get('error')}")
|
||||
|
||||
# ❌ Avoid: Fallback to text parsing for structured responses
|
||||
# This can lead to inconsistent results and parsing errors
|
||||
```
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Structured JSON Response
|
||||
```python
|
||||
from services.llm_providers.gemini_provider import gemini_structured_json_response
|
||||
|
||||
# Define schema
|
||||
monitoring_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"monitoringTasks": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"component": {"type": "string"},
|
||||
"title": {"type": "string"},
|
||||
"description": {"type": "string"},
|
||||
"assignee": {"type": "string"},
|
||||
"frequency": {"type": "string"},
|
||||
"metric": {"type": "string"},
|
||||
"measurementMethod": {"type": "string"},
|
||||
"successCriteria": {"type": "string"},
|
||||
"alertThreshold": {"type": "string"},
|
||||
"actionableInsights": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Generate structured response
|
||||
prompt = "Generate a monitoring plan for content strategy..."
|
||||
result = gemini_structured_json_response(
|
||||
prompt=prompt,
|
||||
schema=monitoring_schema,
|
||||
temperature=0.1,
|
||||
max_tokens=8192
|
||||
)
|
||||
|
||||
# Handle response
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
raise Exception(f"Gemini error: {result.get('error')}")
|
||||
|
||||
# Process successful response
|
||||
monitoring_tasks = result.get("monitoringTasks", [])
|
||||
```
|
||||
|
||||
### Text Response
|
||||
```python
|
||||
from services.llm_providers.gemini_provider import gemini_text_response
|
||||
|
||||
# Generate text response
|
||||
prompt = "Write a blog post about AI in content marketing..."
|
||||
result = gemini_text_response(
|
||||
prompt=prompt,
|
||||
temperature=0.8,
|
||||
max_tokens=2048
|
||||
)
|
||||
|
||||
# Process response
|
||||
if result:
|
||||
print(f"Generated text: {result}")
|
||||
else:
|
||||
print("No response generated")
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues and Solutions
|
||||
|
||||
#### 1. Response.parsed is None
|
||||
**Symptoms**: `response.parsed` returns `None` even with successful HTTP 200
|
||||
**Causes**:
|
||||
- Schema too complex for the model
|
||||
- Token limit too low
|
||||
- Temperature too high for structured output
|
||||
|
||||
**Solutions**:
|
||||
- Simplify schema structure
|
||||
- Increase `max_tokens` to 8192
|
||||
- Lower temperature to 0.1-0.3
|
||||
- Test with smaller outputs first
|
||||
|
||||
#### 2. JSON Parsing Fails
|
||||
**Symptoms**: `JSONDecodeError` or "Unterminated string" errors
|
||||
**Causes**:
|
||||
- Response truncated due to token limits
|
||||
- Schema doesn't match expected output
|
||||
- Model generates malformed JSON
|
||||
|
||||
**Solutions**:
|
||||
- Reduce output size requested
|
||||
- Verify schema matches expected structure
|
||||
- Use structured output instead of text parsing
|
||||
- Increase token limits
|
||||
|
||||
#### 3. Truncation Issues
|
||||
**Symptoms**: Response cuts off mid-sentence or mid-array
|
||||
**Causes**:
|
||||
- Output too large for single response
|
||||
- Token limits exceeded
|
||||
|
||||
**Solutions**:
|
||||
- Reduce number of items requested
|
||||
- Increase `max_tokens` to 8192
|
||||
- Break large requests into smaller chunks
|
||||
- Use `gemini-2.5-pro` for larger outputs
|
||||
|
||||
#### 4. Rate Limiting
|
||||
**Symptoms**: `RetryError` or connection timeouts
|
||||
**Causes**:
|
||||
- Too many requests in short time
|
||||
- Network connectivity issues
|
||||
|
||||
**Solutions**:
|
||||
- Exponential backoff already implemented
|
||||
- Check network connectivity
|
||||
- Reduce request frequency
|
||||
- Verify API key validity
|
||||
|
||||
### Debug Logging
|
||||
|
||||
The module includes comprehensive debug logging. Enable debug mode to see:
|
||||
|
||||
```python
|
||||
import logging
|
||||
logging.getLogger('services.llm_providers.gemini_provider').setLevel(logging.DEBUG)
|
||||
```
|
||||
|
||||
Key log messages to monitor:
|
||||
- `Gemini structured call | prompt_len=X | schema_kind=Y | temp=Z`
|
||||
- `Gemini response | type=X | has_text=Y | has_parsed=Z`
|
||||
- `Using response.parsed for structured output`
|
||||
- `Falling back to response.text parsing`
|
||||
|
||||
## API Reference
|
||||
|
||||
### gemini_structured_json_response()
|
||||
|
||||
Generate structured JSON response using Google's Gemini Pro model.
|
||||
|
||||
**Parameters**:
|
||||
- `prompt` (str): Input prompt for the AI model
|
||||
- `schema` (dict): JSON schema defining expected output structure
|
||||
- `temperature` (float): Controls randomness (0.0-1.0). Use 0.1-0.3 for structured output
|
||||
- `top_p` (float): Nucleus sampling parameter (0.0-1.0)
|
||||
- `top_k` (int): Top-k sampling parameter
|
||||
- `max_tokens` (int): Maximum tokens in response. Use 8192 for complex outputs
|
||||
- `system_prompt` (str, optional): System instruction for the model
|
||||
|
||||
**Returns**:
|
||||
- `dict`: Parsed JSON response matching the provided schema
|
||||
|
||||
**Raises**:
|
||||
- `Exception`: If API key is missing or API call fails
|
||||
|
||||
### gemini_text_response()
|
||||
|
||||
Generate text response using Google's Gemini Pro model.
|
||||
|
||||
**Parameters**:
|
||||
- `prompt` (str): Input prompt for the AI model
|
||||
- `temperature` (float): Controls randomness (0.0-1.0). Higher = more creative
|
||||
- `top_p` (float): Nucleus sampling parameter (0.0-1.0)
|
||||
- `n` (int): Number of responses to generate
|
||||
- `max_tokens` (int): Maximum tokens in response
|
||||
- `system_prompt` (str, optional): System instruction for the model
|
||||
|
||||
**Returns**:
|
||||
- `str`: Generated text response
|
||||
|
||||
**Raises**:
|
||||
- `Exception`: If API key is missing or API call fails
|
||||
|
||||
## Dependencies
|
||||
|
||||
- `google.generativeai` (genai): Official Gemini API client
|
||||
- `tenacity`: Retry logic with exponential backoff
|
||||
- `logging`: Debug and error logging
|
||||
- `json`: Fallback JSON parsing
|
||||
- `re`: Text extraction utilities
|
||||
|
||||
## Version History
|
||||
|
||||
- **v2.0** (January 2025): Enhanced structured output support, improved error handling, comprehensive documentation
|
||||
- **v1.0**: Initial implementation with basic text and structured response support
|
||||
|
||||
## Contributing
|
||||
|
||||
When contributing to this module:
|
||||
|
||||
1. Follow the established patterns for error handling
|
||||
2. Add comprehensive logging for debugging
|
||||
3. Test with both simple and complex schemas
|
||||
4. Update documentation for any new features
|
||||
5. Ensure backward compatibility
|
||||
|
||||
## Support
|
||||
|
||||
For issues or questions:
|
||||
1. Check the troubleshooting section above
|
||||
2. Review debug logs for specific error messages
|
||||
3. Test with simplified schemas to isolate issues
|
||||
4. Verify API key configuration and network connectivity
|
||||
237
backend/services/llm_providers/README_HUGGINGFACE_INTEGRATION.md
Normal file
237
backend/services/llm_providers/README_HUGGINGFACE_INTEGRATION.md
Normal file
@@ -0,0 +1,237 @@
|
||||
# Hugging Face Integration for AI Blog Writer
|
||||
|
||||
## Overview
|
||||
|
||||
The AI Blog Writer now supports both Google Gemini and Hugging Face as LLM providers, with a clean environment variable-based configuration system. This integration uses the [Hugging Face Responses API](https://huggingface.co/docs/inference-providers/guides/responses-api) which provides a unified interface for model interactions.
|
||||
|
||||
## Supported Providers
|
||||
|
||||
### 1. Google Gemini (Default)
|
||||
- **Provider ID**: `google`
|
||||
- **Environment Variable**: `GEMINI_API_KEY`
|
||||
- **Models**: `gemini-2.0-flash-001`
|
||||
- **Features**: Text generation, structured JSON output
|
||||
|
||||
### 2. Hugging Face
|
||||
- **Provider ID**: `huggingface`
|
||||
- **Environment Variable**: `HF_TOKEN`
|
||||
- **Models**: Multiple models via Inference Providers
|
||||
- **Features**: Text generation, structured JSON output, multi-model support
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
Set the `GPT_PROVIDER` environment variable to choose your preferred provider:
|
||||
|
||||
```bash
|
||||
# Use Google Gemini (default)
|
||||
export GPT_PROVIDER=gemini
|
||||
# or
|
||||
export GPT_PROVIDER=google
|
||||
|
||||
# Use Hugging Face
|
||||
export GPT_PROVIDER=hf_response_api
|
||||
# or
|
||||
export GPT_PROVIDER=huggingface
|
||||
# or
|
||||
export GPT_PROVIDER=hf
|
||||
```
|
||||
|
||||
### API Keys
|
||||
|
||||
Configure the appropriate API key for your chosen provider:
|
||||
|
||||
```bash
|
||||
# For Google Gemini
|
||||
export GEMINI_API_KEY=your_gemini_api_key_here
|
||||
|
||||
# For Hugging Face
|
||||
export HF_TOKEN=your_huggingface_token_here
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Text Generation
|
||||
|
||||
```python
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
|
||||
# Generate text (uses configured provider)
|
||||
response = llm_text_gen("Write a blog post about AI trends")
|
||||
print(response)
|
||||
```
|
||||
|
||||
### Structured JSON Generation
|
||||
|
||||
```python
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
|
||||
# Define JSON schema
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"sections": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"heading": {"type": "string"},
|
||||
"content": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Generate structured response
|
||||
response = llm_text_gen(
|
||||
"Create a blog outline about machine learning",
|
||||
json_struct=schema
|
||||
)
|
||||
print(response)
|
||||
```
|
||||
|
||||
### Direct Provider Usage
|
||||
|
||||
```python
|
||||
# Google Gemini
|
||||
from services.llm_providers.gemini_provider import gemini_text_response
|
||||
|
||||
response = gemini_text_response(
|
||||
prompt="Write about AI",
|
||||
temperature=0.7,
|
||||
max_tokens=1000
|
||||
)
|
||||
|
||||
# Hugging Face
|
||||
from services.llm_providers.huggingface_provider import huggingface_text_response
|
||||
|
||||
response = huggingface_text_response(
|
||||
prompt="Write about AI",
|
||||
model="openai/gpt-oss-120b:groq",
|
||||
temperature=0.7,
|
||||
max_tokens=1000
|
||||
)
|
||||
```
|
||||
|
||||
## Available Hugging Face Models
|
||||
|
||||
The Hugging Face provider supports multiple models via Inference Providers:
|
||||
|
||||
- `openai/gpt-oss-120b:groq` (default)
|
||||
- `moonshotai/Kimi-K2-Instruct-0905:groq`
|
||||
- `Qwen/Qwen2.5-VL-7B-Instruct`
|
||||
- `meta-llama/Llama-3.1-8B-Instruct:groq`
|
||||
- `microsoft/Phi-3-medium-4k-instruct:groq`
|
||||
- `mistralai/Mistral-7B-Instruct-v0.3:groq`
|
||||
|
||||
## Provider Selection Logic
|
||||
|
||||
1. **Environment Variable**: If `GPT_PROVIDER` is set, use the specified provider
|
||||
2. **Auto-detection**: If no environment variable, check available API keys:
|
||||
- Prefer Google Gemini if `GEMINI_API_KEY` is available
|
||||
- Fall back to Hugging Face if `HF_TOKEN` is available
|
||||
3. **Fallback**: If the specified provider fails, automatically try the other provider
|
||||
|
||||
## Error Handling
|
||||
|
||||
The system includes comprehensive error handling:
|
||||
|
||||
- **Missing API Keys**: Clear error messages with setup instructions
|
||||
- **Provider Failures**: Automatic fallback to the other provider
|
||||
- **Invalid Models**: Validation with helpful error messages
|
||||
- **Network Issues**: Retry logic with exponential backoff
|
||||
|
||||
## Migration from Previous Version
|
||||
|
||||
### Removed Providers
|
||||
The following providers have been removed to simplify the system:
|
||||
- OpenAI
|
||||
- Anthropic
|
||||
- DeepSeek
|
||||
|
||||
### Updated Imports
|
||||
```python
|
||||
# Old imports (no longer work)
|
||||
from services.llm_providers.openai_provider import openai_chatgpt
|
||||
from services.llm_providers.anthropic_provider import anthropic_text_response
|
||||
from services.llm_providers.deepseek_provider import deepseek_text_response
|
||||
|
||||
# New imports
|
||||
from services.llm_providers.gemini_provider import gemini_text_response, gemini_structured_json_response
|
||||
from services.llm_providers.huggingface_provider import huggingface_text_response, huggingface_structured_json_response
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
Run the integration tests to verify everything works:
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
python -c "
|
||||
import sys
|
||||
sys.path.insert(0, '.')
|
||||
from services.llm_providers.main_text_generation import check_gpt_provider
|
||||
print('Google provider supported:', check_gpt_provider('google'))
|
||||
print('Hugging Face provider supported:', check_gpt_provider('huggingface'))
|
||||
print('OpenAI provider supported:', check_gpt_provider('openai'))
|
||||
"
|
||||
```
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### Google Gemini
|
||||
- Fast response times
|
||||
- High-quality outputs
|
||||
- Good for structured content
|
||||
|
||||
### Hugging Face
|
||||
- Multiple model options
|
||||
- Cost-effective for high-volume usage
|
||||
- Good for experimentation with different models
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **"No LLM API keys configured"**
|
||||
- Ensure either `GEMINI_API_KEY` or `HF_TOKEN` is set
|
||||
- Check that the API key is valid
|
||||
|
||||
2. **"Unknown LLM provider"**
|
||||
- Use only `google` or `huggingface` as provider values
|
||||
- Check the `GPT_PROVIDER` environment variable
|
||||
|
||||
3. **"HF_TOKEN appears to be invalid"**
|
||||
- Ensure your Hugging Face token starts with `hf_`
|
||||
- Get a new token from [Hugging Face Settings](https://huggingface.co/settings/tokens)
|
||||
|
||||
4. **"OpenAI library not available"**
|
||||
- Install the OpenAI library: `pip install openai`
|
||||
- This is required for Hugging Face Responses API
|
||||
|
||||
### Debug Mode
|
||||
|
||||
Enable debug logging to see provider selection:
|
||||
|
||||
```python
|
||||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
```
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- Support for additional Hugging Face models
|
||||
- Model-specific parameter optimization
|
||||
- Advanced caching strategies
|
||||
- Performance monitoring and metrics
|
||||
- A/B testing between providers
|
||||
|
||||
## Support
|
||||
|
||||
For issues or questions:
|
||||
1. Check the troubleshooting section above
|
||||
2. Review the [Hugging Face Responses API documentation](https://huggingface.co/docs/inference-providers/guides/responses-api)
|
||||
3. Check the Google Gemini API documentation for Gemini-specific issues
|
||||
18
backend/services/llm_providers/__init__.py
Normal file
18
backend/services/llm_providers/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""LLM Providers Service for ALwrity Backend.
|
||||
|
||||
This service handles all LLM (Language Model) provider integrations,
|
||||
migrated from the legacy lib/gpt_providers functionality.
|
||||
"""
|
||||
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from services.llm_providers.gemini_provider import gemini_text_response, gemini_structured_json_response
|
||||
from services.llm_providers.huggingface_provider import huggingface_text_response, huggingface_structured_json_response
|
||||
|
||||
|
||||
__all__ = [
|
||||
"llm_text_gen",
|
||||
"gemini_text_response",
|
||||
"gemini_structured_json_response",
|
||||
"huggingface_text_response",
|
||||
"huggingface_structured_json_response"
|
||||
]
|
||||
@@ -0,0 +1,310 @@
|
||||
"""
|
||||
Gemini Audio Text Generation Module
|
||||
|
||||
This module provides a comprehensive interface for working with audio files using Google's Gemini API.
|
||||
It supports various audio processing capabilities including transcription, summarization, and analysis.
|
||||
|
||||
Key Features:
|
||||
------------
|
||||
1. Audio Transcription: Convert speech in audio files to text
|
||||
2. Audio Summarization: Generate concise summaries of audio content
|
||||
3. Segment Analysis: Analyze specific time segments of audio files
|
||||
4. Timestamped Transcription: Generate transcriptions with timestamps
|
||||
5. Token Counting: Count tokens in audio files
|
||||
6. Format Support: Information about supported audio formats
|
||||
|
||||
Supported Audio Formats:
|
||||
----------------------
|
||||
- WAV (audio/wav)
|
||||
- MP3 (audio/mp3)
|
||||
- AIFF (audio/aiff)
|
||||
- AAC (audio/aac)
|
||||
- OGG Vorbis (audio/ogg)
|
||||
- FLAC (audio/flac)
|
||||
|
||||
Technical Details:
|
||||
----------------
|
||||
- Each second of audio is represented as 32 tokens
|
||||
- Maximum supported length of audio data in a single prompt is 9.5 hours
|
||||
- Audio files are downsampled to 16 Kbps data resolution
|
||||
- Multi-channel audio is combined into a single channel
|
||||
|
||||
Usage:
|
||||
------
|
||||
```python
|
||||
from lib.gpt_providers.audio_to_text_generation.gemini_audio_text import transcribe_audio, summarize_audio
|
||||
|
||||
# Basic transcription
|
||||
transcript = transcribe_audio("path/to/audio.mp3")
|
||||
print(transcript)
|
||||
|
||||
# Summarization
|
||||
summary = summarize_audio("path/to/audio.mp3")
|
||||
print(summary)
|
||||
|
||||
# Analyze specific segment
|
||||
segment_analysis = analyze_audio_segment("path/to/audio.mp3", "02:30", "03:29")
|
||||
print(segment_analysis)
|
||||
```
|
||||
|
||||
Requirements:
|
||||
------------
|
||||
- GEMINI_API_KEY environment variable must be set
|
||||
- google-generativeai Python package
|
||||
- python-dotenv for environment variable management
|
||||
- loguru for logging
|
||||
|
||||
Dependencies:
|
||||
------------
|
||||
- google.genai
|
||||
- dotenv
|
||||
- loguru
|
||||
- os, sys, base64, typing
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import google.genai as genai
|
||||
from google.genai import types
|
||||
|
||||
|
||||
from loguru import logger
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
# Use service-specific logger to avoid conflicts
|
||||
logger = get_service_logger("gemini_audio_text")
|
||||
|
||||
|
||||
def load_environment():
|
||||
"""Loads environment variables from a .env file."""
|
||||
load_dotenv()
|
||||
logger.info("Environment variables loaded successfully.")
|
||||
|
||||
|
||||
def configure_google_api():
|
||||
"""
|
||||
Configures the Google Gemini API with the API key from environment variables.
|
||||
|
||||
Raises:
|
||||
ValueError: If the GEMINI_API_KEY environment variable is not set.
|
||||
"""
|
||||
# Use APIKeyManager instead of direct environment variable access
|
||||
api_key_manager = APIKeyManager()
|
||||
api_key = api_key_manager.get_api_key("gemini")
|
||||
|
||||
if not api_key:
|
||||
error_message = "Gemini API key not found. Please configure it in the onboarding process."
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
genai.configure(api_key=api_key)
|
||||
logger.info("Google Gemini API configured successfully.")
|
||||
|
||||
|
||||
def transcribe_audio(audio_file_path: str, prompt: str = "Transcribe the following audio:") -> Optional[str]:
|
||||
"""
|
||||
Transcribes audio using Google's Gemini model.
|
||||
|
||||
Args:
|
||||
audio_file_path (str): The path to the audio file to be transcribed.
|
||||
prompt (str, optional): The prompt to guide the transcription. Defaults to "Transcribe the following audio:".
|
||||
|
||||
Returns:
|
||||
str: The transcribed text from the audio.
|
||||
Returns None if transcription fails.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the audio file is not found.
|
||||
"""
|
||||
try:
|
||||
# Load environment variables and configure the Google API
|
||||
load_environment()
|
||||
configure_google_api()
|
||||
|
||||
logger.info(f"Attempting to transcribe audio file: {audio_file_path}")
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.exists(audio_file_path):
|
||||
error_message = f"FileNotFoundError: The audio file at {audio_file_path} does not exist."
|
||||
logger.error(error_message)
|
||||
raise FileNotFoundError(error_message)
|
||||
|
||||
# Initialize a Gemini model appropriate for audio understanding
|
||||
model = genai.GenerativeModel(model_name="gemini-1.5-flash")
|
||||
|
||||
# Upload the audio file
|
||||
try:
|
||||
audio_file = genai.upload_file(audio_file_path)
|
||||
logger.info(f"Audio file uploaded successfully: {audio_file=}")
|
||||
except FileNotFoundError:
|
||||
error_message = f"FileNotFoundError: The audio file at {audio_file_path} does not exist."
|
||||
logger.error(error_message)
|
||||
raise FileNotFoundError(error_message)
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading audio file: {e}")
|
||||
return None
|
||||
|
||||
# Generate the transcription
|
||||
try:
|
||||
response = model.generate_content([
|
||||
prompt,
|
||||
audio_file
|
||||
])
|
||||
|
||||
# Check for valid response and extract text
|
||||
if response and hasattr(response, 'text'):
|
||||
transcript = response.text
|
||||
logger.info(f"Transcription successful:\n{transcript}")
|
||||
return transcript
|
||||
else:
|
||||
logger.warning("Transcription failed: Invalid or empty response from API.")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during transcription: {e}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"An unexpected error occurred: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def summarize_audio(audio_file_path: str) -> Optional[str]:
|
||||
"""
|
||||
Summarizes the content of an audio file using Google's Gemini model.
|
||||
|
||||
Args:
|
||||
audio_file_path (str): The path to the audio file to be summarized.
|
||||
|
||||
Returns:
|
||||
str: A summary of the audio content.
|
||||
Returns None if summarization fails.
|
||||
"""
|
||||
return transcribe_audio(audio_file_path, prompt="Please summarize the audio content:")
|
||||
|
||||
|
||||
def analyze_audio_segment(audio_file_path: str, start_time: str, end_time: str) -> Optional[str]:
|
||||
"""
|
||||
Analyzes a specific segment of an audio file using timestamps.
|
||||
|
||||
Args:
|
||||
audio_file_path (str): The path to the audio file.
|
||||
start_time (str): Start time in MM:SS format.
|
||||
end_time (str): End time in MM:SS format.
|
||||
|
||||
Returns:
|
||||
str: Analysis of the specified audio segment.
|
||||
Returns None if analysis fails.
|
||||
"""
|
||||
prompt = f"Analyze the audio content from {start_time} to {end_time}."
|
||||
return transcribe_audio(audio_file_path, prompt=prompt)
|
||||
|
||||
|
||||
def transcribe_with_timestamps(audio_file_path: str) -> Optional[str]:
|
||||
"""
|
||||
Transcribes audio with timestamps for each segment.
|
||||
|
||||
Args:
|
||||
audio_file_path (str): The path to the audio file.
|
||||
|
||||
Returns:
|
||||
str: Transcription with timestamps.
|
||||
Returns None if transcription fails.
|
||||
"""
|
||||
return transcribe_audio(audio_file_path, prompt="Transcribe the audio with timestamps for each segment:")
|
||||
|
||||
|
||||
def count_tokens(audio_file_path: str) -> Optional[int]:
|
||||
"""
|
||||
Counts the number of tokens in an audio file.
|
||||
|
||||
Args:
|
||||
audio_file_path (str): The path to the audio file.
|
||||
|
||||
Returns:
|
||||
int: Number of tokens in the audio file.
|
||||
Returns None if counting fails.
|
||||
"""
|
||||
try:
|
||||
# Load environment variables and configure the Google API
|
||||
load_environment()
|
||||
configure_google_api()
|
||||
|
||||
logger.info(f"Attempting to count tokens in audio file: {audio_file_path}")
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.exists(audio_file_path):
|
||||
error_message = f"FileNotFoundError: The audio file at {audio_file_path} does not exist."
|
||||
logger.error(error_message)
|
||||
raise FileNotFoundError(error_message)
|
||||
|
||||
# Initialize a Gemini model
|
||||
model = genai.GenerativeModel(model_name="gemini-1.5-flash")
|
||||
|
||||
# Upload the audio file
|
||||
try:
|
||||
audio_file = genai.upload_file(audio_file_path)
|
||||
logger.info(f"Audio file uploaded successfully: {audio_file=}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading audio file: {e}")
|
||||
return None
|
||||
|
||||
# Count tokens
|
||||
try:
|
||||
response = model.count_tokens([audio_file])
|
||||
token_count = response.total_tokens
|
||||
logger.info(f"Token count: {token_count}")
|
||||
return token_count
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting tokens: {e}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"An unexpected error occurred: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_supported_formats() -> List[str]:
|
||||
"""
|
||||
Returns a list of supported audio formats.
|
||||
|
||||
Returns:
|
||||
List[str]: List of supported MIME types.
|
||||
"""
|
||||
return [
|
||||
"audio/wav",
|
||||
"audio/mp3",
|
||||
"audio/aiff",
|
||||
"audio/aac",
|
||||
"audio/ogg",
|
||||
"audio/flac"
|
||||
]
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Example 1: Basic transcription
|
||||
audio_path = "path/to/your/audio.mp3"
|
||||
transcript = transcribe_audio(audio_path)
|
||||
print(f"Transcript: {transcript}")
|
||||
|
||||
# Example 2: Summarization
|
||||
summary = summarize_audio(audio_path)
|
||||
print(f"Summary: {summary}")
|
||||
|
||||
# Example 3: Analyze specific segment
|
||||
segment_analysis = analyze_audio_segment(audio_path, "02:30", "03:29")
|
||||
print(f"Segment Analysis: {segment_analysis}")
|
||||
|
||||
# Example 4: Transcription with timestamps
|
||||
timestamped_transcript = transcribe_with_timestamps(audio_path)
|
||||
print(f"Timestamped Transcript: {timestamped_transcript}")
|
||||
|
||||
# Example 5: Count tokens
|
||||
token_count = count_tokens(audio_path)
|
||||
print(f"Token Count: {token_count}")
|
||||
|
||||
# Example 6: Get supported formats
|
||||
formats = get_supported_formats()
|
||||
print(f"Supported Formats: {formats}")
|
||||
@@ -0,0 +1,218 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from pytubefix import YouTube
|
||||
from loguru import logger
|
||||
from openai import OpenAI
|
||||
from tqdm import tqdm
|
||||
import streamlit as st
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
) # for exponential backoff
|
||||
|
||||
from .gemini_audio_text import transcribe_audio
|
||||
|
||||
# Import APIKeyManager
|
||||
from ...onboarding.api_key_manager import APIKeyManager
|
||||
|
||||
|
||||
def progress_function(stream, chunk, bytes_remaining):
|
||||
# Calculate the percentage completion
|
||||
current = ((stream.filesize - bytes_remaining) / stream.filesize)
|
||||
progress_bar.update(current - progress_bar.n) # Update the progress bar
|
||||
|
||||
|
||||
def rename_file_with_underscores(file_path):
|
||||
"""Rename a file by replacing spaces and special characters with underscores.
|
||||
|
||||
Args:
|
||||
file_path (str): The original file path.
|
||||
|
||||
Returns:
|
||||
str: The new file path with underscores.
|
||||
"""
|
||||
# Extract the directory and the filename
|
||||
dir_name, original_filename = os.path.split(file_path)
|
||||
|
||||
# Replace spaces and special characters with underscores in the filename
|
||||
new_filename = re.sub(r'[^\w\-_\.]', '_', original_filename)
|
||||
|
||||
# Create the new file path
|
||||
new_file_path = os.path.join(dir_name, new_filename)
|
||||
|
||||
# Rename the file
|
||||
os.rename(file_path, new_file_path)
|
||||
|
||||
return new_file_path
|
||||
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
def speech_to_text(video_url):
|
||||
"""
|
||||
Transcribes speech to text from a YouTube video URL using OpenAI's Whisper model.
|
||||
|
||||
Args:
|
||||
video_url (str): URL of the YouTube video to transcribe.
|
||||
output_path (str, optional): Directory where the audio file will be saved. Defaults to '.'.
|
||||
|
||||
Returns:
|
||||
str: The transcribed text from the video.
|
||||
|
||||
Raises:
|
||||
SystemExit: If a critical error occurs that prevents successful execution.
|
||||
"""
|
||||
output_path = os.getenv("CONTENT_SAVE_DIR")
|
||||
yt = None
|
||||
audio_file = None
|
||||
with st.status("Started Writing..", expanded=False) as status:
|
||||
try:
|
||||
if video_url.startswith("https://www.youtube.com/") or video_url.startswith("http://www.youtube.com/"):
|
||||
logger.info(f"Accessing YouTube URL: {video_url}")
|
||||
status.update(label=f"Accessing YouTube URL: {video_url}")
|
||||
try:
|
||||
vid_id = video_url.split("=")[1]
|
||||
yt = YouTube(video_url, on_progress_callback=progress_function)
|
||||
except Exception as err:
|
||||
logger.error(f"Failed to get pytube stream object: {err}")
|
||||
st.stop()
|
||||
|
||||
logger.info(f"Fetching the highest quality audio stream:{yt.title}")
|
||||
status.update(label=f"Fetching the highest quality audio stream: {yt.title}")
|
||||
try:
|
||||
audio_stream = yt.streams.filter(only_audio=True).first()
|
||||
except Exception as err:
|
||||
logger.error(f"Failed to Download Youtube Audio: {err}")
|
||||
st.stop()
|
||||
|
||||
if audio_stream is None:
|
||||
logger.warning("No audio stream found for this video.")
|
||||
st.warning("No audio stream found for this video.")
|
||||
st.stop()
|
||||
|
||||
logger.info(f"Downloading audio for: {yt.title}")
|
||||
status.update(label=f"Downloading audio for: {yt.title}")
|
||||
global progress_bar
|
||||
progress_bar = tqdm(total=1.0, unit='iB', unit_scale=True, desc=yt.title)
|
||||
try:
|
||||
audio_filename = re.sub(r'[^\w\-_\.]', '_', yt.title) + '.mp4'
|
||||
audio_file = audio_stream.download(
|
||||
output_path=os.getenv("CONTENT_SAVE_DIR"),
|
||||
filename=audio_filename)
|
||||
#audio_file = rename_file_with_underscores(audio_file)
|
||||
except Exception as err:
|
||||
logger.error(f"Failed to download audio file: {audio_file}")
|
||||
|
||||
progress_bar.close()
|
||||
logger.info(f"Audio downloaded: {yt.title} to {audio_file}")
|
||||
status.update(label=f"Audio downloaded: {yt.title} to {output_path}")
|
||||
# Audio filepath from local directory.
|
||||
elif os.path.exists(audio_input):
|
||||
audio_file = video_url
|
||||
|
||||
# Checking file size
|
||||
max_file_size = 24 * 1024 * 1024 # 24MB
|
||||
file_size = os.path.getsize(audio_file)
|
||||
# Convert file size to MB for logging
|
||||
file_size_MB = file_size / (1024 * 1024) # Convert bytes to MB
|
||||
|
||||
logger.info(f"Downloaded Audio Size is: {file_size_MB:.2f} MB")
|
||||
status.update(label=f"Downloaded Audio Size is: {file_size_MB:.2f} MB")
|
||||
|
||||
if file_size > max_file_size:
|
||||
logger.error("File size exceeds 24MB limit.")
|
||||
# FIXME: We can chunk hour long videos, the code is not tested.
|
||||
#long_video(audio_file)
|
||||
sys.exit("File size limit exceeded.")
|
||||
st.error("Audio File size limit exceeded. File a fixme/issues at ALwrity github.")
|
||||
|
||||
try:
|
||||
print(f"Audio File: {audio_file}")
|
||||
transcript = transcribe_audio(audio_file)
|
||||
print(f"\n\n\n--- Tracribe: {transcript} ----\n\n\n")
|
||||
exit(1)
|
||||
status.update(label=f"Initializing OpenAI client for transcription: {audio_file}")
|
||||
logger.info(f"Initializing OpenAI client for transcription: {audio_file}")
|
||||
|
||||
# Use APIKeyManager instead of direct environment variable access
|
||||
api_key_manager = APIKeyManager()
|
||||
api_key = api_key_manager.get_api_key("openai")
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("OpenAI API key not found. Please configure it in the onboarding process.")
|
||||
|
||||
client = OpenAI(api_key=api_key)
|
||||
|
||||
logger.info("Transcribing using OpenAI's Whisper model.")
|
||||
transcript = client.audio.transcriptions.create(
|
||||
model="whisper-1",
|
||||
file=open(audio_file, "rb"),
|
||||
response_format="text"
|
||||
)
|
||||
logger.info(f"\nYouTube video transcription:\n{yt.title}\n{transcript}\n")
|
||||
status.update(label=f"\nYouTube video transcription:\n{yt.title}\n{transcript}\n")
|
||||
return transcript, yt.title
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed in Whisper transcription: {e}")
|
||||
st.warning(f"Failed in Openai Whisper transcription: {e}")
|
||||
transcript = transcribe_audio(audio_file)
|
||||
print(f"\n\n\n--- Tracribe: {transcript} ----\n\n\n")
|
||||
return transcript, yt.title
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"An error occurred during YouTube video processing: {e}")
|
||||
|
||||
finally:
|
||||
try:
|
||||
if os.path.exists(audio_file):
|
||||
os.remove(audio_file)
|
||||
logger.info("Temporary audio file removed.")
|
||||
except PermissionError:
|
||||
st.error(f"Permission error: Cannot remove '{audio_file}'. Please make sure of necessary permissions.")
|
||||
except Exception as e:
|
||||
st.error(f"An error occurred removing audio file: {e}")
|
||||
|
||||
|
||||
def long_video(temp_file_name):
|
||||
"""
|
||||
Transcribes a YouTube video using OpenAI's Whisper API by processing the video in chunks.
|
||||
|
||||
This function handles videos longer than the context limit of the Whisper API by dividing the video into
|
||||
10-minute segments, transcribing each segment individually, and then combining the results.
|
||||
|
||||
Key Changes and Notes:
|
||||
1. Video Splitting: Splits the audio into 10-minute chunks using the moviepy library.
|
||||
2. Chunk Transcription: Each audio chunk is transcribed separately and the results are concatenated.
|
||||
3. Temporary Files for Chunks: Uses temporary files for each audio chunk for transcription.
|
||||
4. Error Handling: Exception handling is included to capture and return any errors during the process.
|
||||
5. Logging: Process steps are logged for debugging and monitoring.
|
||||
6. Cleaning Up: Removes temporary files for both the entire video and individual audio chunks after processing.
|
||||
|
||||
Args:
|
||||
video_url (str): URL of the YouTube video to be transcribed.
|
||||
"""
|
||||
# Extract audio and split into chunks
|
||||
logger.info(f"Processing the YT video: {temp_file_name}")
|
||||
full_audio = mp.AudioFileClip(temp_file_name)
|
||||
duration = full_audio.duration
|
||||
chunk_length = 600 # 10 minutes in seconds
|
||||
chunks = [full_audio.subclip(start, min(start + chunk_length, duration)) for start in range(0, int(duration), chunk_length)]
|
||||
|
||||
combined_transcript = ""
|
||||
for i, chunk in enumerate(chunks):
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as audio_chunk_file:
|
||||
chunk.write_audiofile(audio_chunk_file.name, codec="mp3")
|
||||
with open(audio_chunk_file.name, "rb", encoding="utf-8") as audio_file:
|
||||
# Transcribe each chunk using OpenAI's Whisper API
|
||||
app.logger.info(f"Transcribing chunk {i+1}/{len(chunks)}")
|
||||
transcript = openai.Audio.transcribe("whisper-1", audio_file)
|
||||
combined_transcript += transcript['text'] + "\n\n"
|
||||
|
||||
# Remove the chunk audio file
|
||||
os.remove(audio_chunk_file.name)
|
||||
|
||||
886
backend/services/llm_providers/gemini_grounded_provider.py
Normal file
886
backend/services/llm_providers/gemini_grounded_provider.py
Normal file
@@ -0,0 +1,886 @@
|
||||
"""
|
||||
Enhanced Gemini Provider for Grounded Content Generation
|
||||
|
||||
This provider uses native Google Search grounding to generate content that is
|
||||
factually grounded in current web sources, with automatic citation generation.
|
||||
Based on Google AI's official grounding documentation.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
GOOGLE_GENAI_AVAILABLE = True
|
||||
except ImportError:
|
||||
GOOGLE_GENAI_AVAILABLE = False
|
||||
logger.warn("Google GenAI not available. Install with: pip install google-genai")
|
||||
|
||||
|
||||
class GeminiGroundedProvider:
|
||||
"""
|
||||
Enhanced Gemini provider for grounded content generation with native Google Search.
|
||||
|
||||
This provider uses the official Google Search grounding tool to generate content
|
||||
that is factually grounded in current web sources, with automatic citation generation.
|
||||
|
||||
Based on: https://ai.google.dev/gemini-api/docs/google-search
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Gemini Grounded Provider."""
|
||||
if not GOOGLE_GENAI_AVAILABLE:
|
||||
raise ImportError("Google GenAI library not available. Install with: pip install google-genai")
|
||||
|
||||
self.api_key = os.getenv('GEMINI_API_KEY')
|
||||
if not self.api_key:
|
||||
raise ValueError("GEMINI_API_KEY environment variable is required")
|
||||
|
||||
# Initialize the Gemini client with timeout configuration
|
||||
self.client = genai.Client(api_key=self.api_key)
|
||||
self.timeout = 60 # 60 second timeout for API calls (increased for research)
|
||||
self._cache: Dict[str, Any] = {}
|
||||
logger.info("✅ Gemini Grounded Provider initialized with native Google Search grounding")
|
||||
|
||||
async def generate_grounded_content(
|
||||
self,
|
||||
prompt: str,
|
||||
content_type: str = "linkedin_post",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2048,
|
||||
urls: Optional[List[str]] = None,
|
||||
mode: str = "polished",
|
||||
user_id: Optional[str] = None,
|
||||
validate_subsequent_operations: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate grounded content using native Google Search grounding.
|
||||
|
||||
Args:
|
||||
prompt: The content generation prompt
|
||||
content_type: Type of content to generate
|
||||
temperature: Creativity level (0.0-1.0)
|
||||
max_tokens: Maximum tokens in response
|
||||
urls: Optional list of URLs for URL Context tool
|
||||
mode: Content mode ("draft" or "polished")
|
||||
user_id: User ID for subscription checking (required if validate_subsequent_operations=True)
|
||||
validate_subsequent_operations: If True, validates Google Grounding + 3 LLM calls for research workflow
|
||||
|
||||
Returns:
|
||||
Dictionary containing generated content and grounding metadata
|
||||
"""
|
||||
try:
|
||||
# PRE-FLIGHT VALIDATION: If this is part of a research workflow, validate ALL operations
|
||||
# MUST happen BEFORE any API calls - return immediately if validation fails
|
||||
if validate_subsequent_operations:
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required when validate_subsequent_operations=True")
|
||||
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_research_operations
|
||||
from fastapi import HTTPException
|
||||
import os
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
gpt_provider = os.getenv("GPT_PROVIDER", "google")
|
||||
|
||||
# Validate ALL research operations before making ANY API calls
|
||||
# This prevents wasteful external API calls if subsequent LLM calls would fail
|
||||
# Raises HTTPException immediately if validation fails - frontend gets immediate response
|
||||
validate_research_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id,
|
||||
gpt_provider=gpt_provider
|
||||
)
|
||||
except HTTPException as http_ex:
|
||||
# Re-raise immediately - don't proceed with API call
|
||||
logger.error(f"[Gemini Grounded] ❌ Pre-flight validation failed - blocking API call")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
logger.info(f"[Gemini Grounded] ✅ Pre-flight validation passed - proceeding with API call")
|
||||
logger.info(f"[Gemini Grounded] Generating grounded content for {content_type} using native Google Search")
|
||||
|
||||
# Build the grounded prompt
|
||||
grounded_prompt = self._build_grounded_prompt(prompt, content_type)
|
||||
|
||||
# Configure tools: Google Search and optional URL Context
|
||||
tools: List[Any] = [
|
||||
types.Tool(google_search=types.GoogleSearch())
|
||||
]
|
||||
if urls:
|
||||
try:
|
||||
# URL Context tool (ai.google.dev URL Context)
|
||||
tools.append(types.Tool(url_context=types.UrlContext()))
|
||||
logger.info(f"Enabled URL Context tool for {len(urls)} URLs")
|
||||
except Exception as tool_err:
|
||||
logger.warning(f"URL Context tool not available in SDK version: {tool_err}")
|
||||
|
||||
# Apply mode presets (Draft vs Polished)
|
||||
# Use Gemini 2.0 Flash for better content generation with grounding
|
||||
model_id = "gemini-2.0-flash"
|
||||
if mode == "draft":
|
||||
model_id = "gemini-2.0-flash"
|
||||
temperature = min(1.0, max(0.0, temperature))
|
||||
else:
|
||||
model_id = "gemini-2.0-flash"
|
||||
|
||||
# Configure generation settings
|
||||
config = types.GenerateContentConfig(
|
||||
tools=tools,
|
||||
max_output_tokens=max_tokens,
|
||||
temperature=temperature
|
||||
)
|
||||
|
||||
# Make the request with native grounding and timeout
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
|
||||
try:
|
||||
# Cache first
|
||||
cache_key = self._make_cache_key(model_id, grounded_prompt, urls)
|
||||
if cache_key in self._cache:
|
||||
logger.info("Cache hit for grounded content request")
|
||||
response = self._cache[cache_key]
|
||||
else:
|
||||
# Run the synchronous generate_content in a thread pool to make it awaitable
|
||||
loop = asyncio.get_event_loop()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
response = await asyncio.wait_for(
|
||||
loop.run_in_executor(
|
||||
executor,
|
||||
lambda: self.client.models.generate_content(
|
||||
model=model_id,
|
||||
contents=self._inject_urls_into_prompt(grounded_prompt, urls) if urls else grounded_prompt,
|
||||
config=config,
|
||||
)
|
||||
),
|
||||
timeout=self.timeout
|
||||
)
|
||||
self._cache[cache_key] = response
|
||||
except asyncio.TimeoutError:
|
||||
from services.blog_writer.exceptions import APITimeoutException
|
||||
raise APITimeoutException(
|
||||
f"Gemini API request timed out after {self.timeout} seconds",
|
||||
timeout_seconds=self.timeout,
|
||||
context={"content_type": content_type, "model_id": model_id}
|
||||
)
|
||||
except Exception as api_error:
|
||||
# Handle specific Google API errors with enhanced retry logic
|
||||
error_str = str(api_error)
|
||||
|
||||
# Non-retryable errors
|
||||
if "401" in error_str or "403" in error_str:
|
||||
from services.blog_writer.exceptions import ValidationException
|
||||
raise ValidationException(
|
||||
"Authentication failed. Please check your API credentials.",
|
||||
field="api_key",
|
||||
context={"error": error_str, "content_type": content_type}
|
||||
)
|
||||
elif "400" in error_str:
|
||||
from services.blog_writer.exceptions import ValidationException
|
||||
raise ValidationException(
|
||||
"Invalid request. Please check your input parameters.",
|
||||
field="request",
|
||||
context={"error": error_str, "content_type": content_type}
|
||||
)
|
||||
|
||||
# Retryable errors - use enhanced retry logic
|
||||
from services.blog_writer.retry_utils import retry_with_backoff, RESEARCH_RETRY_CONFIG
|
||||
|
||||
try:
|
||||
response = await retry_with_backoff(
|
||||
lambda: self._make_api_request_with_model(grounded_prompt, config, model_id, urls),
|
||||
config=RESEARCH_RETRY_CONFIG,
|
||||
operation_name=f"gemini_grounded_{content_type}",
|
||||
context={"content_type": content_type, "model_id": model_id}
|
||||
)
|
||||
except Exception as retry_error:
|
||||
# If retry also failed, raise the original error with context
|
||||
from services.blog_writer.exceptions import ResearchFailedException
|
||||
raise ResearchFailedException(
|
||||
f"Google AI service error after retries: {error_str}",
|
||||
context={"original_error": error_str, "retry_error": str(retry_error), "content_type": content_type}
|
||||
)
|
||||
|
||||
# Process the grounded response
|
||||
result = self._process_grounded_response(response, content_type)
|
||||
# Attach URL Context metadata if present
|
||||
try:
|
||||
if hasattr(response, 'candidates') and response.candidates:
|
||||
candidate0 = response.candidates[0]
|
||||
if hasattr(candidate0, 'url_context_metadata') and candidate0.url_context_metadata:
|
||||
result['url_context_metadata'] = candidate0.url_context_metadata
|
||||
logger.info("Attached url_context_metadata to result")
|
||||
except Exception as meta_err:
|
||||
logger.warning(f"Unable to attach url_context_metadata: {meta_err}")
|
||||
|
||||
logger.info(f"✅ Grounded content generated successfully with {len(result.get('sources', []))} sources")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# Log error without causing secondary exceptions
|
||||
try:
|
||||
logger.error(f"❌ Error generating grounded content: {str(e)}")
|
||||
except:
|
||||
# Fallback to print if logging fails
|
||||
print(f"Error generating grounded content: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _make_api_request(self, grounded_prompt: str, config: Any):
|
||||
"""Make the actual API request to Gemini."""
|
||||
import concurrent.futures
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
return await asyncio.wait_for(
|
||||
loop.run_in_executor(
|
||||
executor,
|
||||
lambda: self.client.models.generate_content(
|
||||
model="gemini-2.0-flash",
|
||||
contents=grounded_prompt,
|
||||
config=config,
|
||||
)
|
||||
),
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
async def _make_api_request_with_model(self, grounded_prompt: str, config: Any, model_id: str, urls: Optional[List[str]] = None):
|
||||
"""Make the API request with explicit model id and optional URL injection."""
|
||||
logger.info(f"🔍 DEBUG: Making API request with model: {model_id}")
|
||||
logger.info(f"🔍 DEBUG: Prompt length: {len(grounded_prompt)} characters")
|
||||
logger.info(f"🔍 DEBUG: Prompt preview (first 300 chars): {grounded_prompt[:300]}...")
|
||||
|
||||
import concurrent.futures
|
||||
loop = asyncio.get_event_loop()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
resp = await asyncio.wait_for(
|
||||
loop.run_in_executor(
|
||||
executor,
|
||||
lambda: self.client.models.generate_content(
|
||||
model=model_id,
|
||||
contents=self._inject_urls_into_prompt(grounded_prompt, urls) if urls else grounded_prompt,
|
||||
config=config,
|
||||
)
|
||||
),
|
||||
timeout=self.timeout
|
||||
)
|
||||
self._cache[self._make_cache_key(model_id, grounded_prompt, urls)] = resp
|
||||
return resp
|
||||
|
||||
def _inject_urls_into_prompt(self, prompt: str, urls: Optional[List[str]]) -> str:
|
||||
"""Append URLs to the prompt for URL Context tool to pick up (as per docs)."""
|
||||
if not urls:
|
||||
return prompt
|
||||
safe_urls = [u for u in urls if isinstance(u, str) and u.startswith("http")]
|
||||
if not safe_urls:
|
||||
return prompt
|
||||
urls_block = "\n".join(safe_urls[:20])
|
||||
return f"{prompt}\n\nSOURCE URLS (use url_context to retrieve content):\n{urls_block}"
|
||||
|
||||
def _make_cache_key(self, model_id: str, prompt: str, urls: Optional[List[str]]) -> str:
|
||||
import hashlib
|
||||
u = "|".join((urls or [])[:20])
|
||||
base = f"{model_id}|{prompt}|{u}"
|
||||
return hashlib.sha256(base.encode("utf-8")).hexdigest()
|
||||
|
||||
async def _retry_with_backoff(self, func, max_retries: int = 3, base_delay: float = 1.0):
|
||||
"""Retry a function with exponential backoff."""
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
return await func()
|
||||
except Exception as e:
|
||||
if attempt == max_retries:
|
||||
# Last attempt failed, raise the error
|
||||
raise e
|
||||
|
||||
# Calculate delay with exponential backoff
|
||||
delay = base_delay * (2 ** attempt)
|
||||
logger.warning(f"Attempt {attempt + 1} failed, retrying in {delay} seconds: {str(e)}")
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
def _build_grounded_prompt(self, prompt: str, content_type: str) -> str:
|
||||
"""
|
||||
Build a prompt optimized for grounded content generation.
|
||||
|
||||
Args:
|
||||
prompt: Base prompt
|
||||
content_type: Type of content being generated
|
||||
|
||||
Returns:
|
||||
Enhanced prompt for grounded generation
|
||||
"""
|
||||
content_type_instructions = {
|
||||
"linkedin_post": "You are an expert LinkedIn content strategist. Generate a highly engaging, professional LinkedIn post that drives meaningful engagement, establishes thought leadership, and includes compelling hooks, actionable insights, and strategic hashtags. Every element should be optimized for maximum engagement and shareability.",
|
||||
"linkedin_article": "You are a senior content strategist and industry thought leader. Generate a comprehensive, SEO-optimized LinkedIn article with compelling headlines, structured content, data-driven insights, and practical takeaways. Include proper source citations and engagement elements throughout.",
|
||||
"linkedin_carousel": "You are a visual content strategist specializing in LinkedIn carousels. Generate compelling, story-driven carousel content with clear visual hierarchy, actionable insights per slide, and strategic engagement elements. Each slide should provide immediate value while building anticipation for the next.",
|
||||
"linkedin_video_script": "You are a video content strategist and LinkedIn engagement expert. Generate a compelling video script optimized for LinkedIn's algorithm with attention-grabbing hooks, strategic timing, and engagement-driven content. Include specific visual and audio recommendations for maximum impact.",
|
||||
"linkedin_comment_response": "You are a LinkedIn engagement specialist and industry expert. Generate thoughtful, value-adding comment responses that encourage further discussion, demonstrate expertise, and build meaningful professional relationships. Focus on genuine engagement over generic responses."
|
||||
}
|
||||
|
||||
instruction = content_type_instructions.get(content_type, "Generate professional content with factual accuracy.")
|
||||
|
||||
grounded_prompt = f"""
|
||||
{instruction}
|
||||
|
||||
CRITICAL REQUIREMENTS FOR LINKEDIN CONTENT:
|
||||
- Use ONLY current, factual information from reliable sources (2024-2025)
|
||||
- Cite specific sources for ALL claims, statistics, and recent developments
|
||||
- Ensure content is optimized for LinkedIn's algorithm and engagement patterns
|
||||
- Include strategic hashtags and engagement elements throughout
|
||||
|
||||
User Request: {prompt}
|
||||
|
||||
CONTENT QUALITY STANDARDS:
|
||||
- All factual claims must be backed by current, authoritative sources
|
||||
- Use professional yet conversational language that encourages engagement
|
||||
- Include relevant industry insights, trends, and data points
|
||||
- Make content highly shareable with clear value proposition
|
||||
- Optimize for LinkedIn's professional audience and engagement metrics
|
||||
|
||||
ENGAGEMENT OPTIMIZATION:
|
||||
- Include thought-provoking questions and calls-to-action
|
||||
- Use storytelling elements and real-world examples
|
||||
- Ensure content provides immediate, actionable value
|
||||
- Optimize for comments, shares, and professional networking
|
||||
- Include industry-specific terminology and insights
|
||||
|
||||
REMEMBER: This content will be displayed on LinkedIn with full source attribution and grounding data. Every claim must be verifiable, and the content should position the author as a thought leader in their industry.
|
||||
"""
|
||||
|
||||
return grounded_prompt.strip()
|
||||
|
||||
def _process_grounded_response(self, response, content_type: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Process the Gemini response with grounding metadata.
|
||||
|
||||
Args:
|
||||
response: Gemini API response
|
||||
content_type: Type of content generated
|
||||
|
||||
Returns:
|
||||
Processed content with sources and citations
|
||||
"""
|
||||
try:
|
||||
# Debug: Log response structure
|
||||
logger.info(f"🔍 DEBUG: Response type: {type(response)}")
|
||||
logger.info(f"🔍 DEBUG: Response has 'text': {hasattr(response, 'text')}")
|
||||
logger.info(f"🔍 DEBUG: Response has 'candidates': {hasattr(response, 'candidates')}")
|
||||
logger.info(f"🔍 DEBUG: Response has 'grounding_metadata': {hasattr(response, 'grounding_metadata')}")
|
||||
if hasattr(response, 'grounding_metadata'):
|
||||
logger.info(f"🔍 DEBUG: Grounding metadata: {response.grounding_metadata}")
|
||||
if hasattr(response, 'candidates') and response.candidates:
|
||||
logger.info(f"🔍 DEBUG: Number of candidates: {len(response.candidates)}")
|
||||
candidate = response.candidates[0]
|
||||
logger.info(f"🔍 DEBUG: Candidate type: {type(candidate)}")
|
||||
logger.info(f"🔍 DEBUG: Candidate has 'content': {hasattr(candidate, 'content')}")
|
||||
if hasattr(candidate, 'content') and candidate.content:
|
||||
logger.info(f"🔍 DEBUG: Content type: {type(candidate.content)}")
|
||||
# Check if content is a list or single object
|
||||
if hasattr(candidate.content, '__iter__') and not isinstance(candidate.content, str):
|
||||
try:
|
||||
content_length = len(candidate.content) if candidate.content else 0
|
||||
logger.info(f"🔍 DEBUG: Content is iterable, length: {content_length}")
|
||||
except TypeError:
|
||||
logger.info(f"🔍 DEBUG: Content is iterable but has no len() - treating as single object")
|
||||
for i, part in enumerate(candidate.content):
|
||||
logger.info(f"🔍 DEBUG: Part {i} type: {type(part)}")
|
||||
logger.info(f"🔍 DEBUG: Part {i} has 'text': {hasattr(part, 'text')}")
|
||||
if hasattr(part, 'text'):
|
||||
logger.info(f"🔍 DEBUG: Part {i} text length: {len(part.text) if part.text else 0}")
|
||||
else:
|
||||
logger.info(f"🔍 DEBUG: Content is single object, has 'text': {hasattr(candidate.content, 'text')}")
|
||||
if hasattr(candidate.content, 'text'):
|
||||
logger.info(f"🔍 DEBUG: Content text length: {len(candidate.content.text) if candidate.content.text else 0}")
|
||||
|
||||
# Extract the main content - prioritize response.text as it's more reliable
|
||||
content = ""
|
||||
if hasattr(response, 'text'):
|
||||
logger.info(f"🔍 DEBUG: response.text exists, value: '{response.text}', type: {type(response.text)}")
|
||||
if response.text:
|
||||
content = response.text
|
||||
logger.info(f"🔍 DEBUG: Using response.text, length: {len(content)}")
|
||||
else:
|
||||
logger.info(f"🔍 DEBUG: response.text is empty or None")
|
||||
elif hasattr(response, 'candidates') and response.candidates:
|
||||
candidate = response.candidates[0]
|
||||
if hasattr(candidate, 'content') and candidate.content:
|
||||
# Handle both single Content object and list of parts
|
||||
if hasattr(candidate.content, '__iter__') and not isinstance(candidate.content, str):
|
||||
# Content is a list of parts
|
||||
text_parts = []
|
||||
for part in candidate.content:
|
||||
if hasattr(part, 'text'):
|
||||
text_parts.append(part.text)
|
||||
content = " ".join(text_parts)
|
||||
logger.info(f"🔍 DEBUG: Using candidate.content (list), extracted {len(text_parts)} parts, total length: {len(content)}")
|
||||
else:
|
||||
# Content is a single object
|
||||
if hasattr(candidate.content, 'text'):
|
||||
content = candidate.content.text
|
||||
logger.info(f"🔍 DEBUG: Using candidate.content (single), text length: {len(content)}")
|
||||
else:
|
||||
logger.warning("🔍 DEBUG: candidate.content has no 'text' attribute")
|
||||
|
||||
logger.info(f"Extracted content length: {len(content) if content else 0}")
|
||||
if not content:
|
||||
logger.warning("⚠️ No content extracted from Gemini response - using fallback content")
|
||||
logger.warning("⚠️ This indicates Google Search grounding is not working properly")
|
||||
content = "Generated content about the requested topic."
|
||||
|
||||
# Initialize result structure
|
||||
result = {
|
||||
'content': content,
|
||||
'sources': [],
|
||||
'citations': [],
|
||||
'search_queries': [],
|
||||
'grounding_metadata': {},
|
||||
'content_type': content_type,
|
||||
'generation_timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Debug: Log response structure
|
||||
logger.info(f"Response type: {type(response)}")
|
||||
logger.info(f"Response attributes: {dir(response)}")
|
||||
|
||||
# Extract grounding metadata if available
|
||||
if hasattr(response, 'candidates') and response.candidates:
|
||||
candidate = response.candidates[0]
|
||||
logger.info(f"Candidate attributes: {dir(candidate)}")
|
||||
|
||||
if hasattr(candidate, 'grounding_metadata') and candidate.grounding_metadata:
|
||||
grounding_metadata = candidate.grounding_metadata
|
||||
result['grounding_metadata'] = grounding_metadata
|
||||
logger.info(f"Grounding metadata attributes: {dir(grounding_metadata)}")
|
||||
logger.info(f"Grounding metadata type: {type(grounding_metadata)}")
|
||||
logger.info(f"Grounding metadata value: {grounding_metadata}")
|
||||
|
||||
# Log all available attributes and their values
|
||||
for attr in dir(grounding_metadata):
|
||||
if not attr.startswith('_'):
|
||||
try:
|
||||
value = getattr(grounding_metadata, attr)
|
||||
logger.info(f" {attr}: {type(value)} = {value}")
|
||||
except Exception as e:
|
||||
logger.warning(f" {attr}: Error accessing - {e}")
|
||||
|
||||
# Extract search queries
|
||||
if hasattr(grounding_metadata, 'web_search_queries'):
|
||||
result['search_queries'] = grounding_metadata.web_search_queries
|
||||
logger.info(f"Search queries: {grounding_metadata.web_search_queries}")
|
||||
|
||||
# Extract sources from grounding chunks
|
||||
sources = [] # Initialize sources list
|
||||
if hasattr(grounding_metadata, 'grounding_chunks') and grounding_metadata.grounding_chunks:
|
||||
for i, chunk in enumerate(grounding_metadata.grounding_chunks):
|
||||
logger.info(f"Chunk {i} attributes: {dir(chunk)}")
|
||||
if hasattr(chunk, 'web'):
|
||||
source = {
|
||||
'index': i,
|
||||
'title': getattr(chunk.web, 'title', f'Source {i+1}'),
|
||||
'url': getattr(chunk.web, 'uri', ''),
|
||||
'type': 'web'
|
||||
}
|
||||
sources.append(source)
|
||||
logger.info(f"Extracted {len(sources)} sources from grounding chunks")
|
||||
else:
|
||||
logger.warning("⚠️ No grounding chunks found - this is normal for some queries")
|
||||
logger.info(f"Grounding metadata available fields: {[attr for attr in dir(grounding_metadata) if not attr.startswith('_')]}")
|
||||
|
||||
# Check if we have search queries - this means Google Search was triggered
|
||||
if hasattr(grounding_metadata, 'web_search_queries') and grounding_metadata.web_search_queries:
|
||||
logger.info(f"✅ Google Search was triggered with {len(grounding_metadata.web_search_queries)} queries")
|
||||
# Create sources based on search queries
|
||||
for i, query in enumerate(grounding_metadata.web_search_queries[:5]): # Limit to 5 sources
|
||||
source = {
|
||||
'index': i,
|
||||
'title': f"Search: {query}",
|
||||
'url': f"https://www.google.com/search?q={query.replace(' ', '+')}",
|
||||
'type': 'search_query',
|
||||
'query': query
|
||||
}
|
||||
sources.append(source)
|
||||
logger.info(f"Created {len(sources)} sources from search queries")
|
||||
else:
|
||||
logger.warning("⚠️ No search queries found either - grounding may not have been triggered")
|
||||
|
||||
result['sources'] = sources
|
||||
|
||||
# Extract citations from grounding supports
|
||||
if hasattr(grounding_metadata, 'grounding_supports') and grounding_metadata.grounding_supports:
|
||||
citations = []
|
||||
for support in grounding_metadata.grounding_supports:
|
||||
if hasattr(support, 'segment') and hasattr(support, 'grounding_chunk_indices'):
|
||||
citation = {
|
||||
'type': 'inline',
|
||||
'start_index': getattr(support.segment, 'start_index', 0),
|
||||
'end_index': getattr(support.segment, 'end_index', 0),
|
||||
'text': getattr(support.segment, 'text', ''),
|
||||
'source_indices': support.grounding_chunk_indices,
|
||||
'reference': f"Source {support.grounding_chunk_indices[0] + 1}" if support.grounding_chunk_indices else "Unknown"
|
||||
}
|
||||
citations.append(citation)
|
||||
result['citations'] = citations
|
||||
logger.info(f"Extracted {len(citations)} citations")
|
||||
else:
|
||||
logger.warning("⚠️ No grounding supports found - this is normal when no web sources are retrieved")
|
||||
# Create basic citations from the content if we have sources
|
||||
if sources:
|
||||
citations = []
|
||||
for i, source in enumerate(sources[:3]): # Limit to 3 citations
|
||||
citation = {
|
||||
'type': 'reference',
|
||||
'start_index': 0,
|
||||
'end_index': 0,
|
||||
'text': f"Source {i+1}",
|
||||
'source_indices': [i],
|
||||
'reference': f"Source {i+1}",
|
||||
'source': source
|
||||
}
|
||||
citations.append(citation)
|
||||
result['citations'] = citations
|
||||
logger.info(f"Created {len(citations)} basic citations from sources")
|
||||
else:
|
||||
result['citations'] = []
|
||||
logger.info("No citations created - no sources available")
|
||||
|
||||
# Extract search entry point for UI display
|
||||
if hasattr(grounding_metadata, 'search_entry_point') and grounding_metadata.search_entry_point:
|
||||
if hasattr(grounding_metadata.search_entry_point, 'rendered_content'):
|
||||
result['search_widget'] = grounding_metadata.search_entry_point.rendered_content
|
||||
logger.info("✅ Extracted search widget HTML for UI display")
|
||||
|
||||
# Extract search queries for reference
|
||||
if hasattr(grounding_metadata, 'web_search_queries') and grounding_metadata.web_search_queries:
|
||||
result['search_queries'] = grounding_metadata.web_search_queries
|
||||
logger.info(f"✅ Extracted {len(grounding_metadata.web_search_queries)} search queries")
|
||||
|
||||
logger.info(f"✅ Successfully extracted {len(result['sources'])} sources and {len(result['citations'])} citations from grounding metadata")
|
||||
logger.info(f"Sources: {result['sources']}")
|
||||
logger.info(f"Citations: {result['citations']}")
|
||||
else:
|
||||
logger.error("❌ CRITICAL: No grounding metadata found in response")
|
||||
logger.error(f"Response structure: {dir(response)}")
|
||||
logger.error(f"First candidate structure: {dir(candidates[0]) if candidates else 'No candidates'}")
|
||||
raise ValueError("No grounding metadata found - grounding is not working properly")
|
||||
else:
|
||||
logger.warning("⚠️ No candidates found in response. Returning content without sources.")
|
||||
|
||||
# Add content-specific processing
|
||||
if content_type == "linkedin_post":
|
||||
result.update(self._process_post_content(content))
|
||||
elif content_type == "linkedin_article":
|
||||
result.update(self._process_article_content(content))
|
||||
elif content_type == "linkedin_carousel":
|
||||
result.update(self._process_carousel_content(content))
|
||||
elif content_type == "linkedin_video_script":
|
||||
result.update(self._process_video_script_content(content))
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ CRITICAL: Error processing grounded response: {str(e)}")
|
||||
logger.error(f"Exception type: {type(e)}")
|
||||
logger.error(f"Exception details: {e}")
|
||||
raise ValueError(f"Failed to process grounded response: {str(e)}")
|
||||
|
||||
def _process_post_content(self, content: str) -> Dict[str, Any]:
|
||||
"""Process LinkedIn post content for hashtags and engagement elements."""
|
||||
try:
|
||||
# Handle None content
|
||||
if content is None:
|
||||
content = ""
|
||||
logger.warning("Content is None, using empty string")
|
||||
|
||||
# Extract hashtags
|
||||
hashtags = re.findall(r'#\w+', content)
|
||||
|
||||
# Generate call-to-action if not present
|
||||
cta_patterns = [
|
||||
r'What do you think\?',
|
||||
r'Share your thoughts',
|
||||
r'Comment below',
|
||||
r'What\'s your experience\?',
|
||||
r'Let me know in the comments'
|
||||
]
|
||||
|
||||
has_cta = any(re.search(pattern, content, re.IGNORECASE) for pattern in cta_patterns)
|
||||
call_to_action = None
|
||||
if not has_cta:
|
||||
call_to_action = "What are your thoughts on this? Share in the comments!"
|
||||
|
||||
return {
|
||||
'hashtags': [{'hashtag': tag, 'category': 'general', 'popularity_score': 0.8} for tag in hashtags],
|
||||
'call_to_action': call_to_action,
|
||||
'engagement_prediction': {
|
||||
'estimated_likes': max(50, len(content) // 10),
|
||||
'estimated_comments': max(5, len(content) // 100)
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing post content: {str(e)}")
|
||||
return {}
|
||||
|
||||
def _process_article_content(self, content: str) -> Dict[str, Any]:
|
||||
"""Process LinkedIn article content for structure and SEO."""
|
||||
try:
|
||||
# Extract title (first line or first sentence)
|
||||
lines = content.split('\n')
|
||||
title = lines[0].strip() if lines else "Article Title"
|
||||
|
||||
# Estimate word count
|
||||
word_count = len(content.split())
|
||||
|
||||
# Generate sections based on content structure
|
||||
sections = []
|
||||
current_section = ""
|
||||
|
||||
for line in lines:
|
||||
if line.strip().startswith('#') or line.strip().startswith('##'):
|
||||
if current_section:
|
||||
sections.append({'title': 'Section', 'content': current_section.strip()})
|
||||
current_section = ""
|
||||
else:
|
||||
current_section += line + "\n"
|
||||
|
||||
if current_section:
|
||||
sections.append({'title': 'Content', 'content': current_section.strip()})
|
||||
|
||||
return {
|
||||
'title': title,
|
||||
'word_count': word_count,
|
||||
'sections': sections,
|
||||
'reading_time': max(1, word_count // 200), # 200 words per minute
|
||||
'seo_metadata': {
|
||||
'meta_description': content[:160] + "..." if len(content) > 160 else content,
|
||||
'keywords': self._extract_keywords(content)
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing article content: {str(e)}")
|
||||
return {}
|
||||
|
||||
def _process_carousel_content(self, content: str) -> Dict[str, Any]:
|
||||
"""Process LinkedIn carousel content for slide structure."""
|
||||
try:
|
||||
# Split content into slides (basic implementation)
|
||||
slides = []
|
||||
content_parts = content.split('\n\n')
|
||||
|
||||
for i, part in enumerate(content_parts[:10]): # Max 10 slides
|
||||
if part.strip():
|
||||
slides.append({
|
||||
'slide_number': i + 1,
|
||||
'title': f"Slide {i + 1}",
|
||||
'content': part.strip(),
|
||||
'visual_elements': [],
|
||||
'design_notes': None
|
||||
})
|
||||
|
||||
return {
|
||||
'title': f"Carousel on {content[:50]}...",
|
||||
'slides': slides,
|
||||
'design_guidelines': {
|
||||
'color_scheme': 'professional',
|
||||
'typography': 'clean',
|
||||
'layout': 'minimal'
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing carousel content: {str(e)}")
|
||||
return {}
|
||||
|
||||
def _process_video_script_content(self, content: str) -> Dict[str, Any]:
|
||||
"""Process LinkedIn video script content for structure."""
|
||||
try:
|
||||
# Basic video script processing
|
||||
lines = content.split('\n')
|
||||
hook = ""
|
||||
main_content = []
|
||||
conclusion = ""
|
||||
|
||||
# Extract hook (first few lines)
|
||||
hook_lines = []
|
||||
for line in lines[:3]:
|
||||
if line.strip() and not line.strip().startswith('#'):
|
||||
hook_lines.append(line.strip())
|
||||
if len(' '.join(hook_lines)) > 100:
|
||||
break
|
||||
hook = ' '.join(hook_lines)
|
||||
|
||||
# Extract conclusion (last few lines)
|
||||
conclusion_lines = []
|
||||
for line in lines[-3:]:
|
||||
if line.strip() and not line.strip().startswith('#'):
|
||||
conclusion_lines.insert(0, line.strip())
|
||||
if len(' '.join(conclusion_lines)) > 100:
|
||||
break
|
||||
conclusion = ' '.join(conclusion_lines)
|
||||
|
||||
# Main content (everything in between)
|
||||
main_content_text = content[len(hook):len(content)-len(conclusion)].strip()
|
||||
|
||||
return {
|
||||
'hook': hook,
|
||||
'main_content': [{
|
||||
'scene_number': 1,
|
||||
'content': main_content_text,
|
||||
'duration': 60,
|
||||
'visual_notes': 'Professional presentation style'
|
||||
}],
|
||||
'conclusion': conclusion,
|
||||
'thumbnail_suggestions': ['Professional thumbnail', 'Industry-focused image'],
|
||||
'video_description': f"Professional insights on {content[:100]}..."
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing video script content: {str(e)}")
|
||||
return {}
|
||||
|
||||
def _extract_keywords(self, content: str) -> List[str]:
|
||||
"""Extract relevant keywords from content."""
|
||||
try:
|
||||
# Simple keyword extraction (can be enhanced with NLP)
|
||||
words = re.findall(r'\b\w+\b', content.lower())
|
||||
word_freq = {}
|
||||
|
||||
# Filter out common words
|
||||
stop_words = {'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by', 'is', 'are', 'was', 'were', 'be', 'been', 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could', 'should', 'may', 'might', 'can', 'this', 'that', 'these', 'those', 'a', 'an'}
|
||||
|
||||
for word in words:
|
||||
if word not in stop_words and len(word) > 3:
|
||||
word_freq[word] = word_freq.get(word, 0) + 1
|
||||
|
||||
# Return top keywords
|
||||
sorted_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)
|
||||
return [word for word, freq in sorted_words[:10]]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting keywords: {str(e)}")
|
||||
return []
|
||||
|
||||
def add_citations(self, content: str, sources: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Add inline citations to content based on grounding metadata.
|
||||
|
||||
Args:
|
||||
content: The content to add citations to
|
||||
sources: List of sources from grounding metadata
|
||||
|
||||
Returns:
|
||||
Content with inline citations
|
||||
"""
|
||||
try:
|
||||
if not sources:
|
||||
return content
|
||||
|
||||
# Create citation mapping
|
||||
citation_map = {}
|
||||
for source in sources:
|
||||
index = source.get('index', 0)
|
||||
citation_map[index] = f"[Source {index + 1}]({source.get('url', '')})"
|
||||
|
||||
# Add citations at the end of sentences or paragraphs
|
||||
# This is a simplified approach - in practice, you'd use the groundingSupports data
|
||||
citation_text = "\n\n**Sources:**\n"
|
||||
for i, source in enumerate(sources):
|
||||
citation_text += f"{i+1}. **{source.get('title', f'Source {i+1}')}**\n - URL: [{source.get('url', '')}]({source.get('url', '')})\n\n"
|
||||
|
||||
return content + citation_text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding citations: {str(e)}")
|
||||
return content
|
||||
|
||||
def extract_citations(self, content: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Extract citations from content.
|
||||
|
||||
Args:
|
||||
content: Content to extract citations from
|
||||
|
||||
Returns:
|
||||
List of citation objects
|
||||
"""
|
||||
try:
|
||||
citations = []
|
||||
# Look for citation patterns
|
||||
citation_patterns = [
|
||||
r'\[Source (\d+)\]',
|
||||
r'\[(\d+)\]',
|
||||
r'\(Source (\d+)\)'
|
||||
]
|
||||
|
||||
for pattern in citation_patterns:
|
||||
matches = re.finditer(pattern, content)
|
||||
for match in matches:
|
||||
citations.append({
|
||||
'type': 'inline',
|
||||
'reference': match.group(0),
|
||||
'position': match.start(),
|
||||
'source_index': int(match.group(1)) - 1
|
||||
})
|
||||
|
||||
return citations
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting citations: {str(e)}")
|
||||
return []
|
||||
|
||||
def assess_content_quality(self, content: str, sources: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Assess the quality of generated content.
|
||||
|
||||
Args:
|
||||
content: The generated content
|
||||
sources: List of sources used
|
||||
|
||||
Returns:
|
||||
Quality metrics dictionary
|
||||
"""
|
||||
try:
|
||||
# Basic quality metrics
|
||||
word_count = len(content.split())
|
||||
char_count = len(content)
|
||||
|
||||
# Source coverage
|
||||
source_coverage = min(1.0, len(sources) / max(1, word_count / 100))
|
||||
|
||||
# Professional tone indicators
|
||||
professional_indicators = ['research', 'analysis', 'insights', 'trends', 'industry', 'professional']
|
||||
unprofessional_indicators = ['awesome', 'amazing', 'incredible', 'mind-blowing']
|
||||
|
||||
professional_score = sum(1 for indicator in professional_indicators if indicator.lower() in content.lower()) / len(professional_indicators)
|
||||
unprofessional_score = sum(1 for indicator in unprofessional_indicators if indicator.lower() in content.lower()) / len(unprofessional_indicators)
|
||||
|
||||
tone_score = max(0, professional_score - unprofessional_score)
|
||||
|
||||
# Overall quality score
|
||||
overall_score = (source_coverage * 0.4 + tone_score * 0.3 + min(1.0, word_count / 500) * 0.3)
|
||||
|
||||
return {
|
||||
'overall_score': round(overall_score, 2),
|
||||
'source_coverage': round(source_coverage, 2),
|
||||
'tone_score': round(tone_score, 2),
|
||||
'word_count': word_count,
|
||||
'char_count': char_count,
|
||||
'sources_count': len(sources),
|
||||
'quality_level': 'high' if overall_score > 0.8 else 'medium' if overall_score > 0.6 else 'low'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error assessing content quality: {str(e)}")
|
||||
return {
|
||||
'overall_score': 0.0,
|
||||
'error': str(e)
|
||||
}
|
||||
842
backend/services/llm_providers/gemini_provider.py
Normal file
842
backend/services/llm_providers/gemini_provider.py
Normal file
@@ -0,0 +1,842 @@
|
||||
"""
|
||||
Gemini Provider Module for ALwrity
|
||||
|
||||
This module provides functions for interacting with Google's Gemini API, specifically designed
|
||||
for structured JSON output and text generation. It follows the official Gemini API documentation
|
||||
and implements best practices for reliable AI interactions.
|
||||
|
||||
Key Features:
|
||||
- Structured JSON response generation with schema validation
|
||||
- Text response generation with retry logic
|
||||
- Comprehensive error handling and logging
|
||||
- Automatic API key management
|
||||
- Support for both gemini-2.5-flash and gemini-2.5-pro models
|
||||
|
||||
Best Practices:
|
||||
1. Use structured output for complex, multi-field responses
|
||||
2. Keep schemas simple and flat to avoid truncation
|
||||
3. Set appropriate token limits (8192 for complex outputs)
|
||||
4. Use low temperature (0.1-0.3) for consistent structured output
|
||||
5. Implement proper error handling in calling functions
|
||||
6. Avoid fallback to text parsing for structured responses
|
||||
|
||||
Usage Examples:
|
||||
# Structured JSON response
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tasks": {
|
||||
"type": "array",
|
||||
"items": {"type": "object", "properties": {...}}
|
||||
}
|
||||
}
|
||||
}
|
||||
result = gemini_structured_json_response(prompt, schema, temperature=0.2, max_tokens=8192)
|
||||
|
||||
# Text response
|
||||
result = gemini_text_response(prompt, temperature=0.7, max_tokens=2048)
|
||||
|
||||
Troubleshooting:
|
||||
- If response.parsed is None: Check schema complexity and token limits
|
||||
- If JSON parsing fails: Verify schema matches expected output structure
|
||||
- If truncation occurs: Reduce output size or increase max_tokens
|
||||
- If rate limiting: Implement exponential backoff (already included)
|
||||
|
||||
Dependencies:
|
||||
- google.generativeai (genai)
|
||||
- tenacity (for retry logic)
|
||||
- logging (for debugging)
|
||||
- json (for fallback parsing)
|
||||
- re (for text extraction)
|
||||
|
||||
Author: ALwrity Team
|
||||
Version: 2.0
|
||||
Last Updated: January 2025
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import google.genai as genai
|
||||
from google.genai import types
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Fix the environment loading path - load from backend directory
|
||||
current_dir = Path(__file__).parent.parent # services directory
|
||||
backend_dir = current_dir.parent # backend directory
|
||||
env_path = backend_dir / '.env'
|
||||
|
||||
if env_path.exists():
|
||||
load_dotenv(env_path)
|
||||
print(f"Loaded .env from: {env_path}")
|
||||
else:
|
||||
# Fallback to current directory
|
||||
load_dotenv()
|
||||
print(f"No .env found at {env_path}, using current directory")
|
||||
|
||||
from loguru import logger
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
# Use service-specific logger to avoid conflicts
|
||||
logger = get_service_logger("gemini_provider")
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
# Configure standard logging
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO, format='[%(asctime)s-%(levelname)s-%(module)s-%(lineno)d]- %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_gemini_api_key() -> str:
|
||||
"""Get Gemini API key with proper error handling."""
|
||||
api_key = os.getenv('GEMINI_API_KEY')
|
||||
if not api_key:
|
||||
error_msg = "GEMINI_API_KEY environment variable is not set. Please set it in your .env file."
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Validate API key format (basic check)
|
||||
if not api_key.startswith('AIza'):
|
||||
error_msg = "GEMINI_API_KEY appears to be invalid. It should start with 'AIza'."
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
return api_key
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
def gemini_text_response(prompt, temperature, top_p, n, max_tokens, system_prompt):
|
||||
"""
|
||||
Generate text response using Google's Gemini Pro model.
|
||||
|
||||
This function provides simple text generation with retry logic and error handling.
|
||||
For structured output, use gemini_structured_json_response instead.
|
||||
|
||||
Args:
|
||||
prompt (str): The input prompt for the AI model
|
||||
temperature (float): Controls randomness (0.0-1.0). Higher = more creative
|
||||
top_p (float): Nucleus sampling parameter (0.0-1.0)
|
||||
n (int): Number of responses to generate
|
||||
max_tokens (int): Maximum tokens in response
|
||||
system_prompt (str, optional): System instruction for the model
|
||||
|
||||
Returns:
|
||||
str: Generated text response
|
||||
|
||||
Raises:
|
||||
Exception: If API key is missing or API call fails
|
||||
|
||||
Best Practices:
|
||||
- Use temperature 0.7-0.9 for creative content
|
||||
- Use temperature 0.1-0.3 for factual/consistent content
|
||||
- Set appropriate max_tokens based on expected response length
|
||||
- Implement proper error handling in calling functions
|
||||
|
||||
Example:
|
||||
result = gemini_text_response(
|
||||
"Write a blog post about AI",
|
||||
temperature=0.8,
|
||||
max_tokens=1024
|
||||
)
|
||||
"""
|
||||
#FIXME: Include : https://github.com/google-gemini/cookbook/blob/main/quickstarts/rest/System_instructions_REST.ipynb
|
||||
try:
|
||||
api_key = get_gemini_api_key()
|
||||
client = genai.Client(api_key=api_key)
|
||||
logger.info("✅ Gemini client initialized successfully")
|
||||
except Exception as err:
|
||||
logger.error(f"Failed to configure Gemini: {err}")
|
||||
raise
|
||||
logger.info(f"Temp: {temperature}, MaxTokens: {max_tokens}, TopP: {top_p}, N: {n}")
|
||||
# Set up AI model config
|
||||
generation_config = {
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"top_k": n,
|
||||
"max_output_tokens": max_tokens,
|
||||
}
|
||||
# FIXME: Expose model_name in main_config
|
||||
try:
|
||||
response = client.models.generate_content(
|
||||
model='gemini-2.0-flash-lite',
|
||||
contents=prompt,
|
||||
config=types.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
max_output_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=n,
|
||||
),
|
||||
)
|
||||
|
||||
#logger.info(f"Number of Token in Prompt Sent: {model.count_tokens(prompt)}")
|
||||
return response.text
|
||||
except Exception as err:
|
||||
logger.error(f"Failed to get response from Gemini: {err}. Retrying.")
|
||||
raise
|
||||
|
||||
|
||||
async def test_gemini_api_key(api_key: str) -> tuple[bool, str]:
|
||||
"""
|
||||
Test if the provided Gemini API key is valid.
|
||||
|
||||
Args:
|
||||
api_key (str): The Gemini API key to test
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: A tuple containing (is_valid, message)
|
||||
"""
|
||||
try:
|
||||
# Validate API key format first
|
||||
if not api_key:
|
||||
return False, "API key is empty"
|
||||
|
||||
if not api_key.startswith('AIza'):
|
||||
return False, "API key format appears invalid (should start with 'AIza')"
|
||||
|
||||
# Configure Gemini with the provided key
|
||||
client = genai.Client(api_key=api_key)
|
||||
|
||||
# Try to list models as a simple API test
|
||||
models = client.models.list()
|
||||
|
||||
# Check if Gemini Pro is available
|
||||
model_names = [model.name for model in models]
|
||||
logger.info(f"Available models: {model_names}")
|
||||
|
||||
if any("gemini" in model_name.lower() for model_name in model_names):
|
||||
return True, "Gemini API key is valid"
|
||||
else:
|
||||
return False, "No Gemini models available with this API key"
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error testing Gemini API key: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
def gemini_pro_text_gen(prompt, temperature=0.7, top_p=0.9, top_k=40, max_tokens=2048):
|
||||
"""
|
||||
Generate text using Google's Gemini Pro model.
|
||||
|
||||
Args:
|
||||
prompt (str): The input text to generate completion for
|
||||
temperature (float, optional): Controls randomness. Defaults to 0.7
|
||||
top_p (float, optional): Controls diversity. Defaults to 0.9
|
||||
top_k (int, optional): Controls vocabulary size. Defaults to 40
|
||||
max_tokens (int, optional): Maximum number of tokens to generate. Defaults to 2048
|
||||
|
||||
Returns:
|
||||
str: The generated text completion
|
||||
"""
|
||||
try:
|
||||
# Get API key with proper error handling
|
||||
api_key = get_gemini_api_key()
|
||||
client = genai.Client(api_key=api_key)
|
||||
|
||||
# Generate content using the new client
|
||||
response = client.models.generate_content(
|
||||
model='gemini-2.5-flash',
|
||||
contents=prompt,
|
||||
config=types.GenerateContentConfig(
|
||||
max_output_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
),
|
||||
)
|
||||
|
||||
# Return the generated text
|
||||
return response.text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Gemini Pro text generation: {e}")
|
||||
return str(e)
|
||||
|
||||
def _dict_to_types_schema(schema: Dict[str, Any]) -> types.Schema:
|
||||
"""Convert a lightweight dict schema to google.genai.types.Schema."""
|
||||
if not isinstance(schema, dict):
|
||||
raise ValueError("response_schema must be a dict compatible with types.Schema")
|
||||
|
||||
def _convert(node: Dict[str, Any]) -> types.Schema:
|
||||
node_type = (node.get("type") or "OBJECT").upper()
|
||||
if node_type == "OBJECT":
|
||||
props = node.get("properties") or {}
|
||||
props_types: Dict[str, types.Schema] = {}
|
||||
for key, prop in props.items():
|
||||
if isinstance(prop, dict):
|
||||
props_types[key] = _convert(prop)
|
||||
else:
|
||||
props_types[key] = types.Schema(type=types.Type.STRING)
|
||||
return types.Schema(type=types.Type.OBJECT, properties=props_types if props_types else None)
|
||||
elif node_type == "ARRAY":
|
||||
items_node = node.get("items")
|
||||
if isinstance(items_node, dict):
|
||||
item_schema = _convert(items_node)
|
||||
else:
|
||||
item_schema = types.Schema(type=types.Type.STRING)
|
||||
return types.Schema(type=types.Type.ARRAY, items=item_schema)
|
||||
elif node_type == "NUMBER":
|
||||
return types.Schema(type=types.Type.NUMBER)
|
||||
elif node_type == "INTEGER":
|
||||
return types.Schema(type=types.Type.NUMBER)
|
||||
elif node_type == "BOOLEAN":
|
||||
return types.Schema(type=types.Type.BOOLEAN)
|
||||
else:
|
||||
return types.Schema(type=types.Type.STRING)
|
||||
|
||||
return _convert(schema)
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9, top_k=40, max_tokens=8192, system_prompt=None):
|
||||
"""
|
||||
Generate structured JSON response using Google's Gemini Pro model.
|
||||
|
||||
This function follows the official Gemini API documentation for structured output:
|
||||
https://ai.google.dev/gemini-api/docs/structured-output#python
|
||||
|
||||
Args:
|
||||
prompt (str): The input prompt for the AI model
|
||||
schema (dict): JSON schema defining the expected output structure
|
||||
temperature (float): Controls randomness (0.0-1.0). Use 0.1-0.3 for structured output
|
||||
top_p (float): Nucleus sampling parameter (0.0-1.0)
|
||||
top_k (int): Top-k sampling parameter
|
||||
max_tokens (int): Maximum tokens in response. Use 8192 for complex outputs
|
||||
system_prompt (str, optional): System instruction for the model
|
||||
|
||||
Returns:
|
||||
dict: Parsed JSON response matching the provided schema
|
||||
|
||||
Raises:
|
||||
Exception: If API key is missing or API call fails
|
||||
|
||||
Best Practices:
|
||||
- Keep schemas simple and flat to avoid truncation
|
||||
- Use low temperature (0.1-0.3) for consistent structured output
|
||||
- Set max_tokens to 8192 for complex multi-field responses
|
||||
- Avoid deeply nested schemas with many required fields
|
||||
- Test with smaller outputs first, then scale up
|
||||
|
||||
Example:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tasks": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"description": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
result = gemini_structured_json_response(prompt, schema, temperature=0.2, max_tokens=8192)
|
||||
"""
|
||||
try:
|
||||
# Get API key with proper error handling
|
||||
api_key = get_gemini_api_key()
|
||||
logger.info(f"🔑 Gemini API key loaded: {bool(api_key)} (length: {len(api_key) if api_key else 0})")
|
||||
|
||||
if not api_key:
|
||||
raise Exception("GEMINI_API_KEY not found in environment variables")
|
||||
|
||||
client = genai.Client(api_key=api_key)
|
||||
logger.info("✅ Gemini client initialized for structured JSON response")
|
||||
|
||||
# Prepare schema for SDK (dict -> types.Schema). If schema is already a types.Schema or Pydantic type, use as-is
|
||||
try:
|
||||
if isinstance(schema, dict):
|
||||
types_schema = _dict_to_types_schema(schema)
|
||||
else:
|
||||
types_schema = schema
|
||||
except Exception as conv_err:
|
||||
logger.info(f"Schema conversion warning, defaulting to OBJECT: {conv_err}")
|
||||
types_schema = types.Schema(type=types.Type.OBJECT)
|
||||
|
||||
# Add debugging for API call
|
||||
logger.info(
|
||||
"Gemini structured call | prompt_len=%s | schema_kind=%s | temp=%s | top_p=%s | top_k=%s | max_tokens=%s",
|
||||
len(prompt) if isinstance(prompt, str) else '<non-str>',
|
||||
type(types_schema).__name__,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
max_tokens,
|
||||
)
|
||||
|
||||
# Use the official SDK GenerateContentConfig with response_schema
|
||||
generation_config = types.GenerateContentConfig(
|
||||
response_mime_type='application/json',
|
||||
response_schema=types_schema,
|
||||
max_output_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
system_instruction=system_prompt,
|
||||
)
|
||||
|
||||
logger.info("🚀 Making Gemini API call...")
|
||||
|
||||
# Use enhanced retry logic for structured JSON calls
|
||||
from services.blog_writer.retry_utils import retry_with_backoff, CONTENT_RETRY_CONFIG
|
||||
|
||||
async def make_api_call():
|
||||
return client.models.generate_content(
|
||||
model="gemini-2.5-flash",
|
||||
contents=prompt,
|
||||
config=generation_config,
|
||||
)
|
||||
|
||||
try:
|
||||
# Convert sync call to async for retry logic
|
||||
import asyncio
|
||||
|
||||
# Check if there's already an event loop running
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
# If we're already in an async context, we need to run this differently
|
||||
logger.warning("⚠️ Already in async context, using direct sync call")
|
||||
# For now, let's use a simpler approach without retry logic
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.5-flash",
|
||||
contents=prompt,
|
||||
config=generation_config,
|
||||
)
|
||||
logger.info("✅ Gemini API call completed successfully (sync mode)")
|
||||
except RuntimeError:
|
||||
# No event loop running, we can create one
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
response = loop.run_until_complete(
|
||||
retry_with_backoff(
|
||||
make_api_call,
|
||||
config=CONTENT_RETRY_CONFIG,
|
||||
operation_name="gemini_structured_json",
|
||||
context={"schema_type": type(types_schema).__name__, "max_tokens": max_tokens}
|
||||
)
|
||||
)
|
||||
logger.info("✅ Gemini API call completed successfully")
|
||||
except Exception as api_error:
|
||||
logger.error(f"❌ Gemini API call failed: {api_error}")
|
||||
logger.error(f"❌ API Error type: {type(api_error).__name__}")
|
||||
|
||||
# Enhance error with specific exception types
|
||||
error_str = str(api_error)
|
||||
if "429" in error_str or "rate limit" in error_str.lower():
|
||||
from services.blog_writer.exceptions import APIRateLimitException
|
||||
raise APIRateLimitException(
|
||||
f"Rate limit exceeded for structured JSON generation: {error_str}",
|
||||
retry_after=60,
|
||||
context={"operation": "structured_json", "max_tokens": max_tokens}
|
||||
)
|
||||
elif "timeout" in error_str.lower():
|
||||
from services.blog_writer.exceptions import APITimeoutException
|
||||
raise APITimeoutException(
|
||||
f"Structured JSON generation timed out: {error_str}",
|
||||
timeout_seconds=60,
|
||||
context={"operation": "structured_json", "max_tokens": max_tokens}
|
||||
)
|
||||
elif "401" in error_str or "403" in error_str:
|
||||
from services.blog_writer.exceptions import ValidationException
|
||||
raise ValidationException(
|
||||
"Authentication failed for structured JSON generation. Please check your API credentials.",
|
||||
field="api_key",
|
||||
context={"error": error_str, "operation": "structured_json"}
|
||||
)
|
||||
else:
|
||||
from services.blog_writer.exceptions import ContentGenerationException
|
||||
raise ContentGenerationException(
|
||||
f"Structured JSON generation failed: {error_str}",
|
||||
context={"error": error_str, "operation": "structured_json", "max_tokens": max_tokens}
|
||||
)
|
||||
|
||||
# Check for parsed content first (primary method for structured output)
|
||||
if hasattr(response, 'parsed'):
|
||||
logger.info(f"Response has parsed attribute: {response.parsed is not None}")
|
||||
if response.parsed is not None:
|
||||
logger.info("Using response.parsed for structured output")
|
||||
return response.parsed
|
||||
else:
|
||||
logger.warning("Response.parsed is None, falling back to text parsing")
|
||||
# Debug: Check if there's any text content
|
||||
if hasattr(response, 'text') and response.text:
|
||||
logger.info(f"Text response length: {len(response.text)}")
|
||||
logger.debug(f"Text response preview: {response.text[:200]}...")
|
||||
|
||||
# Check for text content as fallback (only if no parsed content)
|
||||
if hasattr(response, 'text') and response.text:
|
||||
logger.info("No parsed content, trying to parse text response")
|
||||
try:
|
||||
import json
|
||||
import re
|
||||
|
||||
# Clean the text response to fix common JSON issues
|
||||
cleaned_text = response.text.strip()
|
||||
|
||||
# Remove any markdown code blocks if present
|
||||
if cleaned_text.startswith('```json'):
|
||||
cleaned_text = cleaned_text[7:]
|
||||
if cleaned_text.endswith('```'):
|
||||
cleaned_text = cleaned_text[:-3]
|
||||
cleaned_text = cleaned_text.strip()
|
||||
|
||||
# Try to find JSON content between curly braces
|
||||
json_match = re.search(r'\{.*\}', cleaned_text, re.DOTALL)
|
||||
if json_match:
|
||||
cleaned_text = json_match.group(0)
|
||||
|
||||
parsed_text = json.loads(cleaned_text)
|
||||
logger.info("Successfully parsed text as JSON")
|
||||
return parsed_text
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse text as JSON: {e}")
|
||||
logger.debug(f"Problematic text (first 500 chars): {response.text[:500]}")
|
||||
|
||||
# Try to extract and fix JSON manually
|
||||
try:
|
||||
import re
|
||||
# Look for the main JSON object
|
||||
json_pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
|
||||
matches = re.findall(json_pattern, response.text, re.DOTALL)
|
||||
if matches:
|
||||
# Try the largest match (likely the main JSON)
|
||||
largest_match = max(matches, key=len)
|
||||
# Basic cleanup of common issues
|
||||
fixed_json = largest_match.replace('\n', ' ').replace('\r', ' ')
|
||||
# Remove any trailing commas before closing braces
|
||||
fixed_json = re.sub(r',\s*}', '}', fixed_json)
|
||||
fixed_json = re.sub(r',\s*]', ']', fixed_json)
|
||||
|
||||
parsed_text = json.loads(fixed_json)
|
||||
logger.info("Successfully parsed cleaned JSON")
|
||||
return parsed_text
|
||||
except Exception as fix_error:
|
||||
logger.error(f"Failed to fix JSON manually: {fix_error}")
|
||||
|
||||
# Check candidates for content (fallback for edge cases)
|
||||
if hasattr(response, 'candidates') and response.candidates:
|
||||
candidate = response.candidates[0]
|
||||
if hasattr(candidate, 'content') and candidate.content:
|
||||
if hasattr(candidate.content, 'parts') and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, 'text') and part.text:
|
||||
try:
|
||||
import json
|
||||
parsed_text = json.loads(part.text)
|
||||
logger.info("Successfully parsed candidate text as JSON")
|
||||
return parsed_text
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse candidate text as JSON: {e}")
|
||||
|
||||
logger.error("No valid structured response content found")
|
||||
return {"error": "No valid structured response content found"}
|
||||
|
||||
except ValueError as e:
|
||||
# API key related errors should not be retried
|
||||
logger.error(f"API key error in Gemini Pro structured JSON generation: {e}")
|
||||
return {"error": str(e)}
|
||||
except Exception as e:
|
||||
# Check if this is a quota/rate limit error
|
||||
msg = str(e)
|
||||
if "RESOURCE_EXHAUSTED" in msg or "429" in msg or "quota" in msg.lower():
|
||||
logger.error(f"Rate limit/quota error in Gemini Pro structured JSON generation: {msg}")
|
||||
# Return error instead of retrying - quota exhausted means we need to wait or upgrade plan
|
||||
return {"error": msg}
|
||||
# For other errors, let tenacity handle retries
|
||||
logger.error(f"Error in Gemini Pro structured JSON generation: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# Removed JSON repair functions to avoid false positives
|
||||
def _removed_repair_json_string(text: str) -> Optional[str]:
|
||||
"""
|
||||
Attempt to repair common JSON issues in AI responses.
|
||||
"""
|
||||
if not text:
|
||||
return None
|
||||
|
||||
# Remove any non-JSON content before first {
|
||||
start = text.find('{')
|
||||
if start == -1:
|
||||
return None
|
||||
text = text[start:]
|
||||
|
||||
# Remove any content after last }
|
||||
end = text.rfind('}')
|
||||
if end == -1:
|
||||
return None
|
||||
text = text[:end+1]
|
||||
|
||||
# Fix common issues
|
||||
repaired = text
|
||||
|
||||
# 1. Fix unterminated arrays (add missing closing brackets)
|
||||
# Count opening and closing brackets
|
||||
open_brackets = repaired.count('[')
|
||||
close_brackets = repaired.count(']')
|
||||
if open_brackets > close_brackets:
|
||||
# Add missing closing brackets
|
||||
missing_brackets = open_brackets - close_brackets
|
||||
repaired = repaired + ']' * missing_brackets
|
||||
|
||||
# 2. Fix unterminated strings in arrays
|
||||
# Look for patterns like ["item1", "item2" and add missing quote and bracket
|
||||
lines = repaired.split('\n')
|
||||
fixed_lines = []
|
||||
for i, line in enumerate(lines):
|
||||
stripped = line.strip()
|
||||
# Check if line ends with an unquoted string in an array
|
||||
if stripped.endswith('"') and i < len(lines) - 1:
|
||||
next_line = lines[i + 1].strip()
|
||||
if next_line.startswith(']'):
|
||||
# This is fine
|
||||
pass
|
||||
elif not next_line.startswith('"') and not next_line.startswith(']'):
|
||||
# Add missing quote and comma
|
||||
line = line + '",'
|
||||
fixed_lines.append(line)
|
||||
repaired = '\n'.join(fixed_lines)
|
||||
|
||||
# 3. Fix unterminated strings (common issue with AI responses)
|
||||
try:
|
||||
# Handle unterminated strings by finding the last incomplete string and closing it
|
||||
lines = repaired.split('\n')
|
||||
fixed_lines = []
|
||||
for i, line in enumerate(lines):
|
||||
stripped = line.strip()
|
||||
# Check for unterminated strings (line ends with quote but no closing quote)
|
||||
if stripped.endswith('"') and i < len(lines) - 1:
|
||||
next_line = lines[i + 1].strip()
|
||||
# If next line doesn't start with quote or closing bracket, we might have an unterminated string
|
||||
if not next_line.startswith('"') and not next_line.startswith(']') and not next_line.startswith('}'):
|
||||
# Check if this looks like an unterminated string value
|
||||
if ':' in line and not line.strip().endswith('",'):
|
||||
line = line + '",'
|
||||
# Count quotes in the line
|
||||
quote_count = line.count('"')
|
||||
if quote_count % 2 == 1: # Odd number of quotes
|
||||
# Add a quote at the end if it looks like an incomplete string
|
||||
if ':' in line and line.strip().endswith('"'):
|
||||
line = line + '"'
|
||||
elif ':' in line and not line.strip().endswith('"') and not line.strip().endswith(','):
|
||||
line = line + '",'
|
||||
fixed_lines.append(line)
|
||||
repaired = '\n'.join(fixed_lines)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 4. Remove trailing commas before closing braces/brackets
|
||||
repaired = re.sub(r',(\s*[}\]])', r'\1', repaired)
|
||||
|
||||
# 5. Fix missing commas between object properties
|
||||
repaired = re.sub(r'"(\s*)"', r'",\1"', repaired)
|
||||
|
||||
return repaired
|
||||
|
||||
|
||||
# Removed partial JSON extraction to avoid false positives
|
||||
def _removed_extract_partial_json(text: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Extract partial JSON from truncated responses.
|
||||
Attempts to salvage as much data as possible from incomplete JSON.
|
||||
"""
|
||||
if not text:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Find the start of JSON
|
||||
start = text.find('{')
|
||||
if start == -1:
|
||||
return None
|
||||
|
||||
# Extract from start to end, handling common truncation patterns
|
||||
json_text = text[start:]
|
||||
|
||||
# Common truncation patterns and their fixes
|
||||
truncation_patterns = [
|
||||
(r'(["\w\s,{}\[\]\-\.:]+)\.\.\.$', r'\1'), # Remove trailing ...
|
||||
(r'(["\w\s,{}\[\]\-\.:]+)"$', r'\1"'), # Add missing closing quote
|
||||
(r'(["\w\s,{}\[\]\-\.:]+),$', r'\1'), # Remove trailing comma
|
||||
(r'(["\w\s,{}\[\]\-\.:]+)\[(["\w\s,{}\[\]\-\.:]*)$', r'\1\2]'), # Close unclosed arrays
|
||||
(r'(["\w\s,{}\[\]\-\.:]+)\{(["\w\s,{}\[\]\-\.:]*)$', r'\1\2}'), # Close unclosed objects
|
||||
]
|
||||
|
||||
# Apply truncation fixes
|
||||
import re
|
||||
for pattern, replacement in truncation_patterns:
|
||||
json_text = re.sub(pattern, replacement, json_text)
|
||||
|
||||
# Try to balance brackets and braces
|
||||
open_braces = json_text.count('{')
|
||||
close_braces = json_text.count('}')
|
||||
open_brackets = json_text.count('[')
|
||||
close_brackets = json_text.count(']')
|
||||
|
||||
# Add missing closing braces/brackets
|
||||
if open_braces > close_braces:
|
||||
json_text += '}' * (open_braces - close_braces)
|
||||
if open_brackets > close_brackets:
|
||||
json_text += ']' * (open_brackets - close_brackets)
|
||||
|
||||
# Try to parse the repaired JSON
|
||||
try:
|
||||
result = json.loads(json_text)
|
||||
logger.info(f"Successfully extracted partial JSON with {len(str(result))} characters")
|
||||
return result
|
||||
except json.JSONDecodeError as e:
|
||||
logger.debug(f"Partial JSON parsing failed: {e}")
|
||||
|
||||
# Try to extract individual fields as a last resort
|
||||
fields = {}
|
||||
|
||||
# Extract key-value pairs using regex (more comprehensive patterns)
|
||||
kv_patterns = [
|
||||
r'"([^"]+)"\s*:\s*"([^"]*)"', # "key": "value"
|
||||
r'"([^"]+)"\s*:\s*(\d+)', # "key": 123
|
||||
r'"([^"]+)"\s*:\s*(true|false)', # "key": true/false
|
||||
r'"([^"]+)"\s*:\s*null', # "key": null
|
||||
]
|
||||
|
||||
for pattern in kv_patterns:
|
||||
matches = re.findall(pattern, json_text)
|
||||
for key, value in matches:
|
||||
if value == 'true':
|
||||
fields[key] = True
|
||||
elif value == 'false':
|
||||
fields[key] = False
|
||||
elif value == 'null':
|
||||
fields[key] = None
|
||||
elif value.isdigit():
|
||||
fields[key] = int(value)
|
||||
else:
|
||||
fields[key] = value
|
||||
|
||||
# Extract array fields (more robust)
|
||||
array_pattern = r'"([^"]+)"\s*:\s*\[([^\]]*)\]'
|
||||
array_matches = re.findall(array_pattern, json_text)
|
||||
for key, array_content in array_matches:
|
||||
# Parse array items more comprehensively
|
||||
items = []
|
||||
# Look for quoted strings, numbers, booleans, null
|
||||
item_patterns = [
|
||||
r'"([^"]*)"', # quoted strings
|
||||
r'(\d+)', # numbers
|
||||
r'(true|false)', # booleans
|
||||
r'(null)', # null
|
||||
]
|
||||
for pattern in item_patterns:
|
||||
item_matches = re.findall(pattern, array_content)
|
||||
for match in item_matches:
|
||||
if match == 'true':
|
||||
items.append(True)
|
||||
elif match == 'false':
|
||||
items.append(False)
|
||||
elif match == 'null':
|
||||
items.append(None)
|
||||
elif match.isdigit():
|
||||
items.append(int(match))
|
||||
else:
|
||||
items.append(match)
|
||||
if items:
|
||||
fields[key] = items
|
||||
|
||||
# Extract nested object fields (basic)
|
||||
object_pattern = r'"([^"]+)"\s*:\s*\{([^}]*)\}'
|
||||
object_matches = re.findall(object_pattern, json_text)
|
||||
for key, object_content in object_matches:
|
||||
# Simple nested object extraction
|
||||
nested_fields = {}
|
||||
nested_kv_matches = re.findall(r'"([^"]+)"\s*:\s*"([^"]*)"', object_content)
|
||||
for nested_key, nested_value in nested_kv_matches:
|
||||
nested_fields[nested_key] = nested_value
|
||||
if nested_fields:
|
||||
fields[key] = nested_fields
|
||||
|
||||
if fields:
|
||||
logger.info(f"Extracted {len(fields)} fields from truncated JSON: {list(fields.keys())}")
|
||||
# Only return if we have a valid outline structure
|
||||
if 'outline' in fields and isinstance(fields['outline'], list):
|
||||
return {'outline': fields['outline']}
|
||||
else:
|
||||
logger.error("No valid 'outline' field found in partial JSON")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error in partial JSON extraction: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# Removed key-value extraction to avoid false positives
|
||||
def _removed_extract_key_value_pairs(text: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Extract key-value pairs from malformed JSON text as a last resort.
|
||||
"""
|
||||
if not text:
|
||||
return None
|
||||
|
||||
result = {}
|
||||
|
||||
# Look for patterns like "key": "value" or "key": value
|
||||
# This regex looks for quoted keys followed by colons and values
|
||||
pattern = r'"([^"]+)"\s*:\s*(?:"([^"]*)"|([^,}\]]+))'
|
||||
matches = re.findall(pattern, text)
|
||||
|
||||
for key, quoted_value, unquoted_value in matches:
|
||||
value = quoted_value if quoted_value else unquoted_value.strip()
|
||||
|
||||
# Clean up the value - remove any trailing content that looks like the next key
|
||||
# This handles cases where the regex captured too much
|
||||
if value and '"' in value:
|
||||
# Split at the first quote that might be the start of the next key
|
||||
parts = value.split('"')
|
||||
if len(parts) > 1:
|
||||
value = parts[0].strip()
|
||||
|
||||
# Try to parse the value appropriately
|
||||
if value.lower() in ['true', 'false']:
|
||||
result[key] = value.lower() == 'true'
|
||||
elif value.lower() == 'null':
|
||||
result[key] = None
|
||||
elif value.isdigit():
|
||||
result[key] = int(value)
|
||||
elif value.replace('.', '').replace('-', '').isdigit():
|
||||
try:
|
||||
result[key] = float(value)
|
||||
except ValueError:
|
||||
result[key] = value
|
||||
else:
|
||||
result[key] = value
|
||||
|
||||
# Also try to extract array values
|
||||
array_pattern = r'"([^"]+)"\s*:\s*\[([^\]]*)\]'
|
||||
array_matches = re.findall(array_pattern, text)
|
||||
|
||||
for key, array_content in array_matches:
|
||||
# Extract individual array items
|
||||
items = []
|
||||
# Look for quoted strings in the array
|
||||
item_pattern = r'"([^"]*)"'
|
||||
item_matches = re.findall(item_pattern, array_content)
|
||||
for item in item_matches:
|
||||
if item.strip():
|
||||
items.append(item.strip())
|
||||
|
||||
if items:
|
||||
result[key] = items
|
||||
|
||||
return result if result else None
|
||||
441
backend/services/llm_providers/huggingface_provider.py
Normal file
441
backend/services/llm_providers/huggingface_provider.py
Normal file
@@ -0,0 +1,441 @@
|
||||
"""
|
||||
Hugging Face Provider Module for ALwrity
|
||||
|
||||
This module provides functions for interacting with Hugging Face's Inference Providers API
|
||||
using the Responses API (beta) which provides a unified interface for model interactions.
|
||||
|
||||
Key Features:
|
||||
- Text response generation with retry logic
|
||||
- Structured JSON response generation with schema validation
|
||||
- Comprehensive error handling and logging
|
||||
- Automatic API key management
|
||||
- Support for various Hugging Face models via Inference Providers
|
||||
|
||||
Best Practices:
|
||||
1. Use structured output for complex, multi-field responses
|
||||
2. Keep schemas simple and flat to avoid truncation
|
||||
3. Set appropriate token limits (8192 for complex outputs)
|
||||
4. Use low temperature (0.1-0.3) for consistent structured output
|
||||
5. Implement proper error handling in calling functions
|
||||
6. Use the Responses API for better compatibility
|
||||
|
||||
Usage Examples:
|
||||
# Text response
|
||||
result = huggingface_text_response(prompt, temperature=0.7, max_tokens=2048)
|
||||
|
||||
# Structured JSON response
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tasks": {
|
||||
"type": "array",
|
||||
"items": {"type": "object", "properties": {...}}
|
||||
}
|
||||
}
|
||||
}
|
||||
result = huggingface_structured_json_response(prompt, schema, temperature=0.2, max_tokens=8192)
|
||||
|
||||
Dependencies:
|
||||
- openai (for Hugging Face Responses API)
|
||||
- tenacity (for retry logic)
|
||||
- logging (for debugging)
|
||||
- json (for fallback parsing)
|
||||
|
||||
Author: ALwrity Team
|
||||
Version: 1.0
|
||||
Last Updated: January 2025
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import json
|
||||
import re
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Fix the environment loading path - load from backend directory
|
||||
current_dir = Path(__file__).parent.parent # services directory
|
||||
backend_dir = current_dir.parent # backend directory
|
||||
env_path = backend_dir / '.env'
|
||||
|
||||
if env_path.exists():
|
||||
load_dotenv(env_path)
|
||||
print(f"Loaded .env from: {env_path}")
|
||||
else:
|
||||
# Fallback to current directory
|
||||
load_dotenv()
|
||||
print(f"No .env found at {env_path}, using current directory")
|
||||
|
||||
from loguru import logger
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
# Use service-specific logger to avoid conflicts
|
||||
logger = get_service_logger("huggingface_provider")
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
try:
|
||||
from openai import OpenAI
|
||||
OPENAI_AVAILABLE = True
|
||||
except ImportError:
|
||||
OPENAI_AVAILABLE = False
|
||||
logger.warn("OpenAI library not available. Install with: pip install openai")
|
||||
|
||||
def get_huggingface_api_key() -> str:
|
||||
"""Get Hugging Face API key with proper error handling."""
|
||||
api_key = os.getenv('HF_TOKEN')
|
||||
if not api_key:
|
||||
error_msg = "HF_TOKEN environment variable is not set. Please set it in your .env file."
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Validate API key format (basic check)
|
||||
if not api_key.startswith('hf_'):
|
||||
error_msg = "HF_TOKEN appears to be invalid. It should start with 'hf_'."
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
return api_key
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
def huggingface_text_response(
|
||||
prompt: str,
|
||||
model: str = "openai/gpt-oss-120b:groq",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2048,
|
||||
top_p: float = 0.9,
|
||||
system_prompt: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate text response using Hugging Face Inference Providers API.
|
||||
|
||||
This function uses the Hugging Face Responses API which provides a unified interface
|
||||
for model interactions with built-in retry logic and error handling.
|
||||
|
||||
Args:
|
||||
prompt (str): The input prompt for the AI model
|
||||
model (str): Hugging Face model identifier (default: "openai/gpt-oss-120b:groq")
|
||||
temperature (float): Controls randomness (0.0-1.0)
|
||||
max_tokens (int): Maximum tokens in response
|
||||
top_p (float): Nucleus sampling parameter (0.0-1.0)
|
||||
system_prompt (str, optional): System instruction for the model
|
||||
|
||||
Returns:
|
||||
str: Generated text response
|
||||
|
||||
Raises:
|
||||
Exception: If API key is missing or API call fails
|
||||
|
||||
Best Practices:
|
||||
- Use appropriate temperature for your use case (0.7 for creative, 0.1-0.3 for factual)
|
||||
- Set max_tokens based on expected response length
|
||||
- Use system_prompt to guide model behavior
|
||||
- Handle errors gracefully in calling functions
|
||||
|
||||
Example:
|
||||
result = huggingface_text_response(
|
||||
prompt="Write a blog post about AI",
|
||||
model="openai/gpt-oss-120b:groq",
|
||||
temperature=0.7,
|
||||
max_tokens=2048,
|
||||
system_prompt="You are a professional content writer."
|
||||
)
|
||||
"""
|
||||
try:
|
||||
if not OPENAI_AVAILABLE:
|
||||
raise ImportError("OpenAI library not available. Install with: pip install openai")
|
||||
|
||||
# Get API key with proper error handling
|
||||
api_key = get_huggingface_api_key()
|
||||
logger.info(f"🔑 Hugging Face API key loaded: {bool(api_key)} (length: {len(api_key) if api_key else 0})")
|
||||
|
||||
if not api_key:
|
||||
raise Exception("HF_TOKEN not found in environment variables")
|
||||
|
||||
# Initialize Hugging Face client using Responses API
|
||||
client = OpenAI(
|
||||
base_url="https://router.huggingface.co/v1",
|
||||
api_key=api_key,
|
||||
)
|
||||
logger.info("✅ Hugging Face client initialized for text response")
|
||||
|
||||
# Prepare input for the API
|
||||
input_content = []
|
||||
|
||||
# Add system prompt if provided
|
||||
if system_prompt:
|
||||
input_content.append({
|
||||
"role": "system",
|
||||
"content": system_prompt
|
||||
})
|
||||
|
||||
# Add user prompt
|
||||
input_content.append({
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
})
|
||||
|
||||
# Add debugging for API call
|
||||
logger.info(
|
||||
"Hugging Face text call | model=%s | prompt_len=%s | temp=%s | top_p=%s | max_tokens=%s",
|
||||
model,
|
||||
len(prompt) if isinstance(prompt, str) else '<non-str>',
|
||||
temperature,
|
||||
top_p,
|
||||
max_tokens,
|
||||
)
|
||||
|
||||
logger.info("🚀 Making Hugging Face API call...")
|
||||
|
||||
# Add rate limiting to prevent expensive API calls
|
||||
import time
|
||||
time.sleep(1) # 1 second delay between API calls
|
||||
|
||||
# Make the API call using Responses API
|
||||
response = client.responses.parse(
|
||||
model=model,
|
||||
input=input_content,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
)
|
||||
|
||||
# Extract text from response
|
||||
if hasattr(response, 'output_text') and response.output_text:
|
||||
generated_text = response.output_text
|
||||
elif hasattr(response, 'output') and response.output:
|
||||
# Handle case where output is a list
|
||||
if isinstance(response.output, list) and len(response.output) > 0:
|
||||
generated_text = response.output[0].get('content', '')
|
||||
else:
|
||||
generated_text = str(response.output)
|
||||
else:
|
||||
generated_text = str(response)
|
||||
|
||||
# Clean up the response
|
||||
if generated_text:
|
||||
# Remove any markdown formatting if present
|
||||
generated_text = re.sub(r'```[a-zA-Z]*\n?', '', generated_text)
|
||||
generated_text = re.sub(r'```\n?', '', generated_text)
|
||||
generated_text = generated_text.strip()
|
||||
|
||||
logger.info(f"✅ Hugging Face text response generated successfully (length: {len(generated_text)})")
|
||||
return generated_text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Hugging Face text generation failed: {str(e)}")
|
||||
raise Exception(f"Hugging Face text generation failed: {str(e)}")
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
def huggingface_structured_json_response(
|
||||
prompt: str,
|
||||
schema: Dict[str, Any],
|
||||
model: str = "openai/gpt-oss-120b:groq",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 8192,
|
||||
system_prompt: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate structured JSON response using Hugging Face Inference Providers API.
|
||||
|
||||
This function uses the Hugging Face Responses API with structured output support
|
||||
to generate JSON responses that match a provided schema.
|
||||
|
||||
Args:
|
||||
prompt (str): The input prompt for the AI model
|
||||
schema (dict): JSON schema defining the expected output structure
|
||||
model (str): Hugging Face model identifier (default: "openai/gpt-oss-120b:groq")
|
||||
temperature (float): Controls randomness (0.0-1.0). Use 0.1-0.3 for structured output
|
||||
max_tokens (int): Maximum tokens in response. Use 8192 for complex outputs
|
||||
system_prompt (str, optional): System instruction for the model
|
||||
|
||||
Returns:
|
||||
dict: Parsed JSON response matching the provided schema
|
||||
|
||||
Raises:
|
||||
Exception: If API key is missing or API call fails
|
||||
|
||||
Best Practices:
|
||||
- Keep schemas simple and flat to avoid truncation
|
||||
- Use low temperature (0.1-0.3) for consistent structured output
|
||||
- Set max_tokens to 8192 for complex multi-field responses
|
||||
- Avoid deeply nested schemas with many required fields
|
||||
- Test with smaller outputs first, then scale up
|
||||
|
||||
Example:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tasks": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"description": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
result = huggingface_structured_json_response(prompt, schema, temperature=0.2, max_tokens=8192)
|
||||
"""
|
||||
try:
|
||||
if not OPENAI_AVAILABLE:
|
||||
raise ImportError("OpenAI library not available. Install with: pip install openai")
|
||||
|
||||
# Get API key with proper error handling
|
||||
api_key = get_huggingface_api_key()
|
||||
logger.info(f"🔑 Hugging Face API key loaded: {bool(api_key)} (length: {len(api_key) if api_key else 0})")
|
||||
|
||||
if not api_key:
|
||||
raise Exception("HF_TOKEN not found in environment variables")
|
||||
|
||||
# Initialize Hugging Face client using Responses API
|
||||
client = OpenAI(
|
||||
base_url="https://router.huggingface.co/v1",
|
||||
api_key=api_key,
|
||||
)
|
||||
logger.info("✅ Hugging Face client initialized for structured JSON response")
|
||||
|
||||
# Prepare input for the API
|
||||
input_content = []
|
||||
|
||||
# Add system prompt if provided
|
||||
if system_prompt:
|
||||
input_content.append({
|
||||
"role": "system",
|
||||
"content": system_prompt
|
||||
})
|
||||
|
||||
# Add user prompt with JSON instruction
|
||||
json_instruction = "Please respond with valid JSON that matches the provided schema."
|
||||
input_content.append({
|
||||
"role": "user",
|
||||
"content": f"{prompt}\n\n{json_instruction}"
|
||||
})
|
||||
|
||||
# Add debugging for API call
|
||||
logger.info(
|
||||
"Hugging Face structured call | model=%s | prompt_len=%s | schema_kind=%s | temp=%s | max_tokens=%s",
|
||||
model,
|
||||
len(prompt) if isinstance(prompt, str) else '<non-str>',
|
||||
type(schema).__name__,
|
||||
temperature,
|
||||
max_tokens,
|
||||
)
|
||||
|
||||
logger.info("🚀 Making Hugging Face structured API call...")
|
||||
|
||||
# Make the API call using Responses API with structured output
|
||||
# Use simple text generation and parse JSON manually to avoid API format issues
|
||||
logger.info("🚀 Making Hugging Face API call (text mode with JSON parsing)...")
|
||||
|
||||
# Add JSON instruction to the prompt
|
||||
json_instruction = "\n\nPlease respond with valid JSON that matches this exact structure:\n" + json.dumps(schema, indent=2)
|
||||
input_content[-1]["content"] = input_content[-1]["content"] + json_instruction
|
||||
|
||||
# Add rate limiting to prevent expensive API calls
|
||||
import time
|
||||
time.sleep(1) # 1 second delay between API calls
|
||||
|
||||
response = client.responses.parse(
|
||||
model=model,
|
||||
input=input_content,
|
||||
temperature=temperature
|
||||
)
|
||||
|
||||
# Extract structured data from response
|
||||
if hasattr(response, 'output_parsed') and response.output_parsed:
|
||||
# The new API returns parsed data directly (Pydantic model case)
|
||||
logger.info("✅ Hugging Face structured JSON response parsed successfully")
|
||||
# Convert Pydantic model to dict if needed
|
||||
if hasattr(response.output_parsed, 'model_dump'):
|
||||
return response.output_parsed.model_dump()
|
||||
elif hasattr(response.output_parsed, 'dict'):
|
||||
return response.output_parsed.dict()
|
||||
else:
|
||||
return response.output_parsed
|
||||
elif hasattr(response, 'output_text') and response.output_text:
|
||||
# Fallback to text parsing if output_parsed is not available
|
||||
response_text = response.output_text
|
||||
# Clean up the response text
|
||||
response_text = re.sub(r'```json\n?', '', response_text)
|
||||
response_text = re.sub(r'```\n?', '', response_text)
|
||||
response_text = response_text.strip()
|
||||
|
||||
try:
|
||||
parsed_json = json.loads(response_text)
|
||||
logger.info("✅ Hugging Face structured JSON response parsed from text")
|
||||
return parsed_json
|
||||
except json.JSONDecodeError as json_err:
|
||||
logger.error(f"❌ JSON parsing failed: {json_err}")
|
||||
logger.error(f"Raw response: {response_text}")
|
||||
|
||||
# Try to extract JSON from the response using regex
|
||||
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
||||
if json_match:
|
||||
try:
|
||||
extracted_json = json.loads(json_match.group())
|
||||
logger.info("✅ JSON extracted using regex fallback")
|
||||
return extracted_json
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# If all else fails, return a structured error response
|
||||
logger.error("❌ All JSON parsing attempts failed")
|
||||
return {
|
||||
"error": "Failed to parse JSON response",
|
||||
"raw_response": response_text,
|
||||
"schema_expected": schema
|
||||
}
|
||||
else:
|
||||
logger.error("❌ No valid response data found")
|
||||
return {
|
||||
"error": "No valid response data found",
|
||||
"raw_response": str(response),
|
||||
"schema_expected": schema
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e) if str(e) else repr(e)
|
||||
error_type = type(e).__name__
|
||||
logger.error(f"❌ Hugging Face structured JSON generation failed: {error_type}: {error_msg}")
|
||||
logger.error(f"❌ Full exception details: {repr(e)}")
|
||||
import traceback
|
||||
logger.error(f"❌ Traceback: {traceback.format_exc()}")
|
||||
raise Exception(f"Hugging Face structured JSON generation failed: {error_type}: {error_msg}")
|
||||
|
||||
def get_available_models() -> list:
|
||||
"""
|
||||
Get list of available Hugging Face models for text generation.
|
||||
|
||||
Returns:
|
||||
list: List of available model identifiers
|
||||
"""
|
||||
return [
|
||||
"openai/gpt-oss-120b:groq",
|
||||
"moonshotai/Kimi-K2-Instruct-0905:groq",
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
"meta-llama/Llama-3.1-8B-Instruct:groq",
|
||||
"microsoft/Phi-3-medium-4k-instruct:groq",
|
||||
"mistralai/Mistral-7B-Instruct-v0.3:groq"
|
||||
]
|
||||
|
||||
def validate_model(model: str) -> bool:
|
||||
"""
|
||||
Validate if a model identifier is supported.
|
||||
|
||||
Args:
|
||||
model (str): Model identifier to validate
|
||||
|
||||
Returns:
|
||||
bool: True if model is supported, False otherwise
|
||||
"""
|
||||
available_models = get_available_models()
|
||||
return model in available_models
|
||||
17
backend/services/llm_providers/image_generation/__init__.py
Normal file
17
backend/services/llm_providers/image_generation/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from .base import ImageGenerationOptions, ImageGenerationResult, ImageGenerationProvider
|
||||
from .hf_provider import HuggingFaceImageProvider
|
||||
from .gemini_provider import GeminiImageProvider
|
||||
from .stability_provider import StabilityImageProvider
|
||||
from .wavespeed_provider import WaveSpeedImageProvider
|
||||
|
||||
__all__ = [
|
||||
"ImageGenerationOptions",
|
||||
"ImageGenerationResult",
|
||||
"ImageGenerationProvider",
|
||||
"HuggingFaceImageProvider",
|
||||
"GeminiImageProvider",
|
||||
"StabilityImageProvider",
|
||||
"WaveSpeedImageProvider",
|
||||
]
|
||||
|
||||
|
||||
37
backend/services/llm_providers/image_generation/base.py
Normal file
37
backend/services/llm_providers/image_generation/base.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Dict, Any, Protocol
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageGenerationOptions:
|
||||
prompt: str
|
||||
negative_prompt: Optional[str] = None
|
||||
width: int = 1024
|
||||
height: int = 1024
|
||||
guidance_scale: Optional[float] = None
|
||||
steps: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
model: Optional[str] = None
|
||||
extra: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageGenerationResult:
|
||||
image_bytes: bytes
|
||||
width: int
|
||||
height: int
|
||||
provider: str
|
||||
model: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ImageGenerationProvider(Protocol):
|
||||
"""Protocol for image generation providers."""
|
||||
|
||||
def generate(self, options: ImageGenerationOptions) -> ImageGenerationResult:
|
||||
...
|
||||
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from .base import ImageGenerationOptions, ImageGenerationResult, ImageGenerationProvider
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
logger = get_service_logger("image_generation.gemini")
|
||||
|
||||
|
||||
class GeminiImageProvider(ImageGenerationProvider):
|
||||
"""Google Gemini/Imagen backed image generation.
|
||||
|
||||
NOTE: Implementation should call the actual Gemini Images API used in the codebase.
|
||||
Here we keep a minimal interface and expect the underlying client to be wired
|
||||
similarly to other providers and return a PIL image or raw bytes.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
api_key = os.getenv("GOOGLE_API_KEY")
|
||||
if not api_key:
|
||||
logger.warning("GOOGLE_API_KEY not set. Gemini image generation may fail at runtime.")
|
||||
logger.info("GeminiImageProvider initialized")
|
||||
|
||||
def generate(self, options: ImageGenerationOptions) -> ImageGenerationResult:
|
||||
# Placeholder implementation to be replaced by real Gemini/Imagen call.
|
||||
# For now, generate a 1x1 transparent PNG to maintain interface consistency
|
||||
img = Image.new("RGBA", (max(1, options.width), max(1, options.height)), (0, 0, 0, 0))
|
||||
with io.BytesIO() as buf:
|
||||
img.save(buf, format="PNG")
|
||||
png = buf.getvalue()
|
||||
|
||||
return ImageGenerationResult(
|
||||
image_bytes=png,
|
||||
width=img.width,
|
||||
height=img.height,
|
||||
provider="gemini",
|
||||
model=os.getenv("GEMINI_IMAGE_MODEL"),
|
||||
seed=options.seed,
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from PIL import Image
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
from .base import ImageGenerationOptions, ImageGenerationResult, ImageGenerationProvider
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
logger = get_service_logger("image_generation.huggingface")
|
||||
|
||||
|
||||
DEFAULT_HF_MODEL = os.getenv(
|
||||
"HF_IMAGE_MODEL",
|
||||
"black-forest-labs/FLUX.1-Krea-dev",
|
||||
)
|
||||
|
||||
|
||||
class HuggingFaceImageProvider(ImageGenerationProvider):
|
||||
"""Hugging Face Inference Providers (fal-ai) backed image generation.
|
||||
|
||||
API doc: https://huggingface.co/docs/inference-providers/en/tasks/text-to-image
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, provider: str = "fal-ai") -> None:
|
||||
self.api_key = api_key or os.getenv("HF_TOKEN")
|
||||
if not self.api_key:
|
||||
raise RuntimeError("HF_TOKEN is required for Hugging Face image generation")
|
||||
self.provider = provider
|
||||
self.client = InferenceClient(provider=self.provider, api_key=self.api_key)
|
||||
logger.info("HuggingFaceImageProvider initialized (provider=%s)", self.provider)
|
||||
|
||||
def generate(self, options: ImageGenerationOptions) -> ImageGenerationResult:
|
||||
model = options.model or DEFAULT_HF_MODEL
|
||||
params: Dict[str, Any] = {}
|
||||
if options.guidance_scale is not None:
|
||||
params["guidance_scale"] = options.guidance_scale
|
||||
if options.steps is not None:
|
||||
params["num_inference_steps"] = options.steps
|
||||
if options.negative_prompt:
|
||||
params["negative_prompt"] = options.negative_prompt
|
||||
if options.seed is not None:
|
||||
params["seed"] = options.seed
|
||||
|
||||
# The HF InferenceClient returns a PIL Image
|
||||
logger.debug("HF generate: model=%s width=%s height=%s params=%s", model, options.width, options.height, params)
|
||||
img: Image.Image = self.client.text_to_image(
|
||||
options.prompt,
|
||||
model=model,
|
||||
width=options.width,
|
||||
height=options.height,
|
||||
**params,
|
||||
)
|
||||
|
||||
with io.BytesIO() as buf:
|
||||
img.save(buf, format="PNG")
|
||||
image_bytes = buf.getvalue()
|
||||
|
||||
return ImageGenerationResult(
|
||||
image_bytes=image_bytes,
|
||||
width=img.width,
|
||||
height=img.height,
|
||||
provider="huggingface",
|
||||
model=model,
|
||||
seed=options.seed,
|
||||
metadata={"provider": self.provider},
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from .base import ImageGenerationOptions, ImageGenerationResult, ImageGenerationProvider
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
logger = get_service_logger("image_generation.stability")
|
||||
|
||||
|
||||
DEFAULT_STABILITY_MODEL = os.getenv("STABILITY_MODEL", "stable-diffusion-xl-1024-v1-0")
|
||||
|
||||
|
||||
class StabilityImageProvider(ImageGenerationProvider):
|
||||
"""Stability AI Images API provider (simple text-to-image).
|
||||
|
||||
This uses the v1 text-to-image endpoint format. Adjust to match your existing
|
||||
Stability integration if different.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None) -> None:
|
||||
self.api_key = api_key or os.getenv("STABILITY_API_KEY")
|
||||
if not self.api_key:
|
||||
logger.warning("STABILITY_API_KEY not set. Stability generation may fail at runtime.")
|
||||
logger.info("StabilityImageProvider initialized")
|
||||
|
||||
def generate(self, options: ImageGenerationOptions) -> ImageGenerationResult:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload: Dict[str, Any] = {
|
||||
"text_prompts": [
|
||||
{"text": options.prompt, "weight": 1.0},
|
||||
],
|
||||
"cfg_scale": options.guidance_scale or 7.0,
|
||||
"steps": options.steps or 30,
|
||||
"width": options.width,
|
||||
"height": options.height,
|
||||
"seed": options.seed,
|
||||
}
|
||||
if options.negative_prompt:
|
||||
payload["text_prompts"].append({"text": options.negative_prompt, "weight": -1.0})
|
||||
|
||||
model = options.model or DEFAULT_STABILITY_MODEL
|
||||
url = f"https://api.stability.ai/v1/generation/{model}/text-to-image"
|
||||
|
||||
logger.debug("Stability generate: model=%s payload_keys=%s", model, list(payload.keys()))
|
||||
resp = requests.post(url, headers=headers, json=payload, timeout=60)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
# Expecting data["artifacts"][0]["base64"]
|
||||
import base64
|
||||
|
||||
artifact = (data.get("artifacts") or [{}])[0]
|
||||
b64 = artifact.get("base64", "")
|
||||
image_bytes = base64.b64decode(b64)
|
||||
|
||||
# Confirm dimensions by loading once (optional)
|
||||
img = Image.open(io.BytesIO(image_bytes))
|
||||
|
||||
return ImageGenerationResult(
|
||||
image_bytes=image_bytes,
|
||||
width=img.width,
|
||||
height=img.height,
|
||||
provider="stability",
|
||||
model=model,
|
||||
seed=options.seed,
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,243 @@
|
||||
"""WaveSpeed AI image generation provider (Ideogram V3 Turbo & Qwen Image)."""
|
||||
|
||||
import io
|
||||
import os
|
||||
from typing import Optional
|
||||
from PIL import Image
|
||||
|
||||
from .base import ImageGenerationProvider, ImageGenerationOptions, ImageGenerationResult
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
logger = get_service_logger("wavespeed.image_provider")
|
||||
|
||||
|
||||
class WaveSpeedImageProvider(ImageGenerationProvider):
|
||||
"""WaveSpeed AI image generation provider supporting Ideogram V3 and Qwen."""
|
||||
|
||||
SUPPORTED_MODELS = {
|
||||
"ideogram-v3-turbo": {
|
||||
"name": "Ideogram V3 Turbo",
|
||||
"description": "Photorealistic generation with superior text rendering",
|
||||
"cost_per_image": 0.10, # Estimated, adjust based on actual pricing
|
||||
"max_resolution": (1024, 1024),
|
||||
"default_steps": 20,
|
||||
},
|
||||
"qwen-image": {
|
||||
"name": "Qwen Image",
|
||||
"description": "Fast, high-quality text-to-image generation",
|
||||
"cost_per_image": 0.05, # Estimated, adjust based on actual pricing
|
||||
"max_resolution": (1024, 1024),
|
||||
"default_steps": 15,
|
||||
}
|
||||
}
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
"""Initialize WaveSpeed image provider.
|
||||
|
||||
Args:
|
||||
api_key: WaveSpeed API key (falls back to env var if not provided)
|
||||
"""
|
||||
self.api_key = api_key or os.getenv("WAVESPEED_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("WaveSpeed API key not found. Set WAVESPEED_API_KEY environment variable.")
|
||||
|
||||
self.client = WaveSpeedClient(api_key=self.api_key)
|
||||
logger.info("[WaveSpeed Image Provider] Initialized with available models: %s",
|
||||
list(self.SUPPORTED_MODELS.keys()))
|
||||
|
||||
def _validate_options(self, options: ImageGenerationOptions) -> None:
|
||||
"""Validate generation options.
|
||||
|
||||
Args:
|
||||
options: Image generation options
|
||||
|
||||
Raises:
|
||||
ValueError: If options are invalid
|
||||
"""
|
||||
model = options.model or "ideogram-v3-turbo"
|
||||
|
||||
if model not in self.SUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
f"Unsupported model: {model}. "
|
||||
f"Supported models: {list(self.SUPPORTED_MODELS.keys())}"
|
||||
)
|
||||
|
||||
model_info = self.SUPPORTED_MODELS[model]
|
||||
max_width, max_height = model_info["max_resolution"]
|
||||
|
||||
if options.width > max_width or options.height > max_height:
|
||||
raise ValueError(
|
||||
f"Resolution {options.width}x{options.height} exceeds maximum "
|
||||
f"{max_width}x{max_height} for model {model}"
|
||||
)
|
||||
|
||||
if not options.prompt or len(options.prompt.strip()) == 0:
|
||||
raise ValueError("Prompt cannot be empty")
|
||||
|
||||
def _generate_ideogram_v3(self, options: ImageGenerationOptions) -> bytes:
|
||||
"""Generate image using Ideogram V3 Turbo.
|
||||
|
||||
Args:
|
||||
options: Image generation options
|
||||
|
||||
Returns:
|
||||
Image bytes
|
||||
"""
|
||||
logger.info("[Ideogram V3] Starting image generation: %s", options.prompt[:100])
|
||||
|
||||
try:
|
||||
# Prepare parameters for WaveSpeed Ideogram V3 API
|
||||
# Note: Adjust these based on actual WaveSpeed API documentation
|
||||
params = {
|
||||
"model": "ideogram-v3-turbo",
|
||||
"prompt": options.prompt,
|
||||
"width": options.width,
|
||||
"height": options.height,
|
||||
"num_inference_steps": options.steps or self.SUPPORTED_MODELS["ideogram-v3-turbo"]["default_steps"],
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if options.negative_prompt:
|
||||
params["negative_prompt"] = options.negative_prompt
|
||||
|
||||
if options.guidance_scale:
|
||||
params["guidance_scale"] = options.guidance_scale
|
||||
|
||||
if options.seed:
|
||||
params["seed"] = options.seed
|
||||
|
||||
# Call WaveSpeed API (using generic image generation method)
|
||||
# This will need to be adjusted based on actual WaveSpeed client implementation
|
||||
result = self.client.generate_image(**params)
|
||||
|
||||
# Extract image bytes from result
|
||||
# Adjust based on actual WaveSpeed API response format
|
||||
if isinstance(result, bytes):
|
||||
image_bytes = result
|
||||
elif isinstance(result, dict) and "image" in result:
|
||||
image_bytes = result["image"]
|
||||
else:
|
||||
raise ValueError(f"Unexpected response format from WaveSpeed API: {type(result)}")
|
||||
|
||||
logger.info("[Ideogram V3] ✅ Successfully generated image: %d bytes", len(image_bytes))
|
||||
return image_bytes
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Ideogram V3] ❌ Error generating image: %s", str(e), exc_info=True)
|
||||
raise RuntimeError(f"Ideogram V3 generation failed: {str(e)}")
|
||||
|
||||
def _generate_qwen_image(self, options: ImageGenerationOptions) -> bytes:
|
||||
"""Generate image using Qwen Image.
|
||||
|
||||
Args:
|
||||
options: Image generation options
|
||||
|
||||
Returns:
|
||||
Image bytes
|
||||
"""
|
||||
logger.info("[Qwen Image] Starting image generation: %s", options.prompt[:100])
|
||||
|
||||
try:
|
||||
# Prepare parameters for WaveSpeed Qwen Image API
|
||||
params = {
|
||||
"model": "qwen-image",
|
||||
"prompt": options.prompt,
|
||||
"width": options.width,
|
||||
"height": options.height,
|
||||
"num_inference_steps": options.steps or self.SUPPORTED_MODELS["qwen-image"]["default_steps"],
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if options.negative_prompt:
|
||||
params["negative_prompt"] = options.negative_prompt
|
||||
|
||||
if options.guidance_scale:
|
||||
params["guidance_scale"] = options.guidance_scale
|
||||
|
||||
if options.seed:
|
||||
params["seed"] = options.seed
|
||||
|
||||
# Call WaveSpeed API
|
||||
result = self.client.generate_image(**params)
|
||||
|
||||
# Extract image bytes from result
|
||||
if isinstance(result, bytes):
|
||||
image_bytes = result
|
||||
elif isinstance(result, dict) and "image" in result:
|
||||
image_bytes = result["image"]
|
||||
else:
|
||||
raise ValueError(f"Unexpected response format from WaveSpeed API: {type(result)}")
|
||||
|
||||
logger.info("[Qwen Image] ✅ Successfully generated image: %d bytes", len(image_bytes))
|
||||
return image_bytes
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Qwen Image] ❌ Error generating image: %s", str(e), exc_info=True)
|
||||
raise RuntimeError(f"Qwen Image generation failed: {str(e)}")
|
||||
|
||||
def generate(self, options: ImageGenerationOptions) -> ImageGenerationResult:
|
||||
"""Generate image using WaveSpeed AI models.
|
||||
|
||||
Args:
|
||||
options: Image generation options
|
||||
|
||||
Returns:
|
||||
ImageGenerationResult with generated image
|
||||
|
||||
Raises:
|
||||
ValueError: If options are invalid
|
||||
RuntimeError: If generation fails
|
||||
"""
|
||||
# Validate options
|
||||
self._validate_options(options)
|
||||
|
||||
# Determine model
|
||||
model = options.model or "ideogram-v3-turbo"
|
||||
|
||||
# Generate based on model
|
||||
if model == "ideogram-v3-turbo":
|
||||
image_bytes = self._generate_ideogram_v3(options)
|
||||
elif model == "qwen-image":
|
||||
image_bytes = self._generate_qwen_image(options)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model: {model}")
|
||||
|
||||
# Load image to get dimensions
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
width, height = image.size
|
||||
|
||||
# Calculate estimated cost
|
||||
model_info = self.SUPPORTED_MODELS[model]
|
||||
estimated_cost = model_info["cost_per_image"]
|
||||
|
||||
# Return result
|
||||
return ImageGenerationResult(
|
||||
image_bytes=image_bytes,
|
||||
width=width,
|
||||
height=height,
|
||||
provider="wavespeed",
|
||||
model=model,
|
||||
seed=options.seed,
|
||||
metadata={
|
||||
"provider": "wavespeed",
|
||||
"model": model,
|
||||
"model_name": model_info["name"],
|
||||
"prompt": options.prompt,
|
||||
"negative_prompt": options.negative_prompt,
|
||||
"steps": options.steps or model_info["default_steps"],
|
||||
"guidance_scale": options.guidance_scale,
|
||||
"estimated_cost": estimated_cost,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_available_models(cls) -> dict:
|
||||
"""Get available models and their information.
|
||||
|
||||
Returns:
|
||||
Dictionary of available models
|
||||
"""
|
||||
return cls.SUPPORTED_MODELS
|
||||
|
||||
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Gemini Image Description Module
|
||||
|
||||
This module provides functionality to generate text descriptions of images using Google's Gemini API.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import base64
|
||||
from typing import Optional, Dict, Any, List, Union
|
||||
from dotenv import load_dotenv
|
||||
import google.genai as genai
|
||||
from google.genai import types
|
||||
|
||||
from PIL import Image
|
||||
from loguru import logger
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
# Use service-specific logger to avoid conflicts
|
||||
logger = get_service_logger("gemini_image_describe")
|
||||
|
||||
# Import APIKeyManager
|
||||
from ...onboarding.api_key_manager import APIKeyManager
|
||||
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
except ImportError:
|
||||
genai = None
|
||||
logger.warning("Google genai library not available. Install with: pip install google-generativeai")
|
||||
|
||||
|
||||
def describe_image(image_path: str, prompt: str = "Describe this image in detail:") -> Optional[str]:
|
||||
"""
|
||||
Describe an image using Google's Gemini API.
|
||||
|
||||
Parameters:
|
||||
image_path (str): Path to the image file.
|
||||
prompt (str): Prompt for describing the image.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The generated description of the image, or None if an error occurs.
|
||||
"""
|
||||
try:
|
||||
if not genai:
|
||||
logger.error("Google genai library not available")
|
||||
return None
|
||||
|
||||
# Use APIKeyManager instead of direct environment variable access
|
||||
api_key_manager = APIKeyManager()
|
||||
api_key = api_key_manager.get_api_key("gemini")
|
||||
|
||||
if not api_key:
|
||||
error_message = "Gemini API key not found. Please configure it in the onboarding process."
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
# Check if image file exists
|
||||
if not os.path.exists(image_path):
|
||||
error_message = f"Image file not found: {image_path}"
|
||||
logger.error(error_message)
|
||||
raise FileNotFoundError(error_message)
|
||||
|
||||
# Initialize the Gemini client
|
||||
client = genai.Client(api_key=api_key)
|
||||
|
||||
# Open and process the image
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
logger.info(f"Successfully opened image: {image_path}")
|
||||
except Exception as e:
|
||||
error_message = f"Failed to open image: {e}"
|
||||
logger.error(error_message)
|
||||
return None
|
||||
|
||||
# Generate content description
|
||||
try:
|
||||
response = client.models.generate_content(
|
||||
model='gemini-2.0-flash',
|
||||
contents=[
|
||||
prompt,
|
||||
image
|
||||
]
|
||||
)
|
||||
|
||||
# Extract and return the text
|
||||
description = response.text
|
||||
logger.info(f"Successfully generated description for image: {image_path}")
|
||||
return description
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"Failed to generate content: {e}"
|
||||
logger.error(error_message)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"An unexpected error occurred: {e}"
|
||||
logger.error(error_message)
|
||||
return None
|
||||
|
||||
|
||||
def analyze_image_with_prompt(image_path: str, prompt: str) -> Optional[str]:
|
||||
"""
|
||||
Analyze an image with a custom prompt using Google's Gemini API.
|
||||
|
||||
Parameters:
|
||||
image_path (str): Path to the image file.
|
||||
prompt (str): Custom prompt for analyzing the image.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The generated analysis of the image, or None if an error occurs.
|
||||
"""
|
||||
return describe_image(image_path, prompt)
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Example usage of the function
|
||||
image_path = "path/to/your/image.jpg"
|
||||
description = describe_image(image_path)
|
||||
if description:
|
||||
print(f"Image description: {description}")
|
||||
else:
|
||||
print("Failed to generate image description")
|
||||
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
This module provides functionality to analyze images using OpenAI's Vision API.
|
||||
It encodes an image to a base64 string and sends a request to the OpenAI API
|
||||
to interpret the contents of the image, returning a textual description.
|
||||
"""
|
||||
|
||||
import requests
|
||||
import sys
|
||||
import re
|
||||
import base64
|
||||
|
||||
def analyze_and_extract_details_from_image(image_path, api_key):
|
||||
"""
|
||||
Analyzes an image using OpenAI's Vision API and extracts Alt Text, Description, Title, and Caption.
|
||||
|
||||
Args:
|
||||
image_path (str): Path to the image file.
|
||||
api_key (str): Your OpenAI API key.
|
||||
|
||||
Returns:
|
||||
dict: Extracted details including Alt Text, Description, Title, and Caption.
|
||||
"""
|
||||
def encode_image(path):
|
||||
""" Encodes an image to a base64 string. """
|
||||
with open(path, "rb", encoding="utf-8") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode('utf-8')
|
||||
|
||||
base64_image = encode_image(image_path)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4-vision-preview",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "The given image is used in blog content. Analyze the given image and suggest alternative(alt) test, description, title, caption."
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"max_tokens": 300
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
assistant_message = response.json()['choices'][0]['message']['content']
|
||||
|
||||
# Extracting details using regular expressions
|
||||
alt_text_match = re.search(r'Alt Text: "(.*?)"', assistant_message)
|
||||
description_match = re.search(r'Description: (.*?)\n\n', assistant_message)
|
||||
title_match = re.search(r'Title: "(.*?)"', assistant_message)
|
||||
caption_match = re.search(r'Caption: "(.*?)"', assistant_message)
|
||||
|
||||
return {
|
||||
'alt_text': alt_text_match.group(1) if alt_text_match else None,
|
||||
'description': description_match.group(1) if description_match else None,
|
||||
'title': title_match.group(1) if title_match else None,
|
||||
'caption': caption_match.group(1) if caption_match else None
|
||||
}
|
||||
|
||||
except requests.RequestException as e:
|
||||
sys.exit(f"Error: Failed to communicate with OpenAI API. Error: {e}")
|
||||
except Exception as e:
|
||||
sys.exit(f"Error occurred: {e}")
|
||||
319
backend/services/llm_providers/main_audio_generation.py
Normal file
319
backend/services/llm_providers/main_audio_generation.py
Normal file
@@ -0,0 +1,319 @@
|
||||
"""
|
||||
Main Audio Generation Service for ALwrity Backend.
|
||||
|
||||
This service provides AI-powered text-to-speech functionality using WaveSpeed Minimax Speech 02 HD.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
from fastapi import HTTPException
|
||||
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
from services.onboarding.api_key_manager import APIKeyManager
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("audio_generation")
|
||||
|
||||
|
||||
class AudioGenerationResult:
|
||||
"""Result of audio generation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
audio_bytes: bytes,
|
||||
provider: str,
|
||||
model: str,
|
||||
voice_id: str,
|
||||
text_length: int,
|
||||
file_size: int,
|
||||
):
|
||||
self.audio_bytes = audio_bytes
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.voice_id = voice_id
|
||||
self.text_length = text_length
|
||||
self.file_size = file_size
|
||||
|
||||
|
||||
def generate_audio(
|
||||
text: str,
|
||||
voice_id: str = "Wise_Woman",
|
||||
speed: float = 1.0,
|
||||
volume: float = 1.0,
|
||||
pitch: float = 0.0,
|
||||
emotion: str = "happy",
|
||||
user_id: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> AudioGenerationResult:
|
||||
"""
|
||||
Generate audio using AI text-to-speech with subscription tracking.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech (max 10000 characters)
|
||||
voice_id: Voice ID (default: "Wise_Woman")
|
||||
speed: Speech speed (0.5-2.0, default: 1.0)
|
||||
volume: Speech volume (0.1-10.0, default: 1.0)
|
||||
pitch: Speech pitch (-12 to 12, default: 0.0)
|
||||
emotion: Emotion (default: "happy")
|
||||
user_id: User ID for subscription checking (required)
|
||||
**kwargs: Additional parameters (sample_rate, bitrate, format, etc.)
|
||||
|
||||
Returns:
|
||||
AudioGenerationResult: Generated audio result
|
||||
|
||||
Raises:
|
||||
RuntimeError: If subscription limits are exceeded or user_id is missing.
|
||||
"""
|
||||
try:
|
||||
# VALIDATION: Check inputs before any processing or API calls
|
||||
if not text or not isinstance(text, str) or len(text.strip()) == 0:
|
||||
raise ValueError("Text input is required and cannot be empty")
|
||||
|
||||
text = text.strip() # Normalize whitespace
|
||||
|
||||
if len(text) > 10000:
|
||||
raise ValueError(f"Text is too long ({len(text)} characters). Maximum is 10,000 characters.")
|
||||
|
||||
if not user_id:
|
||||
raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
|
||||
logger.info("[audio_gen] Starting audio generation")
|
||||
logger.debug(f"[audio_gen] Text length: {len(text)} characters, voice: {voice_id}")
|
||||
|
||||
# Calculate cost based on character count (every character is 1 token)
|
||||
# Pricing: $0.05 per 1,000 characters
|
||||
character_count = len(text)
|
||||
cost_per_1000_chars = 0.05
|
||||
estimated_cost = (character_count / 1000.0) * cost_per_1000_chars
|
||||
|
||||
try:
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import UsageSummary, APIProvider
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
# Check limits using sync method from pricing service (strict enforcement)
|
||||
# Use AUDIO provider for audio generation
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.AUDIO,
|
||||
tokens_requested=character_count, # Use character count as "tokens" for audio
|
||||
actual_provider_name="wavespeed" # Actual provider is WaveSpeed
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
logger.warning(f"[audio_gen] Subscription limit exceeded for user {user_id}: {message}")
|
||||
error_detail = {
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': 'wavespeed',
|
||||
'usage_info': usage_info if usage_info else {}
|
||||
}
|
||||
raise HTTPException(status_code=429, detail=error_detail)
|
||||
|
||||
# Get current usage for limit checking
|
||||
current_period = pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
usage = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
except HTTPException:
|
||||
raise
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception as sub_error:
|
||||
logger.error(f"[audio_gen] Subscription check failed for user {user_id}: {sub_error}")
|
||||
raise RuntimeError(f"Subscription check failed: {str(sub_error)}")
|
||||
|
||||
# Generate audio using WaveSpeed
|
||||
try:
|
||||
# Avoid passing duplicate enable_sync_mode; allow override via kwargs
|
||||
enable_sync_mode = kwargs.pop("enable_sync_mode", True)
|
||||
|
||||
# Filter out None values from kwargs to prevent WaveSpeed validation errors
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
logger.info(f"[audio_gen] Filtered kwargs (removed None values): {filtered_kwargs}")
|
||||
|
||||
client = WaveSpeedClient()
|
||||
audio_bytes = client.generate_speech(
|
||||
text=text,
|
||||
voice_id=voice_id,
|
||||
speed=speed,
|
||||
volume=volume,
|
||||
pitch=pitch,
|
||||
emotion=emotion,
|
||||
enable_sync_mode=enable_sync_mode,
|
||||
**filtered_kwargs
|
||||
)
|
||||
|
||||
logger.info(f"[audio_gen] ✅ API call successful, generated {len(audio_bytes)} bytes")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as api_error:
|
||||
logger.error(f"[audio_gen] Audio generation API failed: {api_error}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "Audio generation failed",
|
||||
"message": str(api_error)
|
||||
}
|
||||
)
|
||||
|
||||
# TRACK USAGE after successful API call
|
||||
if audio_bytes:
|
||||
logger.info(f"[audio_gen] ✅ API call successful, tracking usage for user {user_id}")
|
||||
try:
|
||||
db_track = next(get_db())
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get or create usage summary
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush()
|
||||
|
||||
# Get current values before update
|
||||
current_calls_before = getattr(summary, "audio_calls", 0) or 0
|
||||
current_cost_before = getattr(summary, "audio_cost", 0.0) or 0.0
|
||||
|
||||
# Update audio calls and cost
|
||||
new_calls = current_calls_before + 1
|
||||
new_cost = current_cost_before + estimated_cost
|
||||
|
||||
# Use direct SQL UPDATE for dynamic attributes
|
||||
# Import sqlalchemy.text with alias to avoid shadowing the 'text' parameter
|
||||
from sqlalchemy import text as sql_text
|
||||
update_query = sql_text("""
|
||||
UPDATE usage_summaries
|
||||
SET audio_calls = :new_calls,
|
||||
audio_cost = :new_cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_query, {
|
||||
'new_calls': new_calls,
|
||||
'new_cost': new_cost,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
|
||||
# Update total cost
|
||||
summary.total_cost = (summary.total_cost or 0.0) + estimated_cost
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
# Create usage log
|
||||
# Store the text parameter in a local variable before any imports to prevent shadowing
|
||||
text_param = text # Capture function parameter before any potential shadowing
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.AUDIO,
|
||||
endpoint="/audio-generation/wavespeed",
|
||||
method="POST",
|
||||
model_used="minimax/speech-02-hd",
|
||||
tokens_input=character_count,
|
||||
tokens_output=0,
|
||||
tokens_total=character_count,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=estimated_cost,
|
||||
response_time=0.0,
|
||||
status_code=200,
|
||||
request_size=len(text_param.encode("utf-8")), # Use captured parameter
|
||||
response_size=len(audio_bytes),
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
|
||||
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
# Get related stats for unified log
|
||||
current_image_calls = getattr(summary, "stability_calls", 0) or 0
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"[audio_gen] ✅ Successfully tracked usage: user {user_id} -> audio -> {new_calls} calls, ${estimated_cost:.4f}")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
|
||||
print(f"""
|
||||
[SUBSCRIPTION] Audio Generation
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: wavespeed
|
||||
├─ Actual Provider: wavespeed
|
||||
├─ Model: minimax/speech-02-hd
|
||||
├─ Voice: {voice_id}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {audio_limit_display}
|
||||
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
|
||||
├─ Characters: {character_count}
|
||||
├─ Images: {current_image_calls} / {image_limit if image_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
except Exception as track_error:
|
||||
logger.error(f"[audio_gen] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"[audio_gen] ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
|
||||
return AudioGenerationResult(
|
||||
audio_bytes=audio_bytes,
|
||||
provider="wavespeed",
|
||||
model="minimax/speech-02-hd",
|
||||
voice_id=voice_id,
|
||||
text_length=character_count,
|
||||
file_size=len(audio_bytes),
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[audio_gen] Error generating audio: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Audio generation failed",
|
||||
"message": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
190
backend/services/llm_providers/main_image_editing.py
Normal file
190
backend/services/llm_providers/main_image_editing.py
Normal file
@@ -0,0 +1,190 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import io
|
||||
from typing import Optional, Dict, Any
|
||||
from PIL import Image
|
||||
|
||||
from .image_generation import (
|
||||
ImageGenerationOptions,
|
||||
ImageGenerationResult,
|
||||
)
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
try:
|
||||
from huggingface_hub import InferenceClient
|
||||
HF_HUB_AVAILABLE = True
|
||||
except ImportError:
|
||||
HF_HUB_AVAILABLE = False
|
||||
|
||||
|
||||
logger = get_service_logger("image_editing.facade")
|
||||
|
||||
|
||||
DEFAULT_IMAGE_EDIT_MODEL = os.getenv(
|
||||
"HF_IMAGE_EDIT_MODEL",
|
||||
"Qwen/Qwen-Image-Edit",
|
||||
)
|
||||
|
||||
|
||||
def _select_provider(explicit: Optional[str]) -> str:
|
||||
"""Select provider for image editing. Defaults to huggingface with fal-ai."""
|
||||
if explicit:
|
||||
return explicit
|
||||
# Default to huggingface for image editing (best support for image-to-image)
|
||||
return "huggingface"
|
||||
|
||||
|
||||
def _get_provider_client(provider_name: str, api_key: Optional[str] = None):
|
||||
"""Get InferenceClient for the specified provider."""
|
||||
if not HF_HUB_AVAILABLE:
|
||||
raise RuntimeError("huggingface_hub is not installed. Install with: pip install huggingface_hub")
|
||||
|
||||
if provider_name == "huggingface":
|
||||
api_key = api_key or os.getenv("HF_TOKEN")
|
||||
if not api_key:
|
||||
raise RuntimeError("HF_TOKEN is required for Hugging Face image editing")
|
||||
# Use fal-ai provider for fast inference
|
||||
return InferenceClient(provider="fal-ai", api_key=api_key)
|
||||
|
||||
raise ValueError(f"Unknown image editing provider: {provider_name}")
|
||||
|
||||
|
||||
def edit_image(
|
||||
input_image_bytes: bytes,
|
||||
prompt: str,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
user_id: Optional[str] = None,
|
||||
mask_bytes: Optional[bytes] = None,
|
||||
) -> ImageGenerationResult:
|
||||
"""Edit image with pre-flight validation.
|
||||
|
||||
Args:
|
||||
input_image_bytes: Input image as bytes (PNG/JPEG)
|
||||
prompt: Natural language prompt describing desired edits (e.g., "Turn the cat into a tiger")
|
||||
options: Image editing options (provider, model, etc.)
|
||||
user_id: User ID for subscription checking (optional, but required for validation)
|
||||
mask_bytes: Optional mask image bytes for selective editing (grayscale, white=edit, black=preserve)
|
||||
|
||||
Returns:
|
||||
ImageGenerationResult with edited image bytes and metadata
|
||||
|
||||
Best Practices for Prompts:
|
||||
- Use clear, specific language describing desired changes
|
||||
- Describe what should change and what should remain
|
||||
- Examples: "Turn the cat into a tiger", "Change background to forest",
|
||||
"Make it look like a watercolor painting"
|
||||
|
||||
Note: Mask support depends on the specific model. Some models may ignore the mask parameter.
|
||||
"""
|
||||
# PRE-FLIGHT VALIDATION: Validate image editing before API call
|
||||
# MUST happen BEFORE any API calls - return immediately if validation fails
|
||||
if user_id:
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_editing_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
logger.info(f"[Image Editing] 🔍 Starting pre-flight validation for user_id={user_id}")
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
# Raises HTTPException immediately if validation fails - frontend gets immediate response
|
||||
validate_image_editing_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id
|
||||
)
|
||||
logger.info(f"[Image Editing] ✅ Pre-flight validation passed for user_id={user_id} - proceeding with image editing")
|
||||
except HTTPException as http_ex:
|
||||
# Re-raise immediately - don't proceed with API call
|
||||
logger.error(f"[Image Editing] ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
logger.warning(f"[Image Editing] ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
|
||||
|
||||
# Validate input
|
||||
if not input_image_bytes:
|
||||
raise ValueError("input_image_bytes is required")
|
||||
if not prompt or not prompt.strip():
|
||||
raise ValueError("prompt is required for image editing")
|
||||
|
||||
opts = options or {}
|
||||
provider_name = _select_provider(opts.get("provider"))
|
||||
model = opts.get("model") or DEFAULT_IMAGE_EDIT_MODEL
|
||||
|
||||
logger.info(f"[Image Editing] Editing image via provider={provider_name} model={model}")
|
||||
|
||||
# Get provider client
|
||||
client = _get_provider_client(provider_name, opts.get("api_key"))
|
||||
|
||||
# Prepare parameters for image-to-image
|
||||
params: Dict[str, Any] = {}
|
||||
if opts.get("guidance_scale") is not None:
|
||||
params["guidance_scale"] = opts.get("guidance_scale")
|
||||
if opts.get("steps") is not None:
|
||||
params["num_inference_steps"] = opts.get("steps")
|
||||
if opts.get("seed") is not None:
|
||||
params["seed"] = opts.get("seed")
|
||||
|
||||
try:
|
||||
# Convert input image bytes to PIL Image for validation
|
||||
input_image = Image.open(io.BytesIO(input_image_bytes))
|
||||
width = input_image.width
|
||||
height = input_image.height
|
||||
|
||||
# Convert mask bytes to PIL Image if provided
|
||||
mask_image = None
|
||||
if mask_bytes:
|
||||
try:
|
||||
mask_image = Image.open(io.BytesIO(mask_bytes)).convert("L") # Convert to grayscale
|
||||
# Ensure mask dimensions match input image
|
||||
if mask_image.size != input_image.size:
|
||||
logger.warning(f"[Image Editing] Mask size {mask_image.size} doesn't match image size {input_image.size}, resizing mask")
|
||||
mask_image = mask_image.resize(input_image.size, Image.Resampling.LANCZOS)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Image Editing] Failed to process mask image: {e}, continuing without mask")
|
||||
mask_image = None
|
||||
|
||||
# Use image_to_image method from Hugging Face InferenceClient
|
||||
# This follows the pattern from the Hugging Face documentation
|
||||
# Docs: https://huggingface.co/docs/inference-providers/en/guides/image-editor
|
||||
# Note: Mask support depends on the model - some models may ignore it
|
||||
call_params = params.copy()
|
||||
if mask_image:
|
||||
call_params["mask_image"] = mask_image
|
||||
logger.info("[Image Editing] Using mask for selective editing")
|
||||
|
||||
edited_image: Image.Image = client.image_to_image(
|
||||
image=input_image,
|
||||
prompt=prompt.strip(),
|
||||
model=model,
|
||||
**call_params,
|
||||
)
|
||||
|
||||
# Convert edited image back to bytes
|
||||
with io.BytesIO() as buf:
|
||||
edited_image.save(buf, format="PNG")
|
||||
edited_image_bytes = buf.getvalue()
|
||||
|
||||
logger.info(f"[Image Editing] ✅ Successfully edited image: {len(edited_image_bytes)} bytes")
|
||||
|
||||
return ImageGenerationResult(
|
||||
image_bytes=edited_image_bytes,
|
||||
width=edited_image.width,
|
||||
height=edited_image.height,
|
||||
provider="huggingface",
|
||||
model=model,
|
||||
seed=opts.get("seed"),
|
||||
metadata={
|
||||
"provider": "fal-ai",
|
||||
"operation": "image_editing",
|
||||
"original_width": width,
|
||||
"original_height": height,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Image Editing] ❌ Error editing image: {e}", exc_info=True)
|
||||
raise RuntimeError(f"Image editing failed: {str(e)}")
|
||||
|
||||
478
backend/services/llm_providers/main_image_generation.py
Normal file
478
backend/services/llm_providers/main_image_generation.py
Normal file
@@ -0,0 +1,478 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from .image_generation import (
|
||||
ImageGenerationOptions,
|
||||
ImageGenerationResult,
|
||||
HuggingFaceImageProvider,
|
||||
GeminiImageProvider,
|
||||
StabilityImageProvider,
|
||||
WaveSpeedImageProvider,
|
||||
)
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
logger = get_service_logger("image_generation.facade")
|
||||
|
||||
|
||||
def _select_provider(explicit: Optional[str]) -> str:
|
||||
if explicit:
|
||||
return explicit
|
||||
gpt_provider = (os.getenv("GPT_PROVIDER") or "").lower()
|
||||
if gpt_provider.startswith("gemini"):
|
||||
return "gemini"
|
||||
if gpt_provider.startswith("hf"):
|
||||
return "huggingface"
|
||||
if os.getenv("STABILITY_API_KEY"):
|
||||
return "stability"
|
||||
if os.getenv("WAVESPEED_API_KEY"):
|
||||
return "wavespeed"
|
||||
# Fallback to huggingface to enable a path if configured
|
||||
return "huggingface"
|
||||
|
||||
|
||||
def _get_provider(provider_name: str):
|
||||
if provider_name == "huggingface":
|
||||
return HuggingFaceImageProvider()
|
||||
if provider_name == "gemini":
|
||||
return GeminiImageProvider()
|
||||
if provider_name == "stability":
|
||||
return StabilityImageProvider()
|
||||
if provider_name == "wavespeed":
|
||||
return WaveSpeedImageProvider()
|
||||
raise ValueError(f"Unknown image provider: {provider_name}")
|
||||
|
||||
|
||||
def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> ImageGenerationResult:
|
||||
"""Generate image with pre-flight validation.
|
||||
|
||||
Args:
|
||||
prompt: Image generation prompt
|
||||
options: Image generation options (provider, model, width, height, etc.)
|
||||
user_id: User ID for subscription checking (optional, but required for validation)
|
||||
"""
|
||||
# PRE-FLIGHT VALIDATION: Validate image generation before API call
|
||||
# MUST happen BEFORE any API calls - return immediately if validation fails
|
||||
if user_id:
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_generation_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
logger.info(f"[Image Generation] 🔍 Starting pre-flight validation for user_id={user_id}")
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
# Raises HTTPException immediately if validation fails - frontend gets immediate response
|
||||
validate_image_generation_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id
|
||||
)
|
||||
logger.info(f"[Image Generation] ✅ Pre-flight validation passed for user_id={user_id} - proceeding with image generation")
|
||||
except HTTPException as http_ex:
|
||||
# Re-raise immediately - don't proceed with API call
|
||||
logger.error(f"[Image Generation] ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
logger.warning(f"[Image Generation] ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
|
||||
opts = options or {}
|
||||
provider_name = _select_provider(opts.get("provider"))
|
||||
|
||||
image_options = ImageGenerationOptions(
|
||||
prompt=prompt,
|
||||
negative_prompt=opts.get("negative_prompt"),
|
||||
width=int(opts.get("width", 1024)),
|
||||
height=int(opts.get("height", 1024)),
|
||||
guidance_scale=opts.get("guidance_scale"),
|
||||
steps=opts.get("steps"),
|
||||
seed=opts.get("seed"),
|
||||
model=opts.get("model"),
|
||||
extra=opts,
|
||||
)
|
||||
|
||||
# Normalize obvious model/provider mismatches
|
||||
model_lower = (image_options.model or "").lower()
|
||||
if provider_name == "stability" and (model_lower.startswith("black-forest-labs/") or model_lower.startswith("runwayml/") or model_lower.startswith("stabilityai/flux")):
|
||||
logger.info("Remapping provider to huggingface for model=%s", image_options.model)
|
||||
provider_name = "huggingface"
|
||||
|
||||
if provider_name == "huggingface" and not image_options.model:
|
||||
# Provide a sensible default HF model if none specified
|
||||
image_options.model = "black-forest-labs/FLUX.1-Krea-dev"
|
||||
|
||||
if provider_name == "wavespeed" and not image_options.model:
|
||||
# Provide a sensible default WaveSpeed model if none specified
|
||||
image_options.model = "ideogram-v3-turbo"
|
||||
|
||||
logger.info("Generating image via provider=%s model=%s", provider_name, image_options.model)
|
||||
provider = _get_provider(provider_name)
|
||||
result = provider.generate(image_options)
|
||||
|
||||
# TRACK USAGE after successful API call
|
||||
has_image_bytes = bool(result.image_bytes) if result else False
|
||||
image_bytes_len = len(result.image_bytes) if (result and result.image_bytes) else 0
|
||||
logger.info(f"[Image Generation] Checking tracking conditions: user_id={user_id}, has_result={bool(result)}, has_image_bytes={has_image_bytes}, image_bytes_len={image_bytes_len}")
|
||||
if user_id and result and result.image_bytes:
|
||||
logger.info(f"[Image Generation] ✅ API call successful, tracking usage for user {user_id}")
|
||||
try:
|
||||
from services.database import get_db as get_db_track
|
||||
db_track = next(get_db_track())
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get or create usage summary
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush()
|
||||
|
||||
# Get cost from result metadata or calculate
|
||||
estimated_cost = 0.0
|
||||
if result.metadata and "estimated_cost" in result.metadata:
|
||||
estimated_cost = float(result.metadata["estimated_cost"])
|
||||
else:
|
||||
# Fallback: estimate based on provider/model
|
||||
if provider_name == "wavespeed":
|
||||
if result.model and "qwen" in result.model.lower():
|
||||
estimated_cost = 0.05
|
||||
else:
|
||||
estimated_cost = 0.10 # ideogram-v3-turbo default
|
||||
elif provider_name == "stability":
|
||||
estimated_cost = 0.04
|
||||
else:
|
||||
estimated_cost = 0.05 # Default estimate
|
||||
|
||||
# Get current values before update
|
||||
current_calls_before = getattr(summary, "stability_calls", 0) or 0
|
||||
current_cost_before = getattr(summary, "stability_cost", 0.0) or 0.0
|
||||
|
||||
# Update image calls and cost
|
||||
new_calls = current_calls_before + 1
|
||||
new_cost = current_cost_before + estimated_cost
|
||||
|
||||
# Use direct SQL UPDATE for dynamic attributes
|
||||
from sqlalchemy import text as sql_text
|
||||
update_query = sql_text("""
|
||||
UPDATE usage_summaries
|
||||
SET stability_calls = :new_calls,
|
||||
stability_cost = :new_cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_query, {
|
||||
'new_calls': new_calls,
|
||||
'new_cost': new_cost,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
|
||||
# Update total cost
|
||||
summary.total_cost = (summary.total_cost or 0.0) + estimated_cost
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
# Determine API provider based on actual provider
|
||||
api_provider = APIProvider.STABILITY # Default for image generation
|
||||
|
||||
# Create usage log
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=api_provider,
|
||||
endpoint="/image-generation",
|
||||
method="POST",
|
||||
model_used=result.model or "unknown",
|
||||
tokens_input=0,
|
||||
tokens_output=0,
|
||||
tokens_total=0,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=estimated_cost,
|
||||
response_time=0.0,
|
||||
status_code=200,
|
||||
request_size=len(prompt.encode("utf-8")),
|
||||
response_size=len(result.image_bytes),
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
|
||||
image_limit_display = image_limit if (image_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
# Get related stats for unified log
|
||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"[Image Generation] ✅ Successfully tracked usage: user {user_id} -> image -> {new_calls} calls, ${estimated_cost:.4f}")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
|
||||
print(f"""
|
||||
[SUBSCRIPTION] Image Generation
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: {provider_name}
|
||||
├─ Actual Provider: {provider_name}
|
||||
├─ Model: {result.model or 'unknown'}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {image_limit_display}
|
||||
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit if audio_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
except Exception as track_error:
|
||||
logger.error(f"[Image Generation] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"[Image Generation] Full traceback: {traceback.format_exc()}")
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"[Image Generation] ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"[Image Generation] Full traceback: {traceback.format_exc()}")
|
||||
else:
|
||||
logger.warning(f"[Image Generation] ⚠️ Skipping usage tracking: user_id={user_id}, image_bytes={len(result.image_bytes) if result.image_bytes else 0} bytes")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def generate_character_image(
|
||||
prompt: str,
|
||||
reference_image_bytes: bytes,
|
||||
user_id: Optional[str] = None,
|
||||
style: str = "Realistic",
|
||||
aspect_ratio: str = "16:9",
|
||||
rendering_speed: str = "Quality",
|
||||
timeout: Optional[int] = None,
|
||||
) -> bytes:
|
||||
"""Generate character-consistent image with pre-flight validation and usage tracking.
|
||||
|
||||
Uses Ideogram Character API via WaveSpeed to maintain character consistency.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt describing the scene/context for the character
|
||||
reference_image_bytes: Reference image bytes (base avatar)
|
||||
user_id: User ID for subscription checking (required)
|
||||
style: Character style type ("Auto", "Fiction", or "Realistic")
|
||||
aspect_ratio: Aspect ratio ("1:1", "16:9", "9:16", "4:3", "3:4")
|
||||
rendering_speed: Rendering speed ("Default", "Turbo", "Quality")
|
||||
timeout: Total timeout in seconds for submission + polling (default: 180)
|
||||
|
||||
Returns:
|
||||
bytes: Generated image bytes with consistent character
|
||||
"""
|
||||
# PRE-FLIGHT VALIDATION: Validate image generation before API call
|
||||
if user_id:
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_generation_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
logger.info(f"[Character Image Generation] 🔍 Starting pre-flight validation for user_id={user_id}")
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
# Raises HTTPException immediately if validation fails
|
||||
validate_image_generation_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id,
|
||||
num_images=1,
|
||||
)
|
||||
logger.info(f"[Character Image Generation] ✅ Pre-flight validation passed for user_id={user_id} - proceeding with character image generation")
|
||||
except HTTPException as http_ex:
|
||||
# Re-raise immediately - don't proceed with API call
|
||||
logger.error(f"[Character Image Generation] ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
logger.warning(f"[Character Image Generation] ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
|
||||
|
||||
# Generate character image via WaveSpeed
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
from fastapi import HTTPException
|
||||
|
||||
try:
|
||||
wavespeed_client = WaveSpeedClient()
|
||||
image_bytes = wavespeed_client.generate_character_image(
|
||||
prompt=prompt,
|
||||
reference_image_bytes=reference_image_bytes,
|
||||
style=style,
|
||||
aspect_ratio=aspect_ratio,
|
||||
rendering_speed=rendering_speed,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# TRACK USAGE after successful API call
|
||||
has_image_bytes = bool(image_bytes) if image_bytes else False
|
||||
image_bytes_len = len(image_bytes) if image_bytes else 0
|
||||
logger.info(f"[Character Image Generation] Checking tracking conditions: user_id={user_id}, has_image_bytes={has_image_bytes}, image_bytes_len={image_bytes_len}")
|
||||
if user_id and image_bytes:
|
||||
logger.info(f"[Character Image Generation] ✅ API call successful, tracking usage for user {user_id}")
|
||||
try:
|
||||
from services.database import get_db as get_db_track
|
||||
db_track = next(get_db_track())
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get or create usage summary
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush()
|
||||
|
||||
# Character image cost (same as ideogram-v3-turbo)
|
||||
estimated_cost = 0.10
|
||||
current_calls_before = getattr(summary, "stability_calls", 0) or 0
|
||||
current_cost_before = getattr(summary, "stability_cost", 0.0) or 0.0
|
||||
|
||||
new_calls = current_calls_before + 1
|
||||
new_cost = current_cost_before + estimated_cost
|
||||
|
||||
# Use direct SQL UPDATE for dynamic attributes
|
||||
from sqlalchemy import text as sql_text
|
||||
update_query = sql_text("""
|
||||
UPDATE usage_summaries
|
||||
SET stability_calls = :new_calls,
|
||||
stability_cost = :new_cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_query, {
|
||||
'new_calls': new_calls,
|
||||
'new_cost': new_cost,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
|
||||
# Update total cost
|
||||
summary.total_cost = (summary.total_cost or 0.0) + estimated_cost
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
# Create usage log
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.STABILITY, # Image generation uses STABILITY provider
|
||||
endpoint="/image-generation/character",
|
||||
method="POST",
|
||||
model_used="ideogram-character",
|
||||
tokens_input=0,
|
||||
tokens_output=0,
|
||||
tokens_total=0,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=estimated_cost,
|
||||
response_time=0.0,
|
||||
status_code=200,
|
||||
request_size=len(prompt.encode("utf-8")),
|
||||
response_size=len(image_bytes),
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
image_limit_display = image_limit if (image_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
# Get related stats
|
||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG
|
||||
print(f"""
|
||||
[SUBSCRIPTION] Image Generation (Character)
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: wavespeed
|
||||
├─ Actual Provider: wavespeed
|
||||
├─ Model: ideogram-character
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {image_limit_display}
|
||||
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit if audio_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
logger.info(f"[Character Image Generation] ✅ Successfully tracked usage: user {user_id} -> {new_calls} calls, ${estimated_cost:.4f}")
|
||||
|
||||
except Exception as track_error:
|
||||
logger.error(f"[Character Image Generation] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"[Character Image Generation] Full traceback: {traceback.format_exc()}")
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"[Character Image Generation] ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"[Character Image Generation] Full traceback: {traceback.format_exc()}")
|
||||
else:
|
||||
logger.warning(f"[Character Image Generation] ⚠️ Skipping usage tracking: user_id={user_id}, image_bytes={len(image_bytes) if image_bytes else 0} bytes")
|
||||
|
||||
return image_bytes
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as api_error:
|
||||
logger.error(f"[Character Image Generation] Character image generation API failed: {api_error}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "Character image generation failed",
|
||||
"message": str(api_error)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
896
backend/services/llm_providers/main_text_generation.py
Normal file
896
backend/services/llm_providers/main_text_generation.py
Normal file
@@ -0,0 +1,896 @@
|
||||
"""Main Text Generation Service for ALwrity Backend.
|
||||
|
||||
This service provides the main LLM text generation functionality,
|
||||
migrated from the legacy lib/gpt_providers/text_generation/main_text_generation.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
from fastapi import HTTPException
|
||||
from ..onboarding.api_key_manager import APIKeyManager
|
||||
|
||||
from .gemini_provider import gemini_text_response, gemini_structured_json_response
|
||||
from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response
|
||||
|
||||
|
||||
def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: Optional[Dict[str, Any]] = None, user_id: str = None) -> str:
|
||||
"""
|
||||
Generate text using Language Model (LLM) based on the provided prompt.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to generate text from.
|
||||
system_prompt (str, optional): Custom system prompt to use instead of the default one.
|
||||
json_struct (dict, optional): JSON schema structure for structured responses.
|
||||
user_id (str): Clerk user ID for subscription checking (required).
|
||||
|
||||
Returns:
|
||||
str: Generated text based on the prompt.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If subscription limits are exceeded or user_id is missing.
|
||||
"""
|
||||
try:
|
||||
logger.info("[llm_text_gen] Starting text generation")
|
||||
logger.debug(f"[llm_text_gen] Prompt length: {len(prompt)} characters")
|
||||
|
||||
# Set default values for LLM parameters
|
||||
gpt_provider = "google" # Default to Google Gemini
|
||||
model = "gemini-2.0-flash-001"
|
||||
temperature = 0.7
|
||||
max_tokens = 4000
|
||||
top_p = 0.9
|
||||
n = 1
|
||||
fp = 16
|
||||
frequency_penalty = 0.0
|
||||
presence_penalty = 0.0
|
||||
|
||||
# Check for GPT_PROVIDER environment variable
|
||||
env_provider = os.getenv('GPT_PROVIDER', '').lower()
|
||||
if env_provider in ['gemini', 'google']:
|
||||
gpt_provider = "google"
|
||||
model = "gemini-2.0-flash-001"
|
||||
elif env_provider in ['hf_response_api', 'huggingface', 'hf']:
|
||||
gpt_provider = "huggingface"
|
||||
model = "openai/gpt-oss-120b:groq"
|
||||
|
||||
# Default blog characteristics
|
||||
blog_tone = "Professional"
|
||||
blog_demographic = "Professional"
|
||||
blog_type = "Informational"
|
||||
blog_language = "English"
|
||||
blog_output_format = "markdown"
|
||||
blog_length = 2000
|
||||
|
||||
# Check which providers have API keys available using APIKeyManager
|
||||
api_key_manager = APIKeyManager()
|
||||
available_providers = []
|
||||
if api_key_manager.get_api_key("gemini"):
|
||||
available_providers.append("google")
|
||||
if api_key_manager.get_api_key("hf_token"):
|
||||
available_providers.append("huggingface")
|
||||
|
||||
# If no environment variable set, auto-detect based on available keys
|
||||
if not env_provider:
|
||||
# Prefer Google Gemini if available, otherwise use Hugging Face
|
||||
if "google" in available_providers:
|
||||
gpt_provider = "google"
|
||||
model = "gemini-2.0-flash-001"
|
||||
elif "huggingface" in available_providers:
|
||||
gpt_provider = "huggingface"
|
||||
model = "openai/gpt-oss-120b:groq"
|
||||
else:
|
||||
logger.error("[llm_text_gen] No API keys found for supported providers.")
|
||||
raise RuntimeError("No LLM API keys configured. Configure GEMINI_API_KEY or HF_TOKEN to enable AI responses.")
|
||||
else:
|
||||
# Environment variable was set, validate it's supported
|
||||
if gpt_provider not in available_providers:
|
||||
logger.warning(f"[llm_text_gen] Provider {gpt_provider} not available, falling back to available providers")
|
||||
if "google" in available_providers:
|
||||
gpt_provider = "google"
|
||||
model = "gemini-2.0-flash-001"
|
||||
elif "huggingface" in available_providers:
|
||||
gpt_provider = "huggingface"
|
||||
model = "openai/gpt-oss-120b:groq"
|
||||
else:
|
||||
raise RuntimeError("No supported providers available.")
|
||||
|
||||
logger.debug(f"[llm_text_gen] Using provider: {gpt_provider}, model: {model}")
|
||||
|
||||
# Map provider name to APIProvider enum (define at function scope for usage tracking)
|
||||
from models.subscription_models import APIProvider
|
||||
provider_enum = None
|
||||
# Store actual provider name for logging (e.g., "huggingface", "gemini")
|
||||
actual_provider_name = None
|
||||
if gpt_provider == "google":
|
||||
provider_enum = APIProvider.GEMINI
|
||||
actual_provider_name = "gemini" # Use "gemini" for consistency in logs
|
||||
elif gpt_provider == "huggingface":
|
||||
provider_enum = APIProvider.MISTRAL # HuggingFace maps to Mistral enum for usage tracking
|
||||
actual_provider_name = "huggingface" # Keep actual provider name for logs
|
||||
|
||||
if not provider_enum:
|
||||
raise RuntimeError(f"Unknown provider {gpt_provider} for subscription checking")
|
||||
|
||||
# SUBSCRIPTION CHECK - Required and strict enforcement
|
||||
if not user_id:
|
||||
raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
|
||||
try:
|
||||
from services.database import get_db
|
||||
from services.subscription import UsageTrackingService, PricingService
|
||||
from models.subscription_models import UsageSummary
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
|
||||
usage_service = UsageTrackingService(db)
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
# Estimate tokens from prompt (input tokens)
|
||||
# CRITICAL: Use worst-case scenario (input + max_tokens) for validation to prevent abuse
|
||||
# This ensures we block requests that would exceed limits even if response is longer than expected
|
||||
input_tokens = int(len(prompt.split()) * 1.3)
|
||||
# Worst-case estimate: assume maximum possible output tokens (max_tokens if specified)
|
||||
# This prevents abuse where actual response tokens exceed the estimate
|
||||
if max_tokens:
|
||||
estimated_output_tokens = max_tokens # Use maximum allowed output tokens
|
||||
else:
|
||||
# If max_tokens not specified, use conservative estimate (input * 1.5)
|
||||
estimated_output_tokens = int(input_tokens * 1.5)
|
||||
estimated_total_tokens = input_tokens + estimated_output_tokens
|
||||
|
||||
# Check limits using sync method from pricing service (strict enforcement)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=provider_enum,
|
||||
tokens_requested=estimated_total_tokens,
|
||||
actual_provider_name=actual_provider_name # Pass actual provider name for correct error messages
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
logger.warning(f"[llm_text_gen] Subscription limit exceeded for user {user_id}: {message}")
|
||||
# Raise HTTPException(429) with usage info so frontend can display subscription modal
|
||||
error_detail = {
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': actual_provider_name or provider_enum.value,
|
||||
'usage_info': usage_info if usage_info else {}
|
||||
}
|
||||
raise HTTPException(status_code=429, detail=error_detail)
|
||||
|
||||
# Get current usage for limit checking only
|
||||
current_period = pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
usage = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
# No separate log here - we'll create unified log after API call and usage tracking
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions (e.g., 429 subscription limit) - preserve error details
|
||||
raise
|
||||
except RuntimeError:
|
||||
# Re-raise subscription limit errors
|
||||
raise
|
||||
except Exception as sub_error:
|
||||
# STRICT: Fail on subscription check errors
|
||||
logger.error(f"[llm_text_gen] Subscription check failed for user {user_id}: {sub_error}")
|
||||
raise RuntimeError(f"Subscription check failed: {str(sub_error)}")
|
||||
|
||||
# Construct the system prompt if not provided
|
||||
if system_prompt is None:
|
||||
system_instructions = f"""You are a highly skilled content writer with a knack for creating engaging and informative content.
|
||||
Your expertise spans various writing styles and formats.
|
||||
|
||||
Writing Style Guidelines:
|
||||
- Tone: {blog_tone}
|
||||
- Target Audience: {blog_demographic}
|
||||
- Content Type: {blog_type}
|
||||
- Language: {blog_language}
|
||||
- Output Format: {blog_output_format}
|
||||
- Target Length: {blog_length} words
|
||||
|
||||
Please provide responses that are:
|
||||
- Well-structured and easy to read
|
||||
- Engaging and informative
|
||||
- Tailored to the specified tone and audience
|
||||
- Professional yet accessible
|
||||
- Optimized for the target content type
|
||||
"""
|
||||
else:
|
||||
system_instructions = system_prompt
|
||||
|
||||
# Generate response based on provider
|
||||
response_text = None
|
||||
actual_provider_used = gpt_provider
|
||||
try:
|
||||
if gpt_provider == "google":
|
||||
if json_struct:
|
||||
response_text = gemini_structured_json_response(
|
||||
prompt=prompt,
|
||||
schema=json_struct,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=n,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
else:
|
||||
response_text = gemini_text_response(
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
elif gpt_provider == "huggingface":
|
||||
if json_struct:
|
||||
response_text = huggingface_structured_json_response(
|
||||
prompt=prompt,
|
||||
schema=json_struct,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
else:
|
||||
response_text = huggingface_text_response(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
else:
|
||||
logger.error(f"[llm_text_gen] Unknown provider: {gpt_provider}")
|
||||
raise RuntimeError("Unknown LLM provider. Supported providers: google, huggingface")
|
||||
|
||||
# TRACK USAGE after successful API call
|
||||
if response_text:
|
||||
logger.info(f"[llm_text_gen] ✅ API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
|
||||
try:
|
||||
db_track = next(get_db())
|
||||
try:
|
||||
# Estimate tokens from prompt and response
|
||||
# Recalculate input tokens from prompt (consistent with pre-flight estimation)
|
||||
tokens_input = int(len(prompt.split()) * 1.3)
|
||||
tokens_output = int(len(str(response_text).split()) * 1.3) # Estimate output tokens
|
||||
tokens_total = tokens_input + tokens_output
|
||||
|
||||
logger.debug(f"[llm_text_gen] Token estimates: input={tokens_input}, output={tokens_output}, total={tokens_total}")
|
||||
|
||||
# Get or create usage summary
|
||||
from models.subscription_models import UsageSummary
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
logger.debug(f"[llm_text_gen] Looking for usage summary: user_id={user_id}, period={current_period}")
|
||||
|
||||
# Get limits once for safety check (to prevent exceeding limits even if actual usage > estimate)
|
||||
provider_name = provider_enum.value
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
token_limit = 0
|
||||
if limits and limits.get('limits'):
|
||||
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
|
||||
|
||||
# CRITICAL: Use raw SQL to read current values directly from DB, bypassing SQLAlchemy cache
|
||||
# This ensures we always get the absolute latest committed values, even across different sessions
|
||||
from sqlalchemy import text
|
||||
current_calls_before = 0
|
||||
current_tokens_before = 0
|
||||
record_count = 0 # Initialize to ensure it's always defined
|
||||
|
||||
# CRITICAL: First check if record exists using COUNT query
|
||||
try:
|
||||
check_query = text("SELECT COUNT(*) FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period")
|
||||
record_count = db_track.execute(check_query, {'user_id': user_id, 'period': current_period}).scalar()
|
||||
logger.debug(f"[llm_text_gen] 🔍 DEBUG: Record count check - found {record_count} record(s) for user={user_id}, period={current_period}")
|
||||
except Exception as count_error:
|
||||
logger.error(f"[llm_text_gen] ❌ COUNT query failed: {count_error}", exc_info=True)
|
||||
record_count = 0
|
||||
|
||||
if record_count and record_count > 0:
|
||||
# Record exists - read current values with raw SQL
|
||||
try:
|
||||
# Validate provider_name to prevent SQL injection (whitelist approach)
|
||||
valid_providers = ['gemini', 'openai', 'anthropic', 'mistral']
|
||||
if provider_name not in valid_providers:
|
||||
raise ValueError(f"Invalid provider_name for SQL query: {provider_name}")
|
||||
|
||||
# Read current values directly from database using raw SQL
|
||||
# CRITICAL: This bypasses SQLAlchemy's session cache and gets absolute latest values
|
||||
sql_query = text(f"""
|
||||
SELECT {provider_name}_calls, {provider_name}_tokens
|
||||
FROM usage_summaries
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
LIMIT 1
|
||||
""")
|
||||
logger.debug(f"[llm_text_gen] 🔍 Executing raw SQL for EXISTING record: SELECT {provider_name}_calls, {provider_name}_tokens WHERE user_id={user_id}, period={current_period}")
|
||||
result = db_track.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
|
||||
if result:
|
||||
raw_calls = result[0] if result[0] is not None else 0
|
||||
raw_tokens = result[1] if result[1] is not None else 0
|
||||
current_calls_before = raw_calls
|
||||
current_tokens_before = raw_tokens
|
||||
logger.debug(f"[llm_text_gen] ✅ Raw SQL SUCCESS: Found EXISTING record - calls={current_calls_before}, tokens={current_tokens_before} (provider={provider_name}, column={provider_name}_calls/{provider_name}_tokens)")
|
||||
logger.debug(f"[llm_text_gen] 🔍 Raw SQL returned row: {result}, extracted calls={raw_calls}, tokens={raw_tokens}")
|
||||
else:
|
||||
logger.error(f"[llm_text_gen] ❌ CRITICAL BUG: Record EXISTS (count={record_count}) but SELECT query returned None! Query: {sql_query}")
|
||||
# Fallback: Use ORM to get values
|
||||
summary_fallback = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if summary_fallback:
|
||||
db_track.refresh(summary_fallback)
|
||||
current_calls_before = getattr(summary_fallback, f"{provider_name}_calls", 0) or 0
|
||||
current_tokens_before = getattr(summary_fallback, f"{provider_name}_tokens", 0) or 0
|
||||
logger.warning(f"[llm_text_gen] ⚠️ Using ORM fallback: calls={current_calls_before}, tokens={current_tokens_before}")
|
||||
except Exception as sql_error:
|
||||
logger.error(f"[llm_text_gen] ❌ Raw SQL query failed: {sql_error}", exc_info=True)
|
||||
# Fallback: Use ORM to get values
|
||||
summary_fallback = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if summary_fallback:
|
||||
db_track.refresh(summary_fallback)
|
||||
current_calls_before = getattr(summary_fallback, f"{provider_name}_calls", 0) or 0
|
||||
current_tokens_before = getattr(summary_fallback, f"{provider_name}_tokens", 0) or 0
|
||||
else:
|
||||
logger.debug(f"[llm_text_gen] ℹ️ No record exists yet (will create new) - user={user_id}, period={current_period}")
|
||||
|
||||
# Get or create usage summary object (needed for ORM update)
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
logger.debug(f"[llm_text_gen] Creating NEW usage summary for user {user_id}, period {current_period}")
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush() # Ensure summary is persisted before updating
|
||||
# New record - values are already 0, no need to set
|
||||
logger.debug(f"[llm_text_gen] ✅ New summary created - starting from 0")
|
||||
else:
|
||||
# CRITICAL: Update the ORM object with values from raw SQL query
|
||||
# This ensures the ORM object reflects the actual database state before we increment
|
||||
logger.debug(f"[llm_text_gen] 🔄 Existing summary found - syncing with raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}")
|
||||
setattr(summary, f"{provider_name}_calls", current_calls_before)
|
||||
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
setattr(summary, f"{provider_name}_tokens", current_tokens_before)
|
||||
logger.debug(f"[llm_text_gen] ✅ Synchronized ORM object: {provider_name}_calls={current_calls_before}, {provider_name}_tokens={current_tokens_before}")
|
||||
|
||||
logger.debug(f"[llm_text_gen] Current {provider_name}_calls from DB (raw SQL): {current_calls_before}")
|
||||
|
||||
# Update provider-specific counters (sync operation)
|
||||
new_calls = current_calls_before + 1
|
||||
|
||||
# CRITICAL: Use direct SQL UPDATE instead of ORM setattr for dynamic attributes
|
||||
# SQLAlchemy doesn't detect changes when using setattr() on dynamic attributes
|
||||
# Using raw SQL UPDATE ensures the change is persisted
|
||||
from sqlalchemy import text
|
||||
update_calls_query = text(f"""
|
||||
UPDATE usage_summaries
|
||||
SET {provider_name}_calls = :new_calls
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_calls_query, {
|
||||
'new_calls': new_calls,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
logger.debug(f"[llm_text_gen] Updated {provider_name}_calls via SQL: {current_calls_before} -> {new_calls}")
|
||||
|
||||
# Update token usage for LLM providers with safety check
|
||||
# CRITICAL: Use current_tokens_before from raw SQL query (NOT from ORM object)
|
||||
# The ORM object may have stale values, but raw SQL always has the latest committed values
|
||||
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
logger.debug(f"[llm_text_gen] Current {provider_name}_tokens from DB (raw SQL): {current_tokens_before}")
|
||||
|
||||
# SAFETY CHECK: Prevent exceeding token limit even if actual usage exceeds estimate
|
||||
# This prevents abuse where actual response tokens exceed pre-flight validation estimate
|
||||
projected_new_tokens = current_tokens_before + tokens_total
|
||||
|
||||
# If limit is set (> 0) and would be exceeded, cap at limit
|
||||
if token_limit > 0 and projected_new_tokens > token_limit:
|
||||
logger.warning(
|
||||
f"[llm_text_gen] ⚠️ ACTUAL token usage ({tokens_total}) exceeded estimate. "
|
||||
f"Would exceed limit: {projected_new_tokens} > {token_limit}. "
|
||||
f"Capping tracked tokens at limit to prevent abuse."
|
||||
)
|
||||
# Cap at limit to prevent abuse
|
||||
new_tokens = token_limit
|
||||
# Adjust tokens_total for accurate total tracking
|
||||
tokens_total = token_limit - current_tokens_before
|
||||
if tokens_total < 0:
|
||||
tokens_total = 0
|
||||
else:
|
||||
new_tokens = projected_new_tokens
|
||||
|
||||
# CRITICAL: Use direct SQL UPDATE instead of ORM setattr for dynamic attributes
|
||||
update_tokens_query = text(f"""
|
||||
UPDATE usage_summaries
|
||||
SET {provider_name}_tokens = :new_tokens
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_tokens_query, {
|
||||
'new_tokens': new_tokens,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
logger.debug(f"[llm_text_gen] Updated {provider_name}_tokens via SQL: {current_tokens_before} -> {new_tokens}")
|
||||
else:
|
||||
current_tokens_before = 0
|
||||
new_tokens = 0
|
||||
|
||||
# Determine tracked tokens (after any safety capping)
|
||||
tracked_tokens_input = min(tokens_input, tokens_total)
|
||||
tracked_tokens_output = max(tokens_total - tracked_tokens_input, 0)
|
||||
|
||||
# Calculate and persist cost for this call
|
||||
try:
|
||||
cost_info = pricing.calculate_api_cost(
|
||||
provider=provider_enum,
|
||||
model_name=model,
|
||||
tokens_input=tracked_tokens_input,
|
||||
tokens_output=tracked_tokens_output,
|
||||
request_count=1
|
||||
)
|
||||
cost_total = cost_info.get('cost_total', 0.0) or 0.0
|
||||
except Exception as cost_error:
|
||||
cost_total = 0.0
|
||||
logger.error(f"[llm_text_gen] ❌ Failed to calculate API cost: {cost_error}", exc_info=True)
|
||||
|
||||
if cost_total > 0:
|
||||
logger.debug(f"[llm_text_gen] 💰 Calculated cost for {provider_name}: ${cost_total:.6f}")
|
||||
update_costs_query = text(f"""
|
||||
UPDATE usage_summaries
|
||||
SET {provider_name}_cost = COALESCE({provider_name}_cost, 0) + :cost,
|
||||
total_cost = COALESCE(total_cost, 0) + :cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_costs_query, {
|
||||
'cost': cost_total,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
|
||||
# Keep ORM object in sync for logging/debugging
|
||||
current_provider_cost = getattr(summary, f"{provider_name}_cost", 0.0) or 0.0
|
||||
setattr(summary, f"{provider_name}_cost", current_provider_cost + cost_total)
|
||||
summary.total_cost = (summary.total_cost or 0.0) + cost_total
|
||||
else:
|
||||
logger.debug(f"[llm_text_gen] 💰 Cost calculation returned $0 for {provider_name} (tokens_input={tracked_tokens_input}, tokens_output={tracked_tokens_output})")
|
||||
|
||||
# Update totals using SQL UPDATE
|
||||
old_total_calls = summary.total_calls or 0
|
||||
old_total_tokens = summary.total_tokens or 0
|
||||
new_total_calls = old_total_calls + 1
|
||||
new_total_tokens = old_total_tokens + tokens_total
|
||||
|
||||
update_totals_query = text("""
|
||||
UPDATE usage_summaries
|
||||
SET total_calls = :total_calls, total_tokens = :total_tokens
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_totals_query, {
|
||||
'total_calls': new_total_calls,
|
||||
'total_tokens': new_total_tokens,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
logger.debug(f"[llm_text_gen] Updated totals via SQL: calls {old_total_calls} -> {new_total_calls}, tokens {old_total_tokens} -> {new_total_tokens}")
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0
|
||||
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) if limits else 0
|
||||
|
||||
# Get image stats for unified log
|
||||
current_images_before = getattr(summary, "stability_calls", 0) or 0
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
|
||||
# Get image editing stats for unified log
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
|
||||
# Get video stats for unified log
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
# Get audio stats for unified log
|
||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
|
||||
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
# CRITICAL DEBUG: Print diagnostic info BEFORE commit (always visible, flushed immediately)
|
||||
import sys
|
||||
debug_msg = f"[DEBUG] BEFORE COMMIT - Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}, Period: {current_period}, New calls will be: {new_calls}, New tokens will be: {new_tokens}"
|
||||
print(debug_msg, flush=True)
|
||||
sys.stdout.flush()
|
||||
logger.debug(f"[llm_text_gen] {debug_msg}")
|
||||
|
||||
# CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions
|
||||
db_track.flush() # Flush to ensure changes are in DB (not just in transaction)
|
||||
db_track.commit() # Commit transaction to make changes visible to other sessions
|
||||
logger.debug(f"[llm_text_gen] ✅ Successfully tracked usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens (COMMITTED to DB)")
|
||||
logger.debug(f"[llm_text_gen] Database state after commit: {provider_name}_calls={new_calls}, {provider_name}_tokens={new_tokens} (should be visible to next session)")
|
||||
|
||||
# CRITICAL: Verify commit worked by reading back from DB immediately after commit
|
||||
try:
|
||||
verify_query = text(f"SELECT {provider_name}_calls, {provider_name}_tokens FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1")
|
||||
verify_result = db_track.execute(verify_query, {'user_id': user_id, 'period': current_period}).first()
|
||||
if verify_result:
|
||||
verified_calls = verify_result[0] if verify_result[0] is not None else 0
|
||||
verified_tokens = verify_result[1] if verify_result[1] is not None else 0
|
||||
logger.debug(f"[llm_text_gen] ✅ VERIFICATION AFTER COMMIT: Read back calls={verified_calls}, tokens={verified_tokens} (expected: calls={new_calls}, tokens={new_tokens})")
|
||||
if verified_calls != new_calls or verified_tokens != new_tokens:
|
||||
logger.error(f"[llm_text_gen] ❌ CRITICAL: COMMIT VERIFICATION FAILED! Expected calls={new_calls}, tokens={new_tokens}, but DB has calls={verified_calls}, tokens={verified_tokens}")
|
||||
# Force another commit attempt
|
||||
db_track.commit()
|
||||
verify_result2 = db_track.execute(verify_query, {'user_id': user_id, 'period': current_period}).first()
|
||||
if verify_result2:
|
||||
verified_calls2 = verify_result2[0] if verify_result2[0] is not None else 0
|
||||
verified_tokens2 = verify_result2[1] if verify_result2[1] is not None else 0
|
||||
logger.debug(f"[llm_text_gen] 🔄 After second commit attempt: calls={verified_calls2}, tokens={verified_tokens2}")
|
||||
else:
|
||||
logger.debug(f"[llm_text_gen] ✅ COMMIT VERIFICATION PASSED: Values match expected values")
|
||||
else:
|
||||
logger.error(f"[llm_text_gen] ❌ CRITICAL: COMMIT VERIFICATION FAILED! Record not found after commit!")
|
||||
except Exception as verify_error:
|
||||
logger.error(f"[llm_text_gen] ❌ Error verifying commit: {verify_error}", exc_info=True)
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
|
||||
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")
|
||||
# Include image stats in the log
|
||||
# DEBUG: Log the actual values being used
|
||||
logger.debug(f"[llm_text_gen] 📊 FINAL VALUES FOR LOG: calls_before={current_calls_before}, calls_after={new_calls}, tokens_before={current_tokens_before}, tokens_after={new_tokens}, provider={provider_name}, enum={provider_enum}")
|
||||
|
||||
# CRITICAL DEBUG: Print diagnostic info to stdout (always visible)
|
||||
print(f"[DEBUG] Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}")
|
||||
|
||||
print(f"""
|
||||
[SUBSCRIPTION] LLM Text Generation
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: {actual_provider_name}
|
||||
├─ Model: {model}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {call_limit if call_limit > 0 else '∞'}
|
||||
├─ Tokens: {current_tokens_before} → {new_tokens} / {token_limit if token_limit > 0 else '∞'}
|
||||
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit_display}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""")
|
||||
except Exception as track_error:
|
||||
logger.error(f"[llm_text_gen] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
# Non-blocking: log error but don't fail the request
|
||||
logger.error(f"[llm_text_gen] ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
|
||||
return response_text
|
||||
except Exception as provider_error:
|
||||
logger.error(f"[llm_text_gen] Provider {gpt_provider} failed: {str(provider_error)}")
|
||||
|
||||
# CIRCUIT BREAKER: Only try ONE fallback to prevent expensive API calls
|
||||
fallback_providers = ["google", "huggingface"]
|
||||
fallback_providers = [p for p in fallback_providers if p in available_providers and p != gpt_provider]
|
||||
|
||||
if fallback_providers:
|
||||
fallback_provider = fallback_providers[0] # Only try the first available
|
||||
try:
|
||||
logger.info(f"[llm_text_gen] Trying SINGLE fallback provider: {fallback_provider}")
|
||||
actual_provider_used = fallback_provider
|
||||
|
||||
# Update provider enum for fallback
|
||||
if fallback_provider == "google":
|
||||
provider_enum = APIProvider.GEMINI
|
||||
actual_provider_name = "gemini"
|
||||
fallback_model = "gemini-2.0-flash-lite"
|
||||
elif fallback_provider == "huggingface":
|
||||
provider_enum = APIProvider.MISTRAL
|
||||
actual_provider_name = "huggingface"
|
||||
fallback_model = "openai/gpt-oss-120b:groq"
|
||||
|
||||
if fallback_provider == "google":
|
||||
if json_struct:
|
||||
response_text = gemini_structured_json_response(
|
||||
prompt=prompt,
|
||||
schema=json_struct,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=n,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
else:
|
||||
response_text = gemini_text_response(
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
elif fallback_provider == "huggingface":
|
||||
if json_struct:
|
||||
response_text = huggingface_structured_json_response(
|
||||
prompt=prompt,
|
||||
schema=json_struct,
|
||||
model="openai/gpt-oss-120b:groq",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
else:
|
||||
response_text = huggingface_text_response(
|
||||
prompt=prompt,
|
||||
model="openai/gpt-oss-120b:groq",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
|
||||
# TRACK USAGE after successful fallback call
|
||||
if response_text:
|
||||
logger.info(f"[llm_text_gen] ✅ Fallback API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
|
||||
try:
|
||||
db_track = next(get_db())
|
||||
try:
|
||||
# Estimate tokens from prompt and response
|
||||
# Recalculate input tokens from prompt (consistent with pre-flight estimation)
|
||||
tokens_input = int(len(prompt.split()) * 1.3)
|
||||
tokens_output = int(len(str(response_text).split()) * 1.3)
|
||||
tokens_total = tokens_input + tokens_output
|
||||
|
||||
# Get or create usage summary
|
||||
from models.subscription_models import UsageSummary
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get limits once for safety check (to prevent exceeding limits even if actual usage > estimate)
|
||||
provider_name = provider_enum.value
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
token_limit = 0
|
||||
if limits and limits.get('limits'):
|
||||
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
|
||||
|
||||
# CRITICAL: Use raw SQL to read current values directly from DB, bypassing SQLAlchemy cache
|
||||
from sqlalchemy import text
|
||||
current_calls_before = 0
|
||||
current_tokens_before = 0
|
||||
|
||||
try:
|
||||
# Validate provider_name to prevent SQL injection
|
||||
valid_providers = ['gemini', 'openai', 'anthropic', 'mistral']
|
||||
if provider_name not in valid_providers:
|
||||
raise ValueError(f"Invalid provider_name for SQL query: {provider_name}")
|
||||
|
||||
# Read current values directly from database using raw SQL
|
||||
sql_query = text(f"""
|
||||
SELECT {provider_name}_calls, {provider_name}_tokens
|
||||
FROM usage_summaries
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
LIMIT 1
|
||||
""")
|
||||
result = db_track.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
|
||||
if result:
|
||||
current_calls_before = result[0] if result[0] is not None else 0
|
||||
current_tokens_before = result[1] if result[1] is not None else 0
|
||||
logger.debug(f"[llm_text_gen] Raw SQL read current values (fallback): calls={current_calls_before}, tokens={current_tokens_before}")
|
||||
except Exception as sql_error:
|
||||
logger.warning(f"[llm_text_gen] Raw SQL query failed (fallback), falling back to ORM: {sql_error}")
|
||||
# Fallback to ORM query if raw SQL fails
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if summary:
|
||||
db_track.refresh(summary)
|
||||
current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0
|
||||
current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0
|
||||
|
||||
# Get or create usage summary object (needed for ORM update)
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush() # Ensure summary is persisted before updating
|
||||
else:
|
||||
# CRITICAL: Update the ORM object with values from raw SQL query
|
||||
# This ensures the ORM object reflects the actual database state before we increment
|
||||
setattr(summary, f"{provider_name}_calls", current_calls_before)
|
||||
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
setattr(summary, f"{provider_name}_tokens", current_tokens_before)
|
||||
logger.debug(f"[llm_text_gen] Synchronized summary object with raw SQL values (fallback): calls={current_calls_before}, tokens={current_tokens_before}")
|
||||
|
||||
# Get "before" state for unified log (from raw SQL query)
|
||||
logger.debug(f"[llm_text_gen] Current {provider_name}_calls from DB (fallback, raw SQL): {current_calls_before}")
|
||||
|
||||
# Update provider-specific counters (sync operation)
|
||||
new_calls = current_calls_before + 1
|
||||
setattr(summary, f"{provider_name}_calls", new_calls)
|
||||
|
||||
# Update token usage for LLM providers with safety check
|
||||
# Use current_tokens_before from raw SQL query (most reliable)
|
||||
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
logger.debug(f"[llm_text_gen] Current {provider_name}_tokens from DB (fallback, raw SQL): {current_tokens_before}")
|
||||
|
||||
# SAFETY CHECK: Prevent exceeding token limit even if actual usage exceeds estimate
|
||||
# This prevents abuse where actual response tokens exceed pre-flight validation estimate
|
||||
projected_new_tokens = current_tokens_before + tokens_total
|
||||
|
||||
# If limit is set (> 0) and would be exceeded, cap at limit
|
||||
if token_limit > 0 and projected_new_tokens > token_limit:
|
||||
logger.warning(
|
||||
f"[llm_text_gen] ⚠️ ACTUAL token usage ({tokens_total}) exceeded estimate in fallback provider. "
|
||||
f"Would exceed limit: {projected_new_tokens} > {token_limit}. "
|
||||
f"Capping tracked tokens at limit to prevent abuse."
|
||||
)
|
||||
# Cap at limit to prevent abuse
|
||||
new_tokens = token_limit
|
||||
# Adjust tokens_total for accurate total tracking
|
||||
tokens_total = token_limit - current_tokens_before
|
||||
if tokens_total < 0:
|
||||
tokens_total = 0
|
||||
else:
|
||||
new_tokens = projected_new_tokens
|
||||
|
||||
setattr(summary, f"{provider_name}_tokens", new_tokens)
|
||||
else:
|
||||
current_tokens_before = 0
|
||||
new_tokens = 0
|
||||
|
||||
# Determine tracked tokens after any safety capping
|
||||
tracked_tokens_input = min(tokens_input, tokens_total)
|
||||
tracked_tokens_output = max(tokens_total - tracked_tokens_input, 0)
|
||||
|
||||
# Calculate and persist cost for this fallback call
|
||||
cost_total = 0.0
|
||||
try:
|
||||
cost_info = pricing.calculate_api_cost(
|
||||
provider=provider_enum,
|
||||
model_name=fallback_model,
|
||||
tokens_input=tracked_tokens_input,
|
||||
tokens_output=tracked_tokens_output,
|
||||
request_count=1
|
||||
)
|
||||
cost_total = cost_info.get('cost_total', 0.0) or 0.0
|
||||
except Exception as cost_error:
|
||||
logger.error(f"[llm_text_gen] ❌ Failed to calculate fallback cost: {cost_error}", exc_info=True)
|
||||
|
||||
if cost_total > 0:
|
||||
update_costs_query = text(f"""
|
||||
UPDATE usage_summaries
|
||||
SET {provider_name}_cost = COALESCE({provider_name}_cost, 0) + :cost,
|
||||
total_cost = COALESCE(total_cost, 0) + :cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_costs_query, {
|
||||
'cost': cost_total,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
setattr(summary, f"{provider_name}_cost", (getattr(summary, f"{provider_name}_cost", 0.0) or 0.0) + cost_total)
|
||||
summary.total_cost = (summary.total_cost or 0.0) + cost_total
|
||||
|
||||
# Update totals (using potentially capped tokens_total from safety check)
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.total_tokens = (summary.total_tokens or 0) + tokens_total
|
||||
|
||||
# Get plan details for unified log (limits already retrieved above)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0
|
||||
|
||||
# Get image stats for unified log
|
||||
current_images_before = getattr(summary, "stability_calls", 0) or 0
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
|
||||
# Get image editing stats for unified log
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
|
||||
# Get video stats for unified log
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
# Get audio stats for unified log
|
||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
|
||||
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
# CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions
|
||||
db_track.flush() # Flush to ensure changes are in DB (not just in transaction)
|
||||
db_track.commit() # Commit transaction to make changes visible to other sessions
|
||||
logger.info(f"[llm_text_gen] ✅ Successfully tracked fallback usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens (committed)")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG for fallback
|
||||
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")
|
||||
# Include image stats in the log
|
||||
print(f"""
|
||||
[SUBSCRIPTION] LLM Text Generation (Fallback)
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: {actual_provider_name}
|
||||
├─ Model: {fallback_model}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {call_limit if call_limit > 0 else '∞'}
|
||||
├─ Tokens: {current_tokens_before} → {new_tokens} / {token_limit if token_limit > 0 else '∞'}
|
||||
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit_display}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""")
|
||||
except Exception as track_error:
|
||||
logger.error(f"[llm_text_gen] ❌ Error tracking fallback usage (non-blocking): {track_error}", exc_info=True)
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"[llm_text_gen] ❌ Failed to track fallback usage: {usage_error}", exc_info=True)
|
||||
|
||||
return response_text
|
||||
except Exception as fallback_error:
|
||||
logger.error(f"[llm_text_gen] Fallback provider {fallback_provider} also failed: {str(fallback_error)}")
|
||||
|
||||
# CIRCUIT BREAKER: Stop immediately to prevent expensive API calls
|
||||
logger.error("[llm_text_gen] CIRCUIT BREAKER: Stopping to prevent expensive API calls.")
|
||||
raise RuntimeError("All LLM providers failed to generate a response.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[llm_text_gen] Error during text generation: {str(e)}")
|
||||
raise
|
||||
|
||||
def check_gpt_provider(gpt_provider: str) -> bool:
|
||||
"""Check if the specified GPT provider is supported."""
|
||||
supported_providers = ["google", "huggingface"]
|
||||
return gpt_provider in supported_providers
|
||||
|
||||
def get_api_key(gpt_provider: str) -> Optional[str]:
|
||||
"""Get API key for the specified provider."""
|
||||
try:
|
||||
api_key_manager = APIKeyManager()
|
||||
provider_mapping = {
|
||||
"google": "gemini",
|
||||
"huggingface": "hf_token"
|
||||
}
|
||||
|
||||
mapped_provider = provider_mapping.get(gpt_provider, gpt_provider)
|
||||
return api_key_manager.get_api_key(mapped_provider)
|
||||
except Exception as e:
|
||||
logger.error(f"[get_api_key] Error getting API key for {gpt_provider}: {str(e)}")
|
||||
return None
|
||||
792
backend/services/llm_providers/main_video_generation.py
Normal file
792
backend/services/llm_providers/main_video_generation.py
Normal file
@@ -0,0 +1,792 @@
|
||||
"""
|
||||
Main Video Generation Service
|
||||
|
||||
Provides a unified interface for AI video generation providers.
|
||||
Supports:
|
||||
- Text-to-video: Hugging Face Inference Providers, WaveSpeed models
|
||||
- Image-to-video: WaveSpeed WAN 2.5, Kandinsky 5 Pro
|
||||
Stubs included for Gemini (Veo 3) and OpenAI (Sora) for future use.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import base64
|
||||
import io
|
||||
import sys
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional, Union, Callable
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
try:
|
||||
from huggingface_hub import InferenceClient
|
||||
HF_HUB_AVAILABLE = True
|
||||
except ImportError:
|
||||
HF_HUB_AVAILABLE = False
|
||||
InferenceClient = None
|
||||
|
||||
from ..onboarding.api_key_manager import APIKeyManager
|
||||
from services.subscription import PricingService
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("video_generation_service")
|
||||
|
||||
class VideoProviderNotImplemented(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _get_api_key(provider: str) -> Optional[str]:
|
||||
try:
|
||||
manager = APIKeyManager()
|
||||
mapping = {
|
||||
"huggingface": "hf_token",
|
||||
"wavespeed": "wavespeed", # WaveSpeed API key
|
||||
"gemini": "gemini", # placeholder for Veo 3
|
||||
"openai": "openai_api_key", # placeholder for Sora
|
||||
}
|
||||
return manager.get_api_key(mapping.get(provider, provider))
|
||||
except Exception as e:
|
||||
logger.error(f"[video_gen] Failed to read API key for {provider}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _coerce_video_bytes(output: Any) -> bytes:
|
||||
"""
|
||||
Normalizes the different return shapes that huggingface_hub may emit for video tasks.
|
||||
According to HF docs, text_to_video() should return bytes directly.
|
||||
"""
|
||||
logger.debug(f"[video_gen] _coerce_video_bytes received type: {type(output)}")
|
||||
|
||||
# Most common case: bytes directly
|
||||
if isinstance(output, (bytes, bytearray, memoryview)):
|
||||
logger.debug(f"[video_gen] Output is bytes: {len(output)} bytes")
|
||||
return bytes(output)
|
||||
|
||||
# Handle file-like objects
|
||||
if hasattr(output, "read"):
|
||||
logger.debug("[video_gen] Output has read() method, reading...")
|
||||
data = output.read()
|
||||
if isinstance(data, (bytes, bytearray, memoryview)):
|
||||
return bytes(data)
|
||||
raise TypeError(f"File-like object returned non-bytes: {type(data)}")
|
||||
|
||||
# Objects with direct attribute access
|
||||
if hasattr(output, "video"):
|
||||
logger.debug("[video_gen] Output has 'video' attribute")
|
||||
data = getattr(output, "video")
|
||||
if isinstance(data, (bytes, bytearray, memoryview)):
|
||||
return bytes(data)
|
||||
if hasattr(data, "read"):
|
||||
return bytes(data.read())
|
||||
|
||||
if hasattr(output, "bytes"):
|
||||
logger.debug("[video_gen] Output has 'bytes' attribute")
|
||||
data = getattr(output, "bytes")
|
||||
if isinstance(data, (bytes, bytearray, memoryview)):
|
||||
return bytes(data)
|
||||
if hasattr(data, "read"):
|
||||
return bytes(data.read())
|
||||
|
||||
# Dict handling - but this shouldn't happen with text_to_video()
|
||||
if isinstance(output, dict):
|
||||
logger.warning(f"[video_gen] Received dict output (unexpected): keys={list(output.keys())}")
|
||||
# Try to get video key safely - use .get() to avoid KeyError
|
||||
data = output.get("video")
|
||||
if data is not None:
|
||||
if isinstance(data, (bytes, bytearray, memoryview)):
|
||||
return bytes(data)
|
||||
if hasattr(data, "read"):
|
||||
return bytes(data.read())
|
||||
# Try other common keys
|
||||
for key in ["data", "content", "file", "result", "output"]:
|
||||
data = output.get(key)
|
||||
if data is not None:
|
||||
if isinstance(data, (bytes, bytearray, memoryview)):
|
||||
return bytes(data)
|
||||
if hasattr(data, "read"):
|
||||
return bytes(data.read())
|
||||
raise TypeError(f"Dict output has no recognized video key. Keys: {list(output.keys())}")
|
||||
|
||||
# String handling (base64)
|
||||
if isinstance(output, str):
|
||||
logger.debug("[video_gen] Output is string, attempting base64 decode")
|
||||
if output.startswith("data:"):
|
||||
_, encoded = output.split(",", 1)
|
||||
return base64.b64decode(encoded)
|
||||
try:
|
||||
return base64.b64decode(output)
|
||||
except Exception as exc:
|
||||
raise TypeError(f"Unable to decode string video payload: {exc}") from exc
|
||||
|
||||
# Fallback: try to use output directly
|
||||
logger.warning(f"[video_gen] Unexpected output type: {type(output)}, attempting direct conversion")
|
||||
try:
|
||||
if hasattr(output, "__bytes__"):
|
||||
return bytes(output)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise TypeError(f"Unsupported video payload type: {type(output)}. Output: {str(output)[:200]}")
|
||||
|
||||
|
||||
def _generate_with_huggingface(
|
||||
prompt: str,
|
||||
num_frames: int = 24 * 4,
|
||||
guidance_scale: float = 7.5,
|
||||
num_inference_steps: int = 30,
|
||||
negative_prompt: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
model: str = "tencent/HunyuanVideo",
|
||||
) -> bytes:
|
||||
"""
|
||||
Generates video bytes using Hugging Face's InferenceClient.
|
||||
"""
|
||||
if not HF_HUB_AVAILABLE:
|
||||
raise RuntimeError("huggingface_hub is not installed. Install with: pip install huggingface_hub")
|
||||
|
||||
token = _get_api_key("huggingface")
|
||||
if not token:
|
||||
raise RuntimeError("HF token not configured. Set an hf_token in APIKeyManager.")
|
||||
|
||||
client = InferenceClient(
|
||||
provider="fal-ai",
|
||||
token=token,
|
||||
)
|
||||
logger.info("[video_gen] Using HuggingFace provider 'fal-ai'")
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
"num_frames": num_frames,
|
||||
"guidance_scale": guidance_scale,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
}
|
||||
if negative_prompt:
|
||||
params["negative_prompt"] = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt]
|
||||
if seed is not None:
|
||||
params["seed"] = seed
|
||||
|
||||
logger.info(
|
||||
"[video_gen] HuggingFace request model=%s frames=%s steps=%s mode=text-to-video",
|
||||
model,
|
||||
num_frames,
|
||||
num_inference_steps,
|
||||
)
|
||||
|
||||
try:
|
||||
logger.info("[video_gen] Calling client.text_to_video()...")
|
||||
video_output = client.text_to_video(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
**params,
|
||||
)
|
||||
|
||||
logger.info(f"[video_gen] text_to_video() returned type: {type(video_output)}")
|
||||
if isinstance(video_output, dict):
|
||||
logger.info(f"[video_gen] Dict keys: {list(video_output.keys())}")
|
||||
elif hasattr(video_output, "__dict__"):
|
||||
logger.info(f"[video_gen] Object attributes: {dir(video_output)}")
|
||||
|
||||
video_bytes = _coerce_video_bytes(video_output)
|
||||
|
||||
if not isinstance(video_bytes, bytes):
|
||||
raise TypeError(f"Expected bytes from text_to_video, got {type(video_bytes)}")
|
||||
|
||||
if len(video_bytes) == 0:
|
||||
raise ValueError("Received empty video bytes from Hugging Face API")
|
||||
|
||||
logger.info(f"[video_gen] Successfully generated video: {len(video_bytes)} bytes")
|
||||
return video_bytes
|
||||
|
||||
except KeyError as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"[video_gen] HF KeyError: {error_msg}", exc_info=True)
|
||||
logger.error(f"[video_gen] This suggests the API response format is unexpected. Check logs above for response type.")
|
||||
raise HTTPException(status_code=502, detail={
|
||||
"error": f"Hugging Face API returned unexpected response format: {error_msg}",
|
||||
"error_type": "KeyError",
|
||||
"hint": "The API response may have changed. Check server logs for details."
|
||||
})
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
error_type = type(e).__name__
|
||||
logger.error(f"[video_gen] HF error ({error_type}): {error_msg}", exc_info=True)
|
||||
raise HTTPException(status_code=502, detail={
|
||||
"error": f"Hugging Face video generation failed: {error_msg}",
|
||||
"error_type": error_type
|
||||
})
|
||||
|
||||
|
||||
async def _generate_image_to_video_wavespeed(
|
||||
image_data: Optional[bytes] = None,
|
||||
image_base64: Optional[str] = None,
|
||||
prompt: str = "",
|
||||
duration: int = 5,
|
||||
resolution: str = "720p",
|
||||
model: str = "alibaba/wan-2.5/image-to-video",
|
||||
negative_prompt: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
audio_base64: Optional[str] = None,
|
||||
enable_prompt_expansion: bool = True,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate video from image using WaveSpeed (WAN 2.5 or Kandinsky 5 Pro).
|
||||
|
||||
Args:
|
||||
image_data: Image bytes (required if image_base64 not provided)
|
||||
image_base64: Image in base64 or data URI format (required if image_data not provided)
|
||||
prompt: Text prompt describing the video motion
|
||||
duration: Video duration in seconds (5 or 10)
|
||||
resolution: Output resolution (480p, 720p, 1080p)
|
||||
model: Model to use (alibaba/wan-2.5/image-to-video, wavespeed/kandinsky5-pro/image-to-video)
|
||||
negative_prompt: Optional negative prompt
|
||||
seed: Optional random seed
|
||||
audio_base64: Optional audio file for synchronization
|
||||
enable_prompt_expansion: Enable prompt optimization
|
||||
|
||||
Returns:
|
||||
Dictionary with video_bytes and metadata (cost, duration, resolution, width, height, etc.)
|
||||
"""
|
||||
# Import here to avoid circular dependencies
|
||||
from services.image_studio.wan25_service import WAN25Service
|
||||
|
||||
logger.info(f"[video_gen] WaveSpeed image-to-video: model={model}, resolution={resolution}, duration={duration}s")
|
||||
|
||||
# Validate inputs
|
||||
if not image_data and not image_base64:
|
||||
raise ValueError("Either image_data or image_base64 must be provided for image-to-video")
|
||||
|
||||
# Convert image_data to base64 if needed
|
||||
if image_data and not image_base64:
|
||||
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
||||
# Add data URI prefix if not present
|
||||
if not image_base64.startswith("data:"):
|
||||
image_base64 = f"data:image/png;base64,{image_base64}"
|
||||
|
||||
# Initialize WAN25Service (handles both WAN 2.5 and Kandinsky 5 Pro)
|
||||
wan25_service = WAN25Service()
|
||||
|
||||
try:
|
||||
# Generate video using WAN25Service (returns full metadata)
|
||||
result = await wan25_service.generate_video(
|
||||
image_base64=image_base64,
|
||||
prompt=prompt,
|
||||
audio_base64=audio_base64,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=seed,
|
||||
enable_prompt_expansion=enable_prompt_expansion,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
video_bytes = result.get("video_bytes")
|
||||
if not video_bytes:
|
||||
raise ValueError("WAN25Service returned no video bytes")
|
||||
|
||||
if not isinstance(video_bytes, bytes):
|
||||
raise TypeError(f"Expected bytes from WAN25Service, got {type(video_bytes)}")
|
||||
|
||||
if len(video_bytes) == 0:
|
||||
raise ValueError("Received empty video bytes from WaveSpeed API")
|
||||
|
||||
logger.info(f"[video_gen] Successfully generated image-to-video: {len(video_bytes)} bytes")
|
||||
|
||||
# Return video bytes with metadata
|
||||
return {
|
||||
"video_bytes": video_bytes,
|
||||
"prompt": result.get("prompt", prompt),
|
||||
"duration": result.get("duration", float(duration)),
|
||||
"model_name": result.get("model_name", model),
|
||||
"cost": result.get("cost", 0.0),
|
||||
"provider": result.get("provider", "wavespeed"),
|
||||
"resolution": result.get("resolution", resolution),
|
||||
"width": result.get("width", 1280),
|
||||
"height": result.get("height", 720),
|
||||
"metadata": result.get("metadata", {}),
|
||||
"source_video_url": result.get("source_video_url"),
|
||||
"prediction_id": result.get("prediction_id"),
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions from WAN25Service
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
error_type = type(e).__name__
|
||||
logger.error(f"[video_gen] WaveSpeed image-to-video error ({error_type}): {error_msg}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": f"WaveSpeed image-to-video generation failed: {error_msg}",
|
||||
"error_type": error_type
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _generate_with_gemini(prompt: str, **kwargs) -> bytes:
|
||||
raise VideoProviderNotImplemented("Gemini Veo 3 integration coming soon.")
|
||||
|
||||
def _generate_with_openai(prompt: str, **kwargs) -> bytes:
|
||||
raise VideoProviderNotImplemented("OpenAI Sora integration coming soon.")
|
||||
|
||||
|
||||
async def _generate_text_to_video_wavespeed(
|
||||
prompt: str,
|
||||
duration: int = 5,
|
||||
resolution: str = "720p",
|
||||
model: str = "hunyuan-video-1.5",
|
||||
negative_prompt: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
audio_base64: Optional[str] = None,
|
||||
enable_prompt_expansion: bool = True,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate text-to-video using WaveSpeed models.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt describing the video
|
||||
duration: Video duration in seconds
|
||||
resolution: Output resolution (480p, 720p)
|
||||
model: Model identifier (e.g., "hunyuan-video-1.5")
|
||||
negative_prompt: Optional negative prompt
|
||||
seed: Optional random seed
|
||||
audio_base64: Optional audio (not supported by all models)
|
||||
enable_prompt_expansion: Enable prompt optimization (not supported by all models)
|
||||
progress_callback: Optional progress callback function
|
||||
**kwargs: Additional model-specific parameters
|
||||
|
||||
Returns:
|
||||
Dictionary with video_bytes, prompt, duration, model_name, cost, etc.
|
||||
"""
|
||||
from .video_generation.wavespeed_provider import get_wavespeed_text_to_video_service
|
||||
|
||||
logger.info(f"[video_gen] WaveSpeed text-to-video: model={model}, resolution={resolution}, duration={duration}s")
|
||||
|
||||
# Get the appropriate service for the model
|
||||
try:
|
||||
service = get_wavespeed_text_to_video_service(model)
|
||||
except ValueError as e:
|
||||
logger.error(f"[video_gen] Unsupported WaveSpeed text-to-video model: {model}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
# Generate video using the service
|
||||
try:
|
||||
result = await service.generate_video(
|
||||
prompt=prompt,
|
||||
duration=duration,
|
||||
resolution=resolution,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=seed,
|
||||
audio_base64=audio_base64,
|
||||
enable_prompt_expansion=enable_prompt_expansion,
|
||||
progress_callback=progress_callback,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
logger.info(f"[video_gen] Successfully generated text-to-video: {len(result.get('video_bytes', b''))} bytes")
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions from service
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
error_type = type(e).__name__
|
||||
logger.error(f"[video_gen] WaveSpeed text-to-video error ({error_type}): {error_msg}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": f"WaveSpeed text-to-video generation failed: {error_msg}",
|
||||
"type": error_type,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def ai_video_generate(
|
||||
prompt: Optional[str] = None,
|
||||
image_data: Optional[bytes] = None,
|
||||
image_base64: Optional[str] = None,
|
||||
operation_type: str = "text-to-video",
|
||||
provider: str = "huggingface",
|
||||
user_id: Optional[str] = None,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Unified video generation entry point for ALL video operations.
|
||||
|
||||
Supports:
|
||||
- text-to-video: prompt required, provider: 'huggingface', 'wavespeed', 'gemini' (stub), 'openai' (stub)
|
||||
- image-to-video: image_data or image_base64 required, provider: 'wavespeed'
|
||||
|
||||
Args:
|
||||
prompt: Text prompt (required for text-to-video)
|
||||
image_data: Image bytes (required for image-to-video if image_base64 not provided)
|
||||
image_base64: Image base64 string (required for image-to-video if image_data not provided)
|
||||
operation_type: "text-to-video" or "image-to-video" (default: "text-to-video")
|
||||
provider: Provider name (default: "huggingface" for text-to-video, "wavespeed" for image-to-video)
|
||||
user_id: Required for subscription/usage tracking
|
||||
progress_callback: Optional function(progress: float, message: str) -> None
|
||||
Called at key stages: submission (10%), polling (20-80%), completion (100%)
|
||||
**kwargs: Model-specific parameters:
|
||||
- For text-to-video: num_frames, guidance_scale, num_inference_steps, negative_prompt, seed, model
|
||||
- For image-to-video: duration, resolution, negative_prompt, seed, audio_base64, enable_prompt_expansion, model
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- video_bytes: Raw video bytes (mp4/webm depending on provider)
|
||||
- prompt: The prompt used (may be enhanced)
|
||||
- duration: Video duration in seconds
|
||||
- model_name: Model used for generation
|
||||
- cost: Cost of generation
|
||||
- provider: Provider name
|
||||
- resolution: Video resolution (for image-to-video)
|
||||
- width: Video width in pixels (for image-to-video)
|
||||
- height: Video height in pixels (for image-to-video)
|
||||
- metadata: Additional metadata dict
|
||||
"""
|
||||
logger.info(f"[video_gen] operation={operation_type}, provider={provider}")
|
||||
|
||||
# Enforce authentication usage like text gen does
|
||||
if not user_id:
|
||||
raise RuntimeError("user_id is required for subscription/usage tracking.")
|
||||
|
||||
# Validate operation type and required inputs
|
||||
if operation_type == "text-to-video":
|
||||
if not prompt:
|
||||
raise ValueError("prompt is required for text-to-video generation")
|
||||
# Set default provider if not specified
|
||||
if provider == "huggingface" and "model" not in kwargs:
|
||||
kwargs.setdefault("model", "tencent/HunyuanVideo")
|
||||
elif operation_type == "image-to-video":
|
||||
if not image_data and not image_base64:
|
||||
raise ValueError("image_data or image_base64 is required for image-to-video generation")
|
||||
# Set default provider and model for image-to-video
|
||||
if provider not in ["wavespeed"]:
|
||||
logger.warning(f"[video_gen] Provider {provider} not supported for image-to-video, defaulting to wavespeed")
|
||||
provider = "wavespeed"
|
||||
if "model" not in kwargs:
|
||||
kwargs.setdefault("model", "alibaba/wan-2.5/image-to-video")
|
||||
# Set defaults for image-to-video
|
||||
kwargs.setdefault("duration", 5)
|
||||
kwargs.setdefault("resolution", "720p")
|
||||
else:
|
||||
raise ValueError(f"Invalid operation_type: {operation_type}. Must be 'text-to-video' or 'image-to-video'")
|
||||
|
||||
# PRE-FLIGHT VALIDATION: Validate video generation before API call
|
||||
# MUST happen BEFORE any API calls - return immediately if validation fails
|
||||
from services.database import get_db
|
||||
from services.subscription.preflight_validator import validate_video_generation_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
# Raises HTTPException immediately if validation fails - frontend gets immediate response
|
||||
validate_video_generation_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id
|
||||
)
|
||||
except HTTPException:
|
||||
# Re-raise immediately - don't proceed with API call
|
||||
logger.error(f"[Video Generation] ❌ Pre-flight validation failed - blocking API call")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
logger.info(f"[Video Generation] ✅ Pre-flight validation passed - proceeding with {operation_type}")
|
||||
|
||||
# Progress callback: Initial submission
|
||||
if progress_callback:
|
||||
progress_callback(10.0, f"Submitting {operation_type} request to {provider}...")
|
||||
|
||||
# Generate video based on operation type
|
||||
model_name = kwargs.get("model", _get_default_model(operation_type, provider))
|
||||
try:
|
||||
if operation_type == "text-to-video":
|
||||
if provider == "huggingface":
|
||||
video_bytes = _generate_with_huggingface(
|
||||
prompt=prompt,
|
||||
**kwargs,
|
||||
)
|
||||
# For text-to-video, create metadata dict (HuggingFace doesn't return metadata)
|
||||
result_dict = {
|
||||
"video_bytes": video_bytes,
|
||||
"prompt": prompt,
|
||||
"duration": kwargs.get("duration", 5.0),
|
||||
"model_name": model_name,
|
||||
"cost": 0.10, # Default cost, will be calculated in track_video_usage
|
||||
"provider": provider,
|
||||
"resolution": kwargs.get("resolution", "720p"),
|
||||
"width": 1280, # Default, actual may vary
|
||||
"height": 720, # Default, actual may vary
|
||||
"metadata": {},
|
||||
}
|
||||
elif provider == "wavespeed":
|
||||
# WaveSpeed text-to-video - use unified service
|
||||
result_dict = await _generate_text_to_video_wavespeed(
|
||||
prompt=prompt,
|
||||
progress_callback=progress_callback,
|
||||
**kwargs,
|
||||
)
|
||||
elif provider == "gemini":
|
||||
video_bytes = _generate_with_gemini(prompt=prompt, **kwargs)
|
||||
result_dict = {
|
||||
"video_bytes": video_bytes,
|
||||
"prompt": prompt,
|
||||
"duration": kwargs.get("duration", 5.0),
|
||||
"model_name": model_name,
|
||||
"cost": 0.10,
|
||||
"provider": provider,
|
||||
"resolution": kwargs.get("resolution", "720p"),
|
||||
"width": 1280,
|
||||
"height": 720,
|
||||
"metadata": {},
|
||||
}
|
||||
elif provider == "openai":
|
||||
video_bytes = _generate_with_openai(prompt=prompt, **kwargs)
|
||||
result_dict = {
|
||||
"video_bytes": video_bytes,
|
||||
"prompt": prompt,
|
||||
"duration": kwargs.get("duration", 5.0),
|
||||
"model_name": model_name,
|
||||
"cost": 0.10,
|
||||
"provider": provider,
|
||||
"resolution": kwargs.get("resolution", "720p"),
|
||||
"width": 1280,
|
||||
"height": 720,
|
||||
"metadata": {},
|
||||
}
|
||||
else:
|
||||
raise RuntimeError(f"Unknown provider for text-to-video: {provider}")
|
||||
|
||||
elif operation_type == "image-to-video":
|
||||
if provider == "wavespeed":
|
||||
# Progress callback: Starting generation
|
||||
if progress_callback:
|
||||
progress_callback(20.0, "Video generation in progress...")
|
||||
|
||||
# Handle async call from sync context
|
||||
# Since ai_video_generate is sync, we need to run async function
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# We're in an async context - use ThreadPoolExecutor to run in new event loop
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
asyncio.run,
|
||||
_generate_image_to_video_wavespeed(
|
||||
image_data=image_data,
|
||||
image_base64=image_base64,
|
||||
prompt=prompt or kwargs.get("prompt", ""),
|
||||
progress_callback=progress_callback,
|
||||
**kwargs
|
||||
)
|
||||
)
|
||||
result_dict = future.result()
|
||||
else:
|
||||
# Event loop exists but not running - use it
|
||||
result_dict = loop.run_until_complete(_generate_image_to_video_wavespeed(
|
||||
image_data=image_data,
|
||||
image_base64=image_base64,
|
||||
prompt=prompt or kwargs.get("prompt", ""),
|
||||
progress_callback=progress_callback,
|
||||
**kwargs
|
||||
))
|
||||
except RuntimeError:
|
||||
# No event loop exists, create a new one
|
||||
result_dict = asyncio.run(_generate_image_to_video_wavespeed(
|
||||
image_data=image_data,
|
||||
image_base64=image_base64,
|
||||
prompt=prompt or kwargs.get("prompt", ""),
|
||||
progress_callback=progress_callback,
|
||||
**kwargs
|
||||
))
|
||||
video_bytes = result_dict["video_bytes"]
|
||||
model_name = result_dict.get("model_name", model_name)
|
||||
|
||||
# Progress callback: Processing result
|
||||
if progress_callback:
|
||||
progress_callback(90.0, "Processing video result...")
|
||||
else:
|
||||
raise RuntimeError(f"Unknown provider for image-to-video: {provider}. Only 'wavespeed' is supported.")
|
||||
|
||||
# Track usage (same pattern as text generation)
|
||||
# Use cost from result_dict if available, otherwise calculate
|
||||
cost_override = result_dict.get("cost") if operation_type == "image-to-video" else kwargs.get("cost_override")
|
||||
track_video_usage(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
prompt=result_dict.get("prompt", prompt or ""),
|
||||
video_bytes=video_bytes,
|
||||
cost_override=cost_override,
|
||||
)
|
||||
|
||||
# Progress callback: Complete
|
||||
if progress_callback:
|
||||
progress_callback(100.0, "Video generation complete!")
|
||||
|
||||
return result_dict
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions (e.g., from validation or API errors)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[video_gen] Error during video generation: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail={"error": str(e)})
|
||||
|
||||
|
||||
def _get_default_model(operation_type: str, provider: str) -> str:
|
||||
"""Get default model for operation type and provider."""
|
||||
defaults = {
|
||||
("text-to-video", "huggingface"): "tencent/HunyuanVideo",
|
||||
("text-to-video", "wavespeed"): "hunyuan-video-1.5",
|
||||
("image-to-video", "wavespeed"): "alibaba/wan-2.5/image-to-video",
|
||||
}
|
||||
return defaults.get((operation_type, provider), "hunyuan-video-1.5")
|
||||
|
||||
|
||||
def track_video_usage(
|
||||
*,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
video_bytes: bytes,
|
||||
cost_override: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Track subscription usage for any video generation (text-to-video or image-to-video).
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
from models.subscription_models import APIProvider, APIUsageLog, UsageSummary
|
||||
from services.database import get_db
|
||||
|
||||
db_track = next(get_db())
|
||||
try:
|
||||
logger.info(f"[video_gen] Starting usage tracking for user={user_id}, provider={provider}, model={model_name}")
|
||||
pricing_service_track = PricingService(db_track)
|
||||
current_period = pricing_service_track.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
logger.debug(f"[video_gen] Billing period: {current_period}")
|
||||
|
||||
usage_summary = (
|
||||
db_track.query(UsageSummary)
|
||||
.filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not usage_summary:
|
||||
logger.debug(f"[video_gen] Creating new UsageSummary for user={user_id}, period={current_period}")
|
||||
usage_summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_summary)
|
||||
db_track.commit()
|
||||
db_track.refresh(usage_summary)
|
||||
else:
|
||||
logger.debug(f"[video_gen] Found existing UsageSummary: video_calls={getattr(usage_summary, 'video_calls', 0)}")
|
||||
|
||||
cost_info = pricing_service_track.get_pricing_for_provider_model(
|
||||
APIProvider.VIDEO,
|
||||
model_name,
|
||||
)
|
||||
default_cost = 0.10
|
||||
if cost_info and cost_info.get("cost_per_request") is not None:
|
||||
default_cost = cost_info["cost_per_request"]
|
||||
cost_per_video = cost_override if cost_override is not None else default_cost
|
||||
logger.debug(f"[video_gen] Cost per video: ${cost_per_video} (override={cost_override}, default={default_cost})")
|
||||
|
||||
current_video_calls_before = getattr(usage_summary, "video_calls", 0) or 0
|
||||
current_video_cost = getattr(usage_summary, "video_cost", 0.0) or 0.0
|
||||
usage_summary.video_calls = current_video_calls_before + 1
|
||||
usage_summary.video_cost = current_video_cost + cost_per_video
|
||||
usage_summary.total_calls = (usage_summary.total_calls or 0) + 1
|
||||
usage_summary.total_cost = (usage_summary.total_cost or 0.0) + cost_per_video
|
||||
# Ensure the object is in the session
|
||||
db_track.add(usage_summary)
|
||||
logger.debug(f"[video_gen] Updated usage_summary: video_calls={current_video_calls_before} → {usage_summary.video_calls}")
|
||||
|
||||
limits = pricing_service_track.get_user_limits(user_id)
|
||||
plan_name = limits.get("plan_name", "unknown") if limits else "unknown"
|
||||
tier = limits.get("tier", "unknown") if limits else "unknown"
|
||||
video_limit = limits["limits"].get("video_calls", 0) if limits else 0
|
||||
current_image_calls = getattr(usage_summary, "stability_calls", 0) or 0
|
||||
image_limit = limits["limits"].get("stability_calls", 0) if limits else 0
|
||||
current_image_edit_calls = getattr(usage_summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits["limits"].get("image_edit_calls", 0) if limits else 0
|
||||
current_audio_calls = getattr(usage_summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits["limits"].get("audio_calls", 0) if limits else 0
|
||||
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
|
||||
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.VIDEO,
|
||||
endpoint=f"/video-generation/{provider}",
|
||||
method="POST",
|
||||
model_used=model_name,
|
||||
tokens_input=0,
|
||||
tokens_output=0,
|
||||
tokens_total=0,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=cost_per_video,
|
||||
response_time=0.0,
|
||||
status_code=200,
|
||||
request_size=len((prompt or "").encode("utf-8")),
|
||||
response_size=len(video_bytes),
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
logger.debug(f"[video_gen] Flushing changes before commit...")
|
||||
db_track.flush()
|
||||
logger.debug(f"[video_gen] Committing usage tracking changes...")
|
||||
db_track.commit()
|
||||
db_track.refresh(usage_summary)
|
||||
logger.debug(f"[video_gen] Commit successful. Final video_calls: {usage_summary.video_calls}, video_cost: {usage_summary.video_cost}")
|
||||
|
||||
video_limit_display = video_limit if video_limit > 0 else '∞'
|
||||
|
||||
log_message = f"""
|
||||
[SUBSCRIPTION] Video Generation
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: video
|
||||
├─ Actual Provider: {provider}
|
||||
├─ Model: {model_name or 'default'}
|
||||
├─ Calls: {current_video_calls_before} → {usage_summary.video_calls} / {video_limit_display}
|
||||
├─ Images: {current_image_calls} / {image_limit if image_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit_display}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
"""
|
||||
logger.info(log_message)
|
||||
return {
|
||||
"previous_calls": current_video_calls_before,
|
||||
"current_calls": usage_summary.video_calls,
|
||||
"video_limit": video_limit,
|
||||
"video_limit_display": video_limit_display,
|
||||
"cost_per_video": cost_per_video,
|
||||
"total_video_cost": usage_summary.video_cost,
|
||||
}
|
||||
except Exception as track_error:
|
||||
logger.error(f"[video_gen] Error tracking usage: {track_error}", exc_info=True)
|
||||
logger.error(f"[video_gen] Exception type: {type(track_error).__name__}", exc_info=True)
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
|
||||
|
||||
10
backend/services/llm_providers/video_generation/__init__.py
Normal file
10
backend/services/llm_providers/video_generation/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Video Generation Services
|
||||
|
||||
Modular services for text-to-video and image-to-video generation.
|
||||
Each provider/model has its own service class for separation of concerns.
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
__all__ = []
|
||||
53
backend/services/llm_providers/video_generation/base.py
Normal file
53
backend/services/llm_providers/video_generation/base.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
Base classes and interfaces for video generation services.
|
||||
|
||||
Provides common interfaces and data structures for video generation providers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Dict, Any, Protocol, Callable
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoGenerationOptions:
|
||||
"""Options for video generation."""
|
||||
prompt: str
|
||||
duration: int = 5
|
||||
resolution: str = "720p"
|
||||
negative_prompt: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
audio_base64: Optional[str] = None
|
||||
enable_prompt_expansion: bool = True
|
||||
model: Optional[str] = None
|
||||
extra: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoGenerationResult:
|
||||
"""Result from video generation."""
|
||||
video_bytes: bytes
|
||||
prompt: str
|
||||
duration: float
|
||||
model_name: str
|
||||
cost: float
|
||||
provider: str
|
||||
resolution: str
|
||||
width: int
|
||||
height: int
|
||||
metadata: Dict[str, Any]
|
||||
source_video_url: Optional[str] = None
|
||||
prediction_id: Optional[str] = None
|
||||
|
||||
|
||||
class VideoGenerationProvider(Protocol):
|
||||
"""Protocol for video generation providers."""
|
||||
|
||||
async def generate_video(
|
||||
self,
|
||||
options: VideoGenerationOptions,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> VideoGenerationResult:
|
||||
"""Generate video with given options."""
|
||||
...
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user