Import 9 alphaear finance skills
- alphaear-deepear-lite: DeepEar Lite API integration - alphaear-logic-visualizer: Draw.io XML finance diagrams - alphaear-news: Real-time finance news (10+ sources) - alphaear-predictor: Kronos time-series forecasting - alphaear-reporter: Professional financial reports - alphaear-search: Web search + local RAG - alphaear-sentiment: FinBERT/LLM sentiment analysis - alphaear-signal-tracker: Signal evolution tracking - alphaear-stock: A-Share/HK/US stock data Updates: - All scripts updated to use universal .env path - Added JINA_API_KEY, LLM_*, DEEPSEEK_API_KEY to .env.example - Updated load_dotenv() to use ~/.config/opencode/.env
This commit is contained in:
0
skills/alphaear-sentiment/scripts/__init__.py
Normal file
0
skills/alphaear-sentiment/scripts/__init__.py
Normal file
581
skills/alphaear-sentiment/scripts/database_manager.py
Normal file
581
skills/alphaear-sentiment/scripts/database_manager.py
Normal file
@@ -0,0 +1,581 @@
|
||||
import sqlite3
|
||||
import json
|
||||
from datetime import datetime, date
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional, Any, Union
|
||||
import pandas as pd
|
||||
from loguru import logger
|
||||
|
||||
class DatabaseManager:
|
||||
"""
|
||||
AlphaEar 数据库管理器 - 负责存储热点数据、搜索缓存和股价数据
|
||||
使用 SQLite 进行持久化存储
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str = "data/signal_flux.db"):
|
||||
self.db_path = Path(db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
self._init_db()
|
||||
logger.info(f"💾 Database initialized at {self.db_path}")
|
||||
|
||||
def _init_db(self):
|
||||
"""初始化表结构"""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
# 1. 每日热点新闻表
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS daily_news (
|
||||
id TEXT PRIMARY KEY,
|
||||
source TEXT,
|
||||
rank INTEGER,
|
||||
title TEXT,
|
||||
url TEXT,
|
||||
content TEXT,
|
||||
publish_time TEXT,
|
||||
crawl_time TEXT,
|
||||
sentiment_score REAL,
|
||||
analysis TEXT,
|
||||
meta_data TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
# 尝试添加 analysis 列(如果表已存在但没有该列)
|
||||
try:
|
||||
cursor.execute("ALTER TABLE daily_news ADD COLUMN analysis TEXT")
|
||||
except:
|
||||
pass # 列已存在
|
||||
|
||||
|
||||
# 2. 搜索缓存表 (原有 JSON 缓存)
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS search_cache (
|
||||
query_hash TEXT PRIMARY KEY,
|
||||
query TEXT,
|
||||
engine TEXT,
|
||||
results TEXT,
|
||||
timestamp TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
# 2.5 搜索详情表 (展开的搜索结果)
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS search_detail (
|
||||
id TEXT,
|
||||
query_hash TEXT,
|
||||
rank INTEGER,
|
||||
title TEXT,
|
||||
url TEXT,
|
||||
content TEXT,
|
||||
publish_time TEXT,
|
||||
crawl_time TEXT,
|
||||
sentiment_score REAL,
|
||||
source TEXT,
|
||||
meta_data TEXT,
|
||||
PRIMARY KEY (query_hash, id)
|
||||
)
|
||||
""")
|
||||
|
||||
# 3. 股价数据表
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS stock_prices (
|
||||
ticker TEXT,
|
||||
date TEXT,
|
||||
open REAL,
|
||||
close REAL,
|
||||
high REAL,
|
||||
low REAL,
|
||||
volume REAL,
|
||||
change_pct REAL,
|
||||
PRIMARY KEY (ticker, date)
|
||||
)
|
||||
""")
|
||||
|
||||
# 4. 股票列表表 (用于检索)
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS stock_list (
|
||||
code TEXT PRIMARY KEY,
|
||||
name TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
# 5. 投资信号表 (ISQ Framework)
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS signals (
|
||||
signal_id TEXT PRIMARY KEY,
|
||||
title TEXT,
|
||||
summary TEXT,
|
||||
transmission_chain TEXT,
|
||||
sentiment_score REAL,
|
||||
confidence REAL,
|
||||
intensity INTEGER,
|
||||
expected_horizon TEXT,
|
||||
price_in_status TEXT,
|
||||
impact_tickers TEXT,
|
||||
industry_tags TEXT,
|
||||
sources TEXT,
|
||||
user_id TEXT,
|
||||
created_at TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
|
||||
|
||||
# 6. 创建索引以优化查询性能
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_news_crawl_time ON daily_news(crawl_time)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_news_source ON daily_news(source)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_search_cache_timestamp ON search_cache(timestamp)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_stock_prices_ticker_date ON stock_prices(ticker, date)")
|
||||
# 尝试添加 user_id 列到 signals 表
|
||||
try:
|
||||
cursor.execute("ALTER TABLE signals ADD COLUMN user_id TEXT")
|
||||
except:
|
||||
pass
|
||||
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_signals_user_id ON signals(user_id)")
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
#
|
||||
# self.conn.commit()
|
||||
|
||||
|
||||
# --- 新闻数据操作 ---
|
||||
|
||||
def save_daily_news(self, news_list: List[Dict]) -> int:
|
||||
"""保存热点新闻,包含发布时间与抓取时间"""
|
||||
cursor = self.conn.cursor()
|
||||
count = 0
|
||||
crawl_time = datetime.now().isoformat()
|
||||
|
||||
for news in news_list:
|
||||
try:
|
||||
# 兼容不同来源的 ID 生成逻辑
|
||||
news_id = news.get('id') or f"{news.get('source')}_{news.get('rank')}_{crawl_time[:10]}"
|
||||
cursor.execute("""
|
||||
INSERT OR REPLACE INTO daily_news
|
||||
(id, source, rank, title, url, content, publish_time, crawl_time, sentiment_score, meta_data)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
news_id,
|
||||
news.get('source'),
|
||||
news.get('rank'),
|
||||
news.get('title'),
|
||||
news.get('url'),
|
||||
news.get('content', ''),
|
||||
news.get('publish_time'), # 新增支持发布时间
|
||||
crawl_time,
|
||||
news.get('sentiment_score'),
|
||||
json.dumps(news.get('meta_data', {}))
|
||||
))
|
||||
count += 1
|
||||
except sqlite3.Error as e:
|
||||
logger.error(f"Database error saving news item {news.get('title')}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error saving news item {news.get('title')}: {e}")
|
||||
|
||||
self.conn.commit()
|
||||
return count
|
||||
|
||||
def get_daily_news(self, source: Optional[str] = None, limit: int = 100, days: int = 1) -> List[Dict]:
|
||||
"""获取最近 N 天的热点新闻"""
|
||||
cursor = self.conn.cursor()
|
||||
# 使用 crawl_time 过滤,保证结果的新鲜度
|
||||
time_threshold = (datetime.now().timestamp() - days * 86400)
|
||||
time_threshold_str = datetime.fromtimestamp(time_threshold).isoformat()
|
||||
|
||||
query = "SELECT * FROM daily_news WHERE crawl_time >= ?"
|
||||
params = [time_threshold_str]
|
||||
|
||||
if source:
|
||||
query += " AND source = ?"
|
||||
params.append(source)
|
||||
|
||||
query += " ORDER BY crawl_time DESC, rank LIMIT ?"
|
||||
params.append(limit)
|
||||
|
||||
cursor.execute(query, params)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def lookup_reference_by_url(self, url: str) -> Optional[Dict[str, Any]]:
|
||||
"""Best-effort lookup of a source item by URL.
|
||||
|
||||
This is used to render a stable bibliography from DB-backed metadata.
|
||||
It searches both `daily_news` and `search_detail`.
|
||||
"""
|
||||
url = (url or "").strip()
|
||||
if not url:
|
||||
return None
|
||||
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT title, source, publish_time, crawl_time, url
|
||||
FROM daily_news
|
||||
WHERE url = ?
|
||||
ORDER BY crawl_time DESC
|
||||
LIMIT 1
|
||||
""",
|
||||
(url,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return dict(row)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT title, source, publish_time, crawl_time, url
|
||||
FROM search_detail
|
||||
WHERE url = ?
|
||||
ORDER BY crawl_time DESC
|
||||
LIMIT 1
|
||||
""",
|
||||
(url,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return dict(row)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def delete_news(self, news_id: str) -> bool:
|
||||
"""删除特定新闻"""
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute("DELETE FROM daily_news WHERE id = ?", (news_id,))
|
||||
self.conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def update_news_content(self, news_id: str, content: str = None, analysis: str = None) -> bool:
|
||||
"""更新新闻的内容或分析结果"""
|
||||
cursor = self.conn.cursor()
|
||||
updates = []
|
||||
params = []
|
||||
|
||||
if content is not None:
|
||||
updates.append("content = ?")
|
||||
params.append(content)
|
||||
if analysis is not None:
|
||||
updates.append("analysis = ?")
|
||||
params.append(analysis)
|
||||
|
||||
if not updates:
|
||||
return False
|
||||
|
||||
params.append(news_id)
|
||||
query = f"UPDATE daily_news SET {', '.join(updates)} WHERE id = ?"
|
||||
cursor.execute(query, params)
|
||||
self.conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
# --- 搜索缓存辅助 ---
|
||||
|
||||
def get_search_cache(self, query_hash: str, ttl_seconds: Optional[int] = None) -> Optional[Dict]:
|
||||
"""获取搜索缓存 (优先查 search_detail)"""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
# 1. 尝试从 search_detail 获取展开的结构化数据
|
||||
cursor.execute("""
|
||||
SELECT * FROM search_detail
|
||||
WHERE query_hash = ?
|
||||
ORDER BY rank
|
||||
""", (query_hash,))
|
||||
details = [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
if details:
|
||||
# 检查 TTL (取第一条的时间)
|
||||
first_time = datetime.fromisoformat(details[0]['crawl_time'])
|
||||
if ttl_seconds and (datetime.now() - first_time).total_seconds() > ttl_seconds:
|
||||
logger.info(f"⌛ Detailed cache expired for hash {query_hash}")
|
||||
pass # Expired, fall through or return None? If Detail expired, Cache likely expired too.
|
||||
# But let's check basic cache just in case metadata differs?
|
||||
# Actually if details exist, we prefer them. If expired, we return None.
|
||||
return None
|
||||
|
||||
logger.info(f"✅ Hit detailed search cache for {query_hash} ({len(details)} items)")
|
||||
# Reconstruct the expected 'results' list format for SearchTools
|
||||
# SearchTools expects a list of dicts.
|
||||
# We return a dict wrapper to match get_search_cache signature returning Dict usually containing 'results' string.
|
||||
# But SearchTools logic:
|
||||
# cache = db.get_search_cache(...)
|
||||
# cached_data = json.loads(cache['results'])
|
||||
|
||||
# To minimize SearchTools changes, we can return a dict mimicking the old structure
|
||||
# OR Change SearchTools to handle list return.
|
||||
# Let's return a special dict that SearchTools can recognize or just format it as before.
|
||||
return {"results": json.dumps(details), "timestamp": details[0]['crawl_time']}
|
||||
|
||||
# 2. Fallback to old table
|
||||
cursor.execute("SELECT * FROM search_cache WHERE query_hash = ?", (query_hash,))
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
row_dict = dict(row)
|
||||
if ttl_seconds:
|
||||
cache_time = datetime.fromisoformat(row_dict['timestamp'])
|
||||
if (datetime.now() - cache_time).total_seconds() > ttl_seconds:
|
||||
logger.info(f"⌛ Cache expired for hash {query_hash}")
|
||||
return None
|
||||
|
||||
return row_dict
|
||||
|
||||
def save_search_cache(self, query_hash: str, query: str, engine: str, results: Union[str, List[Dict]]):
|
||||
"""保存搜索结果 (同时保存到 search_cache 和 search_detail)"""
|
||||
cursor = self.conn.cursor()
|
||||
current_time = datetime.now().isoformat()
|
||||
|
||||
results_str = results if isinstance(results, str) else json.dumps(results)
|
||||
|
||||
# 1. Save summary to search_cache
|
||||
cursor.execute("""
|
||||
INSERT OR REPLACE INTO search_cache (query_hash, query, engine, results, timestamp)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""", (query_hash, query, engine, results_str, current_time))
|
||||
|
||||
# 2. Save details to search_detail if results is a list
|
||||
if isinstance(results, list):
|
||||
for item in results:
|
||||
try:
|
||||
item_id = item.get('id') or f"{hash(item.get('url', ''))}"
|
||||
cursor.execute("""
|
||||
INSERT OR REPLACE INTO search_detail
|
||||
(id, query_hash, rank, title, url, content, publish_time, crawl_time, sentiment_score, source, meta_data)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
str(item_id),
|
||||
query_hash,
|
||||
item.get('rank', 0),
|
||||
item.get('title'),
|
||||
item.get('url'),
|
||||
item.get('content', ''),
|
||||
item.get('publish_time'),
|
||||
item.get('crawl_time') or current_time,
|
||||
item.get('sentiment_score'),
|
||||
item.get('source'),
|
||||
json.dumps(item.get('meta_data', {}))
|
||||
))
|
||||
except sqlite3.Error as e:
|
||||
logger.error(f"Database error saving search detail {item.get('title')}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error saving search detail {item.get('title')}: {e}")
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def find_similar_queries(self, query: str, limit: int = 5) -> List[Dict]:
|
||||
"""模糊搜索相似的已缓存查询"""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
# Simple fuzzy match: query in cached OR cached in query
|
||||
q_wild = f"%{query}%"
|
||||
cursor.execute("""
|
||||
SELECT query, query_hash, timestamp, results
|
||||
FROM search_cache
|
||||
WHERE query LIKE ? OR ? LIKE ('%' || query || '%')
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
""", (q_wild, query, limit))
|
||||
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def search_local_news(self, query: str, limit: int = 5) -> List[Dict]:
|
||||
"""从本地 daily_news 搜索相关新闻"""
|
||||
cursor = self.conn.cursor()
|
||||
q_wild = f"%{query}%"
|
||||
# Search title and content
|
||||
cursor.execute("""
|
||||
SELECT * FROM daily_news
|
||||
WHERE title LIKE ? OR content LIKE ?
|
||||
ORDER BY crawl_time DESC
|
||||
LIMIT ?
|
||||
""", (q_wild, q_wild, limit))
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
# --- 股票数据操作 ---
|
||||
|
||||
def save_stock_list(self, df: pd.DataFrame):
|
||||
"""保存股票列表到 stock_list 表"""
|
||||
cursor = self.conn.cursor()
|
||||
try:
|
||||
# 清空旧表
|
||||
cursor.execute("DELETE FROM stock_list")
|
||||
|
||||
# 批量插入
|
||||
data = df[['code', 'name']].to_dict('records')
|
||||
cursor.executemany(
|
||||
"INSERT INTO stock_list (code, name) VALUES (:code, :name)",
|
||||
data
|
||||
)
|
||||
self.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
logger.error(f"Database error saving stock list: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error saving stock list: {e}")
|
||||
|
||||
def search_stock(self, query: str, limit: int = 5) -> List[Dict]:
|
||||
"""模糊搜索股票代码或名称"""
|
||||
cursor = self.conn.cursor()
|
||||
wild = f"%{query}%"
|
||||
cursor.execute("""
|
||||
SELECT code, name FROM stock_list
|
||||
WHERE code LIKE ? OR name LIKE ?
|
||||
LIMIT ?
|
||||
""", (wild, wild, limit))
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def get_stock_by_code(self, code: str) -> Optional[Dict[str, str]]:
|
||||
"""精确按代码获取股票信息。
|
||||
|
||||
Args:
|
||||
code: 股票代码(A股6位 / 港股5位),必须为纯数字字符串。
|
||||
|
||||
Returns:
|
||||
dict: {"code": str, "name": str} 或 None。
|
||||
"""
|
||||
if not code:
|
||||
return None
|
||||
clean = "".join([c for c in str(code).strip() if c.isdigit()])
|
||||
if not clean:
|
||||
return None
|
||||
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute("SELECT code, name FROM stock_list WHERE code = ? LIMIT 1", (clean,))
|
||||
row = cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def save_stock_prices(self, ticker: str, df: pd.DataFrame):
|
||||
"""保存股价历史数据"""
|
||||
if df.empty:
|
||||
return
|
||||
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
# 确保 DataFrame 有必要的列
|
||||
required_cols = ['date', 'open', 'close', 'high', 'low', 'volume', 'change_pct']
|
||||
for col in required_cols:
|
||||
if col not in df.columns:
|
||||
logger.warning(f"Missing column {col} in stock data for {ticker}")
|
||||
return
|
||||
|
||||
try:
|
||||
for _, row in df.iterrows():
|
||||
cursor.execute("""
|
||||
INSERT OR REPLACE INTO stock_prices
|
||||
(ticker, date, open, close, high, low, volume, change_pct)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
ticker,
|
||||
row['date'],
|
||||
row['open'],
|
||||
row['close'],
|
||||
row['high'],
|
||||
row['low'],
|
||||
row['volume'],
|
||||
row['change_pct']
|
||||
))
|
||||
self.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
logger.error(f"Database error saving stock prices for {ticker}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error saving stock prices for {ticker}: {e}")
|
||||
|
||||
def get_stock_prices(self, ticker: str, start_date: str, end_date: str) -> pd.DataFrame:
|
||||
"""获取指定日期范围的股价数据"""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT * FROM stock_prices
|
||||
WHERE ticker = ? AND date >= ? AND date <= ?
|
||||
ORDER BY date
|
||||
""", (ticker, start_date, end_date))
|
||||
|
||||
rows = cursor.fetchall()
|
||||
if not rows:
|
||||
return pd.DataFrame()
|
||||
|
||||
columns = ['ticker', 'date', 'open', 'close', 'high', 'low', 'volume', 'change_pct']
|
||||
return pd.DataFrame([dict(row) for row in rows], columns=columns)
|
||||
|
||||
def execute_query(self, query: str, params: tuple = ()) -> List[Any]:
|
||||
"""执行自定义 SQL 查询"""
|
||||
try:
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(query, params)
|
||||
if query.strip().upper().startswith("SELECT"):
|
||||
return cursor.fetchall()
|
||||
else:
|
||||
self.conn.commit()
|
||||
return []
|
||||
except sqlite3.Error as e:
|
||||
logger.error(f"SQL execution failed (Database error): {e}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"SQL execution failed (Unexpected error): {e}")
|
||||
return []
|
||||
|
||||
# --- 投资信号操作 (ISQ Framework) ---
|
||||
|
||||
def save_signal(self, signal: Dict[str, Any]):
|
||||
"""保存投资信号"""
|
||||
cursor = self.conn.cursor()
|
||||
created_at = datetime.now().isoformat()
|
||||
|
||||
cursor.execute("""
|
||||
INSERT OR REPLACE INTO signals
|
||||
(signal_id, title, summary, transmission_chain, sentiment_score,
|
||||
confidence, intensity, expected_horizon, price_in_status,
|
||||
impact_tickers, industry_tags, sources, user_id, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
signal.get('signal_id'),
|
||||
signal.get('title'),
|
||||
signal.get('summary'),
|
||||
json.dumps(signal.get('transmission_chain', [])),
|
||||
signal.get('sentiment_score', 0.0),
|
||||
signal.get('confidence', 0.0),
|
||||
signal.get('intensity', 1),
|
||||
signal.get('expected_horizon', 'T+0'),
|
||||
signal.get('price_in_status', '未知'),
|
||||
json.dumps(signal.get('impact_tickers', [])),
|
||||
json.dumps(signal.get('industry_tags', [])),
|
||||
json.dumps(signal.get('sources', [])),
|
||||
signal.get('user_id'),
|
||||
created_at
|
||||
))
|
||||
self.conn.commit()
|
||||
|
||||
def get_recent_signals(self, limit: int = 20, user_id: Optional[str] = None) -> List[Dict]:
|
||||
"""获取最近的投资信号"""
|
||||
cursor = self.conn.cursor()
|
||||
if user_id:
|
||||
cursor.execute("SELECT * FROM signals WHERE user_id = ? ORDER BY created_at DESC LIMIT ?", (user_id, limit))
|
||||
else:
|
||||
cursor.execute("SELECT * FROM signals ORDER BY created_at DESC LIMIT ?", (limit,))
|
||||
rows = cursor.fetchall()
|
||||
|
||||
signals = []
|
||||
for row in rows:
|
||||
d = dict(row)
|
||||
# 解析 JSON 字段
|
||||
for field in ['transmission_chain', 'impact_tickers', 'industry_tags', 'sources']:
|
||||
if d.get(field):
|
||||
try:
|
||||
d[field] = json.loads(d[field])
|
||||
except:
|
||||
pass
|
||||
signals.append(d)
|
||||
return signals
|
||||
|
||||
def close(self):
|
||||
if self.conn:
|
||||
self.conn.close()
|
||||
logger.info("Database connection closed.")
|
||||
|
||||
85
skills/alphaear-sentiment/scripts/llm/capability.py
Normal file
85
skills/alphaear-sentiment/scripts/llm/capability.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import os
|
||||
from typing import Optional, List, Dict, Any
|
||||
from agno.agent import Agent
|
||||
from agno.models.base import Model
|
||||
from loguru import logger
|
||||
from .llm.factory import get_model
|
||||
|
||||
|
||||
def test_tool_call_support(model: Model) -> bool:
|
||||
"""
|
||||
测试模型是否支持原生的 Tool Call (Function Calling)。
|
||||
通过尝试执行一个简单的加法工具来验证。
|
||||
"""
|
||||
|
||||
def get_current_weather(location: str):
|
||||
"""获取指定地点的天气"""
|
||||
return f"{location} 的天气是晴天,25度。"
|
||||
|
||||
test_agent = Agent(
|
||||
model=model,
|
||||
tools=[get_current_weather],
|
||||
instructions="请调用工具查询北京的天气,并直接返回工具的输出结果。",
|
||||
)
|
||||
|
||||
try:
|
||||
# 运行一个简单的任务,观察是否触发了 tool_call
|
||||
response = test_agent.run("北京天气怎么样?")
|
||||
|
||||
# 检查 response 中是否包含 tool_calls
|
||||
# Agno 的 RunResponse 对象通常包含 messages,我们可以检查最后几条消息
|
||||
has_tool_call = False
|
||||
for msg in response.messages:
|
||||
if hasattr(msg, "tool_calls") and msg.tool_calls:
|
||||
has_tool_call = True
|
||||
break
|
||||
|
||||
if has_tool_call:
|
||||
logger.info(f"✅ Model {model.id} supports native tool calling.")
|
||||
return True
|
||||
else:
|
||||
# 如果没有 tool_calls 但返回了正确答案,可能是模型通过纯文本模拟了工具调用(ReAct)
|
||||
# 或者根本没用工具。对于原生支持的判断,我们坚持要求有 tool_calls 结构。
|
||||
logger.warning(
|
||||
f"⚠️ Model {model.id} did NOT use native tool calling structure."
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error testing tool call for {model.id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class ModelCapabilityRegistry:
|
||||
"""
|
||||
模型能力注册表,用于缓存和管理不同模型的能力测试结果。
|
||||
"""
|
||||
|
||||
_cache = {}
|
||||
|
||||
@classmethod
|
||||
def get_capabilities(
|
||||
cls, provider: str, model_id: str, **kwargs
|
||||
) -> Dict[str, bool]:
|
||||
key = f"{provider}:{model_id}"
|
||||
if key not in cls._cache:
|
||||
logger.info(f"🔍 Testing capabilities for {key}...")
|
||||
model = get_model(provider, model_id, **kwargs)
|
||||
supports_tool_call = test_tool_call_support(model)
|
||||
cls._cache[key] = {"supports_tool_call": supports_tool_call}
|
||||
return cls._cache[key]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(os.path.expanduser("~/.config/opencode/.env"))
|
||||
|
||||
# 测试当前配置的模型
|
||||
p = os.getenv("LLM_PROVIDER", "ust")
|
||||
m = os.getenv("LLM_MODEL", "Qwen")
|
||||
|
||||
print(f"Testing {p}/{m}...")
|
||||
res = ModelCapabilityRegistry.get_capabilities(p, m)
|
||||
print(f"Result: {res}")
|
||||
114
skills/alphaear-sentiment/scripts/llm/factory.py
Normal file
114
skills/alphaear-sentiment/scripts/llm/factory.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import os
|
||||
from agno.models.openai import OpenAIChat
|
||||
from agno.models.ollama import Ollama
|
||||
from agno.models.dashscope import DashScope
|
||||
from agno.models.deepseek import DeepSeek
|
||||
from agno.models.openrouter import OpenRouter
|
||||
|
||||
def get_model(model_provider: str, model_id: str, **kwargs):
|
||||
"""
|
||||
Factory to get the appropriate LLM model.
|
||||
|
||||
Args:
|
||||
model_provider: "openai", "ollama", "deepseek"
|
||||
model_id: The specific model ID (e.g., "gpt-4o", "llama3", "deepseek-chat")
|
||||
**kwargs: Additional arguments for the model constructor
|
||||
"""
|
||||
if model_provider == "openai":
|
||||
return OpenAIChat(id=model_id, **kwargs)
|
||||
|
||||
elif model_provider == "ollama":
|
||||
return Ollama(id=model_id, **kwargs)
|
||||
|
||||
elif model_provider == "deepseek":
|
||||
# DeepSeek is OpenAI compatible
|
||||
api_key = os.getenv("DEEPSEEK_API_KEY")
|
||||
if not api_key:
|
||||
print("Warning: DEEPSEEK_API_KEY not set.")
|
||||
|
||||
return DeepSeek(
|
||||
id=model_id,
|
||||
api_key=api_key,
|
||||
**kwargs
|
||||
)
|
||||
elif model_provider == "dashscope":
|
||||
api_key = os.getenv("DASHSCOPE_API_KEY")
|
||||
if not api_key:
|
||||
print("Warning: DASHSCOPE_API_KEY not set.")
|
||||
|
||||
return DashScope(
|
||||
id=model_id,
|
||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
api_key=api_key,
|
||||
**kwargs
|
||||
)
|
||||
elif model_provider == 'openrouter':
|
||||
api_key = os.getenv("OPENROUTER_API_KEY")
|
||||
if not api_key:
|
||||
print('Warning: OPENROUTER_API_KEY not set.')
|
||||
|
||||
return OpenRouter(
|
||||
id=model_id,
|
||||
api_key=api_key,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
elif model_provider == 'zai':
|
||||
api_key = os.getenv("ZAI_KEY_API")
|
||||
if not api_key:
|
||||
print('Warning: ZAI_KEY_API not set.')
|
||||
|
||||
# role_map to ensure compatibility.
|
||||
default_role_map = {
|
||||
"system": "system",
|
||||
"user": "user",
|
||||
"assistant": "assistant",
|
||||
"tool": "tool",
|
||||
"model": "assistant",
|
||||
}
|
||||
|
||||
# Allow callers to override role_map via kwargs, otherwise use default
|
||||
role_map = kwargs.pop("role_map", default_role_map)
|
||||
|
||||
return OpenAIChat(
|
||||
id=model_id,
|
||||
base_url="https://api.z.ai/api/paas/v4",
|
||||
api_key=api_key,
|
||||
timeout=60,
|
||||
role_map=role_map,
|
||||
extra_body={"enable_thinking": False}, # TODO: one more setting for thinking
|
||||
**kwargs
|
||||
)
|
||||
|
||||
elif model_provider == 'ust':
|
||||
api_key = os.getenv("UST_KEY_API")
|
||||
if not api_key:
|
||||
print('Warning: UST_KEY_API not set.')
|
||||
|
||||
# Some UST-compatible endpoints expect the standard OpenAI role names
|
||||
# (e.g. "system", "user", "assistant") rather than Agno's default
|
||||
# mapping which maps "system" -> "developer". Provide an explicit
|
||||
# role_map to ensure compatibility.
|
||||
default_role_map = {
|
||||
"system": "system",
|
||||
"user": "user",
|
||||
"assistant": "assistant",
|
||||
"tool": "tool",
|
||||
"model": "assistant",
|
||||
}
|
||||
|
||||
# Allow callers to override role_map via kwargs, otherwise use default
|
||||
role_map = kwargs.pop("role_map", default_role_map)
|
||||
|
||||
return OpenAIChat(
|
||||
id=model_id,
|
||||
api_key=api_key,
|
||||
base_url=os.getenv("UST_URL"),
|
||||
role_map=role_map,
|
||||
extra_body={"enable_thinking": False}, # TODO: one more setting for thinking
|
||||
**kwargs
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown model provider: {model_provider}")
|
||||
|
||||
80
skills/alphaear-sentiment/scripts/llm/router.py
Normal file
80
skills/alphaear-sentiment/scripts/llm/router.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import os
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
from agno.models.base import Model
|
||||
from loguru import logger
|
||||
from dotenv import load_dotenv
|
||||
from .llm.factory import get_model
|
||||
from utils.llm.capability import ModelCapabilityRegistry
|
||||
|
||||
load_dotenv(os.path.expanduser("~/.config/opencode/.env"))
|
||||
|
||||
|
||||
class ModelRouter:
|
||||
"""
|
||||
模型路由管理器
|
||||
|
||||
功能:
|
||||
1. 管理“推理/写作模型” (Reasoning Model) 和“工具调用模型” (Tool Model)。
|
||||
2. 根据任务需求自动选择合适的模型。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 默认从环境变量读取
|
||||
self.reasoning_provider = os.getenv(
|
||||
"REASONING_MODEL_PROVIDER", os.getenv("LLM_PROVIDER", "openai")
|
||||
)
|
||||
self.reasoning_id = os.getenv(
|
||||
"REASONING_MODEL_ID", os.getenv("LLM_MODEL", "gpt-4o")
|
||||
)
|
||||
self.reasoning_host = os.getenv("REASONING_MODEL_HOST", os.getenv("LLM_HOST"))
|
||||
|
||||
self.tool_provider = os.getenv("TOOL_MODEL_PROVIDER", self.reasoning_provider)
|
||||
self.tool_id = os.getenv("TOOL_MODEL_ID", self.reasoning_id)
|
||||
self.tool_host = os.getenv("TOOL_MODEL_HOST", self.reasoning_host)
|
||||
|
||||
self._reasoning_model = None
|
||||
self._tool_model = None
|
||||
|
||||
logger.info(
|
||||
f"🤖 ModelRouter initialized: Reasoning={self.reasoning_id} ({self.reasoning_host or 'default'}), Tool={self.tool_id} ({self.tool_host or 'default'})"
|
||||
)
|
||||
|
||||
def get_reasoning_model(self, **kwargs) -> Model:
|
||||
if not self._reasoning_model:
|
||||
# 优先使用路由配置的 host
|
||||
if self.reasoning_host and "host" not in kwargs:
|
||||
kwargs["host"] = self.reasoning_host
|
||||
self._reasoning_model = get_model(
|
||||
self.reasoning_provider, self.reasoning_id, **kwargs
|
||||
)
|
||||
return self._reasoning_model
|
||||
|
||||
def get_tool_model(self, **kwargs) -> Model:
|
||||
if not self._tool_model:
|
||||
# 优先使用路由配置的 host
|
||||
if self.tool_host and "host" not in kwargs:
|
||||
kwargs["host"] = self.tool_host
|
||||
|
||||
# 检查 tool_model 是否真的支持 tool call
|
||||
caps = ModelCapabilityRegistry.get_capabilities(
|
||||
self.tool_provider, self.tool_id, **kwargs
|
||||
)
|
||||
if not caps["supports_tool_call"]:
|
||||
logger.warning(
|
||||
f"⚠️ Configured tool model {self.tool_id} might not support native tool calls! Consider using ReAct mode or a different model."
|
||||
)
|
||||
|
||||
self._tool_model = get_model(self.tool_provider, self.tool_id, **kwargs)
|
||||
return self._tool_model
|
||||
|
||||
def get_model_for_agent(self, has_tools: bool = False, **kwargs) -> Model:
|
||||
"""
|
||||
根据 Agent 是否包含工具来返回合适的模型。
|
||||
"""
|
||||
if has_tools:
|
||||
return self.get_tool_model(**kwargs)
|
||||
return self.get_reasoning_model(**kwargs)
|
||||
|
||||
|
||||
# 全局单例
|
||||
router = ModelRouter()
|
||||
205
skills/alphaear-sentiment/scripts/sentiment_tools.py
Normal file
205
skills/alphaear-sentiment/scripts/sentiment_tools.py
Normal file
@@ -0,0 +1,205 @@
|
||||
import os
|
||||
from typing import Dict, List, Union, Optional
|
||||
import json
|
||||
from loguru import logger
|
||||
# IMPORTS REMOVED: agno.agent, get_model
|
||||
# Internal LLM logic has been removed to delegate analysis to the calling Agent.
|
||||
from .database_manager import DatabaseManager
|
||||
|
||||
# 从环境变量读取默认情绪分析模式
|
||||
DEFAULT_SENTIMENT_MODE = os.getenv("SENTIMENT_MODE", "auto") # auto, bert, llm
|
||||
|
||||
class SentimentTools:
|
||||
"""
|
||||
情绪分析工具 - 支持 LLM 和 BERT 两种模式
|
||||
|
||||
模式说明:
|
||||
- "auto": 自动选择,优先使用 BERT(速度快),不可用时回退到 LLM
|
||||
- "bert": 强制使用 BERT 模型(需要 transformers 库)
|
||||
- "llm": 强制使用 LLM(更准确但较慢)
|
||||
|
||||
可通过环境变量 SENTIMENT_MODE 设置默认模式。
|
||||
"""
|
||||
|
||||
def __init__(self, db: DatabaseManager, mode: Optional[str] = None):
|
||||
"""
|
||||
初始化情绪分析工具。
|
||||
|
||||
Args:
|
||||
db: 数据库管理器实例
|
||||
mode: 分析模式,可选 "auto", "bert", "llm"。None 则使用环境变量默认值。
|
||||
model_provider: LLM 提供商,如 "openai", "ust", "deepseek"
|
||||
model_id: 模型标识符
|
||||
"""
|
||||
self.db = db
|
||||
self.mode = mode or DEFAULT_SENTIMENT_MODE
|
||||
self.bert_pipeline = None
|
||||
|
||||
# LLM initialization removed. Agent should perform analysis if needed.
|
||||
|
||||
# Initialize BERT if needed
|
||||
if self.mode in ["bert", "auto"]:
|
||||
try:
|
||||
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
|
||||
from transformers.utils import logging as transformers_logging
|
||||
transformers_logging.set_verbosity_error() # 减少冗余日志
|
||||
|
||||
bert_model = os.getenv("BERT_SENTIMENT_MODEL", "uer/roberta-base-finetuned-chinanews-chinese")
|
||||
|
||||
# 优先使用本地缓存
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_model, local_files_only=True)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(bert_model, local_files_only=True)
|
||||
|
||||
self.bert_pipeline = pipeline(
|
||||
"sentiment-analysis",
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
device=-1
|
||||
)
|
||||
logger.info(f"✅ BERT pipeline loaded from local cache: {bert_model}")
|
||||
except (OSError, ValueError, ImportError):
|
||||
# 本地没有,则从网络下载
|
||||
logger.info(f"📡 Downloading BERT model: {bert_model}...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_model)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(bert_model)
|
||||
|
||||
self.bert_pipeline = pipeline(
|
||||
"sentiment-analysis",
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
device=-1
|
||||
)
|
||||
logger.info(f"✅ BERT Sentiment pipeline ({bert_model}) initialized.")
|
||||
except ImportError:
|
||||
logger.warning("Transformers library not installed. BERT sentiment analysis disabled.")
|
||||
except Exception as e:
|
||||
if self.mode == "bert":
|
||||
logger.error(f"BERT mode requested but failed: {e}")
|
||||
else:
|
||||
logger.warning(f"BERT unavailable, using LLM only. Error: {e}")
|
||||
self.bert_pipeline = None
|
||||
|
||||
|
||||
def analyze_sentiment(self, text: str) -> Dict[str, Union[float, str]]:
|
||||
"""
|
||||
分析文本的情绪极性。仅支持 BERT 模式。
|
||||
如需 LLM 分析,请 Agent 按照 SKILL.md 中的 Prompt 自行执行。
|
||||
|
||||
Args:
|
||||
text: 需要分析的文本内容。
|
||||
|
||||
Returns:
|
||||
BERT 分析结果,或错误信息。
|
||||
"""
|
||||
if self.bert_pipeline:
|
||||
results = self.analyze_sentiment_bert([text])
|
||||
return results[0] if results else {"score": 0.0, "label": "error"}
|
||||
else:
|
||||
return {
|
||||
"score": 0.0,
|
||||
"label": "error",
|
||||
"reason": "BERT pipeline not initialized. For LLM analysis, please manually execute the prompt in SKILL.md."
|
||||
}
|
||||
|
||||
def update_single_news_sentiment(self, news_id: Union[str, int], score: float, reason: str = "") -> bool:
|
||||
"""
|
||||
允许 Agent 将手动分析的结果保存到数据库。
|
||||
|
||||
Args:
|
||||
news_id: 新闻 ID
|
||||
score: -1.0 到 1.0
|
||||
reason: 分析理由
|
||||
|
||||
Returns:
|
||||
Success bool
|
||||
"""
|
||||
try:
|
||||
cursor = self.db.conn.cursor()
|
||||
cursor.execute("""
|
||||
UPDATE daily_news
|
||||
SET sentiment_score = ?, meta_data = json_set(COALESCE(meta_data, '{}'), '$.sentiment_reason', ?)
|
||||
WHERE id = ?
|
||||
""", (score, reason, news_id))
|
||||
self.db.conn.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update sentiment for {news_id}: {e}")
|
||||
return False
|
||||
|
||||
def analyze_sentiment_bert(self, texts: List[str]) -> List[Dict]:
|
||||
"""
|
||||
使用 BERT 进行批量高速情绪分析。
|
||||
|
||||
Args:
|
||||
texts: 需要分析的文本列表。
|
||||
|
||||
Returns:
|
||||
与输入列表等长的分析结果列表。
|
||||
"""
|
||||
if not self.bert_pipeline:
|
||||
return [{"score": 0.0, "label": "error", "reason": "BERT not available"}] * len(texts)
|
||||
|
||||
try:
|
||||
results = self.bert_pipeline(texts, truncation=True, max_length=512)
|
||||
processed = []
|
||||
for r in results:
|
||||
label = r['label'].lower()
|
||||
score = r['score']
|
||||
|
||||
# 标准化不同模型的标签格式
|
||||
if 'negative' in label or 'neg' in label:
|
||||
score = -score
|
||||
elif 'neutral' in label or 'neu' in label:
|
||||
score = 0.0
|
||||
|
||||
processed.append({
|
||||
"score": float(round(score, 3)),
|
||||
"label": "positive" if score > 0.1 else ("negative" if score < -0.1 else "neutral"),
|
||||
"reason": "BERT automated analysis"
|
||||
})
|
||||
return processed
|
||||
except Exception as e:
|
||||
logger.error(f"BERT analysis failed: {e}")
|
||||
return [{"score": 0.0, "label": "error", "reason": str(e)}] * len(texts)
|
||||
|
||||
def batch_update_news_sentiment(self, source: Optional[str] = None, limit: int = 50, use_bert: Optional[bool] = None):
|
||||
"""
|
||||
批量更新数据库中新闻的情绪分数。
|
||||
|
||||
Args:
|
||||
source: 筛选特定新闻源,如 "wallstreetcn"。None 则处理所有来源。
|
||||
limit: 最多处理的新闻数量。
|
||||
use_bert: 是否使用 BERT。None 则根据初始化模式自动决定。
|
||||
|
||||
Returns:
|
||||
成功更新的新闻数量。
|
||||
"""
|
||||
news_items = self.db.get_daily_news(source=source, limit=limit)
|
||||
to_analyze = [item for item in news_items if not item.get('sentiment_score')]
|
||||
|
||||
if not to_analyze:
|
||||
return 0
|
||||
|
||||
updated_count = 0
|
||||
cursor = self.db.conn.cursor()
|
||||
|
||||
# 决定使用哪种方法
|
||||
if self.bert_pipeline:
|
||||
logger.info(f"🚀 Using BERT for batch analysis of {len(to_analyze)} items...")
|
||||
titles = [item['title'] for item in to_analyze]
|
||||
results = self.analyze_sentiment_bert(titles)
|
||||
|
||||
for item, analysis in zip(to_analyze, results):
|
||||
cursor.execute("""
|
||||
UPDATE daily_news
|
||||
SET sentiment_score = ?, meta_data = json_set(COALESCE(meta_data, '{}'), '$.sentiment_reason', ?)
|
||||
WHERE id = ?
|
||||
""", (analysis['score'], analysis['reason'], item['id']))
|
||||
updated_count += 1
|
||||
else:
|
||||
logger.warning("BERT pipeline not available. Batch update skipped. Please use Agentic analysis for high-quality results.")
|
||||
|
||||
self.db.conn.commit()
|
||||
return updated_count
|
||||
|
||||
Reference in New Issue
Block a user