Import 9 alphaear finance skills
- alphaear-deepear-lite: DeepEar Lite API integration - alphaear-logic-visualizer: Draw.io XML finance diagrams - alphaear-news: Real-time finance news (10+ sources) - alphaear-predictor: Kronos time-series forecasting - alphaear-reporter: Professional financial reports - alphaear-search: Web search + local RAG - alphaear-sentiment: FinBERT/LLM sentiment analysis - alphaear-signal-tracker: Signal evolution tracking - alphaear-stock: A-Share/HK/US stock data Updates: - All scripts updated to use universal .env path - Added JINA_API_KEY, LLM_*, DEEPSEEK_API_KEY to .env.example - Updated load_dotenv() to use ~/.config/opencode/.env
This commit is contained in:
0
skills/alphaear-search/scripts/__init__.py
Normal file
0
skills/alphaear-search/scripts/__init__.py
Normal file
122
skills/alphaear-search/scripts/content_extractor.py
Normal file
122
skills/alphaear-search/scripts/content_extractor.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import requests
|
||||
from requests.exceptions import RequestException, Timeout, ConnectionError
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
import threading
|
||||
from typing import Optional
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class ContentExtractor:
|
||||
"""内容提取工具 - 主要接入 Jina Reader API"""
|
||||
|
||||
JINA_BASE_URL = "https://r.jina.ai/"
|
||||
|
||||
# 速率限制配置 (无 API Key 时:20 次/分钟)
|
||||
_rate_limit_no_key = 20 # 每分钟最大请求数
|
||||
_rate_window = 60.0 # 时间窗口(秒)
|
||||
_min_interval = 3.0 # 请求最小间隔(秒)
|
||||
|
||||
# 类级别的速率限制状态
|
||||
_request_times = []
|
||||
_last_request_time = 0.0
|
||||
_lock = threading.Lock()
|
||||
|
||||
@classmethod
|
||||
def _wait_for_rate_limit(cls, has_api_key: bool) -> None:
|
||||
"""等待以满足速率限制要求"""
|
||||
if has_api_key:
|
||||
# 有 API Key 时,只需保持最小间隔
|
||||
time.sleep(0.5)
|
||||
return
|
||||
|
||||
with cls._lock:
|
||||
current_time = time.time()
|
||||
|
||||
# 1. 清理过期的请求记录
|
||||
cls._request_times = [t for t in cls._request_times if current_time - t < cls._rate_window]
|
||||
|
||||
# 2. 检查是否达到速率限制
|
||||
if len(cls._request_times) >= cls._rate_limit_no_key:
|
||||
# 需要等待最旧的请求过期
|
||||
oldest = cls._request_times[0]
|
||||
wait_time = cls._rate_window - (current_time - oldest) + 1.0
|
||||
if wait_time > 0:
|
||||
logger.warning(f"⏳ Jina rate limit reached, waiting {wait_time:.1f}s...")
|
||||
time.sleep(wait_time)
|
||||
current_time = time.time()
|
||||
cls._request_times = [t for t in cls._request_times if current_time - t < cls._rate_window]
|
||||
|
||||
# 3. 确保请求间隔不太快
|
||||
time_since_last = current_time - cls._last_request_time
|
||||
if time_since_last < cls._min_interval:
|
||||
sleep_time = cls._min_interval - time_since_last
|
||||
time.sleep(sleep_time)
|
||||
|
||||
# 4. 记录本次请求
|
||||
cls._request_times.append(time.time())
|
||||
cls._last_request_time = time.time()
|
||||
|
||||
@classmethod
|
||||
def extract_with_jina(cls, url: str, timeout: int = 30) -> Optional[str]:
|
||||
"""
|
||||
使用 Jina Reader 提取网页正文内容 (Markdown 格式)
|
||||
|
||||
无 API Key 时自动限速:每分钟最多 20 次请求,每次间隔至少 3 秒
|
||||
"""
|
||||
if not url or not url.startswith("http"):
|
||||
return None
|
||||
|
||||
logger.info(f"🕸️ Extracting content from: {url} via Jina...")
|
||||
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
|
||||
"Accept": "application/json"
|
||||
}
|
||||
|
||||
# 使用统一的 JINA_API_KEY
|
||||
api_key = os.getenv("JINA_API_KEY")
|
||||
has_api_key = bool(api_key and api_key.strip())
|
||||
|
||||
if has_api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
# 等待速率限制
|
||||
cls._wait_for_rate_limit(has_api_key)
|
||||
|
||||
try:
|
||||
# Jina Reader API
|
||||
full_url = f"{cls.JINA_BASE_URL}{url}"
|
||||
response = requests.get(full_url, headers=headers, timeout=timeout)
|
||||
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
data = response.json()
|
||||
# Jina JSON 响应格式通常在 data.content
|
||||
if isinstance(data, dict) and "data" in data:
|
||||
return data["data"].get("content", "")
|
||||
return data.get("content", response.text)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return response.text
|
||||
elif response.status_code == 429:
|
||||
# 触发速率限制,等待后重试一次
|
||||
logger.warning(f"⚠️ Jina rate limit (429), waiting 60s before retry...")
|
||||
time.sleep(60)
|
||||
return cls.extract_with_jina(url, timeout)
|
||||
else:
|
||||
logger.warning(f"Jina extraction failed (Status {response.status_code}) for {url}")
|
||||
return None
|
||||
|
||||
except Timeout:
|
||||
logger.error(f"Timeout during Jina extraction for {url}")
|
||||
return None
|
||||
except ConnectionError:
|
||||
logger.error(f"Connection error during Jina extraction for {url}")
|
||||
return None
|
||||
except RequestException as e:
|
||||
logger.error(f"Request error during Jina extraction: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during Jina extraction: {e}")
|
||||
return None
|
||||
159
skills/alphaear-search/scripts/database_manager.py
Normal file
159
skills/alphaear-search/scripts/database_manager.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import sqlite3
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional, Union
|
||||
from loguru import logger
|
||||
|
||||
class DatabaseManager:
|
||||
"""
|
||||
AlphaEar Search Database Manager
|
||||
Reduced version for alphaear-search skill
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str = "data/signal_flux.db"):
|
||||
self.db_path = Path(db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
self._init_db()
|
||||
logger.debug(f"💾 Search Database initialized at {self.db_path}")
|
||||
|
||||
def _init_db(self):
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
# 1. Daily News (Required for Local Search RAG)
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS daily_news (
|
||||
id TEXT PRIMARY KEY,
|
||||
source TEXT,
|
||||
rank INTEGER,
|
||||
title TEXT,
|
||||
url TEXT,
|
||||
content TEXT,
|
||||
publish_time TEXT,
|
||||
crawl_time TEXT,
|
||||
sentiment_score REAL,
|
||||
analysis TEXT,
|
||||
meta_data TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
# 2. Search Cache
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS search_cache (
|
||||
query_hash TEXT PRIMARY KEY,
|
||||
query TEXT,
|
||||
engine TEXT,
|
||||
results TEXT,
|
||||
timestamp TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
# 3. Search Details
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS search_detail (
|
||||
id TEXT,
|
||||
query_hash TEXT,
|
||||
rank INTEGER,
|
||||
title TEXT,
|
||||
url TEXT,
|
||||
content TEXT,
|
||||
publish_time TEXT,
|
||||
crawl_time TEXT,
|
||||
sentiment_score REAL,
|
||||
source TEXT,
|
||||
meta_data TEXT,
|
||||
PRIMARY KEY (query_hash, id)
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_search_cache_timestamp ON search_cache(timestamp)")
|
||||
self.conn.commit()
|
||||
|
||||
# --- Search Cache Operations ---
|
||||
|
||||
def get_search_cache(self, query_hash: str, ttl_seconds: Optional[int] = None) -> Optional[Dict]:
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
# Try detailed cache first
|
||||
cursor.execute("""
|
||||
SELECT * FROM search_detail
|
||||
WHERE query_hash = ?
|
||||
ORDER BY rank
|
||||
""", (query_hash,))
|
||||
details = [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
if details:
|
||||
first_time = datetime.fromisoformat(details[0]['crawl_time'])
|
||||
if ttl_seconds and (datetime.now() - first_time).total_seconds() > ttl_seconds:
|
||||
return None
|
||||
return {"results": json.dumps(details), "timestamp": details[0]['crawl_time']}
|
||||
|
||||
# Fallback to simple cache
|
||||
cursor.execute("SELECT * FROM search_cache WHERE query_hash = ?", (query_hash,))
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row: return None
|
||||
row_dict = dict(row)
|
||||
if ttl_seconds:
|
||||
cache_time = datetime.fromisoformat(row_dict['timestamp'])
|
||||
if (datetime.now() - cache_time).total_seconds() > ttl_seconds:
|
||||
return None
|
||||
return row_dict
|
||||
|
||||
def save_search_cache(self, query_hash: str, query: str, engine: str, results: Union[str, List[Dict]]):
|
||||
cursor = self.conn.cursor()
|
||||
current_time = datetime.now().isoformat()
|
||||
results_str = results if isinstance(results, str) else json.dumps(results)
|
||||
|
||||
cursor.execute("""
|
||||
INSERT OR REPLACE INTO search_cache (query_hash, query, engine, results, timestamp)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""", (query_hash, query, engine, results_str, current_time))
|
||||
|
||||
if isinstance(results, list):
|
||||
for item in results:
|
||||
try:
|
||||
item_id = item.get('id') or f"{hash(item.get('url', ''))}"
|
||||
cursor.execute("""
|
||||
INSERT OR REPLACE INTO search_detail
|
||||
(id, query_hash, rank, title, url, content, publish_time, crawl_time, sentiment_score, source, meta_data)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
str(item_id), query_hash, item.get('rank', 0), item.get('title'),
|
||||
item.get('url'), item.get('content', ''), item.get('publish_time'),
|
||||
item.get('crawl_time') or current_time, item.get('sentiment_score'),
|
||||
item.get('source'), json.dumps(item.get('meta_data', {}))
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving search detail: {e}")
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def find_similar_queries(self, query: str, limit: int = 5) -> List[Dict]:
|
||||
cursor = self.conn.cursor()
|
||||
q_wild = f"%{query}%"
|
||||
cursor.execute("""
|
||||
SELECT query, query_hash, timestamp, results
|
||||
FROM search_cache
|
||||
WHERE query LIKE ? OR ? LIKE ('%' || query || '%')
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
""", (q_wild, query, limit))
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def search_local_news(self, query: str, limit: int = 5) -> List[Dict]:
|
||||
cursor = self.conn.cursor()
|
||||
q_wild = f"%{query}%"
|
||||
cursor.execute("""
|
||||
SELECT * FROM daily_news
|
||||
WHERE title LIKE ? OR content LIKE ?
|
||||
ORDER BY crawl_time DESC
|
||||
LIMIT ?
|
||||
""", (q_wild, q_wild, limit))
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def close(self):
|
||||
if self.conn:
|
||||
self.conn.close()
|
||||
216
skills/alphaear-search/scripts/hybrid_search.py
Normal file
216
skills/alphaear-search/scripts/hybrid_search.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import numpy as np
|
||||
import os
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from rank_bm25 import BM25Okapi
|
||||
from loguru import logger
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
class HybridSearcher:
|
||||
"""
|
||||
统一混合检索引擎 (Hybrid RAG)
|
||||
实现 BM25 (文本) + 向量 (语义) 的融合搜索 (RRF)
|
||||
"""
|
||||
|
||||
def __init__(self, data: List[Dict[str, Any]], text_fields: List[str] = ["title", "content"], model_name: str = None):
|
||||
"""
|
||||
初始化搜索器
|
||||
|
||||
Args:
|
||||
data: 数据列表,每个元素为 Dict
|
||||
text_fields: 用于建立索引的文本字段
|
||||
model_name: 向量模型名称,默认使用 paraphrase-multilingual-MiniLM-L12-v2
|
||||
"""
|
||||
self.data = data
|
||||
self.text_fields = text_fields
|
||||
self._corpus = []
|
||||
self._bm25 = None
|
||||
self._vector_model = None
|
||||
self._embeddings = None
|
||||
self._fitted = False
|
||||
self._vector_fitted = False
|
||||
|
||||
# 默认模型
|
||||
self.model_name = model_name or os.getenv("EMBEDDING_MODEL", "paraphrase-multilingual-MiniLM-L12-v2")
|
||||
|
||||
if data:
|
||||
self._prepare_corpus()
|
||||
self._fit_bm25()
|
||||
# 延迟加载向量模型,仅在需要时或初始化时显式调用
|
||||
# self._fit_vector()
|
||||
|
||||
def _prepare_corpus(self):
|
||||
"""准备语料库用于分词"""
|
||||
import jieba # 使用 jieba 进行中文分词
|
||||
|
||||
self._corpus = []
|
||||
self._full_texts = []
|
||||
for item in self.data:
|
||||
text = " ".join([str(item.get(field, "")) for field in self.text_fields])
|
||||
self._full_texts.append(text)
|
||||
# 中文分词优化
|
||||
tokens = list(jieba.cut(text))
|
||||
self._corpus.append(tokens)
|
||||
|
||||
def _fit_bm25(self):
|
||||
"""训练 BM25 模型"""
|
||||
if self._corpus:
|
||||
self._bm25 = BM25Okapi(self._corpus)
|
||||
self._fitted = True
|
||||
logger.info(f"✅ BM25 index fitted with {len(self.data)} documents")
|
||||
|
||||
def _fit_vector(self):
|
||||
"""训练向量模型并生成 Embeddings"""
|
||||
if not self.data:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info(f"📡 Loading embedding model: {self.model_name}...")
|
||||
self._vector_model = SentenceTransformer(self.model_name)
|
||||
logger.info(f"🧠 Encoding {len(self._full_texts)} documents...")
|
||||
self._embeddings = self._vector_model.encode(self._full_texts, show_progress_bar=False)
|
||||
self._vector_fitted = True
|
||||
logger.info("✅ Vector index fitted successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to fit vector index: {e}")
|
||||
self._vector_fitted = False
|
||||
|
||||
def _compute_rrf(self, rank_lists: List[List[int]], k: int = 60) -> List[tuple]:
|
||||
"""
|
||||
计算 Reciprocal Rank Fusion (RRF)
|
||||
|
||||
Args:
|
||||
rank_lists: 多个排序后的索引列表
|
||||
k: RRF 常数,默认 60
|
||||
"""
|
||||
scores = {}
|
||||
for rank_list in rank_lists:
|
||||
for rank, idx in enumerate(rank_list):
|
||||
if idx not in scores:
|
||||
scores[idx] = 0
|
||||
scores[idx] += 1.0 / (k + rank + 1)
|
||||
|
||||
# 按分数排序
|
||||
sorted_indices = sorted(scores.items(), key=lambda x: x[1], reverse=True)
|
||||
return sorted_indices
|
||||
|
||||
def search(self, query: str, top_n: int = 5, use_vector: bool = False) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
执行混合搜索
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
top_n: 返回结果数量
|
||||
use_vector: 是否启用向量搜索
|
||||
"""
|
||||
if not self._fitted or not query:
|
||||
return []
|
||||
|
||||
import jieba
|
||||
query_tokens = list(jieba.cut(query))
|
||||
|
||||
# 1. BM25 搜索结果
|
||||
bm25_scores = self._bm25.get_scores(query_tokens)
|
||||
bm25_rank = np.argsort(bm25_scores)[::-1].tolist()
|
||||
|
||||
rank_lists = [bm25_rank]
|
||||
|
||||
# 2. 向量搜索逻辑
|
||||
if use_vector:
|
||||
if not self._vector_fitted:
|
||||
self._fit_vector()
|
||||
|
||||
if self._vector_fitted:
|
||||
query_embedding = self._vector_model.encode([query], show_progress_bar=False)
|
||||
similarities = cosine_similarity(query_embedding, self._embeddings)[0]
|
||||
vector_rank = np.argsort(similarities)[::-1].tolist()
|
||||
rank_lists.append(vector_rank)
|
||||
else:
|
||||
logger.warning("Vector search requested but model not fitted, falling back to BM25")
|
||||
|
||||
# 3. 融合排序 (RRF)
|
||||
if len(rank_lists) > 1:
|
||||
rrf_results = self._compute_rrf(rank_lists)
|
||||
# RRF 返回 (idx, score) 列表
|
||||
final_rank = [idx for idx, score in rrf_results]
|
||||
else:
|
||||
final_rank = bm25_rank
|
||||
|
||||
# 返回前 top_n 条结果
|
||||
results = [self.data[idx].copy() for idx in final_rank[:top_n]]
|
||||
|
||||
# 为每个结果注入相关性评分
|
||||
for i, res in enumerate(results):
|
||||
try:
|
||||
original_idx = final_rank[i]
|
||||
res["_search_score"] = bm25_scores[original_idx]
|
||||
if use_vector and self._vector_fitted:
|
||||
res["_vector_score"] = float(similarities[original_idx])
|
||||
except:
|
||||
res["_search_score"] = 0
|
||||
|
||||
return results
|
||||
|
||||
class InMemoryRAG(HybridSearcher):
|
||||
"""专门用于 ReportAgent 跨章节检索的内存态 RAG"""
|
||||
|
||||
def search(self, query: str, top_n: int = 3, use_vector: bool = True) -> List[Dict[str, Any]]:
|
||||
"""默认开启向量搜索的内存检索"""
|
||||
return super().search(query, top_n=top_n, use_vector=use_vector)
|
||||
|
||||
def update_data(self, new_data: List[Dict[str, Any]]):
|
||||
"""动态更新数据并重新训练索引"""
|
||||
self.data = new_data
|
||||
self._prepare_corpus()
|
||||
self._fit_bm25()
|
||||
# 如果之前已经加载过向量模型,则更新向量索引
|
||||
if self._vector_model:
|
||||
self._fit_vector()
|
||||
logger.info(f"🔄 InMemoryRAG updated with {len(new_data)} items")
|
||||
|
||||
class LocalNewsSearch(HybridSearcher):
|
||||
"""持久态 RAG:检索数据库中的历史新闻"""
|
||||
|
||||
def __init__(self, db_manager):
|
||||
"""
|
||||
Args:
|
||||
db_manager: DatabaseManager 实例
|
||||
"""
|
||||
self.db = db_manager
|
||||
# 初始时不加载数据,需调用 load_history
|
||||
super().__init__([], ["title", "content"])
|
||||
|
||||
def load_history(self, days: int = 30, limit: int = 1000):
|
||||
"""从数据库加载最近 N 天的新闻构建索引"""
|
||||
try:
|
||||
# 假设 db_manager 有 execute_query
|
||||
query = f"SELECT title, content, publish_time, source FROM daily_news ORDER BY publish_time DESC LIMIT ?"
|
||||
results = self.db.execute_query(query, (limit,))
|
||||
|
||||
data = []
|
||||
for row in results:
|
||||
# 转换 Row 为 Dict
|
||||
if hasattr(row, 'keys'):
|
||||
item = dict(row)
|
||||
else:
|
||||
item = {
|
||||
"title": row[0],
|
||||
"content": row[1],
|
||||
"publish_time": row[2],
|
||||
"source": row[3]
|
||||
}
|
||||
data.append(item)
|
||||
|
||||
self.data = data
|
||||
self._prepare_corpus()
|
||||
self._fit_bm25()
|
||||
# 默认不立即训练向量,等到第一次搜索时按需训练
|
||||
logger.info(f"📚 LocalNewsSearch loaded {len(data)} items from history")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load history for search: {e}")
|
||||
|
||||
def search(self, query: str, top_n: int = 5, use_vector: bool = True) -> List[Dict[str, Any]]:
|
||||
"""执行本地历史搜索,默认开启向量搜索"""
|
||||
if not self.data:
|
||||
self.load_history()
|
||||
return super().search(query, top_n=top_n, use_vector=use_vector)
|
||||
0
skills/alphaear-search/scripts/llm/__init__.py
Normal file
0
skills/alphaear-search/scripts/llm/__init__.py
Normal file
85
skills/alphaear-search/scripts/llm/capability.py
Normal file
85
skills/alphaear-search/scripts/llm/capability.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import os
|
||||
from typing import Optional, List, Dict, Any
|
||||
from agno.agent import Agent
|
||||
from agno.models.base import Model
|
||||
from loguru import logger
|
||||
from .factory import get_model
|
||||
|
||||
|
||||
def test_tool_call_support(model: Model) -> bool:
|
||||
"""
|
||||
测试模型是否支持原生的 Tool Call (Function Calling)。
|
||||
通过尝试执行一个简单的加法工具来验证。
|
||||
"""
|
||||
|
||||
def get_current_weather(location: str):
|
||||
"""获取指定地点的天气"""
|
||||
return f"{location} 的天气是晴天,25度。"
|
||||
|
||||
test_agent = Agent(
|
||||
model=model,
|
||||
tools=[get_current_weather],
|
||||
instructions="请调用工具查询北京的天气,并直接返回工具的输出结果。",
|
||||
)
|
||||
|
||||
try:
|
||||
# 运行一个简单的任务,观察是否触发了 tool_call
|
||||
response = test_agent.run("北京天气怎么样?")
|
||||
|
||||
# 检查 response 中是否包含 tool_calls
|
||||
# Agno 的 RunResponse 对象通常包含 messages,我们可以检查最后几条消息
|
||||
has_tool_call = False
|
||||
for msg in response.messages:
|
||||
if hasattr(msg, "tool_calls") and msg.tool_calls:
|
||||
has_tool_call = True
|
||||
break
|
||||
|
||||
if has_tool_call:
|
||||
logger.info(f"✅ Model {model.id} supports native tool calling.")
|
||||
return True
|
||||
else:
|
||||
# 如果没有 tool_calls 但返回了正确答案,可能是模型通过纯文本模拟了工具调用(ReAct)
|
||||
# 或者根本没用工具。对于原生支持的判断,我们坚持要求有 tool_calls 结构。
|
||||
logger.warning(
|
||||
f"⚠️ Model {model.id} did NOT use native tool calling structure."
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error testing tool call for {model.id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class ModelCapabilityRegistry:
|
||||
"""
|
||||
模型能力注册表,用于缓存和管理不同模型的能力测试结果。
|
||||
"""
|
||||
|
||||
_cache = {}
|
||||
|
||||
@classmethod
|
||||
def get_capabilities(
|
||||
cls, provider: str, model_id: str, **kwargs
|
||||
) -> Dict[str, bool]:
|
||||
key = f"{provider}:{model_id}"
|
||||
if key not in cls._cache:
|
||||
logger.info(f"🔍 Testing capabilities for {key}...")
|
||||
model = get_model(provider, model_id, **kwargs)
|
||||
supports_tool_call = test_tool_call_support(model)
|
||||
cls._cache[key] = {"supports_tool_call": supports_tool_call}
|
||||
return cls._cache[key]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(os.path.expanduser("~/.config/opencode/.env"))
|
||||
|
||||
# 测试当前配置的模型
|
||||
p = os.getenv("LLM_PROVIDER", "ust")
|
||||
m = os.getenv("LLM_MODEL", "Qwen")
|
||||
|
||||
print(f"Testing {p}/{m}...")
|
||||
res = ModelCapabilityRegistry.get_capabilities(p, m)
|
||||
print(f"Result: {res}")
|
||||
114
skills/alphaear-search/scripts/llm/factory.py
Normal file
114
skills/alphaear-search/scripts/llm/factory.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import os
|
||||
from agno.models.openai import OpenAIChat
|
||||
from agno.models.ollama import Ollama
|
||||
from agno.models.dashscope import DashScope
|
||||
from agno.models.deepseek import DeepSeek
|
||||
from agno.models.openrouter import OpenRouter
|
||||
|
||||
def get_model(model_provider: str, model_id: str, **kwargs):
|
||||
"""
|
||||
Factory to get the appropriate LLM model.
|
||||
|
||||
Args:
|
||||
model_provider: "openai", "ollama", "deepseek"
|
||||
model_id: The specific model ID (e.g., "gpt-4o", "llama3", "deepseek-chat")
|
||||
**kwargs: Additional arguments for the model constructor
|
||||
"""
|
||||
if model_provider == "openai":
|
||||
return OpenAIChat(id=model_id, **kwargs)
|
||||
|
||||
elif model_provider == "ollama":
|
||||
return Ollama(id=model_id, **kwargs)
|
||||
|
||||
elif model_provider == "deepseek":
|
||||
# DeepSeek is OpenAI compatible
|
||||
api_key = os.getenv("DEEPSEEK_API_KEY")
|
||||
if not api_key:
|
||||
print("Warning: DEEPSEEK_API_KEY not set.")
|
||||
|
||||
return DeepSeek(
|
||||
id=model_id,
|
||||
api_key=api_key,
|
||||
**kwargs
|
||||
)
|
||||
elif model_provider == "dashscope":
|
||||
api_key = os.getenv("DASHSCOPE_API_KEY")
|
||||
if not api_key:
|
||||
print("Warning: DASHSCOPE_API_KEY not set.")
|
||||
|
||||
return DashScope(
|
||||
id=model_id,
|
||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
api_key=api_key,
|
||||
**kwargs
|
||||
)
|
||||
elif model_provider == 'openrouter':
|
||||
api_key = os.getenv("OPENROUTER_API_KEY")
|
||||
if not api_key:
|
||||
print('Warning: OPENROUTER_API_KEY not set.')
|
||||
|
||||
return OpenRouter(
|
||||
id=model_id,
|
||||
api_key=api_key,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
elif model_provider == 'zai':
|
||||
api_key = os.getenv("ZAI_KEY_API")
|
||||
if not api_key:
|
||||
print('Warning: ZAI_KEY_API not set.')
|
||||
|
||||
# role_map to ensure compatibility.
|
||||
default_role_map = {
|
||||
"system": "system",
|
||||
"user": "user",
|
||||
"assistant": "assistant",
|
||||
"tool": "tool",
|
||||
"model": "assistant",
|
||||
}
|
||||
|
||||
# Allow callers to override role_map via kwargs, otherwise use default
|
||||
role_map = kwargs.pop("role_map", default_role_map)
|
||||
|
||||
return OpenAIChat(
|
||||
id=model_id,
|
||||
base_url="https://api.z.ai/api/paas/v4",
|
||||
api_key=api_key,
|
||||
timeout=60,
|
||||
role_map=role_map,
|
||||
extra_body={"enable_thinking": False}, # TODO: one more setting for thinking
|
||||
**kwargs
|
||||
)
|
||||
|
||||
elif model_provider == 'ust':
|
||||
api_key = os.getenv("UST_KEY_API")
|
||||
if not api_key:
|
||||
print('Warning: UST_KEY_API not set.')
|
||||
|
||||
# Some UST-compatible endpoints expect the standard OpenAI role names
|
||||
# (e.g. "system", "user", "assistant") rather than Agno's default
|
||||
# mapping which maps "system" -> "developer". Provide an explicit
|
||||
# role_map to ensure compatibility.
|
||||
default_role_map = {
|
||||
"system": "system",
|
||||
"user": "user",
|
||||
"assistant": "assistant",
|
||||
"tool": "tool",
|
||||
"model": "assistant",
|
||||
}
|
||||
|
||||
# Allow callers to override role_map via kwargs, otherwise use default
|
||||
role_map = kwargs.pop("role_map", default_role_map)
|
||||
|
||||
return OpenAIChat(
|
||||
id=model_id,
|
||||
api_key=api_key,
|
||||
base_url=os.getenv("UST_URL"),
|
||||
role_map=role_map,
|
||||
extra_body={"enable_thinking": False}, # TODO: one more setting for thinking
|
||||
**kwargs
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown model provider: {model_provider}")
|
||||
|
||||
80
skills/alphaear-search/scripts/llm/router.py
Normal file
80
skills/alphaear-search/scripts/llm/router.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import os
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
from agno.models.base import Model
|
||||
from loguru import logger
|
||||
from dotenv import load_dotenv
|
||||
from .factory import get_model
|
||||
from .capability import ModelCapabilityRegistry
|
||||
|
||||
load_dotenv(os.path.expanduser("~/.config/opencode/.env"))
|
||||
|
||||
|
||||
class ModelRouter:
|
||||
"""
|
||||
模型路由管理器
|
||||
|
||||
功能:
|
||||
1. 管理“推理/写作模型” (Reasoning Model) 和“工具调用模型” (Tool Model)。
|
||||
2. 根据任务需求自动选择合适的模型。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 默认从环境变量读取
|
||||
self.reasoning_provider = os.getenv(
|
||||
"REASONING_MODEL_PROVIDER", os.getenv("LLM_PROVIDER", "openai")
|
||||
)
|
||||
self.reasoning_id = os.getenv(
|
||||
"REASONING_MODEL_ID", os.getenv("LLM_MODEL", "gpt-4o")
|
||||
)
|
||||
self.reasoning_host = os.getenv("REASONING_MODEL_HOST", os.getenv("LLM_HOST"))
|
||||
|
||||
self.tool_provider = os.getenv("TOOL_MODEL_PROVIDER", self.reasoning_provider)
|
||||
self.tool_id = os.getenv("TOOL_MODEL_ID", self.reasoning_id)
|
||||
self.tool_host = os.getenv("TOOL_MODEL_HOST", self.reasoning_host)
|
||||
|
||||
self._reasoning_model = None
|
||||
self._tool_model = None
|
||||
|
||||
logger.info(
|
||||
f"🤖 ModelRouter initialized: Reasoning={self.reasoning_id} ({self.reasoning_host or 'default'}), Tool={self.tool_id} ({self.tool_host or 'default'})"
|
||||
)
|
||||
|
||||
def get_reasoning_model(self, **kwargs) -> Model:
|
||||
if not self._reasoning_model:
|
||||
# 优先使用路由配置的 host
|
||||
if self.reasoning_host and "host" not in kwargs:
|
||||
kwargs["host"] = self.reasoning_host
|
||||
self._reasoning_model = get_model(
|
||||
self.reasoning_provider, self.reasoning_id, **kwargs
|
||||
)
|
||||
return self._reasoning_model
|
||||
|
||||
def get_tool_model(self, **kwargs) -> Model:
|
||||
if not self._tool_model:
|
||||
# 优先使用路由配置的 host
|
||||
if self.tool_host and "host" not in kwargs:
|
||||
kwargs["host"] = self.tool_host
|
||||
|
||||
# 检查 tool_model 是否真的支持 tool call
|
||||
caps = ModelCapabilityRegistry.get_capabilities(
|
||||
self.tool_provider, self.tool_id, **kwargs
|
||||
)
|
||||
if not caps["supports_tool_call"]:
|
||||
logger.warning(
|
||||
f"⚠️ Configured tool model {self.tool_id} might not support native tool calls! Consider using ReAct mode or a different model."
|
||||
)
|
||||
|
||||
self._tool_model = get_model(self.tool_provider, self.tool_id, **kwargs)
|
||||
return self._tool_model
|
||||
|
||||
def get_model_for_agent(self, has_tools: bool = False, **kwargs) -> Model:
|
||||
"""
|
||||
根据 Agent 是否包含工具来返回合适的模型。
|
||||
"""
|
||||
if has_tools:
|
||||
return self.get_tool_model(**kwargs)
|
||||
return self.get_reasoning_model(**kwargs)
|
||||
|
||||
|
||||
# 全局单例
|
||||
router = ModelRouter()
|
||||
479
skills/alphaear-search/scripts/search_tools.py
Normal file
479
skills/alphaear-search/scripts/search_tools.py
Normal file
@@ -0,0 +1,479 @@
|
||||
import os
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
import requests
|
||||
import time
|
||||
import threading
|
||||
from typing import List, Dict, Optional, Any
|
||||
from agno.tools.duckduckgo import DuckDuckGoTools
|
||||
from agno.tools.baidusearch import BaiduSearchTools
|
||||
from datetime import datetime
|
||||
from .database_manager import DatabaseManager
|
||||
from .content_extractor import ContentExtractor
|
||||
from .hybrid_search import LocalNewsSearch
|
||||
|
||||
# 默认搜索缓存 TTL(秒),可通过环境变量覆盖
|
||||
DEFAULT_SEARCH_TTL = int(os.getenv("SEARCH_CACHE_TTL", "3600")) # 默认 1 小时
|
||||
|
||||
|
||||
class JinaSearchEngine:
|
||||
"""Jina Search API 封装 - 使用 s.jina.ai 进行网络搜索"""
|
||||
|
||||
JINA_SEARCH_URL = "https://s.jina.ai/"
|
||||
|
||||
# 速率限制配置
|
||||
_rate_limit_no_key = 10 # 无 key 时每分钟最大请求数
|
||||
_rate_window = 60.0
|
||||
_min_interval = 2.0
|
||||
_request_times = []
|
||||
_last_request_time = 0.0
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = os.getenv("JINA_API_KEY", "").strip()
|
||||
self.has_api_key = bool(self.api_key)
|
||||
if self.has_api_key:
|
||||
logger.info("✅ Jina Search API key configured")
|
||||
|
||||
@classmethod
|
||||
def _wait_for_rate_limit(cls, has_api_key: bool) -> None:
|
||||
"""等待以满足速率限制"""
|
||||
if has_api_key:
|
||||
time.sleep(0.3)
|
||||
return
|
||||
|
||||
with cls._lock:
|
||||
current_time = time.time()
|
||||
cls._request_times = [t for t in cls._request_times if current_time - t < cls._rate_window]
|
||||
|
||||
if len(cls._request_times) >= cls._rate_limit_no_key:
|
||||
oldest = cls._request_times[0]
|
||||
wait_time = cls._rate_window - (current_time - oldest) + 1.0
|
||||
if wait_time > 0:
|
||||
logger.warning(f"⏳ Jina Search rate limit, waiting {wait_time:.1f}s...")
|
||||
time.sleep(wait_time)
|
||||
current_time = time.time()
|
||||
cls._request_times = [t for t in cls._request_times if current_time - t < cls._rate_window]
|
||||
|
||||
time_since_last = current_time - cls._last_request_time
|
||||
if time_since_last < cls._min_interval:
|
||||
time.sleep(cls._min_interval - time_since_last)
|
||||
|
||||
cls._request_times.append(time.time())
|
||||
cls._last_request_time = time.time()
|
||||
|
||||
def search(self, query: str, max_results: int = 5) -> List[Dict]:
|
||||
"""
|
||||
使用 Jina Search API 执行搜索
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
max_results: 返回结果数量
|
||||
|
||||
Returns:
|
||||
搜索结果列表,每个结果包含 title, url, content
|
||||
"""
|
||||
if not query:
|
||||
return []
|
||||
|
||||
logger.info(f"🔍 Jina Search: {query}")
|
||||
|
||||
# 等待速率限制
|
||||
self._wait_for_rate_limit(self.has_api_key)
|
||||
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"X-Retain-Images": "none",
|
||||
}
|
||||
|
||||
if self.has_api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
try:
|
||||
# Jina Search API: https://s.jina.ai/{query}
|
||||
import urllib.parse
|
||||
encoded_query = urllib.parse.quote(query)
|
||||
url = f"{self.JINA_SEARCH_URL}{encoded_query}"
|
||||
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
|
||||
if response.status_code == 429:
|
||||
logger.warning("⚠️ Jina Search rate limited (429), waiting 30s...")
|
||||
time.sleep(30)
|
||||
return self.search(query, max_results)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"Jina Search failed (Status {response.status_code})")
|
||||
return []
|
||||
|
||||
# 解析响应
|
||||
try:
|
||||
data = response.json()
|
||||
except json.JSONDecodeError:
|
||||
# 如果返回纯文本,尝试解析
|
||||
data = {"data": [{"title": "Search Result", "url": "", "content": response.text}]}
|
||||
|
||||
results = []
|
||||
|
||||
# Jina 返回格式可能是 {"data": [...]} 或直接是列表
|
||||
items = data.get("data", []) if isinstance(data, dict) else data
|
||||
if not isinstance(items, list):
|
||||
items = [items] if items else []
|
||||
|
||||
for i, item in enumerate(items[:max_results]):
|
||||
if isinstance(item, dict):
|
||||
results.append({
|
||||
"title": item.get("title", f"Result {i+1}"),
|
||||
"url": item.get("url", ""),
|
||||
"href": item.get("url", ""), # 兼容性
|
||||
"content": item.get("content", item.get("description", "")),
|
||||
"body": item.get("content", item.get("description", "")), # 兼容性
|
||||
})
|
||||
elif isinstance(item, str):
|
||||
results.append({
|
||||
"title": f"Result {i+1}",
|
||||
"url": "",
|
||||
"content": item
|
||||
})
|
||||
|
||||
logger.info(f"✅ Jina Search returned {len(results)} results")
|
||||
return results
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
logger.error("Jina Search timeout")
|
||||
return []
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Jina Search request error: {e}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Jina Search unexpected error: {e}")
|
||||
return []
|
||||
|
||||
class SearchTools:
|
||||
"""扩展性搜索工具库 - 支持多引擎聚合与内容缓存"""
|
||||
|
||||
def __init__(self, db: DatabaseManager):
|
||||
self.db = db
|
||||
|
||||
# 检查 Jina API Key 是否配置
|
||||
jina_api_key = os.getenv("JINA_API_KEY", "").strip()
|
||||
self._jina_enabled = bool(jina_api_key)
|
||||
|
||||
self._engines = {
|
||||
"ddg": DuckDuckGoTools(),
|
||||
"baidu": BaiduSearchTools(),
|
||||
"local": LocalNewsSearch(db)
|
||||
}
|
||||
|
||||
# 如果配置了 Jina API Key,添加 Jina 引擎
|
||||
if self._jina_enabled:
|
||||
self._engines["jina"] = JinaSearchEngine()
|
||||
logger.info("🚀 Jina Search engine enabled (JINA_API_KEY configured)")
|
||||
|
||||
# 确定默认搜索引擎
|
||||
self._default_engine = "jina" if self._jina_enabled else "ddg"
|
||||
|
||||
def _generate_hash(self, query: str, engine: str, max_results: int) -> str:
|
||||
return hashlib.md5(f"{engine}:{query}:{max_results}".encode()).hexdigest()
|
||||
|
||||
def search(self, query: str, engine: str = None, max_results: int = 5, ttl: Optional[int] = None) -> str:
|
||||
"""
|
||||
使用指定搜索引擎执行网络搜索,结果会被缓存以提高效率。
|
||||
|
||||
Args:
|
||||
query: 搜索关键词,如 "英伟达财报" 或 "光伏行业政策"。
|
||||
engine: 搜索引擎选择。可选值:
|
||||
"jina" (Jina Search,需配置 JINA_API_KEY,LLM友好输出),
|
||||
"ddg" (DuckDuckGo,推荐英文/国际搜索),
|
||||
"baidu" (百度,推荐中文/国内搜索),
|
||||
"local" (本地历史新闻搜索,基于向量+BM25)。
|
||||
默认: 若配置了 JINA_API_KEY 则使用 "jina",否则 "ddg"。
|
||||
max_results: 期望返回的结果数量,默认 5 条。
|
||||
ttl: 缓存有效期(秒)。如果缓存超过此时间会重新搜索。
|
||||
默认使用环境变量 SEARCH_CACHE_TTL 或 3600 秒。
|
||||
设为 0 可强制刷新。
|
||||
|
||||
Returns:
|
||||
搜索结果的文本描述,包含标题、摘要和链接。
|
||||
"""
|
||||
# 使用默认引擎(如果配置了 Jina 则优先使用 Jina)
|
||||
if engine is None:
|
||||
engine = self._default_engine
|
||||
|
||||
if engine not in self._engines:
|
||||
return f"Error: Unsupported engine '{engine}'. Available: {list(self._engines.keys())}"
|
||||
|
||||
query_hash = self._generate_hash(query, engine, max_results)
|
||||
effective_ttl = ttl if ttl is not None else DEFAULT_SEARCH_TTL
|
||||
|
||||
# 1. 尝试从缓存读取 (local 引擎不缓存,因为它本身就是查库)
|
||||
if engine != "local":
|
||||
cache = self.db.get_search_cache(query_hash, ttl_seconds=effective_ttl if effective_ttl > 0 else None)
|
||||
if cache and effective_ttl != 0:
|
||||
logger.info(f"ℹ️ Found search results in cache for: {query} ({engine})")
|
||||
return cache['results']
|
||||
|
||||
# 2. 执行真实搜索
|
||||
logger.info(f"📡 Searching {engine} for: {query}")
|
||||
try:
|
||||
tool = self._engines[engine]
|
||||
if engine == "jina":
|
||||
# Jina Search 返回 List[Dict]
|
||||
jina_results = tool.search(query, max_results=max_results)
|
||||
results = []
|
||||
for r in jina_results:
|
||||
results.append({
|
||||
"title": r.get("title", ""),
|
||||
"href": r.get("url", ""),
|
||||
"body": r.get("content", "")
|
||||
})
|
||||
elif engine == "ddg":
|
||||
results = tool.duckduckgo_search(query, max_results=max_results)
|
||||
elif engine == "baidu":
|
||||
results = tool.baidu_search(query, max_results=max_results)
|
||||
elif engine == "local":
|
||||
# LocalNewsSearch 返回的是 List[Dict]
|
||||
local_results = tool.search(query, top_n=max_results)
|
||||
results = []
|
||||
for r in local_results:
|
||||
results.append({
|
||||
"title": r.get("title"),
|
||||
"href": r.get("url", "local"),
|
||||
"body": r.get("content", "")
|
||||
})
|
||||
else:
|
||||
results = "Search not implemented for this engine."
|
||||
|
||||
results_str = str(results)
|
||||
if engine != "local":
|
||||
self.db.save_search_cache(query_hash, query, engine, results_str)
|
||||
return results_str
|
||||
|
||||
except Exception as e:
|
||||
# 搜索失败时的降级策略
|
||||
if engine == "jina":
|
||||
logger.warning(f"⚠️ Jina search failed, falling back to ddg: {query} ({e})")
|
||||
try:
|
||||
return self.search(query, engine="ddg", max_results=max_results, ttl=ttl)
|
||||
except Exception as e2:
|
||||
logger.error(f"❌ DDG fallback also failed for {query}: {e2}")
|
||||
elif engine == "ddg":
|
||||
logger.warning(f"⚠️ DDG search failed, falling back to baidu: {query} ({e})")
|
||||
try:
|
||||
return self.search(query, engine="baidu", max_results=max_results, ttl=ttl)
|
||||
except Exception as e2:
|
||||
logger.error(f"❌ Baidu fallback also failed for {query}: {e2}")
|
||||
|
||||
logger.error(f"❌ Search failed for {query}: {e}")
|
||||
return f"Error occurred during search: {str(e)}"
|
||||
|
||||
def search_list(self, query: str, engine: str = None, max_results: int = 5, ttl: Optional[int] = None, enrich: bool = True) -> List[Dict]:
|
||||
"""
|
||||
执行搜索并返回结构化列表 (List[Dict])。
|
||||
Dict 包含: title, href (or url), body (or snippet)
|
||||
|
||||
Args:
|
||||
engine: 搜索引擎,默认使用配置的默认引擎(Jina 优先)
|
||||
enrich: 是否抓取正文内容 (默认 True)
|
||||
"""
|
||||
# 使用默认引擎
|
||||
if engine is None:
|
||||
engine = self._default_engine
|
||||
|
||||
if engine not in self._engines:
|
||||
logger.error(f"Unsupported engine {engine}")
|
||||
return []
|
||||
|
||||
# 不同的 hash 以区分是否 enrichment
|
||||
enrich_suffix = ":enriched" if enrich else ""
|
||||
query_hash = self._generate_hash(query, engine + enrich_suffix, max_results)
|
||||
effective_ttl = ttl if ttl is not None else DEFAULT_SEARCH_TTL
|
||||
|
||||
# 1. 尝试从缓存读取
|
||||
cache = self.db.get_search_cache(query_hash, ttl_seconds=effective_ttl if effective_ttl > 0 else None)
|
||||
if cache and effective_ttl != 0:
|
||||
try:
|
||||
cached_data = json.loads(cache['results'])
|
||||
if isinstance(cached_data, list):
|
||||
logger.info(f"ℹ️ Found structured search cache for: {query}")
|
||||
return cached_data
|
||||
except:
|
||||
pass
|
||||
|
||||
# 1.5 Smart Cache (Delegated to Agent)
|
||||
# The Agent should call list_similar_searches and judge relevance using PROMPTS.md
|
||||
|
||||
|
||||
# 2. 执行搜索
|
||||
logger.info(f"📡 Searching {engine} (structured) for: {query}")
|
||||
try:
|
||||
tool = self._engines[engine]
|
||||
results = []
|
||||
if engine == "jina":
|
||||
# Jina Search 直接返回结构化数据
|
||||
jina_results = tool.search(query, max_results=max_results)
|
||||
for r in jina_results:
|
||||
results.append({
|
||||
"title": r.get("title", ""),
|
||||
"url": r.get("url", ""),
|
||||
"href": r.get("url", ""),
|
||||
"body": r.get("content", ""),
|
||||
"content": r.get("content", ""),
|
||||
"source": "Jina Search"
|
||||
})
|
||||
elif engine == "ddg":
|
||||
results = tool.duckduckgo_search(query, max_results=max_results)
|
||||
elif engine == "baidu":
|
||||
results = tool.baidu_search(query, max_results=max_results)
|
||||
elif engine == "local":
|
||||
# LocalNewsSearch 返回的是 List[Dict]
|
||||
local_results = tool.search(query, top_n=max_results)
|
||||
results = []
|
||||
for r in local_results:
|
||||
results.append({
|
||||
"title": r.get("title"),
|
||||
"url": r.get("url", "local"),
|
||||
"body": r.get("content", "")[:500],
|
||||
"source": f"Local ({r.get('source', 'db')})",
|
||||
"publish_time": r.get("publish_time")
|
||||
})
|
||||
|
||||
# 处理字符串类型的 JSON 返回 (Baidu 常返 JSON 字符串)
|
||||
if isinstance(results, str) and engine not in ["local", "jina"]:
|
||||
try:
|
||||
results = json.loads(results)
|
||||
except:
|
||||
pass
|
||||
|
||||
# 转为统一格式
|
||||
normalized_results = []
|
||||
if isinstance(results, list):
|
||||
|
||||
for i, r in enumerate(results, 1):
|
||||
title = r.get('title', '')
|
||||
url = r.get('href') or r.get('url') or r.get('link', '')
|
||||
content = r.get('body') or r.get('snippet') or r.get('abstract', '')
|
||||
|
||||
if title and url:
|
||||
normalized_results.append({
|
||||
"id": self._generate_hash(url + query, "search_item", i),
|
||||
"rank": i,
|
||||
"title": title,
|
||||
"url": url,
|
||||
"content": content,
|
||||
"original_snippet": content, # 保留摘要
|
||||
"source": f"Search ({engine})",
|
||||
"publish_time": datetime.now().isoformat(), # 暂用当前时间
|
||||
"crawl_time": datetime.now().isoformat(),
|
||||
"meta_data": {"query": query, "engine": engine}
|
||||
})
|
||||
|
||||
# Fallback if still string and failed to parse
|
||||
elif isinstance(results, str) and results:
|
||||
normalized_results.append({"title": query, "url": "", "content": results, "source": engine})
|
||||
|
||||
# 3. 抓取正文 & 计算情绪 (Enrichment)
|
||||
# 注意:如果使用 Jina Search,内容已经是 LLM 友好格式,可选择跳过 enrichment
|
||||
skip_content_enrichment = (engine == "jina")
|
||||
|
||||
if enrich and normalized_results:
|
||||
logger.info(f"🕸️ Enriching {len(normalized_results)} search results with Jina & Sentiment...")
|
||||
extractor = ContentExtractor()
|
||||
|
||||
# Lazy load sentiment tool
|
||||
if not hasattr(self, 'sentiment_tool') or self.sentiment_tool is None:
|
||||
from .sentiment_tools import SentimentTools
|
||||
self.sentiment_tool = SentimentTools(self.db)
|
||||
|
||||
for item in normalized_results:
|
||||
if item.get("url"):
|
||||
try:
|
||||
# 如果是 Jina Search,内容已经足够好,跳过额外抓取
|
||||
if skip_content_enrichment and item.get("content") and len(item.get("content", "")) > 100:
|
||||
full_content = item["content"]
|
||||
else:
|
||||
# Use Jina Reader to get full content
|
||||
full_content = extractor.extract_with_jina(item["url"], timeout=60)
|
||||
|
||||
if full_content and len(full_content) > 100:
|
||||
item["content"] = full_content
|
||||
|
||||
# Calculate sentiment
|
||||
# Use title + snippet of content for efficiency
|
||||
text_to_analyze = f"{item['title']} {full_content[:500]}"
|
||||
sent_result = self.sentiment_tool.analyze_sentiment(text_to_analyze) # Using self.sentiment_tool
|
||||
score = sent_result.get('score', 0.0)
|
||||
item["sentiment_score"] = float(score)
|
||||
|
||||
logger.info(f" ✅ Enriched: {item['title'][:20]}... (Sentiment: {score:.2f})")
|
||||
else:
|
||||
# Fallback: Use snippet for sentiment
|
||||
logger.info(f" ⚠️ Content short/failed for {item['url']}, using snippet for sentiment.")
|
||||
text_to_analyze = f"{item['title']} {item['content']}" # content is snippet here
|
||||
sent_result = self.sentiment_tool.analyze_sentiment(text_to_analyze)
|
||||
score = sent_result.get('score', 0.0)
|
||||
item["sentiment_score"] = float(score)
|
||||
|
||||
except Exception as e:
|
||||
# Fallback: Use snippet for sentiment on error
|
||||
logger.warning(f"Failed to enrich {item['url']}: {e}. Using snippet.")
|
||||
text_to_analyze = f"{item['title']} {item['content']}"
|
||||
sent_result = self.sentiment_tool.analyze_sentiment(text_to_analyze)
|
||||
score = sent_result.get('score', 0.0)
|
||||
item["sentiment_score"] = float(score)
|
||||
|
||||
# 缓存结果 list
|
||||
if normalized_results:
|
||||
# Pass list directly, DB manager will handle JSON dump for main cache and populate search_details
|
||||
# Only cache if NOT from local news reuse (though this logic path is for fresh search)
|
||||
self.db.save_search_cache(query_hash, query, engine, normalized_results)
|
||||
|
||||
return normalized_results
|
||||
|
||||
except Exception as e:
|
||||
# 搜索失败时的降级策略
|
||||
if engine == "jina":
|
||||
logger.warning(f"⚠️ Jina search_list failed, falling back to ddg: {query} ({e})")
|
||||
try:
|
||||
return self.search_list(query, engine="ddg", max_results=max_results, ttl=ttl, enrich=enrich)
|
||||
except Exception as e2:
|
||||
logger.error(f"❌ DDG fallback (search_list) also failed for {query}: {e2}")
|
||||
elif engine == "ddg":
|
||||
logger.warning(f"⚠️ DDG search_list failed, falling back to baidu: {query} ({e})")
|
||||
try:
|
||||
return self.search_list(query, engine="baidu", max_results=max_results, ttl=ttl, enrich=enrich)
|
||||
except Exception as e2:
|
||||
logger.error(f"❌ Baidu fallback (search_list) also failed for {query}: {e2}")
|
||||
|
||||
logger.error(f"❌ Structured search failed for {query}: {e}")
|
||||
return []
|
||||
|
||||
def list_similar_queries(self, query: str, limit: int = 5) -> List[Dict]:
|
||||
"""
|
||||
查找与当前查询类似的已缓存查询。
|
||||
Agent 可用此方法获取候选缓存,并使用 PROMPTS.md 进行评估以决定是否重用。
|
||||
"""
|
||||
return self.db.find_similar_queries(query, limit=limit)
|
||||
|
||||
|
||||
def aggregate_search(self, query: str, engines: Optional[List[str]] = None, max_results: int = 5) -> str:
|
||||
"""
|
||||
使用多个搜索引擎同时搜索并聚合结果,获得更全面的信息覆盖。
|
||||
|
||||
Args:
|
||||
query: 搜索关键词。
|
||||
engines: 要使用的搜索引擎列表。可选值: ["ddg", "baidu"]。
|
||||
默认同时使用 ddg 和 baidu。
|
||||
max_results: 每个引擎期望返回的结果数量。
|
||||
|
||||
Returns:
|
||||
聚合后的搜索结果,按引擎分组显示。
|
||||
"""
|
||||
engines = engines or ["ddg", "baidu"]
|
||||
aggregated_results = []
|
||||
for engine in engines:
|
||||
res = self.search(query, engine=engine, max_results=max_results)
|
||||
aggregated_results.append(f"--- Results from {engine.upper()} ---\n{res}")
|
||||
|
||||
return "\n\n".join(aggregated_results)
|
||||
231
skills/alphaear-search/scripts/sentiment_tools.py
Normal file
231
skills/alphaear-search/scripts/sentiment_tools.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import os
|
||||
from typing import Dict, List, Union, Optional
|
||||
import json
|
||||
from loguru import logger
|
||||
from agno.agent import Agent
|
||||
from .llm.factory import get_model
|
||||
from .database_manager import DatabaseManager
|
||||
|
||||
# 从环境变量读取默认情绪分析模式
|
||||
DEFAULT_SENTIMENT_MODE = os.getenv("SENTIMENT_MODE", "auto") # auto, bert, llm
|
||||
|
||||
class SentimentTools:
|
||||
"""
|
||||
情绪分析工具 - 支持 LLM 和 BERT 两种模式
|
||||
|
||||
模式说明:
|
||||
- "auto": 自动选择,优先使用 BERT(速度快),不可用时回退到 LLM
|
||||
- "bert": 强制使用 BERT 模型(需要 transformers 库)
|
||||
- "llm": 强制使用 LLM(更准确但较慢)
|
||||
|
||||
可通过环境变量 SENTIMENT_MODE 设置默认模式。
|
||||
"""
|
||||
|
||||
def __init__(self, db: DatabaseManager, mode: Optional[str] = None,
|
||||
model_provider: str = "openai", model_id: str = "gpt-4o"):
|
||||
"""
|
||||
初始化情绪分析工具。
|
||||
|
||||
Args:
|
||||
db: 数据库管理器实例
|
||||
mode: 分析模式,可选 "auto", "bert", "llm"。None 则使用环境变量默认值。
|
||||
model_provider: LLM 提供商,如 "openai", "ust", "deepseek"
|
||||
model_id: 模型标识符
|
||||
"""
|
||||
self.db = db
|
||||
self.mode = mode or DEFAULT_SENTIMENT_MODE
|
||||
self.llm_model = None
|
||||
self.bert_pipeline = None
|
||||
|
||||
# Initialize LLM
|
||||
try:
|
||||
provider = "ust" if os.getenv("UST_KEY_API") else model_provider
|
||||
m_id = "Qwen" if provider == "ust" else model_id
|
||||
self.llm_model = get_model(provider, m_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM initialization skipped: {e}")
|
||||
|
||||
# Initialize BERT if needed
|
||||
if self.mode in ["bert", "auto"]:
|
||||
try:
|
||||
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
|
||||
from transformers.utils import logging as transformers_logging
|
||||
transformers_logging.set_verbosity_error() # 减少冗余日志
|
||||
|
||||
bert_model = os.getenv("BERT_SENTIMENT_MODEL", "uer/roberta-base-finetuned-chinanews-chinese")
|
||||
|
||||
# 优先使用本地缓存
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_model, local_files_only=True)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(bert_model, local_files_only=True)
|
||||
|
||||
self.bert_pipeline = pipeline(
|
||||
"sentiment-analysis",
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
device=-1
|
||||
)
|
||||
logger.info(f"✅ BERT pipeline loaded from local cache: {bert_model}")
|
||||
except (OSError, ValueError, ImportError):
|
||||
# 本地没有,则从网络下载
|
||||
logger.info(f"📡 Downloading BERT model: {bert_model}...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_model)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(bert_model)
|
||||
|
||||
self.bert_pipeline = pipeline(
|
||||
"sentiment-analysis",
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
device=-1
|
||||
)
|
||||
logger.info(f"✅ BERT Sentiment pipeline ({bert_model}) initialized.")
|
||||
except ImportError:
|
||||
logger.warning("Transformers library not installed. BERT sentiment analysis disabled.")
|
||||
except Exception as e:
|
||||
if self.mode == "bert":
|
||||
logger.error(f"BERT mode requested but failed: {e}")
|
||||
else:
|
||||
logger.warning(f"BERT unavailable, using LLM only. Error: {e}")
|
||||
self.bert_pipeline = None
|
||||
|
||||
|
||||
def analyze_sentiment(self, text: str) -> Dict[str, Union[float, str]]:
|
||||
"""
|
||||
分析文本的情绪极性。根据初始化时的 mode 自动选择分析方法。
|
||||
|
||||
Args:
|
||||
text: 需要分析的文本内容,如新闻标题或摘要。
|
||||
|
||||
Returns:
|
||||
包含以下字段的字典:
|
||||
- score: 情绪分值,范围 -1.0(极度负面)到 1.0(极度正面),0.0 为中性
|
||||
- label: 情绪标签,"positive"/"negative"/"neutral"
|
||||
- reason: 分析理由(仅 LLM 模式提供详细理由)
|
||||
"""
|
||||
if self.mode == "bert" and self.bert_pipeline:
|
||||
results = self.analyze_sentiment_bert([text])
|
||||
return results[0] if results else {"score": 0.0, "label": "error"}
|
||||
elif self.mode == "llm" or (self.mode == "auto" and not self.bert_pipeline):
|
||||
return self.analyze_sentiment_llm(text)
|
||||
else:
|
||||
# auto mode with BERT available
|
||||
results = self.analyze_sentiment_bert([text])
|
||||
return results[0] if results else {"score": 0.0, "label": "error"}
|
||||
|
||||
def analyze_sentiment_llm(self, text: str) -> Dict[str, Union[float, str]]:
|
||||
"""
|
||||
使用 LLM 进行深度情绪分析,可获得详细的分析理由。
|
||||
|
||||
Args:
|
||||
text: 需要分析的文本,最多处理前 1000 字符。
|
||||
|
||||
Returns:
|
||||
包含 score, label, reason 的字典。
|
||||
"""
|
||||
if not self.llm_model:
|
||||
return {"score": 0.0, "label": "neutral", "error": "LLM not initialized"}
|
||||
|
||||
analyzer = Agent(model=self.llm_model, markdown=True)
|
||||
prompt = f"""请分析以下金融/新闻文本的情绪极性。
|
||||
返回严格的 JSON 格式:
|
||||
{{"score": <float: -1.0到1.0>, "label": "<positive/negative/neutral>", "reason": "<简短理由>"}}
|
||||
|
||||
文本: {text[:1000]}"""
|
||||
|
||||
try:
|
||||
response = analyzer.run(prompt)
|
||||
content = response.content
|
||||
if "```json" in content:
|
||||
content = content.split("```json")[1].split("```")[0].strip()
|
||||
elif "```" in content:
|
||||
content = content.split("```")[1].split("```")[0].strip()
|
||||
return json.loads(content)
|
||||
except Exception as e:
|
||||
logger.error(f"LLM sentiment failed: {e}")
|
||||
return {"score": 0.0, "label": "error", "reason": str(e)}
|
||||
|
||||
def analyze_sentiment_bert(self, texts: List[str]) -> List[Dict]:
|
||||
"""
|
||||
使用 BERT 进行批量高速情绪分析。
|
||||
|
||||
Args:
|
||||
texts: 需要分析的文本列表。
|
||||
|
||||
Returns:
|
||||
与输入列表等长的分析结果列表。
|
||||
"""
|
||||
if not self.bert_pipeline:
|
||||
return [{"score": 0.0, "label": "error", "reason": "BERT not available"}] * len(texts)
|
||||
|
||||
try:
|
||||
results = self.bert_pipeline(texts, truncation=True, max_length=512)
|
||||
processed = []
|
||||
for r in results:
|
||||
label = r['label'].lower()
|
||||
score = r['score']
|
||||
|
||||
# 标准化不同模型的标签格式
|
||||
if 'negative' in label or 'neg' in label:
|
||||
score = -score
|
||||
elif 'neutral' in label or 'neu' in label:
|
||||
score = 0.0
|
||||
|
||||
processed.append({
|
||||
"score": float(round(score, 3)),
|
||||
"label": "positive" if score > 0.1 else ("negative" if score < -0.1 else "neutral"),
|
||||
"reason": "BERT automated analysis"
|
||||
})
|
||||
return processed
|
||||
except Exception as e:
|
||||
logger.error(f"BERT analysis failed: {e}")
|
||||
return [{"score": 0.0, "label": "error", "reason": str(e)}] * len(texts)
|
||||
|
||||
def batch_update_news_sentiment(self, source: Optional[str] = None, limit: int = 50, use_bert: Optional[bool] = None):
|
||||
"""
|
||||
批量更新数据库中新闻的情绪分数。
|
||||
|
||||
Args:
|
||||
source: 筛选特定新闻源,如 "wallstreetcn"。None 则处理所有来源。
|
||||
limit: 最多处理的新闻数量。
|
||||
use_bert: 是否使用 BERT。None 则根据初始化模式自动决定。
|
||||
|
||||
Returns:
|
||||
成功更新的新闻数量。
|
||||
"""
|
||||
news_items = self.db.get_daily_news(source=source, limit=limit)
|
||||
to_analyze = [item for item in news_items if not item.get('sentiment_score')]
|
||||
|
||||
if not to_analyze:
|
||||
return 0
|
||||
|
||||
# 决定使用哪种方法
|
||||
should_use_bert = use_bert if use_bert is not None else (self.bert_pipeline is not None and self.mode != "llm")
|
||||
|
||||
updated_count = 0
|
||||
cursor = self.db.conn.cursor()
|
||||
|
||||
if should_use_bert and self.bert_pipeline:
|
||||
logger.info(f"🚀 Using BERT for batch analysis of {len(to_analyze)} items...")
|
||||
titles = [item['title'] for item in to_analyze]
|
||||
results = self.analyze_sentiment_bert(titles)
|
||||
|
||||
for item, analysis in zip(to_analyze, results):
|
||||
cursor.execute("""
|
||||
UPDATE daily_news
|
||||
SET sentiment_score = ?, meta_data = json_set(COALESCE(meta_data, '{}'), '$.sentiment_reason', ?)
|
||||
WHERE id = ?
|
||||
""", (analysis['score'], analysis['reason'], item['id']))
|
||||
updated_count += 1
|
||||
else:
|
||||
logger.info(f"🚶 Using LLM for analysis of {len(to_analyze)} items...")
|
||||
for item in to_analyze:
|
||||
analysis = self.analyze_sentiment_llm(item['title'])
|
||||
cursor.execute("""
|
||||
UPDATE daily_news
|
||||
SET sentiment_score = ?, meta_data = json_set(COALESCE(meta_data, '{}'), '$.sentiment_reason', ?)
|
||||
WHERE id = ?
|
||||
""", (analysis.get('score', 0.0), analysis.get('reason', ''), item['id']))
|
||||
updated_count += 1
|
||||
|
||||
self.db.conn.commit()
|
||||
return updated_count
|
||||
Reference in New Issue
Block a user