Added enhanced linguistic analyzer and persona quality improver
This commit is contained in:
Binary file not shown.
@@ -46,14 +46,17 @@ class GeminiGroundedProvider:
|
||||
# Initialize the Gemini client with timeout configuration
|
||||
self.client = genai.Client(api_key=self.api_key)
|
||||
self.timeout = 60 # 60 second timeout for API calls (increased for research)
|
||||
self._cache: Dict[str, Any] = {}
|
||||
logger.info("✅ Gemini Grounded Provider initialized with native Google Search grounding")
|
||||
|
||||
async def generate_grounded_content(
|
||||
self,
|
||||
prompt: str,
|
||||
self,
|
||||
prompt: str,
|
||||
content_type: str = "linkedin_post",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2048
|
||||
max_tokens: int = 2048,
|
||||
urls: Optional[List[str]] = None,
|
||||
mode: str = "polished"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate grounded content using native Google Search grounding.
|
||||
@@ -73,14 +76,29 @@ class GeminiGroundedProvider:
|
||||
# Build the grounded prompt
|
||||
grounded_prompt = self._build_grounded_prompt(prompt, content_type)
|
||||
|
||||
# Configure the grounding tool
|
||||
grounding_tool = types.Tool(
|
||||
google_search=types.GoogleSearch()
|
||||
)
|
||||
# Configure tools: Google Search and optional URL Context
|
||||
tools: List[Any] = [
|
||||
types.Tool(google_search=types.GoogleSearch())
|
||||
]
|
||||
if urls:
|
||||
try:
|
||||
# URL Context tool (ai.google.dev URL Context)
|
||||
tools.append(types.Tool(url_context=types.UrlContext()))
|
||||
logger.info(f"Enabled URL Context tool for {len(urls)} URLs")
|
||||
except Exception as tool_err:
|
||||
logger.warning(f"URL Context tool not available in SDK version: {tool_err}")
|
||||
|
||||
# Apply mode presets (Draft vs Polished)
|
||||
model_id = "gemini-2.5-flash"
|
||||
if mode == "draft":
|
||||
model_id = "gemini-2.5-flash-lite"
|
||||
temperature = min(1.0, max(0.0, temperature))
|
||||
else:
|
||||
model_id = "gemini-2.5-flash"
|
||||
|
||||
# Configure generation settings
|
||||
config = types.GenerateContentConfig(
|
||||
tools=[grounding_tool],
|
||||
tools=tools,
|
||||
max_output_tokens=max_tokens,
|
||||
temperature=temperature
|
||||
)
|
||||
@@ -90,20 +108,27 @@ class GeminiGroundedProvider:
|
||||
import concurrent.futures
|
||||
|
||||
try:
|
||||
# Run the synchronous generate_content in a thread pool to make it awaitable
|
||||
loop = asyncio.get_event_loop()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
response = await asyncio.wait_for(
|
||||
loop.run_in_executor(
|
||||
executor,
|
||||
lambda: self.client.models.generate_content(
|
||||
model="gemini-2.5-flash",
|
||||
contents=grounded_prompt,
|
||||
config=config,
|
||||
)
|
||||
),
|
||||
timeout=self.timeout
|
||||
)
|
||||
# Cache first
|
||||
cache_key = self._make_cache_key(model_id, grounded_prompt, urls)
|
||||
if cache_key in self._cache:
|
||||
logger.info("Cache hit for grounded content request")
|
||||
response = self._cache[cache_key]
|
||||
else:
|
||||
# Run the synchronous generate_content in a thread pool to make it awaitable
|
||||
loop = asyncio.get_event_loop()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
response = await asyncio.wait_for(
|
||||
loop.run_in_executor(
|
||||
executor,
|
||||
lambda: self.client.models.generate_content(
|
||||
model=model_id,
|
||||
contents=self._inject_urls_into_prompt(grounded_prompt, urls) if urls else grounded_prompt,
|
||||
config=config,
|
||||
)
|
||||
),
|
||||
timeout=self.timeout
|
||||
)
|
||||
self._cache[cache_key] = response
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception(f"Gemini API request timed out after {self.timeout} seconds")
|
||||
except Exception as api_error:
|
||||
@@ -112,14 +137,14 @@ class GeminiGroundedProvider:
|
||||
if "503" in error_str and "overloaded" in error_str:
|
||||
# Conservative retry for overloaded service (expensive API calls)
|
||||
response = await self._retry_with_backoff(
|
||||
lambda: self._make_api_request(grounded_prompt, config),
|
||||
lambda: self._make_api_request_with_model(grounded_prompt, config, model_id, urls),
|
||||
max_retries=1, # Only 1 retry to avoid excessive costs
|
||||
base_delay=5 # Longer delay
|
||||
)
|
||||
elif "429" in error_str:
|
||||
# Conservative retry for rate limits
|
||||
response = await self._retry_with_backoff(
|
||||
lambda: self._make_api_request(grounded_prompt, config),
|
||||
lambda: self._make_api_request_with_model(grounded_prompt, config, model_id, urls),
|
||||
max_retries=1, # Only 1 retry
|
||||
base_delay=10 # Much longer delay for rate limits
|
||||
)
|
||||
@@ -132,6 +157,15 @@ class GeminiGroundedProvider:
|
||||
|
||||
# Process the grounded response
|
||||
result = self._process_grounded_response(response, content_type)
|
||||
# Attach URL Context metadata if present
|
||||
try:
|
||||
if hasattr(response, 'candidates') and response.candidates:
|
||||
candidate0 = response.candidates[0]
|
||||
if hasattr(candidate0, 'url_context_metadata') and candidate0.url_context_metadata:
|
||||
result['url_context_metadata'] = candidate0.url_context_metadata
|
||||
logger.info("Attached url_context_metadata to result")
|
||||
except Exception as meta_err:
|
||||
logger.warning(f"Unable to attach url_context_metadata: {meta_err}")
|
||||
|
||||
logger.info(f"✅ Grounded content generated successfully with {len(result.get('sources', []))} sources")
|
||||
return result
|
||||
@@ -162,6 +196,41 @@ class GeminiGroundedProvider:
|
||||
),
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
async def _make_api_request_with_model(self, grounded_prompt: str, config: Any, model_id: str, urls: Optional[List[str]] = None):
|
||||
"""Make the API request with explicit model id and optional URL injection."""
|
||||
import concurrent.futures
|
||||
loop = asyncio.get_event_loop()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
resp = await asyncio.wait_for(
|
||||
loop.run_in_executor(
|
||||
executor,
|
||||
lambda: self.client.models.generate_content(
|
||||
model=model_id,
|
||||
contents=self._inject_urls_into_prompt(grounded_prompt, urls) if urls else grounded_prompt,
|
||||
config=config,
|
||||
)
|
||||
),
|
||||
timeout=self.timeout
|
||||
)
|
||||
self._cache[self._make_cache_key(model_id, grounded_prompt, urls)] = resp
|
||||
return resp
|
||||
|
||||
def _inject_urls_into_prompt(self, prompt: str, urls: Optional[List[str]]) -> str:
|
||||
"""Append URLs to the prompt for URL Context tool to pick up (as per docs)."""
|
||||
if not urls:
|
||||
return prompt
|
||||
safe_urls = [u for u in urls if isinstance(u, str) and u.startswith("http")]
|
||||
if not safe_urls:
|
||||
return prompt
|
||||
urls_block = "\n".join(safe_urls[:20])
|
||||
return f"{prompt}\n\nSOURCE URLS (use url_context to retrieve content):\n{urls_block}"
|
||||
|
||||
def _make_cache_key(self, model_id: str, prompt: str, urls: Optional[List[str]]) -> str:
|
||||
import hashlib
|
||||
u = "|".join((urls or [])[:20])
|
||||
base = f"{model_id}|{prompt}|{u}"
|
||||
return hashlib.sha256(base.encode("utf-8")).hexdigest()
|
||||
|
||||
async def _retry_with_backoff(self, func, max_retries: int = 3, base_delay: float = 1.0):
|
||||
"""Retry a function with exponential backoff."""
|
||||
|
||||
@@ -390,11 +390,19 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
|
||||
)
|
||||
|
||||
# Check for parsed content first (primary method for structured output)
|
||||
if hasattr(response, 'parsed') and response.parsed is not None:
|
||||
logger.info("Using response.parsed for structured output")
|
||||
return response.parsed
|
||||
if hasattr(response, 'parsed'):
|
||||
logger.info(f"Response has parsed attribute: {response.parsed is not None}")
|
||||
if response.parsed is not None:
|
||||
logger.info("Using response.parsed for structured output")
|
||||
return response.parsed
|
||||
else:
|
||||
logger.warning("Response.parsed is None, falling back to text parsing")
|
||||
# Debug: Check if there's any text content
|
||||
if hasattr(response, 'text') and response.text:
|
||||
logger.info(f"Text response length: {len(response.text)}")
|
||||
logger.debug(f"Text response preview: {response.text[:200]}...")
|
||||
|
||||
# Check for text content as fallback
|
||||
# Check for text content as fallback (only if no parsed content)
|
||||
if hasattr(response, 'text') and response.text:
|
||||
logger.info("No parsed content, trying to parse text response")
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user