Use system instructions to steer the behavior of a model

This commit is contained in:
ajaysi
2024-08-18 17:13:00 +05:30
parent f35649f129
commit b97ad5eb2b
7 changed files with 90 additions and 61 deletions

View File

@@ -2,7 +2,6 @@
import os
import sys
from pathlib import Path
import streamlit as st
import google.generativeai as genai
from dotenv import load_dotenv
@@ -21,7 +20,7 @@ from tenacity import (
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def gemini_text_response(prompt, temperature, top_p, n, max_tokens):
def gemini_text_response(prompt, temperature, top_p, n, max_tokens, system_prompt):
""" Common functiont to get response from gemini pro Text. """
#FIXME: Include : https://github.com/google-gemini/cookbook/blob/main/quickstarts/rest/System_instructions_REST.ipynb
try:
@@ -37,7 +36,9 @@ def gemini_text_response(prompt, temperature, top_p, n, max_tokens):
"max_output_tokens": max_tokens,
}
# FIXME: Expose model_name in main_config
model = genai.GenerativeModel(model_name="gemini-1.5-pro-latest", generation_config=generation_config)
model = genai.GenerativeModel(model_name="gemini-1.5-pro-latest",
generation_config=generation_config,
system_instruction=system_prompt)
try:
# text_response = []
response = model.generate_content(prompt, stream=True)
@@ -45,7 +46,6 @@ def gemini_text_response(prompt, temperature, top_p, n, max_tokens):
for chunk in response:
# text_response.append(chunk.text)
print(chunk.text)
#st.write(chunk.text)
else:
print(response)
logger.info(f"Number of Token in Prompt Sent: {model.count_tokens(prompt)}")

View File

@@ -27,7 +27,21 @@ def llm_text_gen(prompt):
str: Generated text based on the prompt.
"""
try:
# Read the config param to create system instruction for the LLM.
gpt_provider, model, temperature, max_tokens, top_p, n, fp = read_return_config_section('llm_config')
blog_tone, blog_demographic, blog_type, blog_language, \
blog_output_format, blog_length = read_return_config_section('blog_characteristics')
# Construct the system prompt with the sidebar config params.
system_instructions = f"""
Below are the guidelines to follow:
1). You must respond in {blog_language} language.
2). Tone and Brand Alignment: Adjust your tone, voice, personality for {blog_tone} audience.
3). Make sure your response content length is of {blog_length} words.
4). The type of blog is {blog_type}, write accordingly.
5). The demographic for this content is {blog_demographic}.
6). Your response should be in {blog_output_format} format.
"""
#gpt_provider = check_gpt_provider(gpt_provider)
# Check if API key is provided for the given gpt_provider
@@ -37,7 +51,7 @@ def llm_text_gen(prompt):
if 'google' in gpt_provider.lower():
try:
logger.info("Using Google Gemini Pro text generation model.")
response = gemini_text_response(prompt, temperature, top_p, n, max_tokens)
response = gemini_text_response(prompt, temperature, top_p, n, max_tokens, system_instructions)
return response
except Exception as err:
logger.error(f"Failed to get response from gemini: {err}")
@@ -45,7 +59,7 @@ def llm_text_gen(prompt):
elif 'openai' in gpt_provider.lower():
try:
logger.info(f"Using OpenAI Model: {model} for text Generation.")
response = openai_chatgpt(prompt, model, temperature, max_tokens, top_p, n, fp)
response = openai_chatgpt(prompt, model, temperature, max_tokens, top_p, n, fp, system_instructions)
return response
except Exception as err:
logger.error(f"Failed to get response from Openai: {err}")

View File

@@ -14,7 +14,7 @@ from tenacity import (
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def openai_chatgpt(prompt, model, temperature, max_tokens, top_p, n, fp):
def openai_chatgpt(prompt, model, temperature, max_tokens, top_p, n, fp, system_prompt):
"""
Wrapper function for OpenAI's ChatGPT completion.
@@ -45,7 +45,8 @@ def openai_chatgpt(prompt, model, temperature, max_tokens, top_p, n, fp):
client = openai.OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
messages=[{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}],
max_tokens=max_tokens,
n=n,
top_p=top_p,