Import 9 alphaear finance skills

- alphaear-deepear-lite: DeepEar Lite API integration
- alphaear-logic-visualizer: Draw.io XML finance diagrams
- alphaear-news: Real-time finance news (10+ sources)
- alphaear-predictor: Kronos time-series forecasting
- alphaear-reporter: Professional financial reports
- alphaear-search: Web search + local RAG
- alphaear-sentiment: FinBERT/LLM sentiment analysis
- alphaear-signal-tracker: Signal evolution tracking
- alphaear-stock: A-Share/HK/US stock data

Updates:
- All scripts updated to use universal .env path
- Added JINA_API_KEY, LLM_*, DEEPSEEK_API_KEY to .env.example
- Updated load_dotenv() to use ~/.config/opencode/.env
This commit is contained in:
Kunthawat Greethong
2026-03-27 10:11:37 +07:00
parent 7edf5bc4d0
commit 58f9380ec4
149 changed files with 26867 additions and 0 deletions

View File

@@ -0,0 +1 @@
# AlphaEar utils package

View File

@@ -0,0 +1,122 @@
import requests
from requests.exceptions import RequestException, Timeout, ConnectionError
import os
import time
import json
import threading
from typing import Optional
from loguru import logger
class ContentExtractor:
"""内容提取工具 - 主要接入 Jina Reader API"""
JINA_BASE_URL = "https://r.jina.ai/"
# 速率限制配置 (无 API Key 时20 次/分钟)
_rate_limit_no_key = 20 # 每分钟最大请求数
_rate_window = 60.0 # 时间窗口(秒)
_min_interval = 3.0 # 请求最小间隔(秒)
# 类级别的速率限制状态
_request_times = []
_last_request_time = 0.0
_lock = threading.Lock()
@classmethod
def _wait_for_rate_limit(cls, has_api_key: bool) -> None:
"""等待以满足速率限制要求"""
if has_api_key:
# 有 API Key 时,只需保持最小间隔
time.sleep(0.5)
return
with cls._lock:
current_time = time.time()
# 1. 清理过期的请求记录
cls._request_times = [t for t in cls._request_times if current_time - t < cls._rate_window]
# 2. 检查是否达到速率限制
if len(cls._request_times) >= cls._rate_limit_no_key:
# 需要等待最旧的请求过期
oldest = cls._request_times[0]
wait_time = cls._rate_window - (current_time - oldest) + 1.0
if wait_time > 0:
logger.warning(f"⏳ Jina rate limit reached, waiting {wait_time:.1f}s...")
time.sleep(wait_time)
current_time = time.time()
cls._request_times = [t for t in cls._request_times if current_time - t < cls._rate_window]
# 3. 确保请求间隔不太快
time_since_last = current_time - cls._last_request_time
if time_since_last < cls._min_interval:
sleep_time = cls._min_interval - time_since_last
time.sleep(sleep_time)
# 4. 记录本次请求
cls._request_times.append(time.time())
cls._last_request_time = time.time()
@classmethod
def extract_with_jina(cls, url: str, timeout: int = 30) -> Optional[str]:
"""
使用 Jina Reader 提取网页正文内容 (Markdown 格式)
无 API Key 时自动限速:每分钟最多 20 次请求,每次间隔至少 3 秒
"""
if not url or not url.startswith("http"):
return None
logger.info(f"🕸️ Extracting content from: {url} via Jina...")
headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
"Accept": "application/json"
}
# 使用统一的 JINA_API_KEY
api_key = os.getenv("JINA_API_KEY")
has_api_key = bool(api_key and api_key.strip())
if has_api_key:
headers["Authorization"] = f"Bearer {api_key}"
# 等待速率限制
cls._wait_for_rate_limit(has_api_key)
try:
# Jina Reader API
full_url = f"{cls.JINA_BASE_URL}{url}"
response = requests.get(full_url, headers=headers, timeout=timeout)
if response.status_code == 200:
try:
data = response.json()
# Jina JSON 响应格式通常在 data.content
if isinstance(data, dict) and "data" in data:
return data["data"].get("content", "")
return data.get("content", response.text)
except (json.JSONDecodeError, TypeError):
return response.text
elif response.status_code == 429:
# 触发速率限制,等待后重试一次
logger.warning(f"⚠️ Jina rate limit (429), waiting 60s before retry...")
time.sleep(60)
return cls.extract_with_jina(url, timeout)
else:
logger.warning(f"Jina extraction failed (Status {response.status_code}) for {url}")
return None
except Timeout:
logger.error(f"Timeout during Jina extraction for {url}")
return None
except ConnectionError:
logger.error(f"Connection error during Jina extraction for {url}")
return None
except RequestException as e:
logger.error(f"Request error during Jina extraction: {e}")
return None
except Exception as e:
logger.error(f"Unexpected error during Jina extraction: {e}")
return None

View File

@@ -0,0 +1,581 @@
import sqlite3
import json
from datetime import datetime, date
from pathlib import Path
from typing import List, Dict, Optional, Any, Union
import pandas as pd
from loguru import logger
class DatabaseManager:
"""
AlphaEar 数据库管理器 - 负责存储热点数据、搜索缓存和股价数据
使用 SQLite 进行持久化存储
"""
def __init__(self, db_path: str = "data/signal_flux.db"):
self.db_path = Path(db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
self.conn.row_factory = sqlite3.Row
self._init_db()
logger.info(f"💾 Database initialized at {self.db_path}")
def _init_db(self):
"""初始化表结构"""
cursor = self.conn.cursor()
# 1. 每日热点新闻表
cursor.execute("""
CREATE TABLE IF NOT EXISTS daily_news (
id TEXT PRIMARY KEY,
source TEXT,
rank INTEGER,
title TEXT,
url TEXT,
content TEXT,
publish_time TEXT,
crawl_time TEXT,
sentiment_score REAL,
analysis TEXT,
meta_data TEXT
)
""")
# 尝试添加 analysis 列(如果表已存在但没有该列)
try:
cursor.execute("ALTER TABLE daily_news ADD COLUMN analysis TEXT")
except:
pass # 列已存在
# 2. 搜索缓存表 (原有 JSON 缓存)
cursor.execute("""
CREATE TABLE IF NOT EXISTS search_cache (
query_hash TEXT PRIMARY KEY,
query TEXT,
engine TEXT,
results TEXT,
timestamp TEXT
)
""")
# 2.5 搜索详情表 (展开的搜索结果)
cursor.execute("""
CREATE TABLE IF NOT EXISTS search_detail (
id TEXT,
query_hash TEXT,
rank INTEGER,
title TEXT,
url TEXT,
content TEXT,
publish_time TEXT,
crawl_time TEXT,
sentiment_score REAL,
source TEXT,
meta_data TEXT,
PRIMARY KEY (query_hash, id)
)
""")
# 3. 股价数据表
cursor.execute("""
CREATE TABLE IF NOT EXISTS stock_prices (
ticker TEXT,
date TEXT,
open REAL,
close REAL,
high REAL,
low REAL,
volume REAL,
change_pct REAL,
PRIMARY KEY (ticker, date)
)
""")
# 4. 股票列表表 (用于检索)
cursor.execute("""
CREATE TABLE IF NOT EXISTS stock_list (
code TEXT PRIMARY KEY,
name TEXT
)
""")
# 5. 投资信号表 (ISQ Framework)
cursor.execute("""
CREATE TABLE IF NOT EXISTS signals (
signal_id TEXT PRIMARY KEY,
title TEXT,
summary TEXT,
transmission_chain TEXT,
sentiment_score REAL,
confidence REAL,
intensity INTEGER,
expected_horizon TEXT,
price_in_status TEXT,
impact_tickers TEXT,
industry_tags TEXT,
sources TEXT,
user_id TEXT,
created_at TEXT
)
""")
# 6. 创建索引以优化查询性能
cursor.execute("CREATE INDEX IF NOT EXISTS idx_news_crawl_time ON daily_news(crawl_time)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_news_source ON daily_news(source)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_search_cache_timestamp ON search_cache(timestamp)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_stock_prices_ticker_date ON stock_prices(ticker, date)")
# 尝试添加 user_id 列到 signals 表
try:
cursor.execute("ALTER TABLE signals ADD COLUMN user_id TEXT")
except:
pass
cursor.execute("CREATE INDEX IF NOT EXISTS idx_signals_user_id ON signals(user_id)")
self.conn.commit()
#
# self.conn.commit()
# --- 新闻数据操作 ---
def save_daily_news(self, news_list: List[Dict]) -> int:
"""保存热点新闻,包含发布时间与抓取时间"""
cursor = self.conn.cursor()
count = 0
crawl_time = datetime.now().isoformat()
for news in news_list:
try:
# 兼容不同来源的 ID 生成逻辑
news_id = news.get('id') or f"{news.get('source')}_{news.get('rank')}_{crawl_time[:10]}"
cursor.execute("""
INSERT OR REPLACE INTO daily_news
(id, source, rank, title, url, content, publish_time, crawl_time, sentiment_score, meta_data)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
news_id,
news.get('source'),
news.get('rank'),
news.get('title'),
news.get('url'),
news.get('content', ''),
news.get('publish_time'), # 新增支持发布时间
crawl_time,
news.get('sentiment_score'),
json.dumps(news.get('meta_data', {}))
))
count += 1
except sqlite3.Error as e:
logger.error(f"Database error saving news item {news.get('title')}: {e}")
except Exception as e:
logger.error(f"Unexpected error saving news item {news.get('title')}: {e}")
self.conn.commit()
return count
def get_daily_news(self, source: Optional[str] = None, limit: int = 100, days: int = 1) -> List[Dict]:
"""获取最近 N 天的热点新闻"""
cursor = self.conn.cursor()
# 使用 crawl_time 过滤,保证结果的新鲜度
time_threshold = (datetime.now().timestamp() - days * 86400)
time_threshold_str = datetime.fromtimestamp(time_threshold).isoformat()
query = "SELECT * FROM daily_news WHERE crawl_time >= ?"
params = [time_threshold_str]
if source:
query += " AND source = ?"
params.append(source)
query += " ORDER BY crawl_time DESC, rank LIMIT ?"
params.append(limit)
cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
def lookup_reference_by_url(self, url: str) -> Optional[Dict[str, Any]]:
"""Best-effort lookup of a source item by URL.
This is used to render a stable bibliography from DB-backed metadata.
It searches both `daily_news` and `search_detail`.
"""
url = (url or "").strip()
if not url:
return None
cursor = self.conn.cursor()
try:
cursor.execute(
"""
SELECT title, source, publish_time, crawl_time, url
FROM daily_news
WHERE url = ?
ORDER BY crawl_time DESC
LIMIT 1
""",
(url,),
)
row = cursor.fetchone()
if row:
return dict(row)
except Exception:
pass
try:
cursor.execute(
"""
SELECT title, source, publish_time, crawl_time, url
FROM search_detail
WHERE url = ?
ORDER BY crawl_time DESC
LIMIT 1
""",
(url,),
)
row = cursor.fetchone()
if row:
return dict(row)
except Exception:
pass
return None
def delete_news(self, news_id: str) -> bool:
"""删除特定新闻"""
cursor = self.conn.cursor()
cursor.execute("DELETE FROM daily_news WHERE id = ?", (news_id,))
self.conn.commit()
return cursor.rowcount > 0
def update_news_content(self, news_id: str, content: str = None, analysis: str = None) -> bool:
"""更新新闻的内容或分析结果"""
cursor = self.conn.cursor()
updates = []
params = []
if content is not None:
updates.append("content = ?")
params.append(content)
if analysis is not None:
updates.append("analysis = ?")
params.append(analysis)
if not updates:
return False
params.append(news_id)
query = f"UPDATE daily_news SET {', '.join(updates)} WHERE id = ?"
cursor.execute(query, params)
self.conn.commit()
return cursor.rowcount > 0
# --- 搜索缓存辅助 ---
def get_search_cache(self, query_hash: str, ttl_seconds: Optional[int] = None) -> Optional[Dict]:
"""获取搜索缓存 (优先查 search_detail)"""
cursor = self.conn.cursor()
# 1. 尝试从 search_detail 获取展开的结构化数据
cursor.execute("""
SELECT * FROM search_detail
WHERE query_hash = ?
ORDER BY rank
""", (query_hash,))
details = [dict(row) for row in cursor.fetchall()]
if details:
# 检查 TTL (取第一条的时间)
first_time = datetime.fromisoformat(details[0]['crawl_time'])
if ttl_seconds and (datetime.now() - first_time).total_seconds() > ttl_seconds:
logger.info(f"⌛ Detailed cache expired for hash {query_hash}")
pass # Expired, fall through or return None? If Detail expired, Cache likely expired too.
# But let's check basic cache just in case metadata differs?
# Actually if details exist, we prefer them. If expired, we return None.
return None
logger.info(f"✅ Hit detailed search cache for {query_hash} ({len(details)} items)")
# Reconstruct the expected 'results' list format for SearchTools
# SearchTools expects a list of dicts.
# We return a dict wrapper to match get_search_cache signature returning Dict usually containing 'results' string.
# But SearchTools logic:
# cache = db.get_search_cache(...)
# cached_data = json.loads(cache['results'])
# To minimize SearchTools changes, we can return a dict mimicking the old structure
# OR Change SearchTools to handle list return.
# Let's return a special dict that SearchTools can recognize or just format it as before.
return {"results": json.dumps(details), "timestamp": details[0]['crawl_time']}
# 2. Fallback to old table
cursor.execute("SELECT * FROM search_cache WHERE query_hash = ?", (query_hash,))
row = cursor.fetchone()
if not row:
return None
row_dict = dict(row)
if ttl_seconds:
cache_time = datetime.fromisoformat(row_dict['timestamp'])
if (datetime.now() - cache_time).total_seconds() > ttl_seconds:
logger.info(f"⌛ Cache expired for hash {query_hash}")
return None
return row_dict
def save_search_cache(self, query_hash: str, query: str, engine: str, results: Union[str, List[Dict]]):
"""保存搜索结果 (同时保存到 search_cache 和 search_detail)"""
cursor = self.conn.cursor()
current_time = datetime.now().isoformat()
results_str = results if isinstance(results, str) else json.dumps(results)
# 1. Save summary to search_cache
cursor.execute("""
INSERT OR REPLACE INTO search_cache (query_hash, query, engine, results, timestamp)
VALUES (?, ?, ?, ?, ?)
""", (query_hash, query, engine, results_str, current_time))
# 2. Save details to search_detail if results is a list
if isinstance(results, list):
for item in results:
try:
item_id = item.get('id') or f"{hash(item.get('url', ''))}"
cursor.execute("""
INSERT OR REPLACE INTO search_detail
(id, query_hash, rank, title, url, content, publish_time, crawl_time, sentiment_score, source, meta_data)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
str(item_id),
query_hash,
item.get('rank', 0),
item.get('title'),
item.get('url'),
item.get('content', ''),
item.get('publish_time'),
item.get('crawl_time') or current_time,
item.get('sentiment_score'),
item.get('source'),
json.dumps(item.get('meta_data', {}))
))
except sqlite3.Error as e:
logger.error(f"Database error saving search detail {item.get('title')}: {e}")
except Exception as e:
logger.error(f"Unexpected error saving search detail {item.get('title')}: {e}")
self.conn.commit()
def find_similar_queries(self, query: str, limit: int = 5) -> List[Dict]:
"""模糊搜索相似的已缓存查询"""
cursor = self.conn.cursor()
# Simple fuzzy match: query in cached OR cached in query
q_wild = f"%{query}%"
cursor.execute("""
SELECT query, query_hash, timestamp, results
FROM search_cache
WHERE query LIKE ? OR ? LIKE ('%' || query || '%')
ORDER BY timestamp DESC
LIMIT ?
""", (q_wild, query, limit))
return [dict(row) for row in cursor.fetchall()]
def search_local_news(self, query: str, limit: int = 5) -> List[Dict]:
"""从本地 daily_news 搜索相关新闻"""
cursor = self.conn.cursor()
q_wild = f"%{query}%"
# Search title and content
cursor.execute("""
SELECT * FROM daily_news
WHERE title LIKE ? OR content LIKE ?
ORDER BY crawl_time DESC
LIMIT ?
""", (q_wild, q_wild, limit))
return [dict(row) for row in cursor.fetchall()]
# --- 股票数据操作 ---
def save_stock_list(self, df: pd.DataFrame):
"""保存股票列表到 stock_list 表"""
cursor = self.conn.cursor()
try:
# 清空旧表
cursor.execute("DELETE FROM stock_list")
# 批量插入
data = df[['code', 'name']].to_dict('records')
cursor.executemany(
"INSERT INTO stock_list (code, name) VALUES (:code, :name)",
data
)
self.conn.commit()
except sqlite3.Error as e:
logger.error(f"Database error saving stock list: {e}")
except Exception as e:
logger.error(f"Unexpected error saving stock list: {e}")
def search_stock(self, query: str, limit: int = 5) -> List[Dict]:
"""模糊搜索股票代码或名称"""
cursor = self.conn.cursor()
wild = f"%{query}%"
cursor.execute("""
SELECT code, name FROM stock_list
WHERE code LIKE ? OR name LIKE ?
LIMIT ?
""", (wild, wild, limit))
return [dict(row) for row in cursor.fetchall()]
def get_stock_by_code(self, code: str) -> Optional[Dict[str, str]]:
"""精确按代码获取股票信息。
Args:
code: 股票代码A股6位 / 港股5位必须为纯数字字符串。
Returns:
dict: {"code": str, "name": str} 或 None。
"""
if not code:
return None
clean = "".join([c for c in str(code).strip() if c.isdigit()])
if not clean:
return None
cursor = self.conn.cursor()
cursor.execute("SELECT code, name FROM stock_list WHERE code = ? LIMIT 1", (clean,))
row = cursor.fetchone()
return dict(row) if row else None
def save_stock_prices(self, ticker: str, df: pd.DataFrame):
"""保存股价历史数据"""
if df.empty:
return
cursor = self.conn.cursor()
# 确保 DataFrame 有必要的列
required_cols = ['date', 'open', 'close', 'high', 'low', 'volume', 'change_pct']
for col in required_cols:
if col not in df.columns:
logger.warning(f"Missing column {col} in stock data for {ticker}")
return
try:
for _, row in df.iterrows():
cursor.execute("""
INSERT OR REPLACE INTO stock_prices
(ticker, date, open, close, high, low, volume, change_pct)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", (
ticker,
row['date'],
row['open'],
row['close'],
row['high'],
row['low'],
row['volume'],
row['change_pct']
))
self.conn.commit()
except sqlite3.Error as e:
logger.error(f"Database error saving stock prices for {ticker}: {e}")
except Exception as e:
logger.error(f"Unexpected error saving stock prices for {ticker}: {e}")
def get_stock_prices(self, ticker: str, start_date: str, end_date: str) -> pd.DataFrame:
"""获取指定日期范围的股价数据"""
cursor = self.conn.cursor()
cursor.execute("""
SELECT * FROM stock_prices
WHERE ticker = ? AND date >= ? AND date <= ?
ORDER BY date
""", (ticker, start_date, end_date))
rows = cursor.fetchall()
if not rows:
return pd.DataFrame()
columns = ['ticker', 'date', 'open', 'close', 'high', 'low', 'volume', 'change_pct']
return pd.DataFrame([dict(row) for row in rows], columns=columns)
def execute_query(self, query: str, params: tuple = ()) -> List[Any]:
"""执行自定义 SQL 查询"""
try:
cursor = self.conn.cursor()
cursor.execute(query, params)
if query.strip().upper().startswith("SELECT"):
return cursor.fetchall()
else:
self.conn.commit()
return []
except sqlite3.Error as e:
logger.error(f"SQL execution failed (Database error): {e}")
return []
except Exception as e:
logger.error(f"SQL execution failed (Unexpected error): {e}")
return []
# --- 投资信号操作 (ISQ Framework) ---
def save_signal(self, signal: Dict[str, Any]):
"""保存投资信号"""
cursor = self.conn.cursor()
created_at = datetime.now().isoformat()
cursor.execute("""
INSERT OR REPLACE INTO signals
(signal_id, title, summary, transmission_chain, sentiment_score,
confidence, intensity, expected_horizon, price_in_status,
impact_tickers, industry_tags, sources, user_id, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
signal.get('signal_id'),
signal.get('title'),
signal.get('summary'),
json.dumps(signal.get('transmission_chain', [])),
signal.get('sentiment_score', 0.0),
signal.get('confidence', 0.0),
signal.get('intensity', 1),
signal.get('expected_horizon', 'T+0'),
signal.get('price_in_status', '未知'),
json.dumps(signal.get('impact_tickers', [])),
json.dumps(signal.get('industry_tags', [])),
json.dumps(signal.get('sources', [])),
signal.get('user_id'),
created_at
))
self.conn.commit()
def get_recent_signals(self, limit: int = 20, user_id: Optional[str] = None) -> List[Dict]:
"""获取最近的投资信号"""
cursor = self.conn.cursor()
if user_id:
cursor.execute("SELECT * FROM signals WHERE user_id = ? ORDER BY created_at DESC LIMIT ?", (user_id, limit))
else:
cursor.execute("SELECT * FROM signals ORDER BY created_at DESC LIMIT ?", (limit,))
rows = cursor.fetchall()
signals = []
for row in rows:
d = dict(row)
# 解析 JSON 字段
for field in ['transmission_chain', 'impact_tickers', 'industry_tags', 'sources']:
if d.get(field):
try:
d[field] = json.loads(d[field])
except:
pass
signals.append(d)
return signals
def close(self):
if self.conn:
self.conn.close()
logger.info("Database connection closed.")

View File

@@ -0,0 +1,216 @@
import numpy as np
import os
from typing import List, Dict, Any, Optional, Union
from rank_bm25 import BM25Okapi
from loguru import logger
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
class HybridSearcher:
"""
统一混合检索引擎 (Hybrid RAG)
实现 BM25 (文本) + 向量 (语义) 的融合搜索 (RRF)
"""
def __init__(self, data: List[Dict[str, Any]], text_fields: List[str] = ["title", "content"], model_name: str = None):
"""
初始化搜索器
Args:
data: 数据列表,每个元素为 Dict
text_fields: 用于建立索引的文本字段
model_name: 向量模型名称,默认使用 paraphrase-multilingual-MiniLM-L12-v2
"""
self.data = data
self.text_fields = text_fields
self._corpus = []
self._bm25 = None
self._vector_model = None
self._embeddings = None
self._fitted = False
self._vector_fitted = False
# 默认模型
self.model_name = model_name or os.getenv("EMBEDDING_MODEL", "paraphrase-multilingual-MiniLM-L12-v2")
if data:
self._prepare_corpus()
self._fit_bm25()
# 延迟加载向量模型,仅在需要时或初始化时显式调用
# self._fit_vector()
def _prepare_corpus(self):
"""准备语料库用于分词"""
import jieba # 使用 jieba 进行中文分词
self._corpus = []
self._full_texts = []
for item in self.data:
text = " ".join([str(item.get(field, "")) for field in self.text_fields])
self._full_texts.append(text)
# 中文分词优化
tokens = list(jieba.cut(text))
self._corpus.append(tokens)
def _fit_bm25(self):
"""训练 BM25 模型"""
if self._corpus:
self._bm25 = BM25Okapi(self._corpus)
self._fitted = True
logger.info(f"✅ BM25 index fitted with {len(self.data)} documents")
def _fit_vector(self):
"""训练向量模型并生成 Embeddings"""
if not self.data:
return
try:
logger.info(f"📡 Loading embedding model: {self.model_name}...")
self._vector_model = SentenceTransformer(self.model_name)
logger.info(f"🧠 Encoding {len(self._full_texts)} documents...")
self._embeddings = self._vector_model.encode(self._full_texts, show_progress_bar=False)
self._vector_fitted = True
logger.info("✅ Vector index fitted successfully")
except Exception as e:
logger.error(f"❌ Failed to fit vector index: {e}")
self._vector_fitted = False
def _compute_rrf(self, rank_lists: List[List[int]], k: int = 60) -> List[tuple]:
"""
计算 Reciprocal Rank Fusion (RRF)
Args:
rank_lists: 多个排序后的索引列表
k: RRF 常数,默认 60
"""
scores = {}
for rank_list in rank_lists:
for rank, idx in enumerate(rank_list):
if idx not in scores:
scores[idx] = 0
scores[idx] += 1.0 / (k + rank + 1)
# 按分数排序
sorted_indices = sorted(scores.items(), key=lambda x: x[1], reverse=True)
return sorted_indices
def search(self, query: str, top_n: int = 5, use_vector: bool = False) -> List[Dict[str, Any]]:
"""
执行混合搜索
Args:
query: 搜索关键词
top_n: 返回结果数量
use_vector: 是否启用向量搜索
"""
if not self._fitted or not query:
return []
import jieba
query_tokens = list(jieba.cut(query))
# 1. BM25 搜索结果
bm25_scores = self._bm25.get_scores(query_tokens)
bm25_rank = np.argsort(bm25_scores)[::-1].tolist()
rank_lists = [bm25_rank]
# 2. 向量搜索逻辑
if use_vector:
if not self._vector_fitted:
self._fit_vector()
if self._vector_fitted:
query_embedding = self._vector_model.encode([query], show_progress_bar=False)
similarities = cosine_similarity(query_embedding, self._embeddings)[0]
vector_rank = np.argsort(similarities)[::-1].tolist()
rank_lists.append(vector_rank)
else:
logger.warning("Vector search requested but model not fitted, falling back to BM25")
# 3. 融合排序 (RRF)
if len(rank_lists) > 1:
rrf_results = self._compute_rrf(rank_lists)
# RRF 返回 (idx, score) 列表
final_rank = [idx for idx, score in rrf_results]
else:
final_rank = bm25_rank
# 返回前 top_n 条结果
results = [self.data[idx].copy() for idx in final_rank[:top_n]]
# 为每个结果注入相关性评分
for i, res in enumerate(results):
try:
original_idx = final_rank[i]
res["_search_score"] = bm25_scores[original_idx]
if use_vector and self._vector_fitted:
res["_vector_score"] = float(similarities[original_idx])
except:
res["_search_score"] = 0
return results
class InMemoryRAG(HybridSearcher):
"""专门用于 ReportAgent 跨章节检索的内存态 RAG"""
def search(self, query: str, top_n: int = 3, use_vector: bool = True) -> List[Dict[str, Any]]:
"""默认开启向量搜索的内存检索"""
return super().search(query, top_n=top_n, use_vector=use_vector)
def update_data(self, new_data: List[Dict[str, Any]]):
"""动态更新数据并重新训练索引"""
self.data = new_data
self._prepare_corpus()
self._fit_bm25()
# 如果之前已经加载过向量模型,则更新向量索引
if self._vector_model:
self._fit_vector()
logger.info(f"🔄 InMemoryRAG updated with {len(new_data)} items")
class LocalNewsSearch(HybridSearcher):
"""持久态 RAG检索数据库中的历史新闻"""
def __init__(self, db_manager):
"""
Args:
db_manager: DatabaseManager 实例
"""
self.db = db_manager
# 初始时不加载数据,需调用 load_history
super().__init__([], ["title", "content"])
def load_history(self, days: int = 30, limit: int = 1000):
"""从数据库加载最近 N 天的新闻构建索引"""
try:
# 假设 db_manager 有 execute_query
query = f"SELECT title, content, publish_time, source FROM daily_news ORDER BY publish_time DESC LIMIT ?"
results = self.db.execute_query(query, (limit,))
data = []
for row in results:
# 转换 Row 为 Dict
if hasattr(row, 'keys'):
item = dict(row)
else:
item = {
"title": row[0],
"content": row[1],
"publish_time": row[2],
"source": row[3]
}
data.append(item)
self.data = data
self._prepare_corpus()
self._fit_bm25()
# 默认不立即训练向量,等到第一次搜索时按需训练
logger.info(f"📚 LocalNewsSearch loaded {len(data)} items from history")
except Exception as e:
logger.error(f"Failed to load history for search: {e}")
def search(self, query: str, top_n: int = 5, use_vector: bool = True) -> List[Dict[str, Any]]:
"""执行本地历史搜索,默认开启向量搜索"""
if not self.data:
self.load_history()
return super().search(query, top_n=top_n, use_vector=use_vector)

View File

@@ -0,0 +1,180 @@
import ast
import json
import re
from typing import Optional, Any
from loguru import logger
def _strip_comments(text: str) -> str:
"""
Safely remove C-style comments (// and /* */) from JSON-like text,
preserving strings (including URLs like http://).
"""
result = []
i = 0
n = len(text)
in_string = False
escape = False
while i < n:
char = text[i]
if in_string:
if char == '\\':
escape = not escape
elif char == '"' and not escape:
in_string = False
else:
escape = False
result.append(char)
i += 1
continue
# Not in string
if char == '"':
in_string = True
result.append(char)
i += 1
continue
# Check for // comment
if i + 1 < n and text[i:i+2] == '//':
i += 2
while i < n and text[i] != '\n':
i += 1
continue
# Check for /* comment
if i + 1 < n and text[i:i+2] == '/*':
i += 2
while i + 1 < n and text[i:i+2] != '*/':
i += 1
i += 2
continue
result.append(char)
i += 1
return ''.join(result)
def extract_json(text: str) -> Optional[Any]:
"""
更加鲁棒的 JSON 提取工具。
处理:
1. Markdown 代码块 (```json ... ```)
2. 首尾多余字符
3. 同一个文本中多个 JSON 对象 (仅提取第一个)
4. 简单的 JSON 修复 (末尾逗号等)
5. C 风格注释 (// 和 /* */)
"""
if not text:
return None
# 1. 清理明显的 Markdown 包装
text = text.strip()
# 先尝试精确匹配 ```json ... ``` 或 ```...```
md_match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', text, re.DOTALL)
if md_match:
text = md_match.group(1).strip()
elif text.startswith("```"):
# 回退:如果开头有 ``` 但没完整匹配
text = re.sub(r'^```[a-z]*\n?', '', text)
text = re.sub(r'\n?```\s*$', '', text)
# 2. 寻找第一个 JSON 起始符 { 或 [
start_brace = text.find('{')
start_bracket = text.find('[')
if start_brace == -1 and start_bracket == -1:
return None
start_idx = start_brace if (start_bracket == -1 or (start_brace != -1 and start_brace < start_bracket)) else start_bracket
# 2.5 预处理:修复一些极其常见的 LLM 错误
potential_json = text[start_idx:].strip()
# remove comments safely
potential_json = _strip_comments(potential_json)
# b. 修复缺失开头引号的键: nodes": [ -> "nodes": [
# 匹配模式: (空白或换行) 单词 紧跟引号和冒号
potential_json = re.sub(r'([\{\,]\s*)([a-zA-Z_]\w*)\"\s*:', r'\1"\2":', potential_json)
# c. 修复缺失末尾引号的键: "nodes: [ -> "nodes": [
potential_json = re.sub(r'([\{\,]\s*)\"([a-zA-Z_]\w*)\s*:', r'\1"\2":', potential_json)
# d. 修复完全缺失引号的键: nodes: [ -> "nodes": [
# 注意避免匹配到像 http:// 这种内容,所以限定在 { 或 , 之后
potential_json = re.sub(r'([\{\,]\s*)([a-zA-Z_]\w*)\s*:', r'\1"\2":', potential_json)
# 3. 使用 raw_decode 尝试解析
decoder = json.JSONDecoder()
# 首先尝试直接解析(不做任何预处理)
try:
obj = json.loads(potential_json)
return obj
except json.JSONDecodeError:
pass
# 简单预处理:移除对象/列表末位多余逗号
processed_json = re.sub(r',\s*([\]}])', r'\1', potential_json)
try:
obj, end_pos = decoder.raw_decode(processed_json)
return obj
except json.JSONDecodeError:
pass
# e. 修复未终止的字符串字面量问题:移除值中的实际换行符
# LLM 可能在字符串值中生成包含真实 newline 的内容,导致 JSON 非法
def fix_multiline_strings(s):
# 简单策略:将字符串值内的换行替换为空格
lines = s.split('\n')
result = []
in_string = False
for line in lines:
# 计算未转义的引号数
quote_count = line.count('"') - line.count('\\"')
if in_string:
result[-1] += ' ' + line.strip()
else:
result.append(line)
if quote_count % 2 == 1:
in_string = not in_string
return '\n'.join(result)
fixed_json = fix_multiline_strings(processed_json)
try:
obj, end_pos = decoder.raw_decode(fixed_json)
return obj
except json.JSONDecodeError:
try:
# 4. 尝试处理单引号问题 (JSON 规范要求双引号,但 LLM 常输出单引号)
# 这是一个简单的替换技巧,仅针对像 {'key': 'value'} 这样的结构
# 注意:这可能会破坏包含单引号的字符串值,所以作为较后的回退
fix_quotes = re.sub(r"'(.*?)':", r'"\1":', processed_json) # 修复键
fix_quotes = re.sub(r":\s*'(.*?)'", r': "\1"', fix_quotes) # 修复简单值
obj, end_pos = decoder.raw_decode(fix_quotes)
return obj
except (json.JSONDecodeError, TypeError):
try:
# 5. 使用 ast.literal_eval 作为终极回退 (处理 Python 字典格式)
# 提取第一个匹配的括号对内容
# 寻找匹配的 { }
stack = []
for i, char in enumerate(potential_json):
if char == '{': stack.append('{')
elif char == '}':
if stack: stack.pop()
if not stack:
content = potential_json[:i+1]
return ast.literal_eval(content)
except (ValueError, SyntaxError, MemoryError) as e:
logger.warning(f"All JSON extraction attempts failed: {e}")
except Exception as e:
logger.error(f"Unexpected error during JSON extraction: {e}")
return None

View File

@@ -0,0 +1,85 @@
import os
from typing import Optional, List, Dict, Any
from agno.agent import Agent
from agno.models.base import Model
from loguru import logger
from ..llm.factory import get_model
def test_tool_call_support(model: Model) -> bool:
"""
测试模型是否支持原生的 Tool Call (Function Calling)。
通过尝试执行一个简单的加法工具来验证。
"""
def get_current_weather(location: str):
"""获取指定地点的天气"""
return f"{location} 的天气是晴天25度。"
test_agent = Agent(
model=model,
tools=[get_current_weather],
instructions="请调用工具查询北京的天气,并直接返回工具的输出结果。",
)
try:
# 运行一个简单的任务,观察是否触发了 tool_call
response = test_agent.run("北京天气怎么样?")
# 检查 response 中是否包含 tool_calls
# Agno 的 RunResponse 对象通常包含 messages我们可以检查最后几条消息
has_tool_call = False
for msg in response.messages:
if hasattr(msg, "tool_calls") and msg.tool_calls:
has_tool_call = True
break
if has_tool_call:
logger.info(f"✅ Model {model.id} supports native tool calling.")
return True
else:
# 如果没有 tool_calls 但返回了正确答案可能是模型通过纯文本模拟了工具调用ReAct
# 或者根本没用工具。对于原生支持的判断,我们坚持要求有 tool_calls 结构。
logger.warning(
f"⚠️ Model {model.id} did NOT use native tool calling structure."
)
return False
except Exception as e:
logger.error(f"❌ Error testing tool call for {model.id}: {e}")
return False
class ModelCapabilityRegistry:
"""
模型能力注册表,用于缓存和管理不同模型的能力测试结果。
"""
_cache = {}
@classmethod
def get_capabilities(
cls, provider: str, model_id: str, **kwargs
) -> Dict[str, bool]:
key = f"{provider}:{model_id}"
if key not in cls._cache:
logger.info(f"🔍 Testing capabilities for {key}...")
model = get_model(provider, model_id, **kwargs)
supports_tool_call = test_tool_call_support(model)
cls._cache[key] = {"supports_tool_call": supports_tool_call}
return cls._cache[key]
if __name__ == "__main__":
import os
from dotenv import load_dotenv
load_dotenv(os.path.expanduser("~/.config/opencode/.env"))
# 测试当前配置的模型
p = os.getenv("LLM_PROVIDER", "ust")
m = os.getenv("LLM_MODEL", "Qwen")
print(f"Testing {p}/{m}...")
res = ModelCapabilityRegistry.get_capabilities(p, m)
print(f"Result: {res}")

View File

@@ -0,0 +1,114 @@
import os
from agno.models.openai import OpenAIChat
from agno.models.ollama import Ollama
from agno.models.dashscope import DashScope
from agno.models.deepseek import DeepSeek
from agno.models.openrouter import OpenRouter
def get_model(model_provider: str, model_id: str, **kwargs):
"""
Factory to get the appropriate LLM model.
Args:
model_provider: "openai", "ollama", "deepseek"
model_id: The specific model ID (e.g., "gpt-4o", "llama3", "deepseek-chat")
**kwargs: Additional arguments for the model constructor
"""
if model_provider == "openai":
return OpenAIChat(id=model_id, **kwargs)
elif model_provider == "ollama":
return Ollama(id=model_id, **kwargs)
elif model_provider == "deepseek":
# DeepSeek is OpenAI compatible
api_key = os.getenv("DEEPSEEK_API_KEY")
if not api_key:
print("Warning: DEEPSEEK_API_KEY not set.")
return DeepSeek(
id=model_id,
api_key=api_key,
**kwargs
)
elif model_provider == "dashscope":
api_key = os.getenv("DASHSCOPE_API_KEY")
if not api_key:
print("Warning: DASHSCOPE_API_KEY not set.")
return DashScope(
id=model_id,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
api_key=api_key,
**kwargs
)
elif model_provider == 'openrouter':
api_key = os.getenv("OPENROUTER_API_KEY")
if not api_key:
print('Warning: OPENROUTER_API_KEY not set.')
return OpenRouter(
id=model_id,
api_key=api_key,
**kwargs
)
elif model_provider == 'zai':
api_key = os.getenv("ZAI_KEY_API")
if not api_key:
print('Warning: ZAI_KEY_API not set.')
# role_map to ensure compatibility.
default_role_map = {
"system": "system",
"user": "user",
"assistant": "assistant",
"tool": "tool",
"model": "assistant",
}
# Allow callers to override role_map via kwargs, otherwise use default
role_map = kwargs.pop("role_map", default_role_map)
return OpenAIChat(
id=model_id,
base_url="https://api.z.ai/api/paas/v4",
api_key=api_key,
timeout=60,
role_map=role_map,
extra_body={"enable_thinking": False}, # TODO: one more setting for thinking
**kwargs
)
elif model_provider == 'ust':
api_key = os.getenv("UST_KEY_API")
if not api_key:
print('Warning: UST_KEY_API not set.')
# Some UST-compatible endpoints expect the standard OpenAI role names
# (e.g. "system", "user", "assistant") rather than Agno's default
# mapping which maps "system" -> "developer". Provide an explicit
# role_map to ensure compatibility.
default_role_map = {
"system": "system",
"user": "user",
"assistant": "assistant",
"tool": "tool",
"model": "assistant",
}
# Allow callers to override role_map via kwargs, otherwise use default
role_map = kwargs.pop("role_map", default_role_map)
return OpenAIChat(
id=model_id,
api_key=api_key,
base_url=os.getenv("UST_URL"),
role_map=role_map,
extra_body={"enable_thinking": False}, # TODO: one more setting for thinking
**kwargs
)
else:
raise ValueError(f"Unknown model provider: {model_provider}")

View File

@@ -0,0 +1,80 @@
import os
from typing import Optional, List, Dict, Any, Union
from agno.models.base import Model
from loguru import logger
from dotenv import load_dotenv
from ..llm.factory import get_model
from ..llm.capability import ModelCapabilityRegistry
load_dotenv(os.path.expanduser("~/.config/opencode/.env"))
class ModelRouter:
"""
模型路由管理器
功能:
1. 管理“推理/写作模型” (Reasoning Model) 和“工具调用模型” (Tool Model)。
2. 根据任务需求自动选择合适的模型。
"""
def __init__(self):
# 默认从环境变量读取
self.reasoning_provider = os.getenv(
"REASONING_MODEL_PROVIDER", os.getenv("LLM_PROVIDER", "openai")
)
self.reasoning_id = os.getenv(
"REASONING_MODEL_ID", os.getenv("LLM_MODEL", "gpt-4o")
)
self.reasoning_host = os.getenv("REASONING_MODEL_HOST", os.getenv("LLM_HOST"))
self.tool_provider = os.getenv("TOOL_MODEL_PROVIDER", self.reasoning_provider)
self.tool_id = os.getenv("TOOL_MODEL_ID", self.reasoning_id)
self.tool_host = os.getenv("TOOL_MODEL_HOST", self.reasoning_host)
self._reasoning_model = None
self._tool_model = None
logger.info(
f"🤖 ModelRouter initialized: Reasoning={self.reasoning_id} ({self.reasoning_host or 'default'}), Tool={self.tool_id} ({self.tool_host or 'default'})"
)
def get_reasoning_model(self, **kwargs) -> Model:
if not self._reasoning_model:
# 优先使用路由配置的 host
if self.reasoning_host and "host" not in kwargs:
kwargs["host"] = self.reasoning_host
self._reasoning_model = get_model(
self.reasoning_provider, self.reasoning_id, **kwargs
)
return self._reasoning_model
def get_tool_model(self, **kwargs) -> Model:
if not self._tool_model:
# 优先使用路由配置的 host
if self.tool_host and "host" not in kwargs:
kwargs["host"] = self.tool_host
# 检查 tool_model 是否真的支持 tool call
caps = ModelCapabilityRegistry.get_capabilities(
self.tool_provider, self.tool_id, **kwargs
)
if not caps["supports_tool_call"]:
logger.warning(
f"⚠️ Configured tool model {self.tool_id} might not support native tool calls! Consider using ReAct mode or a different model."
)
self._tool_model = get_model(self.tool_provider, self.tool_id, **kwargs)
return self._tool_model
def get_model_for_agent(self, has_tools: bool = False, **kwargs) -> Model:
"""
根据 Agent 是否包含工具来返回合适的模型。
"""
if has_tools:
return self.get_tool_model(**kwargs)
return self.get_reasoning_model(**kwargs)
# 全局单例
router = ModelRouter()

View File

@@ -0,0 +1,45 @@
import os
import sys
from datetime import datetime
from typing import Optional
from loguru import logger
def setup_file_logging(
run_id: str,
log_dir: str = "logs",
level: str = "INFO",
retention: str = "10 days",
rotation: str = "20 MB",
) -> str:
"""Configure Loguru to log to stderr + a per-run file.
Returns the log file path.
"""
os.makedirs(log_dir, exist_ok=True)
# Remove default handler to avoid duplicate logs.
logger.remove()
# Console
logger.add(sys.stderr, level=level, backtrace=False, diagnose=False)
# File (safe for multi-thread via enqueue)
log_path = os.path.join(log_dir, f"signalflux_{run_id}.log")
logger.add(
log_path,
level=level,
rotation=rotation,
retention=retention,
enqueue=True,
backtrace=True,
diagnose=False,
encoding="utf-8",
)
return log_path
def make_run_id(prefix: Optional[str] = None) -> str:
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
return f"{prefix}_{ts}" if prefix else ts

View File

@@ -0,0 +1,256 @@
import requests
from requests.exceptions import RequestException, Timeout
import json
import time
from datetime import datetime
from typing import List, Dict, Optional
from loguru import logger
from .database_manager import DatabaseManager
from .content_extractor import ContentExtractor
class NewsNowTools:
"""热点新闻获取工具 - 接入 NewsNow API 与 Jina 内容提取"""
BASE_URL = "https://newsnow.busiyi.world"
SOURCES = {
# 金融类
"cls": "财联社",
"wallstreetcn": "华尔街见闻",
"xueqiu": "雪球热榜",
# 综合/社交
"weibo": "微博热搜",
"zhihu": "知乎热榜",
"baidu": "百度热搜",
"toutiao": "今日头条",
"douyin": "抖音热榜",
"thepaper": "澎湃新闻",
# 科技类
"36kr": "36氪",
"ithome": "IT之家",
"v2ex": "V2EX",
"juejin": "掘金",
"hackernews": "Hacker News",
}
def __init__(self, db: DatabaseManager):
self.db = db
self.user_agent = (
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
)
self.extractor = ContentExtractor()
# Simple in-memory cache: source_id -> {"time": timestamp, "data": []}
self._cache = {}
def fetch_hot_news(self, source_id: str, count: int = 15, fetch_content: bool = False) -> List[Dict]:
"""
从指定新闻源获取热点新闻列表支持5分钟缓存
"""
# 1. Check cache validity (5 minutes)
cache_key = f"{source_id}_{count}"
cached = self._cache.get(cache_key)
now = time.time()
if cached and (now - cached["time"] < 300):
logger.info(f"⚡ Using cached news for {source_id} (Age: {int(now - cached['time'])}s)")
return cached["data"]
try:
url = f"{self.BASE_URL}/api/s?id={source_id}"
response = requests.get(url, headers={"User-Agent": self.user_agent}, timeout=30)
if response.status_code == 200:
data = response.json()
items = data.get("items", [])[:count]
processed_items = []
for i, item in enumerate(items, 1):
item_url = item.get("url", "")
content = ""
if fetch_content and item_url:
content = self.extractor.extract_with_jina(item_url) or ""
processed_items.append({
"id": item.get("id") or f"{source_id}_{int(time.time())}_{i}",
"source": source_id,
"rank": i,
"title": item.get("title", ""),
"url": item_url,
"content": content,
"publish_time": item.get("publish_time"),
"meta_data": item.get("extra", {})
})
# Update Cache
self._cache[cache_key] = {"time": now, "data": processed_items}
logger.info(f"✅ Fetched and cached news for {source_id}")
self.db.save_daily_news(processed_items)
return processed_items
else:
logger.error(f"NewsNow API Error: {response.status_code}")
# Fallback to stale cache if available
if cached:
logger.warning(f"⚠️ API failed, using stale cache for {source_id}")
return cached["data"]
return []
except Timeout:
logger.error(f"Timeout fetching hot news from {source_id}")
if cached:
logger.warning(f"⚠️ Timeout, using stale cache for {source_id}")
return cached["data"]
return []
except RequestException as e:
logger.error(f"Network error fetching hot news from {source_id}: {e}")
if cached:
logger.warning(f"⚠️ Network check failed, using stale cache for {source_id}")
return cached["data"]
return []
except json.JSONDecodeError:
logger.error(f"Failed to parse JSON response from NewsNow for {source_id}")
return []
except Exception as e:
logger.error(f"Unexpected error fetching hot news from {source_id}: {e}")
return []
def fetch_news_content(self, url: str) -> Optional[str]:
"""
使用 Jina Reader 抓取指定 URL 的网页正文内容。
Args:
url: 需要抓取内容的完整网页 URL必须以 http:// 或 https:// 开头。
Returns:
提取的网页正文内容 (Markdown 格式),如果失败则返回 None。
"""
return self.extractor.extract_with_jina(url)
def get_unified_trends(self, sources: Optional[List[str]] = None) -> str:
"""
获取多平台综合热点报告,自动聚合多个新闻源的热门内容。
Args:
sources: 要扫描的新闻源列表。可选值按类别:
**金融类**: "cls", "wallstreetcn", "xueqiu"
**综合类**: "weibo", "zhihu", "baidu", "toutiao", "douyin", "thepaper"
**科技类**: "36kr", "ithome", "v2ex", "juejin", "hackernews"
Returns:
格式化的 Markdown 热点汇总报告,包含各平台 Top 10 热点标题和链接。
"""
sources = sources or ["weibo", "zhihu", "wallstreetcn"]
all_news = []
for src in sources:
all_news.extend(self.fetch_hot_news(src))
time.sleep(0.2)
if not all_news:
return "❌ 未能获取到热点数据"
report = f"# 实时全网热点汇总 ({datetime.now().strftime('%Y-%m-%d %H:%M')})\n\n"
for src in sources:
src_name = self.SOURCES.get(src, src)
report += f"### 🔥 {src_name}\n"
src_news = [n for n in all_news if n['source'] == src]
for n in src_news[:10]:
report += f"- {n['title']} ([链接]({n['url']}))\n"
report += "\n"
return report
class PolymarketTools:
"""Polymarket 预测市场数据工具 - 获取热门预测市场反映公众情绪和预期"""
BASE_URL = "https://gamma-api.polymarket.com"
def __init__(self, db: DatabaseManager):
self.db = db
self.user_agent = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36"
def get_active_markets(self, limit: int = 20) -> List[Dict]:
"""
获取活跃的预测市场,用于分析公众情绪和预期。
预测市场数据可以反映:
- 公众对重大事件的预期概率
- 市场情绪和风险偏好
- 热门话题的关注度
Args:
limit: 获取的市场数量,默认 20 个。
Returns:
包含预测市场信息的列表,每个市场包含:
- question: 预测问题
- outcomes: 可能的结果
- outcomePrices: 各结果的概率价格
- volume: 交易量
"""
try:
response = requests.get(
f"{self.BASE_URL}/markets",
params={"active": "true", "closed": "false", "limit": limit},
headers={"User-Agent": self.user_agent, "Accept": "application/json"},
timeout=30
)
if response.status_code == 200:
markets = response.json()
result = []
for m in markets:
result.append({
"id": m.get("id"),
"question": m.get("question"),
"slug": m.get("slug"),
"outcomes": m.get("outcomes"),
"outcomePrices": m.get("outcomePrices"),
"volume": m.get("volume"),
"liquidity": m.get("liquidity"),
})
logger.info(f"✅ 获取 {len(result)} 个预测市场")
return result
else:
logger.warning(f"⚠️ Polymarket API 返回 {response.status_code}")
return []
except Timeout:
logger.error("Timeout fetching Polymarket markets")
return []
except RequestException as e:
logger.error(f"Network error fetching Polymarket markets: {e}")
return []
except json.JSONDecodeError:
logger.error("Failed to parse JSON response from Polymarket")
return []
except Exception as e:
logger.error(f"Unexpected error fetching Polymarket markets: {e}")
return []
def get_market_summary(self, limit: int = 10) -> str:
"""
获取预测市场摘要报告,用于了解当前热门话题和公众预期。
Args:
limit: 获取的市场数量
Returns:
格式化的预测市场报告
"""
markets = self.get_active_markets(limit)
if not markets:
return "❌ 无法获取 Polymarket 数据"
report = f"# 🔮 Polymarket 热门预测 ({datetime.now().strftime('%Y-%m-%d %H:%M')})\n\n"
for i, m in enumerate(markets, 1):
question = m.get("question", "Unknown")
prices = m.get("outcomePrices", [])
volume = m.get("volume", 0)
report += f"**{i}. {question}**\n"
if prices:
report += f" 概率: {prices}\n"
if volume:
report += f" 交易量: ${float(volume):,.0f}\n"
report += "\n"
return report

View File

@@ -0,0 +1,137 @@
import os
import sys
import torch
import pandas as pd
import numpy as np
import glob
from loguru import logger
from datetime import datetime, timedelta
# Setup paths
KRONOS_DIR = os.path.dirname(os.path.abspath(__file__))
SRC_DIR = os.path.dirname(os.path.dirname(KRONOS_DIR))
if SRC_DIR not in sys.path:
sys.path.insert(0, SRC_DIR)
from ..kronos.auto_synthesis_training import AutoSynthesisTrainer
from ..kronos.model import KronosPredictor
from ..visualizer import VisualizerTools
from ..schema.models import ForecastResult, KLinePoint
class NewsModelEvaluator:
def __init__(self, model_path=None):
self.trainer = AutoSynthesisTrainer()
self.device = self.trainer.device
if model_path is None:
# Try to find the latest model in exports/models
model_files = glob.glob(os.path.join(SRC_DIR, "exports/models/*.pt"))
if not model_files:
logger.warning("⚠️ No trained models found in exports/models/. Using base model (zero-init proj).")
else:
model_path = max(model_files, key=os.path.getctime)
if model_path:
self.load_weights(model_path)
def load_weights(self, path):
logger.info(f"🔄 Loading model weights from {path}...")
checkpoint = torch.load(path, map_location=self.device)
self.trainer.model.news_proj.load_state_dict(checkpoint['news_proj_state_dict'])
logger.success("✅ News projection layer loaded.")
def evaluate_range(self, start_idx=100, end_idx=200, pred_len=5):
# 1. Fetch Tickers
res = self.trainer.db.execute_query("SELECT code FROM stock_list")
all_tickers = [row['code'] for row in res]
test_tickers = all_tickers[start_idx:end_idx]
if not test_tickers:
logger.error(f"No tickers found in range {start_idx}-{end_idx}")
return
logger.info(f"🚀 Evaluating News Model on stocks {start_idx} to {end_idx}...")
# 2. Discover Shocks
shocks = self.trainer.discover_shocks(test_tickers, pred_len=pred_len)
# 3. Associate News & Predict
self.trainer.model.eval()
predictor = KronosPredictor(self.trainer.model, self.trainer.tokenizer, device=self.device)
save_dir = os.path.join(SRC_DIR, "exports/evaluation_results")
os.makedirs(save_dir, exist_ok=True)
count = 0
for shock in shocks:
summary = self.trainer.find_reason_and_verify(shock)
if not summary:
continue
logger.info(f"📈 Testing shock: {shock['ticker']} on {shock['date']}")
# Embedding news
news_emb = self.trainer.embedder.encode(summary)
# Prediction
h = shock['history']
t = shock['target']
actuals = t['close'].values[:pred_len]
x_ts = pd.to_datetime(h['date'])
future_dates = pd.date_range(start=x_ts.iloc[-1] + timedelta(days=1), periods=pred_len, freq='B')
y_ts = pd.Series(future_dates)
# A. Base Prediction (No news)
p_base = predictor.predict(h, x_ts, y_ts, pred_len=pred_len, news_emb=None, verbose=False)
# B. News-Aware Prediction
p_news = predictor.predict(h, x_ts, y_ts, pred_len=pred_len, news_emb=news_emb, verbose=False)
# Calculate Improvement
b_preds = p_base['close'].values[:len(actuals)]
n_preds = p_news['close'].values[:len(actuals)]
b_mae = np.mean(np.abs(b_preds - actuals))
n_mae = np.mean(np.abs(n_preds - actuals))
improvement = (b_mae - n_mae) / (b_mae + 1e-6) * 100
# C. Visualize
try:
def to_kp_list(preds_df):
points = []
for idx, row in preds_df.iterrows():
points.append(KLinePoint(
date=str(idx)[:10], open=row['open'], high=row['high'],
low=row['low'], close=row['close'], volume=row.get('volume', 0)
))
return points
forecast_obj = ForecastResult(
ticker=shock['ticker'],
base_forecast=to_kp_list(p_base),
adjusted_forecast=to_kp_list(p_news),
rationale=summary
)
chart = VisualizerTools.generate_stock_chart(
df=h, ticker=shock['ticker'],
title=f"Test Eval: {shock['ticker']} ({shock['date']}) Imp: {improvement:.1f}%",
forecast=forecast_obj,
ground_truth=t[['date', 'open', 'high', 'low', 'close', 'volume']]
)
safe_date = shock['date'].replace("-", "")
filename = f"test_{shock['ticker']}_{safe_date}.html"
VisualizerTools.render_chart_to_file(chart, os.path.join(save_dir, filename))
logger.success(f"📊 Result for {shock['ticker']} saved. Base MAE: {b_mae:.4f}, News MAE: {n_mae:.4f}")
count += 1
except Exception as e:
logger.error(f"Visualization failed: {e}")
logger.info(f"🏁 Finished evaluation. {count} cases visualized in {save_dir}")
if __name__ == "__main__":
# If you have a specific model, pass the path here. Otherwise it picks the latest.
evaluator = NewsModelEvaluator()
evaluator.evaluate_range(start_idx=100, end_idx=200, pred_len=1)

View File

@@ -0,0 +1,196 @@
# Ref: https://github.com/shiyu-coder/Kronos
from model import Kronos, KronosTokenizer, KronosPredictor
import pandas as pd
import sqlite3
import torch
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from pandas.tseries.offsets import BusinessDay
import numpy as np
def get_device():
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")
return device
def load_predictor():
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
device = get_device()
tokenizer = tokenizer.to(device)
model = model.to(device)
return KronosPredictor(model, tokenizer, device=device, max_context=512)
def load_data(ticker="002111", db_path="AlphaEar/data/signal_flux.db"):
with sqlite3.connect(db_path) as conn:
df = pd.read_sql_query(f"SELECT * FROM stock_prices WHERE ticker = '{ticker}'", conn)
df['date'] = pd.to_datetime(df['date'])
df = df.sort_values('date').reset_index(drop=True)
return df
def plot_kline_matplotlib(ax, ax_vol, dates, df, label_suffix="", color_up='#ef4444', color_down='#22c55e', alpha=1.0, is_prediction=False):
"""
绘制 K 线图和成交量
"""
# X axis mapping to integers for consistent spacing
x = np.arange(len(dates))
# K-line data
opens = df['open'].values
closes = df['close'].values
highs = df['high'].values
lows = df['low'].values
volumes = df['volume'].values
# Width of the candlestick
width = 0.6
for i in range(len(x)):
color = color_up if closes[i] >= opens[i] else color_down
linestyle = '--' if is_prediction else '-'
# Wick
ax.vlines(x[i], lows[i], highs[i], color=color, linewidth=1, alpha=alpha, linestyle=linestyle)
# Body
rect_bottom = min(opens[i], closes[i])
rect_height = abs(opens[i] - closes[i])
if rect_height == 0: rect_height = 0.001 # Visual hair
ax.add_patch(plt.Rectangle((x[i] - width/2, rect_bottom), width, rect_height,
edgecolor=color, facecolor=color if not is_prediction else 'none',
alpha=alpha, linewidth=1, linestyle=linestyle))
# Volume
ax_vol.bar(x[i], volumes[i], color=color, alpha=alpha * 0.5, width=width)
def render_comparison_chart(history_df, actual_df, pred_df, title):
"""
渲染组合图:历史 K 线 + 真值 K 线 + 预测 K 线
"""
# Combine all dates for X axis
all_dates = pd.concat([history_df['date'], actual_df['date'] if actual_df is not None else pred_df.index.to_series()]).unique()
all_dates = sorted(all_dates)
date_to_idx = {date: i for i, date in enumerate(all_dates)}
fig = plt.figure(figsize=(14, 8), facecolor='white')
gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1], hspace=0.1)
ax_main = fig.add_subplot(gs[0])
ax_vol = fig.add_subplot(gs[1], sharex=ax_main)
# 1. Plot History
hist_indices = [date_to_idx[d] for d in history_df['date']]
# We use a custom x for plotting to ensure continuity
plot_kline_matplotlib(ax_main, ax_vol, history_df['date'], history_df, alpha=0.8)
offset = len(history_df)
# 2. Plot Actual if exists
if actual_df is not None:
# Shift indices
actual_x = np.arange(len(actual_df)) + offset
# Plotting manually to handle offset
for i in range(len(actual_df)):
idx = actual_x[i]
row = actual_df.iloc[i]
color = '#ef4444' if row['close'] >= row['open'] else '#22c55e'
ax_main.vlines(idx, row['low'], row['high'], color=color, linewidth=1, alpha=0.9)
ax_main.add_patch(plt.Rectangle((idx - 0.3, min(row['open'], row['close'])), 0.6, abs(row['open']-row['close']),
edgecolor=color, facecolor=color, alpha=0.9))
ax_vol.bar(idx, row['volume'], color=color, alpha=0.4)
# 3. Plot Prediction
pred_x = np.arange(len(pred_df)) + offset
for i in range(len(pred_df)):
idx = pred_x[i]
row = pred_df.iloc[i]
color = '#ff8c00' # Orange for prediction to distinguish
ax_main.vlines(idx, row['low'], row['high'], color=color, linewidth=1.5, linestyle='--')
ax_main.add_patch(plt.Rectangle((idx - 0.3, min(row['open'], row['close'])), 0.6, abs(row['open']-row['close']),
edgecolor=color, facecolor='none', linewidth=1.5, linestyle='--'))
# Plot secondary prediction line for close
if i == 0:
# Connect to history
ax_main.plot([offset-1, idx], [history_df['close'].iloc[-1], row['close']], color=color, linestyle='--', alpha=0.6)
elif i > 0:
ax_main.plot([idx-1, idx], [pred_df['close'].iloc[i-1], row['close']], color=color, linestyle='--', alpha=0.6)
# Styling
ax_main.set_title(title, fontsize=14, fontweight='bold')
ax_main.grid(True, linestyle=':', alpha=0.6)
ax_vol.grid(True, linestyle=':', alpha=0.6)
ax_vol.set_ylabel('Volume')
ax_main.set_ylabel('Price')
# Set X ticks
step = max(1, len(all_dates) // 10)
ax_vol.set_xticks(np.arange(0, len(all_dates), step))
ax_vol.set_xticklabels([all_dates[i].strftime('%Y-%m-%d') for i in range(0, len(all_dates), step)], rotation=45)
plt.tight_layout()
plt.show()
plt.close()
def run_backtest(df, predictor, lookback, pred_len, start_index=0):
total_len = len(df)
history_start = start_index
history_end = start_index + lookback
pred_start = history_end
available_pred_len = total_len - pred_start
if available_pred_len <= 0: return
actual_pred_len = min(pred_len, available_pred_len)
pred_end = pred_start + actual_pred_len
x_df = df.iloc[history_start : history_end].copy()
y_true_df = df.iloc[pred_start : pred_end].copy()
y_timestamp = y_true_df['date']
print(f"Backtesting: {x_df['date'].iloc[0].date()} to {y_timestamp.iloc[-1].date()}")
pred_df = predictor.predict(
df=x_df[['open', 'high', 'low', 'close', 'volume']],
x_timestamp=x_df['date'],
y_timestamp=y_timestamp,
pred_len=actual_pred_len,
T=1.0, top_p=0.9, sample_count=1
)
render_comparison_chart(x_df, y_true_df, pred_df, f"Backtest: {TICKER} K-Line Comparison")
def run_forecast(df, predictor, lookback, pred_len):
if len(df) < lookback: return
x_df = df.iloc[-lookback:].copy()
last_date = x_df['date'].iloc[-1]
future_dates = pd.date_range(start=last_date + BusinessDay(1), periods=pred_len, freq='B')
future_dates = pd.Series(future_dates)
print(f"Forecasting: Starting from {future_dates.iloc[0].date()}")
pred_df = predictor.predict(
df=x_df[['open', 'high', 'low', 'close', 'volume']],
x_timestamp=x_df['date'],
y_timestamp=future_dates,
pred_len=pred_len,
T=1.0, top_p=0.9, sample_count=1
)
render_comparison_chart(x_df, None, pred_df, f"Forecast: {TICKER} Future K-Line")
if __name__ == "__main__":
LOOKBACK = 20
PRED_LEN = 10
TICKER = '002111'
pred_model = load_predictor()
stock_data = load_data(TICKER)
total_rows = len(stock_data)
backtest_start = max(0, total_rows - LOOKBACK - PRED_LEN - 10) # Leave some space to see trend
print("\n--- Running Backtest ---")
run_backtest(stock_data, pred_model, LOOKBACK, PRED_LEN, start_index=backtest_start)
print("\n--- Running Forecast ---")
run_forecast(stock_data, pred_model, LOOKBACK, PRED_LEN)

View File

@@ -0,0 +1,16 @@
from .kronos import KronosTokenizer, Kronos, KronosPredictor
model_dict = {
'kronos_tokenizer': KronosTokenizer,
'kronos': Kronos,
'kronos_predictor': KronosPredictor
}
def get_model_class(model_name):
if model_name in model_dict:
return model_dict[model_name]
else:
print(f"Model {model_name} not found in model_dict")
raise NotImplementedError

View File

@@ -0,0 +1,676 @@
import numpy as np
import pandas as pd
import torch
from huggingface_hub import PyTorchModelHubMixin
import sys
from tqdm import trange
sys.path.append("../")
from model.module import *
class KronosTokenizer(nn.Module, PyTorchModelHubMixin):
"""
KronosTokenizer module for tokenizing input data using a hybrid quantization approach.
This tokenizer utilizes a combination of encoder and decoder Transformer blocks
along with the Binary Spherical Quantization (BSQuantizer) to compress and decompress input data.
Args:
d_in (int): Input dimension.
d_model (int): Model dimension.
n_heads (int): Number of attention heads.
ff_dim (int): Feed-forward dimension.
n_enc_layers (int): Number of encoder layers.
n_dec_layers (int): Number of decoder layers.
ffn_dropout_p (float): Dropout probability for feed-forward networks.
attn_dropout_p (float): Dropout probability for attention mechanisms.
resid_dropout_p (float): Dropout probability for residual connections.
s1_bits (int): Number of bits for the pre token in BSQuantizer.
s2_bits (int): Number of bits for the post token in BSQuantizer.
beta (float): Beta parameter for BSQuantizer.
gamma0 (float): Gamma0 parameter for BSQuantizer.
gamma (float): Gamma parameter for BSQuantizer.
zeta (float): Zeta parameter for BSQuantizer.
group_size (int): Group size parameter for BSQuantizer.
"""
def __init__(self, d_in, d_model, n_heads, ff_dim, n_enc_layers, n_dec_layers, ffn_dropout_p, attn_dropout_p, resid_dropout_p, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size):
super().__init__()
self.d_in = d_in
self.d_model = d_model
self.n_heads = n_heads
self.ff_dim = ff_dim
self.enc_layers = n_enc_layers
self.dec_layers = n_dec_layers
self.ffn_dropout_p = ffn_dropout_p
self.attn_dropout_p = attn_dropout_p
self.resid_dropout_p = resid_dropout_p
self.s1_bits = s1_bits
self.s2_bits = s2_bits
self.codebook_dim = s1_bits + s2_bits # Total dimension of the codebook after quantization
self.embed = nn.Linear(self.d_in, self.d_model)
self.head = nn.Linear(self.d_model, self.d_in)
# Encoder Transformer Blocks
self.encoder = nn.ModuleList([
TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p)
for _ in range(self.enc_layers - 1)
])
# Decoder Transformer Blocks
self.decoder = nn.ModuleList([
TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p)
for _ in range(self.dec_layers - 1)
])
self.quant_embed = nn.Linear(in_features=self.d_model, out_features=self.codebook_dim) # Linear layer before quantization
self.post_quant_embed_pre = nn.Linear(in_features=self.s1_bits, out_features=self.d_model) # Linear layer after quantization (pre part - s1 bits)
self.post_quant_embed = nn.Linear(in_features=self.codebook_dim, out_features=self.d_model) # Linear layer after quantization (full codebook)
self.tokenizer = BSQuantizer(self.s1_bits, self.s2_bits, beta, gamma0, gamma, zeta, group_size) # BSQuantizer module
def forward(self, x):
"""
Forward pass of the KronosTokenizer.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_in).
Returns:
tuple: A tuple containing:
- tuple: (z_pre, z) - Reconstructed outputs from decoder with s1_bits and full codebook respectively,
both of shape (batch_size, seq_len, d_in).
- torch.Tensor: bsq_loss - Loss from the BSQuantizer.
- torch.Tensor: quantized - Quantized representation from BSQuantizer.
- torch.Tensor: z_indices - Indices from the BSQuantizer.
"""
z = self.embed(x)
for layer in self.encoder:
z = layer(z)
z = self.quant_embed(z) # (B, T, codebook)
bsq_loss, quantized, z_indices = self.tokenizer(z)
quantized_pre = quantized[:, :, :self.s1_bits] # Extract the first part of quantized representation (s1_bits)
z_pre = self.post_quant_embed_pre(quantized_pre)
z = self.post_quant_embed(quantized)
# Decoder layers (for pre part - s1 bits)
for layer in self.decoder:
z_pre = layer(z_pre)
z_pre = self.head(z_pre)
# Decoder layers (for full codebook)
for layer in self.decoder:
z = layer(z)
z = self.head(z)
return (z_pre, z), bsq_loss, quantized, z_indices
def indices_to_bits(self, x, half=False):
"""
Converts indices to bit representations and scales them.
Args:
x (torch.Tensor): Indices tensor.
half (bool, optional): Whether to process only half of the codebook dimension. Defaults to False.
Returns:
torch.Tensor: Bit representation tensor.
"""
if half:
x1 = x[0] # Assuming x is a tuple of indices if half is True
x2 = x[1]
mask = 2 ** torch.arange(self.codebook_dim//2, device=x1.device, dtype=torch.long) # Create a mask for bit extraction
x1 = (x1.unsqueeze(-1) & mask) != 0 # Extract bits for the first half
x2 = (x2.unsqueeze(-1) & mask) != 0 # Extract bits for the second half
x = torch.cat([x1, x2], dim=-1) # Concatenate the bit representations
else:
mask = 2 ** torch.arange(self.codebook_dim, device=x.device, dtype=torch.long) # Create a mask for bit extraction
x = (x.unsqueeze(-1) & mask) != 0 # Extract bits
x = x.float() * 2 - 1 # Convert boolean to bipolar (-1, 1)
q_scale = 1. / (self.codebook_dim ** 0.5) # Scaling factor
x = x * q_scale
return x
def encode(self, x, half=False):
"""
Encodes the input data into quantized indices.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_in).
half (bool, optional): Whether to use half quantization in BSQuantizer. Defaults to False.
Returns:
torch.Tensor: Quantized indices from BSQuantizer.
"""
z = self.embed(x)
for layer in self.encoder:
z = layer(z)
z = self.quant_embed(z)
bsq_loss, quantized, z_indices = self.tokenizer(z, half=half, collect_metrics=False)
return z_indices
def decode(self, x, half=False):
"""
Decodes quantized indices back to the input data space.
Args:
x (torch.Tensor): Quantized indices tensor.
half (bool, optional): Whether the indices were generated with half quantization. Defaults to False.
Returns:
torch.Tensor: Reconstructed output tensor of shape (batch_size, seq_len, d_in).
"""
quantized = self.indices_to_bits(x, half)
z = self.post_quant_embed(quantized)
for layer in self.decoder:
z = layer(z)
z = self.head(z)
return z
class Kronos(nn.Module, PyTorchModelHubMixin):
"""
Kronos Model.
Args:
s1_bits (int): Number of bits for pre tokens.
s2_bits (int): Number of bits for post tokens.
n_layers (int): Number of Transformer blocks.
d_model (int): Dimension of the model's embeddings and hidden states.
n_heads (int): Number of attention heads in the MultiheadAttention layers.
ff_dim (int): Dimension of the feedforward network in the Transformer blocks.
ffn_dropout_p (float): Dropout probability for the feedforward network.
attn_dropout_p (float): Dropout probability for the attention layers.
resid_dropout_p (float): Dropout probability for residual connections.
token_dropout_p (float): Dropout probability for token embeddings.
learn_te (bool): Whether to use learnable temporal embeddings.
"""
def __init__(self, s1_bits, s2_bits, n_layers, d_model, n_heads, ff_dim, ffn_dropout_p, attn_dropout_p, resid_dropout_p, token_dropout_p, learn_te, news_dim=None):
super().__init__()
self.s1_bits = s1_bits
self.s2_bits = s2_bits
self.n_layers = n_layers
self.d_model = d_model
self.n_heads = n_heads
self.learn_te = learn_te
self.ff_dim = ff_dim
self.ffn_dropout_p = ffn_dropout_p
self.attn_dropout_p = attn_dropout_p
self.resid_dropout_p = resid_dropout_p
self.token_dropout_p = token_dropout_p
self.news_dim = news_dim
self.s1_vocab_size = 2 ** self.s1_bits
self.token_drop = nn.Dropout(self.token_dropout_p)
self.embedding = HierarchicalEmbedding(self.s1_bits, self.s2_bits, self.d_model)
self.time_emb = TemporalEmbedding(self.d_model, self.learn_te)
self.transformer = nn.ModuleList([
TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p)
for _ in range(self.n_layers)
])
self.norm = RMSNorm(self.d_model)
self.dep_layer = DependencyAwareLayer(self.d_model)
self.head = DualHead(self.s1_bits, self.s2_bits, self.d_model)
if self.news_dim is not None:
self.news_proj = nn.Linear(self.news_dim, self.d_model)
else:
self.news_proj = None
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.xavier_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0, std=self.embedding.d_model ** -0.5)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, RMSNorm):
nn.init.ones_(module.weight)
def forward(self, s1_ids, s2_ids, stamp=None, padding_mask=None, use_teacher_forcing=False, s1_targets=None, news_emb=None):
"""
Args:
s1_ids (torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len]
s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len]
stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None.
padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None.
use_teacher_forcing (bool, optional): Whether to use teacher forcing for s1 decoding. Defaults to False.
s1_targets (torch.Tensor, optional): Target s1 token IDs for teacher forcing. Shape: [batch_size, seq_len]. Defaults to None.
news_emb (torch.Tensor, optional): News embedding tensor. Shape: [batch_size, news_dim]. Defaults to None.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
- s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size]
- s2_logits: Logits for s2 token predictions, conditioned on s1. Shape: [batch_size, seq_len, s2_vocab_size]
"""
x = self.embedding([s1_ids, s2_ids])
if stamp is not None:
time_embedding = self.time_emb(stamp)
x = x + time_embedding
x = self.token_drop(x)
for layer in self.transformer:
x = layer(x, key_padding_mask=padding_mask)
x = self.norm(x)
if news_emb is not None and self.news_proj is not None:
news_bias = self.news_proj(news_emb).unsqueeze(1) # [B, 1, d_model]
x = x + news_bias
s1_logits = self.head(x)
if use_teacher_forcing:
sibling_embed = self.embedding.emb_s1(s1_targets)
else:
s1_probs = F.softmax(s1_logits.detach(), dim=-1)
sample_s1_ids = torch.multinomial(s1_probs.view(-1, self.s1_vocab_size), 1).view(s1_ids.shape)
sibling_embed = self.embedding.emb_s1(sample_s1_ids)
x2 = self.dep_layer(x, sibling_embed, key_padding_mask=padding_mask) # Dependency Aware Layer: Condition on s1 embeddings
s2_logits = self.head.cond_forward(x2)
return s1_logits, s2_logits
def decode_s1(self, s1_ids, s2_ids, stamp=None, padding_mask=None, news_emb=None):
"""
Decodes only the s1 tokens.
This method performs a forward pass to predict only s1 tokens. It returns the s1 logits
and the context representation from the Transformer, which can be used for subsequent s2 decoding.
Args:
s1_ids (torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len]
s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len]
stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None.
padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None.
news_emb (torch.Tensor, optional): News embedding tensor. Shape: [batch_size, news_dim]. Defaults to None.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
- s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size]
- context: Context representation from the Transformer. Shape: [batch_size, seq_len, d_model]
"""
x = self.embedding([s1_ids, s2_ids])
if stamp is not None:
time_embedding = self.time_emb(stamp)
x = x + time_embedding
x = self.token_drop(x)
for layer in self.transformer:
x = layer(x, key_padding_mask=padding_mask)
x = self.norm(x)
if news_emb is not None and self.news_proj is not None:
news_bias = self.news_proj(news_emb).unsqueeze(1) # [B, 1, d_model]
x = x + news_bias
s1_logits = self.head(x)
return s1_logits, x
def decode_s2(self, context, s1_ids, padding_mask=None):
"""
Decodes the s2 tokens, conditioned on the context and s1 tokens.
This method decodes s2 tokens based on a pre-computed context representation (typically from `decode_s1`)
and the s1 token IDs. It uses the dependency-aware layer and the conditional s2 head to predict s2 tokens.
Args:
context (torch.Tensor): Context representation from the transformer (output of decode_s1).
Shape: [batch_size, seq_len, d_model]
s1_ids (torch.torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len]
padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None.
Returns:
torch.Tensor: s2 logits. Shape: [batch_size, seq_len, s2_vocab_size]
"""
sibling_embed = self.embedding.emb_s1(s1_ids)
x2 = self.dep_layer(context, sibling_embed, key_padding_mask=padding_mask)
return self.head.cond_forward(x2)
def top_k_top_p_filtering(
logits,
top_k: int = 0,
top_p: float = 1.0,
filter_value: float = -float("Inf"),
min_tokens_to_keep: int = 1,
):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens_to_keep per batch example in the output
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
return logits
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None, sample_logits=True):
logits = logits / temperature
if top_k is not None or top_p is not None:
if top_k > 0 or top_p < 1.0:
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
probs = F.softmax(logits, dim=-1)
if not sample_logits:
_, x = top_k(probs, k=1, dim=-1)
else:
x = torch.multinomial(probs, num_samples=1)
return x
def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context, pred_len, clip=5, T=1.0, top_k=0, top_p=0.99, sample_count=5, verbose=False, news_emb=None):
with torch.no_grad():
x = torch.clip(x, -clip, clip)
device = x.device
x = x.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, x.size(1), x.size(2)).to(device)
x_stamp = x_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, x_stamp.size(1), x_stamp.size(2)).to(device)
y_stamp = y_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, y_stamp.size(1), y_stamp.size(2)).to(device)
x_token = tokenizer.encode(x, half=True)
initial_seq_len = x.size(1)
batch_size = x_token[0].size(0)
total_seq_len = initial_seq_len + pred_len
full_stamp = torch.cat([x_stamp, y_stamp], dim=1)
generated_pre = x_token[0].new_empty(batch_size, pred_len)
generated_post = x_token[1].new_empty(batch_size, pred_len)
pre_buffer = x_token[0].new_zeros(batch_size, max_context)
post_buffer = x_token[1].new_zeros(batch_size, max_context)
buffer_len = min(initial_seq_len, max_context)
if buffer_len > 0:
start_idx = max(0, initial_seq_len - max_context)
pre_buffer[:, :buffer_len] = x_token[0][:, start_idx:start_idx + buffer_len]
post_buffer[:, :buffer_len] = x_token[1][:, start_idx:start_idx + buffer_len]
if verbose:
ran = trange
else:
ran = range
for i in ran(pred_len):
current_seq_len = initial_seq_len + i
window_len = min(current_seq_len, max_context)
if current_seq_len <= max_context:
input_tokens = [
pre_buffer[:, :window_len],
post_buffer[:, :window_len]
]
else:
input_tokens = [pre_buffer, post_buffer]
context_end = current_seq_len
context_start = max(0, context_end - max_context)
current_stamp = full_stamp[:, context_start:context_end, :].contiguous()
s1_logits, context = model.decode_s1(input_tokens[0], input_tokens[1], current_stamp, news_emb=news_emb)
s1_logits = s1_logits[:, -1, :]
sample_pre = sample_from_logits(s1_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True)
s2_logits = model.decode_s2(context, sample_pre)
s2_logits = s2_logits[:, -1, :]
sample_post = sample_from_logits(s2_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True)
generated_pre[:, i] = sample_pre.squeeze(-1)
generated_post[:, i] = sample_post.squeeze(-1)
if current_seq_len < max_context:
pre_buffer[:, current_seq_len] = sample_pre.squeeze(-1)
post_buffer[:, current_seq_len] = sample_post.squeeze(-1)
else:
pre_buffer.copy_(torch.roll(pre_buffer, shifts=-1, dims=1))
post_buffer.copy_(torch.roll(post_buffer, shifts=-1, dims=1))
pre_buffer[:, -1] = sample_pre.squeeze(-1)
post_buffer[:, -1] = sample_post.squeeze(-1)
full_pre = torch.cat([x_token[0], generated_pre], dim=1)
full_post = torch.cat([x_token[1], generated_post], dim=1)
context_start = max(0, total_seq_len - max_context)
input_tokens = [
full_pre[:, context_start:total_seq_len].contiguous(),
full_post[:, context_start:total_seq_len].contiguous()
]
z = tokenizer.decode(input_tokens, half=True)
z = z.reshape(-1, sample_count, z.size(1), z.size(2))
preds = z.cpu().numpy()
preds = np.mean(preds, axis=1)
return preds
def calc_time_stamps(x_timestamp):
time_df = pd.DataFrame()
time_df['minute'] = x_timestamp.dt.minute
time_df['hour'] = x_timestamp.dt.hour
time_df['weekday'] = x_timestamp.dt.weekday
time_df['day'] = x_timestamp.dt.day
time_df['month'] = x_timestamp.dt.month
return time_df
class KronosPredictor:
def __init__(self, model, tokenizer, device="cuda:0", max_context=512, clip=5):
self.tokenizer = tokenizer
self.model = model
self.max_context = max_context
self.clip = clip
self.price_cols = ['open', 'high', 'low', 'close']
self.vol_col = 'volume'
self.amt_vol = 'amount'
self.time_cols = ['minute', 'hour', 'weekday', 'day', 'month']
self.device = device
self.tokenizer = self.tokenizer.to(self.device)
self.model = self.model.to(self.device)
def generate(self, x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose, news_emb=None):
x_tensor = torch.from_numpy(np.array(x).astype(np.float32)).to(self.device)
x_stamp_tensor = torch.from_numpy(np.array(x_stamp).astype(np.float32)).to(self.device)
y_stamp_tensor = torch.from_numpy(np.array(y_stamp).astype(np.float32)).to(self.device)
preds = auto_regressive_inference(self.tokenizer, self.model, x_tensor, x_stamp_tensor, y_stamp_tensor, self.max_context, pred_len,
self.clip, T, top_k, top_p, sample_count, verbose, news_emb=news_emb)
preds = preds[:, -pred_len:, :]
return preds
def predict(self, df, x_timestamp, y_timestamp, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True, news_emb=None):
if not isinstance(df, pd.DataFrame):
raise ValueError("Input must be a pandas DataFrame.")
if not all(col in df.columns for col in self.price_cols):
raise ValueError(f"Price columns {self.price_cols} not found in DataFrame.")
df = df.copy()
if self.vol_col not in df.columns:
df[self.vol_col] = 0.0 # Fill missing volume with zeros
df[self.amt_vol] = 0.0 # Fill missing amount with zeros
if self.amt_vol not in df.columns and self.vol_col in df.columns:
df[self.amt_vol] = df[self.vol_col] * df[self.price_cols].mean(axis=1)
if df[self.price_cols + [self.vol_col, self.amt_vol]].isnull().values.any():
raise ValueError("Input DataFrame contains NaN values in price or volume columns.")
x_time_df = calc_time_stamps(x_timestamp)
y_time_df = calc_time_stamps(y_timestamp)
x = df[self.price_cols + [self.vol_col, self.amt_vol]].values.astype(np.float32)
x_stamp = x_time_df.values.astype(np.float32)
y_stamp = y_time_df.values.astype(np.float32)
x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
x = (x - x_mean) / (x_std + 1e-5)
x = np.clip(x, -self.clip, self.clip)
x = x[np.newaxis, :]
x_stamp = x_stamp[np.newaxis, :]
y_stamp = y_stamp[np.newaxis, :]
if news_emb is not None:
news_emb_tensor = torch.from_numpy(np.array(news_emb).astype(np.float32)).to(self.device)
# Ensure batch dimension for news_emb if only one sample
if news_emb_tensor.ndim == 1:
news_emb_tensor = news_emb_tensor.unsqueeze(0)
else:
news_emb_tensor = None
preds = self.generate(x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose, news_emb=news_emb_tensor)
preds = preds.squeeze(0)
preds = preds * (x_std + 1e-5) + x_mean
pred_df = pd.DataFrame(preds, columns=self.price_cols + [self.vol_col, self.amt_vol], index=y_timestamp)
return pred_df
def predict_batch(self, df_list, x_timestamp_list, y_timestamp_list, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True):
"""
Perform parallel (batch) prediction on multiple time series. All series must have the same historical length and prediction length (pred_len).
Args:
df_list (List[pd.DataFrame]): List of input DataFrames, each containing price columns and optional volume/amount columns.
x_timestamp_list (List[pd.DatetimeIndex or Series]): List of timestamps corresponding to historical data, length should match the number of rows in each DataFrame.
y_timestamp_list (List[pd.DatetimeIndex or Series]): List of future prediction timestamps, length should equal pred_len.
pred_len (int): Number of prediction steps.
T (float): Sampling temperature.
top_k (int): Top-k filtering threshold.
top_p (float): Top-p (nucleus sampling) threshold.
sample_count (int): Number of parallel samples per series, automatically averaged internally.
verbose (bool): Whether to display autoregressive progress.
Returns:
List[pd.DataFrame]: List of prediction results in the same order as input, each DataFrame contains
`open, high, low, close, volume, amount` columns, indexed by corresponding `y_timestamp`.
"""
# Basic validation
if not isinstance(df_list, (list, tuple)) or not isinstance(x_timestamp_list, (list, tuple)) or not isinstance(y_timestamp_list, (list, tuple)):
raise ValueError("df_list, x_timestamp_list, y_timestamp_list must be list or tuple types.")
if not (len(df_list) == len(x_timestamp_list) == len(y_timestamp_list)):
raise ValueError("df_list, x_timestamp_list, y_timestamp_list must have consistent lengths.")
num_series = len(df_list)
x_list = []
x_stamp_list = []
y_stamp_list = []
means = []
stds = []
seq_lens = []
y_lens = []
for i in range(num_series):
df = df_list[i]
if not isinstance(df, pd.DataFrame):
raise ValueError(f"Input at index {i} is not a pandas DataFrame.")
if not all(col in df.columns for col in self.price_cols):
raise ValueError(f"DataFrame at index {i} is missing price columns {self.price_cols}.")
df = df.copy()
if self.vol_col not in df.columns:
df[self.vol_col] = 0.0
df[self.amt_vol] = 0.0
if self.amt_vol not in df.columns and self.vol_col in df.columns:
df[self.amt_vol] = df[self.vol_col] * df[self.price_cols].mean(axis=1)
if df[self.price_cols + [self.vol_col, self.amt_vol]].isnull().values.any():
raise ValueError(f"DataFrame at index {i} contains NaN values in price or volume columns.")
x_timestamp = x_timestamp_list[i]
y_timestamp = y_timestamp_list[i]
x_time_df = calc_time_stamps(x_timestamp)
y_time_df = calc_time_stamps(y_timestamp)
x = df[self.price_cols + [self.vol_col, self.amt_vol]].values.astype(np.float32)
x_stamp = x_time_df.values.astype(np.float32)
y_stamp = y_time_df.values.astype(np.float32)
if x.shape[0] != x_stamp.shape[0]:
raise ValueError(f"Inconsistent lengths at index {i}: x has {x.shape[0]} vs x_stamp has {x_stamp.shape[0]}.")
if y_stamp.shape[0] != pred_len:
raise ValueError(f"y_timestamp length at index {i} should equal pred_len={pred_len}, got {y_stamp.shape[0]}.")
x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
x_norm = (x - x_mean) / (x_std + 1e-5)
x_norm = np.clip(x_norm, -self.clip, self.clip)
x_list.append(x_norm)
x_stamp_list.append(x_stamp)
y_stamp_list.append(y_stamp)
means.append(x_mean)
stds.append(x_std)
seq_lens.append(x_norm.shape[0])
y_lens.append(y_stamp.shape[0])
# Require all series to have consistent historical and prediction lengths for batch processing
if len(set(seq_lens)) != 1:
raise ValueError(f"Parallel prediction requires all series to have consistent historical lengths, got: {seq_lens}")
if len(set(y_lens)) != 1:
raise ValueError(f"Parallel prediction requires all series to have consistent prediction lengths, got: {y_lens}")
x_batch = np.stack(x_list, axis=0).astype(np.float32) # (B, seq_len, feat)
x_stamp_batch = np.stack(x_stamp_list, axis=0).astype(np.float32) # (B, seq_len, time_feat)
y_stamp_batch = np.stack(y_stamp_list, axis=0).astype(np.float32) # (B, pred_len, time_feat)
preds = self.generate(x_batch, x_stamp_batch, y_stamp_batch, pred_len, T, top_k, top_p, sample_count, verbose)
# preds: (B, pred_len, feat)
pred_dfs = []
for i in range(num_series):
preds_i = preds[i] * (stds[i] + 1e-5) + means[i]
pred_df = pd.DataFrame(preds_i, columns=self.price_cols + [self.vol_col, self.amt_vol], index=y_timestamp_list[i])
pred_dfs.append(pred_df)
return pred_dfs

View File

@@ -0,0 +1,562 @@
import math
from einops import rearrange, reduce
import torch
import torch.nn as nn
from torch.autograd import Function
import torch.nn.functional as F
class DifferentiableEntropyFunction(Function):
@staticmethod
def forward(ctx, zq, basis, K, eps):
zb = (zq + 1) / 2
zi = ((zb * basis).sum(-1)).to(torch.int64)
cnt = torch.scatter_reduce(torch.zeros(2 ** K, device=zq.device, dtype=zq.dtype),
0,
zi.flatten(),
torch.ones_like(zi.flatten()).to(zq.dtype),
'sum')
prob = (cnt + eps) / (cnt + eps).sum()
H = -(prob * torch.log(prob)).sum()
ctx.save_for_backward(zq, zi, prob)
ctx.K = K
return H
@staticmethod
def backward(ctx, grad_output):
zq, zi, prob = ctx.saved_tensors
grad_array = -grad_output * (torch.log(prob) + 1) / zi.numel() / ctx.K
reord_grad = grad_array[zi.flatten()].reshape(zi.shape)
grad_input = reord_grad.unsqueeze(-1) * zq
return grad_input, None, None, None, None
def codebook_entropy(zq, basis, K, eps=1e-4):
return DifferentiableEntropyFunction.apply(zq, basis, K, eps)
class BinarySphericalQuantizer(nn.Module):
def __init__(self, embed_dim, beta, gamma0, gamma, zeta,
input_format='bchw',
soft_entropy=True, group_size=9,
persample_entropy_compute='analytical',
cb_entropy_compute='group',
l2_norm=True,
inv_temperature=1):
"""
Paper link: https://arxiv.org/pdf/2406.07548.pdf
Here we use the official implementation of the BinarySphericalQuantizer.
"""
super().__init__()
self.embed_dim = embed_dim
self.beta = beta # loss weight for commit loss
self.gamma0 = gamma0 # loss weight for entropy penalty
self.gamma = gamma # loss weight for entropy penalty
self.zeta = zeta # loss weight for entire entropy penalty
self.input_format = input_format
assert self.embed_dim % group_size == 0, "embed_dim must be divisible by group_size"
self.num_groups = self.embed_dim // group_size
self.group_size = group_size
assert persample_entropy_compute in ['group', 'analytical'], "persample_entropy_compute must be either 'group' or 'analytical'"
assert cb_entropy_compute in ['group', 'nce'], "cb_entropy_compute must be either 'group' or 'nce'"
self.persample_entropy_compute = persample_entropy_compute
self.cb_entropy_compute = cb_entropy_compute
self.l2_norm = l2_norm
self.inv_temperature = inv_temperature
self.register_buffer('basis', 2 ** torch.arange(embed_dim - 1, -1, -1))
self.register_buffer('group_basis', 2 ** torch.arange(group_size - 1, -1, -1))
self.num_dimensions = 2 ** embed_dim
self.bits_per_index = embed_dim
# we only need to keep the codebook portion up to the group size
# because we approximate the H loss with this subcode
group_codes = torch.arange(2 ** self.group_size)
group_codebook = self.indexes_to_codes(group_codes).float()[:, -group_size:]
self.register_buffer('group_codebook', group_codebook, persistent=False)
self.soft_entropy = soft_entropy # soft_entropy: Sec 3.2 of https://arxiv.org/pdf/1911.05894.pdf
def quantize(self, z):
assert z.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {z.shape[-1]}"
zhat = torch.where(z > 0,
torch.tensor(1, dtype=z.dtype, device=z.device),
torch.tensor(-1, dtype=z.dtype, device=z.device))
return z + (zhat - z).detach()
def forward(self, z, collect_metrics=True):
# if self.input_format == 'bchw':
# z = rearrange(z, 'b c h w -> b h w c')
zq = self.quantize(z)
q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
zq = zq * q_scale
if not collect_metrics:
return zq, zq.new_zeros(()), {}
indices = self.codes_to_indexes(zq.detach())
group_indices = self.codes_to_group_indexes(zq.detach())
if not self.training:
used_codes = torch.unique(indices, return_counts=False)
else:
used_codes = None
if self.soft_entropy:
persample_entropy, cb_entropy, avg_prob = self.soft_entropy_loss(z)
entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy
else:
zb_by_sample = ((zq + 1) / 2).reshape(z.shape[0], -1, z.shape[-1]).to(torch.float32)
persample_entropy = self.get_hard_per_sample_entropy(zb_by_sample)
cb_entropy = codebook_entropy(zq, self.basis, self.embed_dim)
entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy
# commit loss
commit_loss = self.beta * torch.mean(((zq.detach() - z) ** 2).sum(dim=-1))
# if self.input_format == 'bchw':
# zq = rearrange(zq, 'b h w c -> b c h w')
return (
zq,
commit_loss + self.zeta * entropy_penalty / self.inv_temperature,
{"H": cb_entropy, "used_codes": used_codes, "indices": indices, "group_indices": group_indices,
"avg_prob": avg_prob}
)
def soft_entropy_loss(self, z):
# if we divide the code in subgroups of size group_size, the codebook will be of size 2 ** group_size
# the sub-code is the last group_size bits of the full code
group_code_book = self.group_codebook / (self.embed_dim ** 0.5 if self.l2_norm else 1)
divided_z = rearrange(z, '... (g c) -> ... g c', c=self.group_size)
# we calculate the distance between the divided_z and the codebook for each subgroup
distance = - 2 * torch.einsum('... g c, d c ->... g d', divided_z, group_code_book)
prob = (-distance * self.inv_temperature).softmax(dim=-1)
if self.persample_entropy_compute == 'analytical':
if self.l2_norm:
p = torch.sigmoid(-4 * z / (self.embed_dim ** 0.5) * self.inv_temperature)
else:
p = torch.sigmoid(-4 * z * self.inv_temperature)
prob = torch.stack([p, 1 - p], dim=-1)
per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()
else:
per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()
# macro average of the probability of each subgroup
avg_prob = reduce(prob, '... g d ->g d', 'mean')
codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False)
# the approximation of the entropy is the sum of the entropy of each subgroup
return per_sample_entropy, codebook_entropy.sum(), avg_prob
def get_hard_per_sample_entropy(self, zb_by_sample):
probs_per_dim = zb_by_sample.sum(1) / zb_by_sample.shape[1]
persample_entropy = - probs_per_dim * torch.log(probs_per_dim + 1e-8) - (1 - probs_per_dim) * torch.log(1 - probs_per_dim + 1e-8)
persample_entropy = persample_entropy.sum(-1)
return persample_entropy.mean()
def codes_to_indexes(self, zhat):
"""Converts a `code` to an index in the codebook.
Args:
zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1}
"""
assert zhat.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {zhat.shape[-1]}"
return ((zhat + 1) / 2 * self.basis).sum(axis=-1).to(torch.int64)
def codes_to_group_indexes(self, zhat):
"""Converts a `code` to a list of indexes (in groups) in the codebook.
Args:
zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1}
"""
zhat_in_group = rearrange(zhat, 'b ... (g c) -> b ... g c', c=self.group_size)
return ((zhat_in_group + 1) / 2 * self.group_basis).sum(axis=-1).to(torch.int64)
def indexes_to_codes(self, indices):
"""Inverse of `indexes_to_codes`."""
indices = indices.unsqueeze(-1)
codes_non_centered = torch.remainder(
torch.floor_divide(indices, self.basis), 2
)
return codes_non_centered * 2 - 1
def group_indexes_to_codes(self, group_indices):
"""Inverse of `group_indexes_to_codes`."""
group_indices = group_indices.unsqueeze(-1)
codes_non_centered = torch.remainder(
torch.floor_divide(group_indices, self.group_basis), 2
)
codes_non_centered = rearrange(codes_non_centered, 'b ... g c -> b ... (g c)')
return codes_non_centered * 2 - 1
def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True):
if normalize:
probs = (count + eps) / (count + eps).sum(dim=dim, keepdim=True)
else:
probs = count
H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim)
return H
def get_group_codebook_entry(self, group_indices):
z_q = self.group_indexes_to_codes(group_indices)
q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
z_q = z_q * q_scale
if self.input_format == 'bchw':
h, w = int(z_q.shape[1] ** 0.5)
assert h * w == z_q.shape[1], 'Invalid sequence length'
z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h)
return z_q
def get_codebook_entry(self, indices):
z_q = self.indexes_to_codes(indices)
q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
z_q = z_q * q_scale
if self.input_format == 'bchw':
h, w = int(z_q.shape[1] ** 0.5)
assert h * w == z_q.shape[1], 'Invalid sequence length'
z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h)
return z_q
class BSQuantizer(nn.Module):
def __init__(self, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size):
super().__init__()
self.codebook_dim = s1_bits + s2_bits
self.s1_bits = s1_bits
self.s2_bits = s2_bits
self.bsq = BinarySphericalQuantizer(self.codebook_dim, beta, gamma0, gamma, zeta, group_size=group_size)
def bits_to_indices(self, bits):
bits = (bits >= 0).to(torch.long)
indices = 2 ** torch.arange(
0,
bits.shape[-1],
1,
dtype=torch.long,
device=bits.device,
)
return (bits * indices).sum(-1)
def forward(self, z, half=False, collect_metrics=True):
z = F.normalize(z, dim=-1)
quantized, bsq_loss, metrics = self.bsq(z, collect_metrics=collect_metrics)
if half:
q_pre = quantized[:, :, :self.s1_bits]
q_post = quantized[:, :, self.s1_bits:]
z_indices = [self.bits_to_indices(q_pre), self.bits_to_indices(q_post)]
else:
z_indices = self.bits_to_indices(quantized)
return bsq_loss, quantized, z_indices
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
class FeedForward(nn.Module):
def __init__(self, d_model, ff_dim, ffn_dropout_p=0.0):
super().__init__()
self.w1 = nn.Linear(d_model, ff_dim, bias=False)
self.w3 = nn.Linear(d_model, ff_dim, bias=False)
self.w2 = nn.Linear(ff_dim, d_model, bias=False)
self.ffn_dropout = nn.Dropout(ffn_dropout_p)
def forward(self, x):
return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class RotaryPositionalEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
def _update_cos_sin_cache(self, x, seq_len):
if seq_len != self.seq_len_cached:
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]
return self.cos_cached, self.sin_cached
def forward(self, q, k):
cos, sin = self._update_cos_sin_cache(q, q.shape[-2])
return (
(q * cos) + (self._rotate_half(q) * sin),
(k * cos) + (self._rotate_half(k) * sin),
)
def _rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
class MultiHeadAttentionWithRoPE(nn.Module):
def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout_p=0.0):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
self.rotary = RotaryPositionalEmbedding(self.head_dim)
self.attn_dropout_p = attn_dropout_p
self.resid_dropout = nn.Dropout(resid_dropout_p)
def forward(self, x, key_padding_mask=None):
batch_size, seq_len, _ = x.shape
q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
q, k = self.rotary(q, k)
if key_padding_mask is not None:
attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) # [batch, 1, 1, seq_len]
attn_mask = attn_mask.expand(-1, self.n_heads, seq_len, -1) # [batch, n_heads, q_len, k_len]
else:
attn_mask = None
attn_output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_dropout_p if self.training else 0.0,
is_causal=True
)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
return self.resid_dropout(self.out_proj(attn_output))
class MultiHeadCrossAttentionWithRoPE(nn.Module):
def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout=0.0):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
self.rotary = RotaryPositionalEmbedding(self.head_dim)
self.attn_dropout_p = attn_dropout_p
self.resid_dropout = nn.Dropout(resid_dropout)
def forward(self, query, key, value, key_padding_mask=None):
batch_size, q_len, _ = query.shape
_, seq_len, _ = key.shape
q = self.q_proj(query).view(batch_size, q_len, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(key).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(value).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
q, k = self.rotary(q, k)
if key_padding_mask is not None:
attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2)
attn_mask = attn_mask.expand(-1, self.n_heads, q_len, -1)
else:
attn_mask = None
is_causal_flag = self.training
attn_output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_dropout_p if self.training else 0.0,
is_causal=is_causal_flag
)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, q_len, self.d_model)
return self.resid_dropout(self.out_proj(attn_output))
class HierarchicalEmbedding(nn.Module):
def __init__(self, s1_bits, s2_bits, d_model=256):
super().__init__()
self.s1_bits = s1_bits
self.s2_bits = s2_bits
vocab_s1 = 2 ** s1_bits
vocab_s2 = 2 ** s2_bits
self.emb_s1 = nn.Embedding(vocab_s1, d_model)
self.emb_s2 = nn.Embedding(vocab_s2, d_model)
self.d_model = d_model
self.fusion_proj = nn.Linear(d_model * 2, d_model)
nn.init.normal_(self.emb_s1.weight, mean=0, std=d_model ** -0.5)
nn.init.normal_(self.emb_s2.weight, mean=0, std=d_model ** -0.5)
def split_token(self, token_ids: torch.Tensor, s2_bits: int):
"""Inputs:
token_ids (torch.Tensor): Composite token IDs of shape [batch_size, seq_len] or [N], each in range [0, 2^(s1_bits + s2_bits) - 1].
s2_bits (int): Number of low bits used for the fine token (s2).
"""
assert isinstance(s2_bits, int) and s2_bits > 0, "s2_bits must be a positive integer"
t = token_ids.long()
mask = (1 << s2_bits) - 1
s2_ids = t & mask # extract low bits
s1_ids = t >> s2_bits # extract high bits
return s1_ids, s2_ids
def forward(self, token_ids):
"""Inputs:
token_ids:
- tuple or list: (s1_ids, s2_ids), each of shape [batch_size, seq_len], or
- torch.Tensor: composite token IDs of shape [batch_size, seq_len], which will be split into (s1_ids, s2_ids) internally.
Output: [batch_size, seq_len, d_model]
"""
if isinstance(token_ids, tuple) or isinstance(token_ids, list):
s1_ids, s2_ids = token_ids
else:
s1_ids, s2_ids = self.split_token(token_ids, self.s2_bits)
s1_emb = self.emb_s1(s1_ids) * math.sqrt(self.d_model)
s2_emb = self.emb_s2(s2_ids) * math.sqrt(self.d_model)
return self.fusion_proj(torch.cat([s1_emb, s2_emb], dim=-1))
class DependencyAwareLayer(nn.Module):
def __init__(self, d_model, n_heads=4, attn_dropout_p=0.0, resid_dropout=0.0):
super().__init__()
self.cross_attn = MultiHeadCrossAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout)
self.norm = RMSNorm(d_model)
def forward(self, hidden_states, sibling_embed, key_padding_mask=None):
"""hidden_states: [batch, seq_len, d_model]
sibling_embed: Embedding from another subtoken
"""
attn_out = self.cross_attn(
query=sibling_embed,
key=hidden_states,
value=hidden_states,
key_padding_mask=key_padding_mask
)
return self.norm(hidden_states + attn_out)
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, ff_dim=1024, ffn_dropout_p=0.0, attn_dropout_p=0.0, resid_dropout_p=0.0):
super().__init__()
self.norm1 = RMSNorm(d_model)
self.self_attn = MultiHeadAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout_p)
self.norm2 = RMSNorm(d_model)
self.ffn = FeedForward(d_model, ff_dim, ffn_dropout_p)
def forward(self, x, key_padding_mask=None):
residual = x
x = self.norm1(x)
attn_out = self.self_attn(x, key_padding_mask=key_padding_mask)
x = residual + attn_out
residual = x
x = self.norm2(x)
ffn_out = self.ffn(x)
x = residual + ffn_out
return x
class DualHead(nn.Module):
def __init__(self, s1_bits, s2_bits, d_model):
super().__init__()
self.vocab_s1 = 2 ** s1_bits
self.vocab_s2 = 2 ** s2_bits
self.proj_s1 = nn.Linear(d_model, self.vocab_s1)
self.proj_s2 = nn.Linear(d_model, self.vocab_s2)
def compute_loss(self, s1_logits, s2_logits, s1_targets, s2_targets, padding_mask=None):
if padding_mask is not None:
valid_mask = (padding_mask == 0)
s1_logits = s1_logits[valid_mask]
s2_logits = s2_logits[valid_mask]
s1_targets = s1_targets[valid_mask]
s2_targets = s2_targets[valid_mask]
ce_s1 = F.cross_entropy(s1_logits, s1_targets)
ce_s2 = F.cross_entropy(s2_logits, s2_targets)
else:
ce_s1 = F.cross_entropy(s1_logits.reshape(-1, self.vocab_s1), s1_targets.reshape(-1))
ce_s2 = F.cross_entropy(s2_logits.reshape(-1, self.vocab_s2), s2_targets.reshape(-1))
ce_loss = (ce_s1 + ce_s2) / 2
return ce_loss, ce_s1, ce_s2
def forward(self, x):
return self.proj_s1(x)
def cond_forward(self, x2):
return self.proj_s2(x2)
class FixedEmbedding(nn.Module):
def __init__(self, c_in, d_model):
super(FixedEmbedding, self).__init__()
w = torch.zeros(c_in, d_model).float()
w.require_grad = False
position = torch.arange(0, c_in).float().unsqueeze(1)
div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
w[:, 0::2] = torch.sin(position * div_term)
w[:, 1::2] = torch.cos(position * div_term)
self.emb = nn.Embedding(c_in, d_model)
self.emb.weight = nn.Parameter(w, requires_grad=False)
def forward(self, x):
return self.emb(x).detach()
class TemporalEmbedding(nn.Module):
def __init__(self, d_model, learn_pe):
super(TemporalEmbedding, self).__init__()
minute_size = 60
hour_size = 24
weekday_size = 7
day_size = 32
month_size = 13
Embed = FixedEmbedding if not learn_pe else nn.Embedding
self.minute_embed = Embed(minute_size, d_model)
self.hour_embed = Embed(hour_size, d_model)
self.weekday_embed = Embed(weekday_size, d_model)
self.day_embed = Embed(day_size, d_model)
self.month_embed = Embed(month_size, d_model)
def forward(self, x):
x = x.long()
minute_x = self.minute_embed(x[:, :, 0])
hour_x = self.hour_embed(x[:, :, 1])
weekday_x = self.weekday_embed(x[:, :, 2])
day_x = self.day_embed(x[:, :, 3])
month_x = self.month_embed(x[:, :, 4])
return hour_x + weekday_x + day_x + month_x + minute_x

View File

@@ -0,0 +1,539 @@
import os
import sys
import time
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import json
import random
from loguru import logger
from datetime import datetime, timedelta
from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv
load_dotenv(os.path.expanduser("~/.config/opencode/.env"))
# Setup paths
KRONOS_DIR = os.path.dirname(os.path.abspath(__file__))
SRC_DIR = os.path.dirname(os.path.dirname(KRONOS_DIR))
if SRC_DIR not in sys.path:
sys.path.insert(0, SRC_DIR)
from ..kronos.model import Kronos, KronosTokenizer, KronosPredictor
from ..database_manager import DatabaseManager
from ..stock_tools import StockTools
from ..search_tools import SearchTools
from ..llm.factory import get_model
from ..visualizer import VisualizerTools
from ..schema.models import ForecastResult, KLinePoint
from agno.agent import Agent
class AutoSynthesisTrainer:
def __init__(self, news_dim=384):
self.device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
self.db = DatabaseManager()
self.tools = StockTools(self.db)
self.searcher = SearchTools(self.db)
# Try loading from local cache first to avoid network timeouts
model_name = os.getenv(
"EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
)
try:
logger.info(f"🔄 Attempting to load {model_name} from local cache...")
self.embedder = SentenceTransformer(
model_name, device=self.device, local_files_only=True
)
logger.success("✅ Model loaded from local cache.")
except Exception:
logger.warning(
"⚠️ Local cache not found or incomplete. Attempting to download..."
)
self.embedder = SentenceTransformer(model_name, device=self.device)
self.news_dim = news_dim
# Try loading from local cache first to avoid network timeouts
try:
logger.info(
"🔄 Attempting to load Kronos and Tokenizer from local cache..."
)
self.tokenizer = KronosTokenizer.from_pretrained(
"NeoQuasar/Kronos-Tokenizer-base", local_files_only=True
).to(self.device)
base_model = Kronos.from_pretrained(
"NeoQuasar/Kronos-base", local_files_only=True
)
logger.success("✅ Kronos and Tokenizer loaded from local cache.")
except Exception:
logger.warning(
"⚠️ Local Kronos/Tokenizer not found or incomplete. Attempting to download..."
)
self.tokenizer = KronosTokenizer.from_pretrained(
"NeoQuasar/Kronos-Tokenizer-base"
).to(self.device)
base_model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
self.model = Kronos(
base_model.s1_bits,
base_model.s2_bits,
base_model.n_layers,
base_model.d_model,
base_model.n_heads,
base_model.ff_dim,
base_model.ffn_dropout_p,
base_model.attn_dropout_p,
base_model.resid_dropout_p,
base_model.token_dropout_p,
base_model.learn_te,
news_dim=self.news_dim,
).to(self.device)
self.model.load_state_dict(base_model.state_dict(), strict=False)
# LLM for causality verification
provider = os.getenv("LLM_PROVIDER", "ust")
model_id = os.getenv("LLM_MODEL", "Qwen")
self.llm_agent = Agent(model=get_model(provider, model_id))
def discover_shocks(
self, ticker_list, threshold=2.0, limit_per_stock=5, days=365, pred_len=5
):
"""1. Find days with significant price movements (Look back 1 year)"""
shocks = []
end_date = datetime.now().strftime("%Y-%m-%d")
start_date = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d")
for ticker in ticker_list:
df = self.tools.get_stock_price(
ticker, start_date=start_date, end_date=end_date
)
if df.empty or len(df) < 60:
continue
# Look for big moves
moves = df[df["change_pct"].abs() > threshold].copy()
if moves.empty:
continue
count = 0
for idx, row in moves.iterrows():
# Ensure we have history before this day AND enough future days for eval
date_idx = df.index.get_loc(idx)
if date_idx < 50 or date_idx + pred_len > len(df):
continue
shocks.append(
{
"ticker": ticker,
"date": row["date"],
"change": row["change_pct"],
"history": df.iloc[date_idx - 50 : date_idx],
"target": df.iloc[
date_idx : date_idx + pred_len
], # Now capturing pred_len days
}
)
count += 1
if count >= limit_per_stock:
break
logger.info(
f"✨ Discovered {len(shocks)} potential price shocks over the last {days} days."
)
return shocks
def find_reason_and_verify(self, shock):
"""2. Search for reasons and verify causality using LLM"""
ticker_info = self.db.get_stock_by_code(shock["ticker"])
name = ticker_info["name"] if ticker_info else shock["ticker"]
date_str = shock["date"]
# Try multiple query variations and engines
queries = [
f"{name} ({shock['ticker']}) {date_str} 为什么涨跌 原因",
f"{name} {date_str} 异动 原因",
f"{shock['ticker']} {date_str} 新闻",
]
search_results = []
for query in queries:
logger.info(f"🔍 Searching for reason: {query}")
# Try alternate engines
for engine in ["baidu"]:
try:
results = self.searcher.search_list(
query, engine=engine, max_results=3, enrich=False
)
if results:
search_results = results
break
except Exception as e:
logger.warning(f"Search failed for {query} on {engine}: {e}")
if search_results:
break
time.sleep(random.uniform(1.0, 2.0))
if not search_results:
logger.warning(
f"⚠️ No search results found for {name} on {date_str} after multiple attempts."
)
return None
context = "\n".join(
[f"- {r['title']}: {r.get('content', '')[:300]}" for r in search_results]
)
prompt = f"""
任务:判断以下新闻是否解释了该股票在 {date_str}{shock["change"]:.2f}% 价格变动。
股票:{name}
日期:{date_str}
变动:{shock["change"]:.2f}%
搜索结果:
{context}
要求:
1. 该新闻是否在该日期左右发生?
2. 该新闻是否能逻辑上解释这种大幅波动(如财报、利好政策、重组、大环境暴跌等)?
3. 如果是,请总结一段 100 字以内的“核心推动原因”。
4. 返回 JSON: {{"is_causal": true/false, "summary": "原因摘要"}}
"""
try:
res = self.llm_agent.run(prompt)
data = json.loads(
res.content.replace("```json", "").replace("```", "").strip()
)
if data.get("is_causal"):
logger.success(
f"✅ Verified cause for {name} on {date_str}: {data['summary']}"
)
return data["summary"]
else:
logger.warning(
f"❌ Verified cause for {name} on {date_str}: {data['summary']}"
)
return None
except Exception as e:
logger.warning(f"Verification failed: {e}")
return None
def save_model(self, path=None):
"""Save the news_proj weights"""
if path is None:
save_dir = os.path.join(SRC_DIR, "exports/models")
os.makedirs(save_dir, exist_ok=True)
path = os.path.join(
save_dir, f"kronos_news_v1_{datetime.now().strftime('%Y%m%d_%H%M')}.pt"
)
# We only really need to save the news_proj part as it's the only one we train
torch.save(
{
"news_proj_state_dict": self.model.news_proj.state_dict(),
"news_dim": self.news_dim,
"d_model": self.model.d_model,
},
path,
)
logger.success(f"💾 Model weights saved to {path}")
return path
def run_synthesis_and_train(self, tickers, pred_len=5):
# 1. Discovery
shocks = self.discover_shocks(tickers, pred_len=pred_len)
print(f"find {len(shocks)} shocks")
# 2. News Association & Verification
dataset = []
max_news_items = 200 # Limit to 200 news items per session to avoid search bans
logger.info(
f"🧬 Starting News Association for {len(shocks)} shocks (Max limit: {max_news_items})"
)
for i, shock in enumerate(shocks):
if len(dataset) >= max_news_items:
logger.info("Reached maximum news items limit for this session.")
break
summary = self.find_reason_and_verify(shock)
if summary:
# 3. Embedding news
emb = self.embedder.encode(summary)
dataset.append(
{
"history": shock["history"],
"target": shock["target"],
"news_emb": emb,
"summary": summary,
}
)
# Add delay after search with randomness to avoid being blocked
if i < len(shocks) - 1:
delay = random.uniform(2.0, 4.0)
time.sleep(delay)
if not dataset:
logger.error(
"❌ No verified news-price pairs found. Adjust threshold or check if news is available in that period."
)
return
# 4. Train/Val Split
random.seed(42)
random.shuffle(dataset)
if len(dataset) < 2:
train_set = dataset
val_set = []
logger.warning(
f"⚠️ Only {len(dataset)} sample(s) found. Training on all, skipping validation."
)
else:
split_idx = max(1, int(len(dataset) * 0.8))
if split_idx >= len(dataset):
split_idx = len(dataset) - 1
train_set = dataset[:split_idx]
val_set = dataset[split_idx:]
logger.info(
f"🏗️ Dataset Split: {len(train_set)} samples for training, {len(val_set)} for validation."
)
if not train_set:
logger.error("❌ No samples for training.")
return
# 5. Training (Few-shot)
optimizer = torch.optim.Adam(self.model.news_proj.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
self.model.train()
loss_history = []
logger.info(f"🚀 Training for 30 epochs...")
for epoch in range(30):
total_loss = 0
for item in train_set:
optimizer.zero_grad()
# Prep Data
hist_df = item["history"]
# For training, we still focus on the immediate next point (teacher forcing)
target_df = item["target"].iloc[:1]
hist_raw = hist_df[
["open", "high", "low", "close", "volume"]
].values.astype(np.float32)
hist_raw = np.column_stack([hist_raw, hist_raw[:, 3] * hist_raw[:, 4]])
mean, std = hist_raw.mean(axis=0), hist_raw.std(axis=0) + 1e-5
hist_norm = (
torch.from_numpy((hist_raw - mean) / std)
.unsqueeze(0)
.to(self.device)
)
target_raw = target_df[
["open", "high", "low", "close", "volume"]
].values.astype(np.float32)
target_raw = np.column_stack(
[target_raw, target_raw[:, 3] * target_raw[:, 4]]
)
target_norm = (
torch.from_numpy((target_raw - mean) / std)
.unsqueeze(0)
.to(self.device)
)
with torch.no_grad():
z_indices = self.tokenizer.encode(hist_norm, half=True)
t_indices = self.tokenizer.encode(target_norm, half=True)
s1_ids, s2_ids = z_indices[0], z_indices[1]
t_s1, t_s2 = t_indices[0], t_indices[1]
news_t = torch.from_numpy(item["news_emb"]).unsqueeze(0).to(self.device)
s1_logits, s2_logits = self.model(
s1_ids,
s2_ids,
news_emb=news_t,
use_teacher_forcing=True,
s1_targets=t_s1,
)
loss = (
criterion(s1_logits[:, -1, :], t_s1[:, 0])
+ criterion(s2_logits[:, -1, :], t_s2[:, 0])
) / 2
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_epoch_loss = total_loss / max(1, len(train_set))
loss_history.append(avg_epoch_loss)
if (epoch + 1) % 10 == 0:
logger.info(f"Epoch {epoch + 1} Loss: {avg_epoch_loss:.4f}")
# 5.1 Visualize Loss Curve
loss_chart = VisualizerTools.generate_loss_chart(loss_history)
VisualizerTools.render_chart_to_file(
loss_chart,
os.path.join(SRC_DIR, "exports/training_results/loss_curve.html"),
)
# 5.2 Save final model
self.save_model()
# 6. Final Evaluation on Validation Set
if not val_set:
logger.warning("⚠️ Validation set is empty. Skipping statistical analysis.")
return
logger.info(
f"🧪 Final Evaluation: Base vs News-Integrated ({pred_len}-day Window)"
)
self.model.eval()
predictor = KronosPredictor(self.model, self.tokenizer, device=self.device)
base_maes = []
news_maes = []
print("\n" + "=" * 90)
print(
f"{'Date':<12} | {'Ticker':<8} | {'Base MAE':<15} | {'News MAE':<15} | {'Improvement'}"
)
print("-" * 90)
for item in val_set:
h = item["history"]
t = item["target"]
actuals = t["close"].values[:pred_len]
x_ts = pd.to_datetime(h["date"])
# Future timestamps: handle business days if possible, or just simple offset
future_dates = pd.date_range(
start=x_ts.iloc[-1] + timedelta(days=1), periods=pred_len, freq="B"
)
y_ts = pd.Series(future_dates)
# A. Base Prediction
p_base = predictor.predict(
h, x_ts, y_ts, pred_len=pred_len, news_emb=None, verbose=False
)
b_preds = p_base["close"].values[: len(actuals)]
# B. News-Aware Prediction
p_news = predictor.predict(
h,
x_ts,
y_ts,
pred_len=pred_len,
news_emb=item["news_emb"],
verbose=False,
)
n_preds = p_news["close"].values[: len(actuals)]
# Calculate MAE over the window
b_mae = np.mean(np.abs(b_preds - actuals))
n_mae = np.mean(np.abs(n_preds - actuals))
base_maes.append(b_mae)
news_maes.append(n_mae)
improvement = (b_mae - n_mae) / (b_mae + 1e-6) * 100
date_str = str(t["date"].values[0])[:10]
ticker = h.iloc[-1]["ticker"] if "ticker" in h.columns else "Stock"
print(
f"{date_str:<12} | {ticker:<8} | {b_mae:<15.4f} | {n_mae:<15.4f} | {improvement:>+7.1f}%"
)
# C. Generate Visualization for this case
try:
# Helper to convert DF to KLinePoints
def to_kp_list(preds_df):
points = []
for idx, row in preds_df.iterrows():
points.append(
KLinePoint(
date=str(idx)[:10],
open=row["open"],
high=row["high"],
low=row["low"],
close=row["close"],
volume=row["volume"] if "volume" in row else 0,
)
)
return points
forecast_obj = ForecastResult(
ticker=ticker,
base_forecast=to_kp_list(p_base),
adjusted_forecast=to_kp_list(p_news),
rationale=item["summary"],
)
# Ground truth for visualizer expects a DataFrame with 'date' and 'close'
gt_df = t[["date", "open", "high", "low", "close", "volume"]]
chart = VisualizerTools.generate_stock_chart(
df=h,
ticker=ticker,
title=f"Training Eval: {ticker} ({date_str}) Improvement: {improvement:.1f}%",
forecast=forecast_obj,
ground_truth=gt_df,
)
safe_date = date_str.replace("-", "")
filename = f"eval_{ticker}_{safe_date}.html"
VisualizerTools.render_chart_to_file(
chart, os.path.join(SRC_DIR, f"exports/training_results/{filename}")
)
except Exception as e:
logger.error(f"Failed to generate eval chart for {ticker}: {e}")
# Summary Statistics
avg_base_err = sum(base_maes) / max(1, len(base_maes))
avg_news_err = sum(news_maes) / max(1, len(news_maes))
overall_imp = (avg_base_err - avg_news_err) / (avg_base_err + 1e-6) * 100
print("-" * 90)
print(
f"{'AVERAGE':<12} | {'-':<8} | {avg_base_err:<15.4f} | {avg_news_err:<15.4f} | {overall_imp:>+7.1f}%"
)
print("=" * 90 + "\n")
logger.success(
f"🏁 Statistical Analysis Complete. Avg Error Reduction ({pred_len}-day): {overall_imp:.2f}%"
)
logger.info(
f"📊 Visualization results saved to: {os.path.join(SRC_DIR, 'exports/training_results/')}"
)
if __name__ == "__main__":
trainer = AutoSynthesisTrainer()
logger.info("📂 Fetching all stock codes from database...")
res = trainer.db.execute_query("SELECT code FROM stock_list")
all_tickers = [row["code"] for row in res]
if not all_tickers:
logger.warning("⚠️ No tickers found in stock_list table. Trying to sync...")
trainer.tools._check_and_update_stock_list(force=True)
res = trainer.db.execute_query("SELECT code FROM stock_list")
all_tickers = [row["code"] for row in res]
logger.info(f"🚀 Starting training on potential stocks (1-year scan)...")
# 为了演示,我们扫描前 100 个股票,寻找最近一年的冲击点
trainer.run_synthesis_and_train(all_tickers[:100], pred_len=1)

View File

@@ -0,0 +1,611 @@
import os
import hashlib
import json
import re
import requests
import time
import threading
from typing import List, Dict, Optional, Any
from agno.tools.duckduckgo import DuckDuckGoTools
from agno.tools.baidusearch import BaiduSearchTools
from agno.agent import Agent
from loguru import logger
from datetime import datetime
from .database_manager import DatabaseManager
from .content_extractor import ContentExtractor
from .llm.factory import get_model
from .hybrid_search import LocalNewsSearch
# 默认搜索缓存 TTL可通过环境变量覆盖
DEFAULT_SEARCH_TTL = int(os.getenv("SEARCH_CACHE_TTL", "3600")) # 默认 1 小时
class JinaSearchEngine:
"""Jina Search API 封装 - 使用 s.jina.ai 进行网络搜索"""
JINA_SEARCH_URL = "https://s.jina.ai/"
# 速率限制配置
_rate_limit_no_key = 10 # 无 key 时每分钟最大请求数
_rate_window = 60.0
_min_interval = 2.0
_request_times = []
_last_request_time = 0.0
_lock = threading.Lock()
def __init__(self):
self.api_key = os.getenv("JINA_API_KEY", "").strip()
self.has_api_key = bool(self.api_key)
if self.has_api_key:
logger.info("✅ Jina Search API key configured")
@classmethod
def _wait_for_rate_limit(cls, has_api_key: bool) -> None:
"""等待以满足速率限制"""
if has_api_key:
time.sleep(0.3)
return
with cls._lock:
current_time = time.time()
cls._request_times = [t for t in cls._request_times if current_time - t < cls._rate_window]
if len(cls._request_times) >= cls._rate_limit_no_key:
oldest = cls._request_times[0]
wait_time = cls._rate_window - (current_time - oldest) + 1.0
if wait_time > 0:
logger.warning(f"⏳ Jina Search rate limit, waiting {wait_time:.1f}s...")
time.sleep(wait_time)
current_time = time.time()
cls._request_times = [t for t in cls._request_times if current_time - t < cls._rate_window]
time_since_last = current_time - cls._last_request_time
if time_since_last < cls._min_interval:
time.sleep(cls._min_interval - time_since_last)
cls._request_times.append(time.time())
cls._last_request_time = time.time()
def search(self, query: str, max_results: int = 5) -> List[Dict]:
"""
使用 Jina Search API 执行搜索
Args:
query: 搜索关键词
max_results: 返回结果数量
Returns:
搜索结果列表,每个结果包含 title, url, content
"""
if not query:
return []
logger.info(f"🔍 Jina Search: {query}")
# 等待速率限制
self._wait_for_rate_limit(self.has_api_key)
headers = {
"Accept": "application/json",
"X-Retain-Images": "none",
}
if self.has_api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
try:
# Jina Search API: https://s.jina.ai/{query}
import urllib.parse
encoded_query = urllib.parse.quote(query)
url = f"{self.JINA_SEARCH_URL}{encoded_query}"
response = requests.get(url, headers=headers, timeout=30)
if response.status_code == 429:
logger.warning("⚠️ Jina Search rate limited (429), waiting 30s...")
time.sleep(30)
return self.search(query, max_results)
if response.status_code != 200:
logger.warning(f"Jina Search failed (Status {response.status_code})")
return []
# 解析响应
try:
data = response.json()
except json.JSONDecodeError:
# 如果返回纯文本,尝试解析
data = {"data": [{"title": "Search Result", "url": "", "content": response.text}]}
results = []
# Jina 返回格式可能是 {"data": [...]} 或直接是列表
items = data.get("data", []) if isinstance(data, dict) else data
if not isinstance(items, list):
items = [items] if items else []
for i, item in enumerate(items[:max_results]):
if isinstance(item, dict):
results.append({
"title": item.get("title", f"Result {i+1}"),
"url": item.get("url", ""),
"href": item.get("url", ""), # 兼容性
"content": item.get("content", item.get("description", "")),
"body": item.get("content", item.get("description", "")), # 兼容性
})
elif isinstance(item, str):
results.append({
"title": f"Result {i+1}",
"url": "",
"content": item
})
logger.info(f"✅ Jina Search returned {len(results)} results")
return results
except requests.exceptions.Timeout:
logger.error("Jina Search timeout")
return []
except requests.exceptions.RequestException as e:
logger.error(f"Jina Search request error: {e}")
return []
except Exception as e:
logger.error(f"Jina Search unexpected error: {e}")
return []
class SearchTools:
"""扩展性搜索工具库 - 支持多引擎聚合与内容缓存"""
def __init__(self, db: DatabaseManager):
self.db = db
# 检查 Jina API Key 是否配置
jina_api_key = os.getenv("JINA_API_KEY", "").strip()
self._jina_enabled = bool(jina_api_key)
self._engines = {
"ddg": DuckDuckGoTools(),
"baidu": BaiduSearchTools(),
"local": LocalNewsSearch(db)
}
# 如果配置了 Jina API Key添加 Jina 引擎
if self._jina_enabled:
self._engines["jina"] = JinaSearchEngine()
logger.info("🚀 Jina Search engine enabled (JINA_API_KEY configured)")
# 确定默认搜索引擎
self._default_engine = "jina" if self._jina_enabled else "ddg"
def _generate_hash(self, query: str, engine: str, max_results: int) -> str:
return hashlib.md5(f"{engine}:{query}:{max_results}".encode()).hexdigest()
def search(self, query: str, engine: str = None, max_results: int = 5, ttl: Optional[int] = None) -> str:
"""
使用指定搜索引擎执行网络搜索,结果会被缓存以提高效率。
Args:
query: 搜索关键词,如 "英伟达财报""光伏行业政策"
engine: 搜索引擎选择。可选值:
"jina" (Jina Search需配置 JINA_API_KEYLLM友好输出),
"ddg" (DuckDuckGo推荐英文/国际搜索),
"baidu" (百度,推荐中文/国内搜索),
"local" (本地历史新闻搜索,基于向量+BM25)。
默认: 若配置了 JINA_API_KEY 则使用 "jina",否则 "ddg"
max_results: 期望返回的结果数量,默认 5 条。
ttl: 缓存有效期(秒)。如果缓存超过此时间会重新搜索。
默认使用环境变量 SEARCH_CACHE_TTL 或 3600 秒。
设为 0 可强制刷新。
Returns:
搜索结果的文本描述,包含标题、摘要和链接。
"""
# 使用默认引擎(如果配置了 Jina 则优先使用 Jina
if engine is None:
engine = self._default_engine
if engine not in self._engines:
return f"Error: Unsupported engine '{engine}'. Available: {list(self._engines.keys())}"
query_hash = self._generate_hash(query, engine, max_results)
effective_ttl = ttl if ttl is not None else DEFAULT_SEARCH_TTL
# 1. 尝试从缓存读取 (local 引擎不缓存,因为它本身就是查库)
if engine != "local":
cache = self.db.get_search_cache(query_hash, ttl_seconds=effective_ttl if effective_ttl > 0 else None)
if cache and effective_ttl != 0:
logger.info(f" Found search results in cache for: {query} ({engine})")
return cache['results']
# 2. 执行真实搜索
logger.info(f"📡 Searching {engine} for: {query}")
try:
tool = self._engines[engine]
if engine == "jina":
# Jina Search 返回 List[Dict]
jina_results = tool.search(query, max_results=max_results)
results = []
for r in jina_results:
results.append({
"title": r.get("title", ""),
"href": r.get("url", ""),
"body": r.get("content", "")
})
elif engine == "ddg":
results = tool.duckduckgo_search(query, max_results=max_results)
elif engine == "baidu":
results = tool.baidu_search(query, max_results=max_results)
elif engine == "local":
# LocalNewsSearch 返回的是 List[Dict]
local_results = tool.search(query, top_n=max_results)
results = []
for r in local_results:
results.append({
"title": r.get("title"),
"href": r.get("url", "local"),
"body": r.get("content", "")
})
else:
results = "Search not implemented for this engine."
results_str = str(results)
if engine != "local":
self.db.save_search_cache(query_hash, query, engine, results_str)
return results_str
except Exception as e:
# 搜索失败时的降级策略
if engine == "jina":
logger.warning(f"⚠️ Jina search failed, falling back to ddg: {query} ({e})")
try:
return self.search(query, engine="ddg", max_results=max_results, ttl=ttl)
except Exception as e2:
logger.error(f"❌ DDG fallback also failed for {query}: {e2}")
elif engine == "ddg":
logger.warning(f"⚠️ DDG search failed, falling back to baidu: {query} ({e})")
try:
return self.search(query, engine="baidu", max_results=max_results, ttl=ttl)
except Exception as e2:
logger.error(f"❌ Baidu fallback also failed for {query}: {e2}")
logger.error(f"❌ Search failed for {query}: {e}")
return f"Error occurred during search: {str(e)}"
def search_list(self, query: str, engine: str = None, max_results: int = 5, ttl: Optional[int] = None, enrich: bool = True) -> List[Dict]:
"""
执行搜索并返回结构化列表 (List[Dict])。
Dict 包含: title, href (or url), body (or snippet)
Args:
engine: 搜索引擎默认使用配置的默认引擎Jina 优先)
enrich: 是否抓取正文内容 (默认 True)
"""
# 使用默认引擎
if engine is None:
engine = self._default_engine
if engine not in self._engines:
logger.error(f"Unsupported engine {engine}")
return []
# 不同的 hash 以区分是否 enrichment
enrich_suffix = ":enriched" if enrich else ""
query_hash = self._generate_hash(query, engine + enrich_suffix, max_results)
effective_ttl = ttl if ttl is not None else DEFAULT_SEARCH_TTL
# 1. 尝试从缓存读取
cache = self.db.get_search_cache(query_hash, ttl_seconds=effective_ttl if effective_ttl > 0 else None)
if cache and effective_ttl != 0:
try:
cached_data = json.loads(cache['results'])
if isinstance(cached_data, list):
logger.info(f" Found structured search cache for: {query}")
return cached_data
except:
pass
# 1.5 Smart Cache (Fuzzy + LLM)
if effective_ttl != 0:
try:
# 1. Similar cached queries
similar_queries = self.db.find_similar_queries(query, limit=3)
# Filter by TTL
valid_candidates = []
for q in similar_queries:
if q['query'] == query: continue
q_time = datetime.fromisoformat(q['timestamp'])
if effective_ttl and (datetime.now() - q_time).total_seconds() > effective_ttl:
continue
q['type'] = 'cached_search'
valid_candidates.append(q)
# 2. Relevant local news (as search results)
local_news = self.db.search_local_news(query, limit=3)
if local_news:
# Group local news as a single "candidate" source? Or individual?
# Better to treat "Local News Database" as one candidate source that contains X items.
# Or just add them to candidates list?
# Let's package strictly relevant news as a "local_news_bundle"
valid_candidates.append({
'type': 'local_news',
'query': 'Local Database News',
'items': local_news,
'timestamp': datetime.now().isoformat()
})
if valid_candidates:
logger.info(f"🤔 Found {len(valid_candidates)} smart cache candidates (Queries/News). Asking LLM...")
evaluation = self._evaluate_cache_relevance(query, valid_candidates)
if evaluation and evaluation.get('reuse', False):
idx = evaluation.get('index', -1)
if 0 <= idx < len(valid_candidates):
chosen = valid_candidates[idx]
logger.info(f"🤖 LLM suggested reusing: '{chosen.get('query')}' ({chosen['type']})")
if chosen['type'] == 'cached_search':
# Load the chosen cache
cache = self.db.get_search_cache(chosen['query_hash'])
if cache:
try:
cached_data = json.loads(cache['results'])
if isinstance(cached_data, list):
return cached_data
except:
pass
elif chosen['type'] == 'local_news':
# Convert local news items to search result format
news_results = []
for i, news in enumerate(chosen['items'], 1):
news_results.append({
"id": news.get('id'),
"rank": i,
"title": news.get('title'),
"url": news.get('url'),
"content": news.get('content'),
"original_snippet": news.get('content')[:200] if news.get('content') else '',
"source": f"Local News ({news.get('source')})",
"publish_time": news.get('publish_time'),
"crawl_time": news.get('crawl_time'),
"sentiment_score": news.get('sentiment_score', 0),
"meta_data": {"origin": "local_db"}
})
return news_results
except Exception as e:
logger.warning(f"Smart cache check failed: {e}")
# 2. 执行搜索
logger.info(f"📡 Searching {engine} (structured) for: {query}")
try:
tool = self._engines[engine]
results = []
if engine == "jina":
# Jina Search 直接返回结构化数据
jina_results = tool.search(query, max_results=max_results)
for r in jina_results:
results.append({
"title": r.get("title", ""),
"url": r.get("url", ""),
"href": r.get("url", ""),
"body": r.get("content", ""),
"content": r.get("content", ""),
"source": "Jina Search"
})
elif engine == "ddg":
results = tool.duckduckgo_search(query, max_results=max_results)
elif engine == "baidu":
results = tool.baidu_search(query, max_results=max_results)
elif engine == "local":
# LocalNewsSearch 返回的是 List[Dict]
local_results = tool.search(query, top_n=max_results)
results = []
for r in local_results:
results.append({
"title": r.get("title"),
"url": r.get("url", "local"),
"body": r.get("content", "")[:500],
"source": f"Local ({r.get('source', 'db')})",
"publish_time": r.get("publish_time")
})
# 处理字符串类型的 JSON 返回 (Baidu 常返 JSON 字符串)
if isinstance(results, str) and engine not in ["local", "jina"]:
try:
results = json.loads(results)
except:
pass
# 转为统一格式
normalized_results = []
if isinstance(results, list):
for i, r in enumerate(results, 1):
title = r.get('title', '')
url = r.get('href') or r.get('url') or r.get('link', '')
content = r.get('body') or r.get('snippet') or r.get('abstract', '')
if title and url:
normalized_results.append({
"id": self._generate_hash(url + query, "search_item", i),
"rank": i,
"title": title,
"url": url,
"content": content,
"original_snippet": content, # 保留摘要
"source": f"Search ({engine})",
"publish_time": datetime.now().isoformat(), # 暂用当前时间
"crawl_time": datetime.now().isoformat(),
"meta_data": {"query": query, "engine": engine}
})
# Fallback if still string and failed to parse
elif isinstance(results, str) and results:
normalized_results.append({"title": query, "url": "", "content": results, "source": engine})
# 3. 抓取正文 & 计算情绪 (Enrichment)
# 注意:如果使用 Jina Search内容已经是 LLM 友好格式,可选择跳过 enrichment
skip_content_enrichment = (engine == "jina")
if enrich and normalized_results:
logger.info(f"🕸️ Enriching {len(normalized_results)} search results with Jina & Sentiment...")
extractor = ContentExtractor()
# Lazy load sentiment tool
if not hasattr(self, 'sentiment_tool') or self.sentiment_tool is None:
from ..sentiment_tools import SentimentTools
self.sentiment_tool = SentimentTools(self.db)
for item in normalized_results:
if item.get("url"):
try:
# 如果是 Jina Search内容已经足够好跳过额外抓取
if skip_content_enrichment and item.get("content") and len(item.get("content", "")) > 100:
full_content = item["content"]
else:
# Use Jina Reader to get full content
full_content = extractor.extract_with_jina(item["url"], timeout=60)
if full_content and len(full_content) > 100:
item["content"] = full_content
# Calculate sentiment
# Use title + snippet of content for efficiency
text_to_analyze = f"{item['title']} {full_content[:500]}"
sent_result = self.sentiment_tool.analyze_sentiment(text_to_analyze) # Using self.sentiment_tool
score = sent_result.get('score', 0.0)
item["sentiment_score"] = float(score)
logger.info(f" ✅ Enriched: {item['title'][:20]}... (Sentiment: {score:.2f})")
else:
# Fallback: Use snippet for sentiment
logger.info(f" ⚠️ Content short/failed for {item['url']}, using snippet for sentiment.")
text_to_analyze = f"{item['title']} {item['content']}" # content is snippet here
sent_result = self.sentiment_tool.analyze_sentiment(text_to_analyze)
score = sent_result.get('score', 0.0)
item["sentiment_score"] = float(score)
except Exception as e:
# Fallback: Use snippet for sentiment on error
logger.warning(f"Failed to enrich {item['url']}: {e}. Using snippet.")
text_to_analyze = f"{item['title']} {item['content']}"
sent_result = self.sentiment_tool.analyze_sentiment(text_to_analyze)
score = sent_result.get('score', 0.0)
item["sentiment_score"] = float(score)
# 缓存结果 list
if normalized_results:
# Pass list directly, DB manager will handle JSON dump for main cache and populate search_details
# Only cache if NOT from local news reuse (though this logic path is for fresh search)
self.db.save_search_cache(query_hash, query, engine, normalized_results)
return normalized_results
except Exception as e:
# 搜索失败时的降级策略
if engine == "jina":
logger.warning(f"⚠️ Jina search_list failed, falling back to ddg: {query} ({e})")
try:
return self.search_list(query, engine="ddg", max_results=max_results, ttl=ttl, enrich=enrich)
except Exception as e2:
logger.error(f"❌ DDG fallback (search_list) also failed for {query}: {e2}")
elif engine == "ddg":
logger.warning(f"⚠️ DDG search_list failed, falling back to baidu: {query} ({e})")
try:
return self.search_list(query, engine="baidu", max_results=max_results, ttl=ttl, enrich=enrich)
except Exception as e2:
logger.error(f"❌ Baidu fallback (search_list) also failed for {query}: {e2}")
logger.error(f"❌ Structured search failed for {query}: {e}")
return []
def _evaluate_cache_relevance(self, current_query: str, candidates: List[Dict]) -> Dict:
"""
使用 LLM 评估缓存候选是否足以回答当前问题。
"""
try:
# Prepare candidates text
candidates_desc = []
for i, c in enumerate(candidates):
if c['type'] == 'cached_search':
# Preview cached results if available?
# Maybe just use the query string as a proxy for what's in there.
# Or peek at 'results' snippet.
preview = ""
try:
# Attempt to peek first result title from JSON string
# Note: c.get('results') might be a stringified JSON list
res_list = json.loads(c.get('results', '[]'))
if res_list and isinstance(res_list, list) and len(res_list) > 0:
first_item = res_list[0]
if isinstance(first_item, dict) and 'title' in first_item:
preview = f" (Contains: {first_item.get('title', '')[:50]}...)"
except:
pass
candidates_desc.append(f"[{i}] Old Search Query: '{c['query']}' {preview} (Time: {c['timestamp']})")
elif c['type'] == 'local_news':
# List titles of local news
titles = [item['title'] for item in c['items'][:3]]
candidates_desc.append(f"[{i}] Local Database News: {', '.join(titles)}... (Time: {c['timestamp']})")
prompt = f"""
Task: Decide if existing information is sufficient for the new search query.
New Query: "{current_query}"
Available Information Candidates:
{chr(10).join(candidates_desc)}
Instructions:
1. Analyze if any candidate provides ENOUGH up-to-date info for the "New Query".
2. If yes, choose the best one.
3. If the query implies needing LATEST real-time info and candidates are old, choose none.
4. Return strictly JSON: {{"reuse": true/false, "index": <candidate_index_int>, "reason": "short explanation"}}
"""
# 初始化模型
provider = os.getenv("LLM_PROVIDER", "ust")
model_id = os.getenv("LLM_MODEL", "Qwen")
host = os.getenv("LLM_HOST")
if host:
model = get_model(provider, model_id, host=host)
else:
model = get_model(provider, model_id)
agent = Agent(model=model, markdown=True)
response = agent.run(prompt)
content = response.content
# Parse JSON
json_match = re.search(r'```json\s*(.*?)\s*```', content, re.DOTALL)
if json_match:
return json.loads(json_match.group(1))
elif '{' in content:
# Fallback for cases where LLM doesn't wrap in ```json
return json.loads(content[content.find('{'):content.rfind('}')+1])
return {"reuse": False}
except Exception as e:
logger.warning(f"LLM evaluation failed: {e}")
return {"reuse": False}
def aggregate_search(self, query: str, engines: Optional[List[str]] = None, max_results: int = 5) -> str:
"""
使用多个搜索引擎同时搜索并聚合结果,获得更全面的信息覆盖。
Args:
query: 搜索关键词。
engines: 要使用的搜索引擎列表。可选值: ["ddg", "baidu"]。
默认同时使用 ddg 和 baidu。
max_results: 每个引擎期望返回的结果数量。
Returns:
聚合后的搜索结果,按引擎分组显示。
"""
engines = engines or ["ddg", "baidu"]
aggregated_results = []
for engine in engines:
res = self.search(query, engine=engine, max_results=max_results)
aggregated_results.append(f"--- Results from {engine.upper()} ---\n{res}")
return "\n\n".join(aggregated_results)

View File

@@ -0,0 +1,231 @@
import os
from typing import Dict, List, Union, Optional
import json
from loguru import logger
from agno.agent import Agent
from .llm.factory import get_model
from .database_manager import DatabaseManager
# 从环境变量读取默认情绪分析模式
DEFAULT_SENTIMENT_MODE = os.getenv("SENTIMENT_MODE", "auto") # auto, bert, llm
class SentimentTools:
"""
情绪分析工具 - 支持 LLM 和 BERT 两种模式
模式说明:
- "auto": 自动选择,优先使用 BERT速度快不可用时回退到 LLM
- "bert": 强制使用 BERT 模型(需要 transformers 库)
- "llm": 强制使用 LLM更准确但较慢
可通过环境变量 SENTIMENT_MODE 设置默认模式。
"""
def __init__(self, db: DatabaseManager, mode: Optional[str] = None,
model_provider: str = "openai", model_id: str = "gpt-4o"):
"""
初始化情绪分析工具。
Args:
db: 数据库管理器实例
mode: 分析模式,可选 "auto", "bert", "llm"。None 则使用环境变量默认值。
model_provider: LLM 提供商,如 "openai", "ust", "deepseek"
model_id: 模型标识符
"""
self.db = db
self.mode = mode or DEFAULT_SENTIMENT_MODE
self.llm_model = None
self.bert_pipeline = None
# Initialize LLM
try:
provider = "ust" if os.getenv("UST_KEY_API") else model_provider
m_id = "Qwen" if provider == "ust" else model_id
self.llm_model = get_model(provider, m_id)
except Exception as e:
logger.warning(f"LLM initialization skipped: {e}")
# Initialize BERT if needed
if self.mode in ["bert", "auto"]:
try:
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
from transformers.utils import logging as transformers_logging
transformers_logging.set_verbosity_error() # 减少冗余日志
bert_model = os.getenv("BERT_SENTIMENT_MODEL", "uer/roberta-base-finetuned-chinanews-chinese")
# 优先使用本地缓存
try:
tokenizer = AutoTokenizer.from_pretrained(bert_model, local_files_only=True)
model = AutoModelForSequenceClassification.from_pretrained(bert_model, local_files_only=True)
self.bert_pipeline = pipeline(
"sentiment-analysis",
model=model,
tokenizer=tokenizer,
device=-1
)
logger.info(f"✅ BERT pipeline loaded from local cache: {bert_model}")
except (OSError, ValueError, ImportError):
# 本地没有,则从网络下载
logger.info(f"📡 Downloading BERT model: {bert_model}...")
tokenizer = AutoTokenizer.from_pretrained(bert_model)
model = AutoModelForSequenceClassification.from_pretrained(bert_model)
self.bert_pipeline = pipeline(
"sentiment-analysis",
model=model,
tokenizer=tokenizer,
device=-1
)
logger.info(f"✅ BERT Sentiment pipeline ({bert_model}) initialized.")
except ImportError:
logger.warning("Transformers library not installed. BERT sentiment analysis disabled.")
except Exception as e:
if self.mode == "bert":
logger.error(f"BERT mode requested but failed: {e}")
else:
logger.warning(f"BERT unavailable, using LLM only. Error: {e}")
self.bert_pipeline = None
def analyze_sentiment(self, text: str) -> Dict[str, Union[float, str]]:
"""
分析文本的情绪极性。根据初始化时的 mode 自动选择分析方法。
Args:
text: 需要分析的文本内容,如新闻标题或摘要。
Returns:
包含以下字段的字典:
- score: 情绪分值,范围 -1.0(极度负面)到 1.0极度正面0.0 为中性
- label: 情绪标签,"positive"/"negative"/"neutral"
- reason: 分析理由(仅 LLM 模式提供详细理由)
"""
if self.mode == "bert" and self.bert_pipeline:
results = self.analyze_sentiment_bert([text])
return results[0] if results else {"score": 0.0, "label": "error"}
elif self.mode == "llm" or (self.mode == "auto" and not self.bert_pipeline):
return self.analyze_sentiment_llm(text)
else:
# auto mode with BERT available
results = self.analyze_sentiment_bert([text])
return results[0] if results else {"score": 0.0, "label": "error"}
def analyze_sentiment_llm(self, text: str) -> Dict[str, Union[float, str]]:
"""
使用 LLM 进行深度情绪分析,可获得详细的分析理由。
Args:
text: 需要分析的文本,最多处理前 1000 字符。
Returns:
包含 score, label, reason 的字典。
"""
if not self.llm_model:
return {"score": 0.0, "label": "neutral", "error": "LLM not initialized"}
analyzer = Agent(model=self.llm_model, markdown=True)
prompt = f"""请分析以下金融/新闻文本的情绪极性。
返回严格的 JSON 格式:
{{"score": <float: -1.0到1.0>, "label": "<positive/negative/neutral>", "reason": "<简短理由>"}}
文本: {text[:1000]}"""
try:
response = analyzer.run(prompt)
content = response.content
if "```json" in content:
content = content.split("```json")[1].split("```")[0].strip()
elif "```" in content:
content = content.split("```")[1].split("```")[0].strip()
return json.loads(content)
except Exception as e:
logger.error(f"LLM sentiment failed: {e}")
return {"score": 0.0, "label": "error", "reason": str(e)}
def analyze_sentiment_bert(self, texts: List[str]) -> List[Dict]:
"""
使用 BERT 进行批量高速情绪分析。
Args:
texts: 需要分析的文本列表。
Returns:
与输入列表等长的分析结果列表。
"""
if not self.bert_pipeline:
return [{"score": 0.0, "label": "error", "reason": "BERT not available"}] * len(texts)
try:
results = self.bert_pipeline(texts, truncation=True, max_length=512)
processed = []
for r in results:
label = r['label'].lower()
score = r['score']
# 标准化不同模型的标签格式
if 'negative' in label or 'neg' in label:
score = -score
elif 'neutral' in label or 'neu' in label:
score = 0.0
processed.append({
"score": float(round(score, 3)),
"label": "positive" if score > 0.1 else ("negative" if score < -0.1 else "neutral"),
"reason": "BERT automated analysis"
})
return processed
except Exception as e:
logger.error(f"BERT analysis failed: {e}")
return [{"score": 0.0, "label": "error", "reason": str(e)}] * len(texts)
def batch_update_news_sentiment(self, source: Optional[str] = None, limit: int = 50, use_bert: Optional[bool] = None):
"""
批量更新数据库中新闻的情绪分数。
Args:
source: 筛选特定新闻源,如 "wallstreetcn"。None 则处理所有来源。
limit: 最多处理的新闻数量。
use_bert: 是否使用 BERT。None 则根据初始化模式自动决定。
Returns:
成功更新的新闻数量。
"""
news_items = self.db.get_daily_news(source=source, limit=limit)
to_analyze = [item for item in news_items if not item.get('sentiment_score')]
if not to_analyze:
return 0
# 决定使用哪种方法
should_use_bert = use_bert if use_bert is not None else (self.bert_pipeline is not None and self.mode != "llm")
updated_count = 0
cursor = self.db.conn.cursor()
if should_use_bert and self.bert_pipeline:
logger.info(f"🚀 Using BERT for batch analysis of {len(to_analyze)} items...")
titles = [item['title'] for item in to_analyze]
results = self.analyze_sentiment_bert(titles)
for item, analysis in zip(to_analyze, results):
cursor.execute("""
UPDATE daily_news
SET sentiment_score = ?, meta_data = json_set(COALESCE(meta_data, '{}'), '$.sentiment_reason', ?)
WHERE id = ?
""", (analysis['score'], analysis['reason'], item['id']))
updated_count += 1
else:
logger.info(f"🚶 Using LLM for analysis of {len(to_analyze)} items...")
for item in to_analyze:
analysis = self.analyze_sentiment_llm(item['title'])
cursor.execute("""
UPDATE daily_news
SET sentiment_score = ?, meta_data = json_set(COALESCE(meta_data, '{}'), '$.sentiment_reason', ?)
WHERE id = ?
""", (analysis.get('score', 0.0), analysis.get('reason', ''), item['id']))
updated_count += 1
self.db.conn.commit()
return updated_count

View File

@@ -0,0 +1,257 @@
from datetime import datetime, timedelta
from typing import List, Dict, Optional
import akshare as ak
import pandas as pd
import re
import sqlite3
from requests.exceptions import RequestException
from loguru import logger
from .database_manager import DatabaseManager
import os
from contextlib import contextmanager
@contextmanager
def temporary_no_proxy():
"""Context manager to temporarily unset proxy environment variables."""
proxies = {k: os.environ.get(k) for k in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY']}
for k in proxies:
if k in os.environ:
del os.environ[k]
try:
yield
finally:
for k, v in proxies.items():
if v is not None:
os.environ[k] = v
class StockTools:
"""金融分析股票工具 - 结合高性能数据库缓存与增量更新"""
def __init__(self, db: DatabaseManager, auto_update: bool = True):
"""
初始化股票工具
Args:
db: 数据库管理器
auto_update: 是否在列表为空时自动更新,默认 True
"""
self.db = db
if auto_update:
self._check_and_update_stock_list()
def _check_and_update_stock_list(self, force: bool = False):
"""检查并更新股票列表。仅在列表为空或 force=True 时从网络拉取。"""
# 直接查询表中记录数
cursor = self.db.conn.cursor()
cursor.execute("SELECT COUNT(*) FROM stock_list")
count = cursor.fetchone()[0]
if count > 0 and not force:
logger.info(f" Stock list already cached ({count} stocks)")
return
logger.info("📡 Updating A-share and HK-share stock list from akshare...")
def fetch_data():
# A-share
df_a = ak.stock_zh_a_spot_em()
df_a = df_a[['代码', '名称']].copy()
df_a.columns = ['code', 'name']
# HK-share
df_hk = ak.stock_hk_spot_em()
df_hk = df_hk[['代码', '名称']].copy()
df_hk.columns = ['code', 'name']
# Combine
return pd.concat([df_a, df_hk], ignore_index=True)
try:
try:
df_combined = fetch_data()
except (RequestException, Exception) as e:
if "Proxy" in str(e) or "proxy" in str(e):
logger.warning(f"⚠️ Proxy error detected: {e}. Retrying with proxy disabled...")
with temporary_no_proxy():
df_combined = fetch_data()
else:
raise e
self.db.save_stock_list(df_combined)
logger.info(f"✅ Cached {len(df_combined)} stocks (A-share + HK) to database.")
except Exception as e:
logger.error(f"❌ Failed to sync stock list: {e}")
def search_ticker(self, query: str, limit: int = 5) -> List[Dict]:
"""
模糊搜索 A 股股票代码或名称,支持常见缩写。
"""
# 清洗后缀 (如 CATL.SZ -> CATL, 000001.SZ -> 000001)
clean_query = re.sub(r'\.(SZ|SH|HK|US)$', '', query, flags=re.IGNORECASE)
# 常见缩写映射
aliases = {
"CATL": "宁德时代",
"BYD": "比亚迪",
"TSLA": "特斯拉",
"Moutai": "贵州茅台",
"Tencent": "腾讯",
"Alibaba": "阿里巴巴",
"Meituan": "美团",
}
search_query = aliases.get(clean_query.upper(), clean_query)
# Robustness: if regex-like ticker code is embedded in query (e.g. "300364 中文在线"), try to extract it
if not search_query.isdigit():
# Extract explicit 5-6 digit codes
match = re.search(r'\b(\d{5,6})\b', clean_query)
if match:
search_query = match.group(1)
return self.db.search_stock(search_query, limit)
def get_stock_price(
self,
ticker: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
force_sync: bool = False,
) -> pd.DataFrame:
"""
获取指定股票的历史价格数据。优先从本地缓存读取,缺失时自动从网络补齐。
Args:
ticker: 股票代码,如 "600519"(贵州茅台)或 "000001"(平安银行)。
start_date: 开始日期,格式 "YYYY-MM-DD"。默认为 90 天前。
end_date: 结束日期,格式 "YYYY-MM-DD"。默认为今天。
Returns:
包含 date, open, close, high, low, volume, change_pct 列的 DataFrame。
"""
now = datetime.now()
if not end_date:
end_date = now.strftime('%Y-%m-%d')
if not start_date:
start_date = (now - timedelta(days=90)).strftime('%Y-%m-%d')
df_db = self.db.get_stock_prices(ticker, start_date, end_date)
need_update = False
if df_db.empty:
need_update = True
else:
db_latest = pd.to_datetime(df_db['date'].max())
req_latest = pd.to_datetime(end_date)
if (req_latest - db_latest).days > 2:
need_update = True
if force_sync:
need_update = True
if need_update:
logger.info(f"📡 Data stale or missing for {ticker}, syncing from network...")
# 清洗 ticker确保只包含数字Akshare A 股接口通常只需要数字代码)
clean_ticker = "".join(filter(str.isdigit, ticker))
if not clean_ticker:
# Non A/H numeric tickers are not supported by the current data source.
logger.warning(f"⚠️ Unsupported ticker format (A/H only): {ticker}")
return df_db
try:
s_fmt = start_date.replace("-", "")
e_fmt = end_date.replace("-", "")
df_remote = None
def fetch_data():
if len(clean_ticker) == 5:
# HK Stock
return ak.stock_hk_hist(
symbol=clean_ticker, period="daily",
start_date=s_fmt, end_date=e_fmt,
adjust="qfq"
)
else:
# A-share Stock
return ak.stock_zh_a_hist(
symbol=clean_ticker, period="daily",
start_date=s_fmt, end_date=e_fmt,
adjust="qfq"
)
try:
df_remote = fetch_data()
except (RequestException, Exception) as e:
if "Proxy" in str(e) or "proxy" in str(e):
logger.warning(f"⚠️ Proxy error detected: {e}. Retrying with proxy disabled...")
with temporary_no_proxy():
df_remote = fetch_data()
else:
raise e
if df_remote is not None and not df_remote.empty:
df_remote = df_remote.rename(columns={
'日期': 'date', '开盘': 'open', '收盘': 'close',
'最高': 'high', '最低': 'low', '成交量': 'volume',
'涨跌幅': 'change_pct'
})
# 确保日期格式正确
df_remote['date'] = pd.to_datetime(df_remote['date']).dt.strftime('%Y-%m-%d')
# 只有在获取到有意义的数据时才保存
self.db.save_stock_prices(clean_ticker, df_remote) # 保存时使用清洗后的 clean_ticker
# 重新查询数据库返回结果,保证一致性
return self.db.get_stock_prices(clean_ticker, start_date, end_date)
else:
logger.warning(f"⚠️ Akshare returned empty data for {clean_ticker}")
except KeyError as e:
# Akshare 有时在某些股票无数据时会抛出 KeyError
logger.warning(f"⚠️ Akshare data missing for {clean_ticker}: {e}")
except (RequestException, ConnectionError) as e:
logger.error(f"❌ Network error during Akshare sync for {clean_ticker}: {e}")
except sqlite3.Error as e:
logger.error(f"❌ Database error during Akshare sync for {clean_ticker}: {e}")
except Exception as e:
logger.error(f"❌ Unexpected error during Akshare sync for {clean_ticker}: {e}")
return df_db
def get_stock_analysis(ticker: str, db: DatabaseManager) -> str:
"""
生成指定股票的分析摘要报告。
Args:
ticker: 股票代码
db: 数据库管理器实例
Returns:
Markdown 格式的分析报告,包含价格走势和关键指标。
"""
tools = StockTools(db)
df = tools.get_stock_price(ticker)
if df.empty:
return f"❌ 未能获取 {ticker} 的股价数据。"
latest = df.iloc[-1]
change = ((latest['close'] - df.iloc[0]['close']) / df.iloc[0]['close']) * 100
report = [
f"## 📊 {ticker} 分析报告",
f"- **查询时段**: {df.iloc[0]['date']} -> {latest['date']}",
f"- **当前价**: ¥{latest['close']:.2f}",
f"- **时段涨跌**: {change:+.2f}%",
f"- **最高/最低**: ¥{df['high'].max():.2f} / ¥{df['low'].min():.2f}",
"\n### 最近交易概览",
"```",
df.tail(5)[['date', 'close', 'change_pct', 'volume']].to_string(index=False),
"```"
]
return "\n".join(report)