Link
This commit is contained in:
@@ -21,6 +21,8 @@ from tenacity import (
|
||||
)
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
|
||||
# Configure standard logging
|
||||
import logging
|
||||
@@ -167,3 +169,64 @@ def gemini_pro_text_gen(prompt, temperature=0.7, top_p=0.9, top_k=40, max_tokens
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Gemini Pro text generation: {e}")
|
||||
return str(e)
|
||||
|
||||
def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9, top_k=40, max_tokens=2048, system_prompt=None):
|
||||
"""
|
||||
Generate structured JSON response using Google's Gemini Pro model.
|
||||
|
||||
Args:
|
||||
prompt (str): The input text to generate completion for
|
||||
schema (dict): The JSON schema to follow for the response
|
||||
temperature (float, optional): Controls randomness. Defaults to 0.7
|
||||
top_p (float, optional): Controls diversity. Defaults to 0.9
|
||||
top_k (int, optional): Controls vocabulary size. Defaults to 40
|
||||
max_tokens (int, optional): Maximum number of tokens to generate. Defaults to 2048
|
||||
system_prompt (str, optional): System instructions for the model
|
||||
|
||||
Returns:
|
||||
dict: The generated structured JSON response
|
||||
"""
|
||||
try:
|
||||
# Configure the model
|
||||
client = genai.Client(api_key=os.getenv('GEMINI_API_KEY'))
|
||||
|
||||
# Set up generation config
|
||||
generation_config = {
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
"max_output_tokens": max_tokens,
|
||||
}
|
||||
|
||||
# Generate content with structured response
|
||||
response = client.models.generate_content(
|
||||
model='gemini-2.0-flash',
|
||||
contents=prompt,
|
||||
config=types.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
max_output_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
response_mime_type='application/json',
|
||||
response_schema=schema
|
||||
),
|
||||
)
|
||||
|
||||
# Parse the response
|
||||
try:
|
||||
# First try to get the parsed response
|
||||
if hasattr(response, 'parsed'):
|
||||
return response.parsed
|
||||
|
||||
# If parsed is not available, try to parse the text
|
||||
response_text = response.text
|
||||
return json.loads(response_text)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Error parsing JSON response: {e}")
|
||||
return {"error": f"Failed to parse JSON response: {e}", "raw_response": response_text}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Gemini Pro structured JSON generation: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
@@ -19,12 +19,13 @@ from .deepseek_text_gen import deepseek_text_response
|
||||
from ...utils.read_main_config_params import read_return_config_section
|
||||
|
||||
|
||||
def llm_text_gen(prompt, system_prompt=None):
|
||||
def llm_text_gen(prompt, system_prompt=None, json_struct=None):
|
||||
"""
|
||||
Generate text using Language Model (LLM) based on the provided prompt.
|
||||
Args:
|
||||
prompt (str): The prompt to generate text from.
|
||||
system_prompt (str, optional): Custom system prompt to use instead of the default one.
|
||||
json_struct (dict, optional): JSON schema structure for structured responses.
|
||||
Returns:
|
||||
str: Generated text based on the prompt.
|
||||
"""
|
||||
@@ -77,7 +78,10 @@ def llm_text_gen(prompt, system_prompt=None):
|
||||
if 'google' in gpt_provider.lower():
|
||||
try:
|
||||
logger.info("Using Google Gemini Pro text generation model.")
|
||||
response = gemini_text_response(prompt, temperature, top_p, n, max_tokens, system_instructions)
|
||||
if json_struct:
|
||||
response = gemini_structured_json_response(prompt, json_struct, temperature, top_p, n, max_tokens, system_instructions)
|
||||
else:
|
||||
response = gemini_text_response(prompt, temperature, top_p, n, max_tokens, system_instructions)
|
||||
return response
|
||||
except Exception as err:
|
||||
logger.error(f"Failed to get response from gemini: {err}")
|
||||
|
||||
Reference in New Issue
Block a user