Files
ALwrity/backend/llm_providers/gemini_provider.py
2025-08-06 12:48:02 +05:30

339 lines
13 KiB
Python

# Using Gemini Pro LLM model
import os
import sys
from pathlib import Path
from typing import Dict, Any
import time
import google.genai as genai
from google.genai import types
from dotenv import load_dotenv
load_dotenv(Path('../../../.env'))
from loguru import logger
logger.remove()
logger.add(sys.stdout,
colorize=True,
format="<level>{level}</level>|<green>{file}:{line}:{function}</green>| {message}"
)
from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential,
)
import asyncio
import json
import re
# Configure standard logging
import logging
logging.basicConfig(level=logging.INFO, format='[%(asctime)s-%(levelname)s-%(module)s-%(lineno)d]- %(message)s')
logger = logging.getLogger(__name__)
def get_gemini_api_key():
"""Get Gemini API key from API key manager or environment."""
try:
# Try to get from API key manager first
from services.api_key_manager import get_api_key_manager
api_key_manager = get_api_key_manager()
api_key = api_key_manager.get_api_key("gemini")
if api_key:
return api_key
except Exception as e:
logger.warning(f"Could not get API key from manager: {e}")
# Fallback to environment variable
api_key = os.getenv('GEMINI_API_KEY')
if not api_key:
raise ValueError("Gemini API key not found in environment variables or API key manager")
return api_key
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def gemini_text_response(prompt, temperature=0.7, top_p=0.9, n=40, max_tokens=2048, system_prompt=None):
"""Get response from Gemini Pro Text using official SDK pattern."""
try:
# Get API key
api_key = get_gemini_api_key()
logger.info(f"Temp: {temperature}, MaxTokens: {max_tokens}, TopP: {top_p}, N: {n}")
# Create the client with API key (official SDK pattern)
client = genai.Client(api_key=api_key)
# Prepare content with system instruction if provided
if system_prompt:
# Use system instruction in generation config (official SDK pattern)
generation_config = types.GenerateContentConfig(
temperature=temperature,
top_p=top_p,
top_k=n,
max_output_tokens=max_tokens,
system_instruction=system_prompt
)
response = client.models.generate_content(
model="gemini-2.0-flash-001", # Using the recommended model from docs
contents=prompt,
config=generation_config
)
else:
# Standard generation without system instruction (official SDK pattern)
generation_config = types.GenerateContentConfig(
temperature=temperature,
top_p=top_p,
top_k=n,
max_output_tokens=max_tokens,
)
response = client.models.generate_content(
model="gemini-2.0-flash-001", # Using the recommended model from docs
contents=prompt,
config=generation_config
)
logger.info(f"[gemini_text_response] Generated response with {len(response.text)} characters")
return response.text
except Exception as err:
logger.error(f"Failed to get response from Gemini: {err}. Retrying.")
raise
def _clean_schema_for_gemini(schema):
"""Clean schema to remove unsupported properties for Gemini API."""
if isinstance(schema, dict):
# Remove unsupported properties
unsupported_props = ['additionalProperties', 'pattern', 'format', 'minLength', 'maxLength']
cleaned = {}
for key, value in schema.items():
if key not in unsupported_props:
if isinstance(value, dict):
cleaned_value = _clean_schema_for_gemini(value)
# Skip empty objects or objects with empty properties
if key == "properties" and not cleaned_value:
continue
if key == "properties" and isinstance(cleaned_value, dict):
# Remove any properties that have empty object definitions
non_empty_props = {}
for prop_key, prop_value in cleaned_value.items():
if isinstance(prop_value, dict):
if prop_value.get("type") == "object":
# If it's an object type, ensure it has properties or change to string
if not prop_value.get("properties"):
non_empty_props[prop_key] = {"type": "string"}
else:
non_empty_props[prop_key] = prop_value
else:
non_empty_props[prop_key] = prop_value
else:
non_empty_props[prop_key] = prop_value
cleaned[key] = non_empty_props
else:
cleaned[key] = cleaned_value
elif isinstance(value, list):
cleaned[key] = [_clean_schema_for_gemini(item) if isinstance(item, dict) else item for item in value]
else:
cleaned[key] = value
return cleaned
elif isinstance(schema, list):
return [_clean_schema_for_gemini(item) if isinstance(item, dict) else item for item in schema]
else:
return schema
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def gemini_structured_json_response(prompt: str, schema: Dict[str, Any], model_name: str = "gemini-2.0-flash-001") -> str:
"""
Generate structured JSON response using Gemini API according to official SDK
"""
try:
api_key = get_gemini_api_key()
if not api_key:
logger.error("Gemini API key not found")
return json.dumps({"error": "API key not found"})
# Clean and validate schema
cleaned_schema = _clean_schema_for_gemini(schema)
validated_schema = _validate_and_fix_schema(cleaned_schema)
logger.info(f"🤖 Making Gemini API call to {model_name}")
logger.info(f"📝 Prompt: {prompt[:200]}...")
logger.info(f"🔧 Schema: {json.dumps(validated_schema, indent=2)}")
# Create the client with API key (official SDK pattern)
client = genai.Client(api_key=api_key)
generation_config = types.GenerateContentConfig(
temperature=0.7,
top_p=0.8,
top_k=40,
max_output_tokens=8192,
)
# Create the prompt with schema
full_prompt = f"""
{prompt}
Please respond with a valid JSON object that matches this schema:
{json.dumps(validated_schema, indent=2)}
Ensure the response is valid JSON and matches the schema exactly.
"""
logger.info(f"🚀 Sending request to Gemini API...")
start_time = time.time()
# Generate content using official SDK pattern
response = client.models.generate_content(
model=model_name,
contents=full_prompt,
config=generation_config
)
end_time = time.time()
logger.info(f"⏱️ Gemini API response received in {end_time - start_time:.2f} seconds")
logger.info(f"📄 Raw response: {response.text[:500]}...")
# Try to parse the response as JSON
try:
# First, try to extract JSON from the response
json_text = response.text.strip()
# Remove markdown code blocks if present
if json_text.startswith("```json"):
json_text = json_text[7:]
if json_text.endswith("```"):
json_text = json_text[:-3]
json_text = json_text.strip()
# Try to parse as JSON
parsed = json.loads(json_text)
logger.info(f"✅ Successfully parsed JSON response: {json.dumps(parsed, indent=2)}")
return json.dumps(parsed)
except json.JSONDecodeError as e:
logger.warning(f"❌ JSON parsing failed: {e}")
logger.warning(f"📄 Attempted to parse: {json_text}")
# Try to find JSON-like content in the response
import re
json_match = re.search(r'\{.*\}', response.text, re.DOTALL)
if json_match:
try:
parsed = json.loads(json_match.group())
logger.info(f"✅ Found and parsed JSON in response: {json.dumps(parsed, indent=2)}")
return json.dumps(parsed)
except json.JSONDecodeError:
logger.warning("❌ Failed to parse extracted JSON")
logger.warning("❌ No valid JSON found in response, returning full text")
return json.dumps({"error": "Invalid JSON response", "raw_text": response.text})
except Exception as e:
logger.error(f"❌ Gemini API error: {str(e)}")
return json.dumps({"error": f"Gemini API error: {str(e)}"})
def _validate_and_fix_schema(schema):
"""Validate and fix schema to ensure it's compatible with Gemini API."""
if isinstance(schema, dict):
# Check for empty object properties
if "properties" in schema and isinstance(schema["properties"], dict):
fixed_properties = {}
for key, value in schema["properties"].items():
if isinstance(value, dict):
if value.get("type") == "object":
# If object has no properties or empty properties, change to string
if not value.get("properties") or not value["properties"]:
fixed_properties[key] = {"type": "string"}
else:
# Recursively fix nested objects
fixed_properties[key] = _validate_and_fix_schema(value)
else:
fixed_properties[key] = value
else:
fixed_properties[key] = value
schema["properties"] = fixed_properties
# Recursively fix nested objects
for key, value in schema.items():
if isinstance(value, dict):
schema[key] = _validate_and_fix_schema(value)
return schema
async def test_gemini_api_key(api_key: str) -> tuple[bool, str]:
"""
Test if the provided Gemini API key is valid using official SDK pattern.
Args:
api_key (str): The Gemini API key to test
Returns:
tuple[bool, str]: A tuple containing (is_valid, message)
"""
try:
# Try to generate a simple response as a test using official SDK pattern
test_prompt = "Hello"
client = genai.Client(api_key=api_key)
response = client.models.generate_content(
model="gemini-2.0-flash-001", # Using the recommended model from docs
contents=test_prompt,
config=types.GenerateContentConfig(
temperature=0.1,
max_output_tokens=50
)
)
# If we get here, the key is valid
return True, "Gemini API key is valid"
except Exception as e:
error_msg = str(e)
if "API_KEY_INVALID" in error_msg or "authentication" in error_msg.lower():
return False, "Invalid Gemini API key"
elif "quota" in error_msg.lower() or "rate" in error_msg.lower():
return False, "Rate limit exceeded. Please try again later."
else:
return False, f"Error testing Gemini API key: {error_msg}"
def gemini_pro_text_gen(prompt, temperature=0.7, top_p=0.9, top_k=40, max_tokens=2048):
"""
Generate text using Google's Gemini Pro model according to official SDK.
Args:
prompt (str): The input text to generate completion for
temperature (float, optional): Controls randomness. Defaults to 0.7
top_p (float, optional): Controls diversity. Defaults to 0.9
top_k (int, optional): Controls vocabulary size. Defaults to 40
max_tokens (int, optional): Maximum number of tokens to generate. Defaults to 2048
Returns:
str: The generated text completion
"""
try:
# Get API key
api_key = get_gemini_api_key()
# Create the client with API key (official SDK pattern)
client = genai.Client(api_key=api_key)
# Generate content using the official SDK pattern
response = client.models.generate_content(
model='gemini-2.0-flash-001', # Using the recommended model from docs
contents=prompt,
config=types.GenerateContentConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_output_tokens=max_tokens,
)
)
# Return the generated text
return response.text
except Exception as e:
logger.error(f"Error in Gemini Pro text generation: {e}")
return str(e)