This commit is contained in:
ajaysi
2025-04-15 12:37:25 +05:30
parent d4f1fc77a1
commit a2b92e27bc
6 changed files with 500 additions and 134 deletions

View File

@@ -21,6 +21,8 @@ from tenacity import (
)
import asyncio
import json
import re
# Configure standard logging
import logging
@@ -167,3 +169,64 @@ def gemini_pro_text_gen(prompt, temperature=0.7, top_p=0.9, top_k=40, max_tokens
except Exception as e:
logger.error(f"Error in Gemini Pro text generation: {e}")
return str(e)
def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9, top_k=40, max_tokens=2048, system_prompt=None):
"""
Generate structured JSON response using Google's Gemini Pro model.
Args:
prompt (str): The input text to generate completion for
schema (dict): The JSON schema to follow for the response
temperature (float, optional): Controls randomness. Defaults to 0.7
top_p (float, optional): Controls diversity. Defaults to 0.9
top_k (int, optional): Controls vocabulary size. Defaults to 40
max_tokens (int, optional): Maximum number of tokens to generate. Defaults to 2048
system_prompt (str, optional): System instructions for the model
Returns:
dict: The generated structured JSON response
"""
try:
# Configure the model
client = genai.Client(api_key=os.getenv('GEMINI_API_KEY'))
# Set up generation config
generation_config = {
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"max_output_tokens": max_tokens,
}
# Generate content with structured response
response = client.models.generate_content(
model='gemini-2.0-flash',
contents=prompt,
config=types.GenerateContentConfig(
system_instruction=system_prompt,
max_output_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
response_mime_type='application/json',
response_schema=schema
),
)
# Parse the response
try:
# First try to get the parsed response
if hasattr(response, 'parsed'):
return response.parsed
# If parsed is not available, try to parse the text
response_text = response.text
return json.loads(response_text)
except json.JSONDecodeError as e:
logger.error(f"Error parsing JSON response: {e}")
return {"error": f"Failed to parse JSON response: {e}", "raw_response": response_text}
except Exception as e:
logger.error(f"Error in Gemini Pro structured JSON generation: {e}")
return {"error": str(e)}

View File

@@ -19,12 +19,13 @@ from .deepseek_text_gen import deepseek_text_response
from ...utils.read_main_config_params import read_return_config_section
def llm_text_gen(prompt, system_prompt=None):
def llm_text_gen(prompt, system_prompt=None, json_struct=None):
"""
Generate text using Language Model (LLM) based on the provided prompt.
Args:
prompt (str): The prompt to generate text from.
system_prompt (str, optional): Custom system prompt to use instead of the default one.
json_struct (dict, optional): JSON schema structure for structured responses.
Returns:
str: Generated text based on the prompt.
"""
@@ -77,7 +78,10 @@ def llm_text_gen(prompt, system_prompt=None):
if 'google' in gpt_provider.lower():
try:
logger.info("Using Google Gemini Pro text generation model.")
response = gemini_text_response(prompt, temperature, top_p, n, max_tokens, system_instructions)
if json_struct:
response = gemini_structured_json_response(prompt, json_struct, temperature, top_p, n, max_tokens, system_instructions)
else:
response = gemini_text_response(prompt, temperature, top_p, n, max_tokens, system_instructions)
return response
except Exception as err:
logger.error(f"Failed to get response from gemini: {err}")