ALwrity Version 0.5.0 (Fastapi + React )
This commit is contained in:
339
backend/llm_providers/gemini_provider.py
Normal file
339
backend/llm_providers/gemini_provider.py
Normal file
@@ -0,0 +1,339 @@
|
||||
# Using Gemini Pro LLM model
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
import time
|
||||
import google.genai as genai
|
||||
from google.genai import types
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(Path('../../../.env'))
|
||||
from loguru import logger
|
||||
logger.remove()
|
||||
logger.add(sys.stdout,
|
||||
colorize=True,
|
||||
format="<level>{level}</level>|<green>{file}:{line}:{function}</green>| {message}"
|
||||
)
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
|
||||
# Configure standard logging
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO, format='[%(asctime)s-%(levelname)s-%(module)s-%(lineno)d]- %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_gemini_api_key():
|
||||
"""Get Gemini API key from API key manager or environment."""
|
||||
try:
|
||||
# Try to get from API key manager first
|
||||
from services.api_key_manager import get_api_key_manager
|
||||
api_key_manager = get_api_key_manager()
|
||||
api_key = api_key_manager.get_api_key("gemini")
|
||||
if api_key:
|
||||
return api_key
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not get API key from manager: {e}")
|
||||
|
||||
# Fallback to environment variable
|
||||
api_key = os.getenv('GEMINI_API_KEY')
|
||||
if not api_key:
|
||||
raise ValueError("Gemini API key not found in environment variables or API key manager")
|
||||
|
||||
return api_key
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
def gemini_text_response(prompt, temperature=0.7, top_p=0.9, n=40, max_tokens=2048, system_prompt=None):
|
||||
"""Get response from Gemini Pro Text using official SDK pattern."""
|
||||
try:
|
||||
# Get API key
|
||||
api_key = get_gemini_api_key()
|
||||
|
||||
logger.info(f"Temp: {temperature}, MaxTokens: {max_tokens}, TopP: {top_p}, N: {n}")
|
||||
|
||||
# Create the client with API key (official SDK pattern)
|
||||
client = genai.Client(api_key=api_key)
|
||||
|
||||
# Prepare content with system instruction if provided
|
||||
if system_prompt:
|
||||
# Use system instruction in generation config (official SDK pattern)
|
||||
generation_config = types.GenerateContentConfig(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=n,
|
||||
max_output_tokens=max_tokens,
|
||||
system_instruction=system_prompt
|
||||
)
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.0-flash-001", # Using the recommended model from docs
|
||||
contents=prompt,
|
||||
config=generation_config
|
||||
)
|
||||
else:
|
||||
# Standard generation without system instruction (official SDK pattern)
|
||||
generation_config = types.GenerateContentConfig(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=n,
|
||||
max_output_tokens=max_tokens,
|
||||
)
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.0-flash-001", # Using the recommended model from docs
|
||||
contents=prompt,
|
||||
config=generation_config
|
||||
)
|
||||
|
||||
logger.info(f"[gemini_text_response] Generated response with {len(response.text)} characters")
|
||||
return response.text
|
||||
|
||||
except Exception as err:
|
||||
logger.error(f"Failed to get response from Gemini: {err}. Retrying.")
|
||||
raise
|
||||
|
||||
def _clean_schema_for_gemini(schema):
|
||||
"""Clean schema to remove unsupported properties for Gemini API."""
|
||||
if isinstance(schema, dict):
|
||||
# Remove unsupported properties
|
||||
unsupported_props = ['additionalProperties', 'pattern', 'format', 'minLength', 'maxLength']
|
||||
cleaned = {}
|
||||
|
||||
for key, value in schema.items():
|
||||
if key not in unsupported_props:
|
||||
if isinstance(value, dict):
|
||||
cleaned_value = _clean_schema_for_gemini(value)
|
||||
# Skip empty objects or objects with empty properties
|
||||
if key == "properties" and not cleaned_value:
|
||||
continue
|
||||
if key == "properties" and isinstance(cleaned_value, dict):
|
||||
# Remove any properties that have empty object definitions
|
||||
non_empty_props = {}
|
||||
for prop_key, prop_value in cleaned_value.items():
|
||||
if isinstance(prop_value, dict):
|
||||
if prop_value.get("type") == "object":
|
||||
# If it's an object type, ensure it has properties or change to string
|
||||
if not prop_value.get("properties"):
|
||||
non_empty_props[prop_key] = {"type": "string"}
|
||||
else:
|
||||
non_empty_props[prop_key] = prop_value
|
||||
else:
|
||||
non_empty_props[prop_key] = prop_value
|
||||
else:
|
||||
non_empty_props[prop_key] = prop_value
|
||||
cleaned[key] = non_empty_props
|
||||
else:
|
||||
cleaned[key] = cleaned_value
|
||||
elif isinstance(value, list):
|
||||
cleaned[key] = [_clean_schema_for_gemini(item) if isinstance(item, dict) else item for item in value]
|
||||
else:
|
||||
cleaned[key] = value
|
||||
|
||||
return cleaned
|
||||
elif isinstance(schema, list):
|
||||
return [_clean_schema_for_gemini(item) if isinstance(item, dict) else item for item in schema]
|
||||
else:
|
||||
return schema
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
def gemini_structured_json_response(prompt: str, schema: Dict[str, Any], model_name: str = "gemini-2.0-flash-001") -> str:
|
||||
"""
|
||||
Generate structured JSON response using Gemini API according to official SDK
|
||||
"""
|
||||
try:
|
||||
api_key = get_gemini_api_key()
|
||||
if not api_key:
|
||||
logger.error("Gemini API key not found")
|
||||
return json.dumps({"error": "API key not found"})
|
||||
|
||||
# Clean and validate schema
|
||||
cleaned_schema = _clean_schema_for_gemini(schema)
|
||||
validated_schema = _validate_and_fix_schema(cleaned_schema)
|
||||
|
||||
logger.info(f"🤖 Making Gemini API call to {model_name}")
|
||||
logger.info(f"📝 Prompt: {prompt[:200]}...")
|
||||
logger.info(f"🔧 Schema: {json.dumps(validated_schema, indent=2)}")
|
||||
|
||||
# Create the client with API key (official SDK pattern)
|
||||
client = genai.Client(api_key=api_key)
|
||||
|
||||
generation_config = types.GenerateContentConfig(
|
||||
temperature=0.7,
|
||||
top_p=0.8,
|
||||
top_k=40,
|
||||
max_output_tokens=8192,
|
||||
)
|
||||
|
||||
# Create the prompt with schema
|
||||
full_prompt = f"""
|
||||
{prompt}
|
||||
|
||||
Please respond with a valid JSON object that matches this schema:
|
||||
|
||||
{json.dumps(validated_schema, indent=2)}
|
||||
|
||||
Ensure the response is valid JSON and matches the schema exactly.
|
||||
"""
|
||||
|
||||
logger.info(f"🚀 Sending request to Gemini API...")
|
||||
start_time = time.time()
|
||||
|
||||
# Generate content using official SDK pattern
|
||||
response = client.models.generate_content(
|
||||
model=model_name,
|
||||
contents=full_prompt,
|
||||
config=generation_config
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
logger.info(f"⏱️ Gemini API response received in {end_time - start_time:.2f} seconds")
|
||||
logger.info(f"📄 Raw response: {response.text[:500]}...")
|
||||
|
||||
# Try to parse the response as JSON
|
||||
try:
|
||||
# First, try to extract JSON from the response
|
||||
json_text = response.text.strip()
|
||||
|
||||
# Remove markdown code blocks if present
|
||||
if json_text.startswith("```json"):
|
||||
json_text = json_text[7:]
|
||||
if json_text.endswith("```"):
|
||||
json_text = json_text[:-3]
|
||||
|
||||
json_text = json_text.strip()
|
||||
|
||||
# Try to parse as JSON
|
||||
parsed = json.loads(json_text)
|
||||
logger.info(f"✅ Successfully parsed JSON response: {json.dumps(parsed, indent=2)}")
|
||||
return json.dumps(parsed)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"❌ JSON parsing failed: {e}")
|
||||
logger.warning(f"📄 Attempted to parse: {json_text}")
|
||||
|
||||
# Try to find JSON-like content in the response
|
||||
import re
|
||||
json_match = re.search(r'\{.*\}', response.text, re.DOTALL)
|
||||
if json_match:
|
||||
try:
|
||||
parsed = json.loads(json_match.group())
|
||||
logger.info(f"✅ Found and parsed JSON in response: {json.dumps(parsed, indent=2)}")
|
||||
return json.dumps(parsed)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("❌ Failed to parse extracted JSON")
|
||||
|
||||
logger.warning("❌ No valid JSON found in response, returning full text")
|
||||
return json.dumps({"error": "Invalid JSON response", "raw_text": response.text})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Gemini API error: {str(e)}")
|
||||
return json.dumps({"error": f"Gemini API error: {str(e)}"})
|
||||
|
||||
def _validate_and_fix_schema(schema):
|
||||
"""Validate and fix schema to ensure it's compatible with Gemini API."""
|
||||
if isinstance(schema, dict):
|
||||
# Check for empty object properties
|
||||
if "properties" in schema and isinstance(schema["properties"], dict):
|
||||
fixed_properties = {}
|
||||
for key, value in schema["properties"].items():
|
||||
if isinstance(value, dict):
|
||||
if value.get("type") == "object":
|
||||
# If object has no properties or empty properties, change to string
|
||||
if not value.get("properties") or not value["properties"]:
|
||||
fixed_properties[key] = {"type": "string"}
|
||||
else:
|
||||
# Recursively fix nested objects
|
||||
fixed_properties[key] = _validate_and_fix_schema(value)
|
||||
else:
|
||||
fixed_properties[key] = value
|
||||
else:
|
||||
fixed_properties[key] = value
|
||||
|
||||
schema["properties"] = fixed_properties
|
||||
|
||||
# Recursively fix nested objects
|
||||
for key, value in schema.items():
|
||||
if isinstance(value, dict):
|
||||
schema[key] = _validate_and_fix_schema(value)
|
||||
|
||||
return schema
|
||||
|
||||
async def test_gemini_api_key(api_key: str) -> tuple[bool, str]:
|
||||
"""
|
||||
Test if the provided Gemini API key is valid using official SDK pattern.
|
||||
|
||||
Args:
|
||||
api_key (str): The Gemini API key to test
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: A tuple containing (is_valid, message)
|
||||
"""
|
||||
try:
|
||||
# Try to generate a simple response as a test using official SDK pattern
|
||||
test_prompt = "Hello"
|
||||
client = genai.Client(api_key=api_key)
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.0-flash-001", # Using the recommended model from docs
|
||||
contents=test_prompt,
|
||||
config=types.GenerateContentConfig(
|
||||
temperature=0.1,
|
||||
max_output_tokens=50
|
||||
)
|
||||
)
|
||||
|
||||
# If we get here, the key is valid
|
||||
return True, "Gemini API key is valid"
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
if "API_KEY_INVALID" in error_msg or "authentication" in error_msg.lower():
|
||||
return False, "Invalid Gemini API key"
|
||||
elif "quota" in error_msg.lower() or "rate" in error_msg.lower():
|
||||
return False, "Rate limit exceeded. Please try again later."
|
||||
else:
|
||||
return False, f"Error testing Gemini API key: {error_msg}"
|
||||
|
||||
def gemini_pro_text_gen(prompt, temperature=0.7, top_p=0.9, top_k=40, max_tokens=2048):
|
||||
"""
|
||||
Generate text using Google's Gemini Pro model according to official SDK.
|
||||
|
||||
Args:
|
||||
prompt (str): The input text to generate completion for
|
||||
temperature (float, optional): Controls randomness. Defaults to 0.7
|
||||
top_p (float, optional): Controls diversity. Defaults to 0.9
|
||||
top_k (int, optional): Controls vocabulary size. Defaults to 40
|
||||
max_tokens (int, optional): Maximum number of tokens to generate. Defaults to 2048
|
||||
|
||||
Returns:
|
||||
str: The generated text completion
|
||||
"""
|
||||
try:
|
||||
# Get API key
|
||||
api_key = get_gemini_api_key()
|
||||
|
||||
# Create the client with API key (official SDK pattern)
|
||||
client = genai.Client(api_key=api_key)
|
||||
|
||||
# Generate content using the official SDK pattern
|
||||
response = client.models.generate_content(
|
||||
model='gemini-2.0-flash-001', # Using the recommended model from docs
|
||||
contents=prompt,
|
||||
config=types.GenerateContentConfig(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
max_output_tokens=max_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
# Return the generated text
|
||||
return response.text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Gemini Pro text generation: {e}")
|
||||
return str(e)
|
||||
Reference in New Issue
Block a user