Import 9 alphaear finance skills

- alphaear-deepear-lite: DeepEar Lite API integration
- alphaear-logic-visualizer: Draw.io XML finance diagrams
- alphaear-news: Real-time finance news (10+ sources)
- alphaear-predictor: Kronos time-series forecasting
- alphaear-reporter: Professional financial reports
- alphaear-search: Web search + local RAG
- alphaear-sentiment: FinBERT/LLM sentiment analysis
- alphaear-signal-tracker: Signal evolution tracking
- alphaear-stock: A-Share/HK/US stock data

Updates:
- All scripts updated to use universal .env path
- Added JINA_API_KEY, LLM_*, DEEPSEEK_API_KEY to .env.example
- Updated load_dotenv() to use ~/.config/opencode/.env
This commit is contained in:
Kunthawat Greethong
2026-03-27 10:11:37 +07:00
parent 7edf5bc4d0
commit 58f9380ec4
149 changed files with 26867 additions and 0 deletions

View File

@@ -0,0 +1,137 @@
import os
import sys
import torch
import pandas as pd
import numpy as np
import glob
from loguru import logger
from datetime import datetime, timedelta
# Setup paths
KRONOS_DIR = os.path.dirname(os.path.abspath(__file__))
SRC_DIR = os.path.dirname(os.path.dirname(KRONOS_DIR))
if SRC_DIR not in sys.path:
sys.path.insert(0, SRC_DIR)
from ..kronos.auto_synthesis_training import AutoSynthesisTrainer
from ..kronos.model import KronosPredictor
from ..visualizer import VisualizerTools
from ..schema.models import ForecastResult, KLinePoint
class NewsModelEvaluator:
def __init__(self, model_path=None):
self.trainer = AutoSynthesisTrainer()
self.device = self.trainer.device
if model_path is None:
# Try to find the latest model in exports/models
model_files = glob.glob(os.path.join(SRC_DIR, "exports/models/*.pt"))
if not model_files:
logger.warning("⚠️ No trained models found in exports/models/. Using base model (zero-init proj).")
else:
model_path = max(model_files, key=os.path.getctime)
if model_path:
self.load_weights(model_path)
def load_weights(self, path):
logger.info(f"🔄 Loading model weights from {path}...")
checkpoint = torch.load(path, map_location=self.device)
self.trainer.model.news_proj.load_state_dict(checkpoint['news_proj_state_dict'])
logger.success("✅ News projection layer loaded.")
def evaluate_range(self, start_idx=100, end_idx=200, pred_len=5):
# 1. Fetch Tickers
res = self.trainer.db.execute_query("SELECT code FROM stock_list")
all_tickers = [row['code'] for row in res]
test_tickers = all_tickers[start_idx:end_idx]
if not test_tickers:
logger.error(f"No tickers found in range {start_idx}-{end_idx}")
return
logger.info(f"🚀 Evaluating News Model on stocks {start_idx} to {end_idx}...")
# 2. Discover Shocks
shocks = self.trainer.discover_shocks(test_tickers, pred_len=pred_len)
# 3. Associate News & Predict
self.trainer.model.eval()
predictor = KronosPredictor(self.trainer.model, self.trainer.tokenizer, device=self.device)
save_dir = os.path.join(SRC_DIR, "exports/evaluation_results")
os.makedirs(save_dir, exist_ok=True)
count = 0
for shock in shocks:
summary = self.trainer.find_reason_and_verify(shock)
if not summary:
continue
logger.info(f"📈 Testing shock: {shock['ticker']} on {shock['date']}")
# Embedding news
news_emb = self.trainer.embedder.encode(summary)
# Prediction
h = shock['history']
t = shock['target']
actuals = t['close'].values[:pred_len]
x_ts = pd.to_datetime(h['date'])
future_dates = pd.date_range(start=x_ts.iloc[-1] + timedelta(days=1), periods=pred_len, freq='B')
y_ts = pd.Series(future_dates)
# A. Base Prediction (No news)
p_base = predictor.predict(h, x_ts, y_ts, pred_len=pred_len, news_emb=None, verbose=False)
# B. News-Aware Prediction
p_news = predictor.predict(h, x_ts, y_ts, pred_len=pred_len, news_emb=news_emb, verbose=False)
# Calculate Improvement
b_preds = p_base['close'].values[:len(actuals)]
n_preds = p_news['close'].values[:len(actuals)]
b_mae = np.mean(np.abs(b_preds - actuals))
n_mae = np.mean(np.abs(n_preds - actuals))
improvement = (b_mae - n_mae) / (b_mae + 1e-6) * 100
# C. Visualize
try:
def to_kp_list(preds_df):
points = []
for idx, row in preds_df.iterrows():
points.append(KLinePoint(
date=str(idx)[:10], open=row['open'], high=row['high'],
low=row['low'], close=row['close'], volume=row.get('volume', 0)
))
return points
forecast_obj = ForecastResult(
ticker=shock['ticker'],
base_forecast=to_kp_list(p_base),
adjusted_forecast=to_kp_list(p_news),
rationale=summary
)
chart = VisualizerTools.generate_stock_chart(
df=h, ticker=shock['ticker'],
title=f"Test Eval: {shock['ticker']} ({shock['date']}) Imp: {improvement:.1f}%",
forecast=forecast_obj,
ground_truth=t[['date', 'open', 'high', 'low', 'close', 'volume']]
)
safe_date = shock['date'].replace("-", "")
filename = f"test_{shock['ticker']}_{safe_date}.html"
VisualizerTools.render_chart_to_file(chart, os.path.join(save_dir, filename))
logger.success(f"📊 Result for {shock['ticker']} saved. Base MAE: {b_mae:.4f}, News MAE: {n_mae:.4f}")
count += 1
except Exception as e:
logger.error(f"Visualization failed: {e}")
logger.info(f"🏁 Finished evaluation. {count} cases visualized in {save_dir}")
if __name__ == "__main__":
# If you have a specific model, pass the path here. Otherwise it picks the latest.
evaluator = NewsModelEvaluator()
evaluator.evaluate_range(start_idx=100, end_idx=200, pred_len=1)

View File

@@ -0,0 +1,196 @@
# Ref: https://github.com/shiyu-coder/Kronos
from model import Kronos, KronosTokenizer, KronosPredictor
import pandas as pd
import sqlite3
import torch
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from pandas.tseries.offsets import BusinessDay
import numpy as np
def get_device():
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")
return device
def load_predictor():
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
device = get_device()
tokenizer = tokenizer.to(device)
model = model.to(device)
return KronosPredictor(model, tokenizer, device=device, max_context=512)
def load_data(ticker="002111", db_path="AlphaEar/data/signal_flux.db"):
with sqlite3.connect(db_path) as conn:
df = pd.read_sql_query(f"SELECT * FROM stock_prices WHERE ticker = '{ticker}'", conn)
df['date'] = pd.to_datetime(df['date'])
df = df.sort_values('date').reset_index(drop=True)
return df
def plot_kline_matplotlib(ax, ax_vol, dates, df, label_suffix="", color_up='#ef4444', color_down='#22c55e', alpha=1.0, is_prediction=False):
"""
绘制 K 线图和成交量
"""
# X axis mapping to integers for consistent spacing
x = np.arange(len(dates))
# K-line data
opens = df['open'].values
closes = df['close'].values
highs = df['high'].values
lows = df['low'].values
volumes = df['volume'].values
# Width of the candlestick
width = 0.6
for i in range(len(x)):
color = color_up if closes[i] >= opens[i] else color_down
linestyle = '--' if is_prediction else '-'
# Wick
ax.vlines(x[i], lows[i], highs[i], color=color, linewidth=1, alpha=alpha, linestyle=linestyle)
# Body
rect_bottom = min(opens[i], closes[i])
rect_height = abs(opens[i] - closes[i])
if rect_height == 0: rect_height = 0.001 # Visual hair
ax.add_patch(plt.Rectangle((x[i] - width/2, rect_bottom), width, rect_height,
edgecolor=color, facecolor=color if not is_prediction else 'none',
alpha=alpha, linewidth=1, linestyle=linestyle))
# Volume
ax_vol.bar(x[i], volumes[i], color=color, alpha=alpha * 0.5, width=width)
def render_comparison_chart(history_df, actual_df, pred_df, title):
"""
渲染组合图:历史 K 线 + 真值 K 线 + 预测 K 线
"""
# Combine all dates for X axis
all_dates = pd.concat([history_df['date'], actual_df['date'] if actual_df is not None else pred_df.index.to_series()]).unique()
all_dates = sorted(all_dates)
date_to_idx = {date: i for i, date in enumerate(all_dates)}
fig = plt.figure(figsize=(14, 8), facecolor='white')
gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1], hspace=0.1)
ax_main = fig.add_subplot(gs[0])
ax_vol = fig.add_subplot(gs[1], sharex=ax_main)
# 1. Plot History
hist_indices = [date_to_idx[d] for d in history_df['date']]
# We use a custom x for plotting to ensure continuity
plot_kline_matplotlib(ax_main, ax_vol, history_df['date'], history_df, alpha=0.8)
offset = len(history_df)
# 2. Plot Actual if exists
if actual_df is not None:
# Shift indices
actual_x = np.arange(len(actual_df)) + offset
# Plotting manually to handle offset
for i in range(len(actual_df)):
idx = actual_x[i]
row = actual_df.iloc[i]
color = '#ef4444' if row['close'] >= row['open'] else '#22c55e'
ax_main.vlines(idx, row['low'], row['high'], color=color, linewidth=1, alpha=0.9)
ax_main.add_patch(plt.Rectangle((idx - 0.3, min(row['open'], row['close'])), 0.6, abs(row['open']-row['close']),
edgecolor=color, facecolor=color, alpha=0.9))
ax_vol.bar(idx, row['volume'], color=color, alpha=0.4)
# 3. Plot Prediction
pred_x = np.arange(len(pred_df)) + offset
for i in range(len(pred_df)):
idx = pred_x[i]
row = pred_df.iloc[i]
color = '#ff8c00' # Orange for prediction to distinguish
ax_main.vlines(idx, row['low'], row['high'], color=color, linewidth=1.5, linestyle='--')
ax_main.add_patch(plt.Rectangle((idx - 0.3, min(row['open'], row['close'])), 0.6, abs(row['open']-row['close']),
edgecolor=color, facecolor='none', linewidth=1.5, linestyle='--'))
# Plot secondary prediction line for close
if i == 0:
# Connect to history
ax_main.plot([offset-1, idx], [history_df['close'].iloc[-1], row['close']], color=color, linestyle='--', alpha=0.6)
elif i > 0:
ax_main.plot([idx-1, idx], [pred_df['close'].iloc[i-1], row['close']], color=color, linestyle='--', alpha=0.6)
# Styling
ax_main.set_title(title, fontsize=14, fontweight='bold')
ax_main.grid(True, linestyle=':', alpha=0.6)
ax_vol.grid(True, linestyle=':', alpha=0.6)
ax_vol.set_ylabel('Volume')
ax_main.set_ylabel('Price')
# Set X ticks
step = max(1, len(all_dates) // 10)
ax_vol.set_xticks(np.arange(0, len(all_dates), step))
ax_vol.set_xticklabels([all_dates[i].strftime('%Y-%m-%d') for i in range(0, len(all_dates), step)], rotation=45)
plt.tight_layout()
plt.show()
plt.close()
def run_backtest(df, predictor, lookback, pred_len, start_index=0):
total_len = len(df)
history_start = start_index
history_end = start_index + lookback
pred_start = history_end
available_pred_len = total_len - pred_start
if available_pred_len <= 0: return
actual_pred_len = min(pred_len, available_pred_len)
pred_end = pred_start + actual_pred_len
x_df = df.iloc[history_start : history_end].copy()
y_true_df = df.iloc[pred_start : pred_end].copy()
y_timestamp = y_true_df['date']
print(f"Backtesting: {x_df['date'].iloc[0].date()} to {y_timestamp.iloc[-1].date()}")
pred_df = predictor.predict(
df=x_df[['open', 'high', 'low', 'close', 'volume']],
x_timestamp=x_df['date'],
y_timestamp=y_timestamp,
pred_len=actual_pred_len,
T=1.0, top_p=0.9, sample_count=1
)
render_comparison_chart(x_df, y_true_df, pred_df, f"Backtest: {TICKER} K-Line Comparison")
def run_forecast(df, predictor, lookback, pred_len):
if len(df) < lookback: return
x_df = df.iloc[-lookback:].copy()
last_date = x_df['date'].iloc[-1]
future_dates = pd.date_range(start=last_date + BusinessDay(1), periods=pred_len, freq='B')
future_dates = pd.Series(future_dates)
print(f"Forecasting: Starting from {future_dates.iloc[0].date()}")
pred_df = predictor.predict(
df=x_df[['open', 'high', 'low', 'close', 'volume']],
x_timestamp=x_df['date'],
y_timestamp=future_dates,
pred_len=pred_len,
T=1.0, top_p=0.9, sample_count=1
)
render_comparison_chart(x_df, None, pred_df, f"Forecast: {TICKER} Future K-Line")
if __name__ == "__main__":
LOOKBACK = 20
PRED_LEN = 10
TICKER = '002111'
pred_model = load_predictor()
stock_data = load_data(TICKER)
total_rows = len(stock_data)
backtest_start = max(0, total_rows - LOOKBACK - PRED_LEN - 10) # Leave some space to see trend
print("\n--- Running Backtest ---")
run_backtest(stock_data, pred_model, LOOKBACK, PRED_LEN, start_index=backtest_start)
print("\n--- Running Forecast ---")
run_forecast(stock_data, pred_model, LOOKBACK, PRED_LEN)

View File

@@ -0,0 +1,16 @@
from .kronos import KronosTokenizer, Kronos, KronosPredictor
model_dict = {
'kronos_tokenizer': KronosTokenizer,
'kronos': Kronos,
'kronos_predictor': KronosPredictor
}
def get_model_class(model_name):
if model_name in model_dict:
return model_dict[model_name]
else:
print(f"Model {model_name} not found in model_dict")
raise NotImplementedError

View File

@@ -0,0 +1,676 @@
import numpy as np
import pandas as pd
import torch
from huggingface_hub import PyTorchModelHubMixin
import sys
from tqdm import trange
sys.path.append("../")
from model.module import *
class KronosTokenizer(nn.Module, PyTorchModelHubMixin):
"""
KronosTokenizer module for tokenizing input data using a hybrid quantization approach.
This tokenizer utilizes a combination of encoder and decoder Transformer blocks
along with the Binary Spherical Quantization (BSQuantizer) to compress and decompress input data.
Args:
d_in (int): Input dimension.
d_model (int): Model dimension.
n_heads (int): Number of attention heads.
ff_dim (int): Feed-forward dimension.
n_enc_layers (int): Number of encoder layers.
n_dec_layers (int): Number of decoder layers.
ffn_dropout_p (float): Dropout probability for feed-forward networks.
attn_dropout_p (float): Dropout probability for attention mechanisms.
resid_dropout_p (float): Dropout probability for residual connections.
s1_bits (int): Number of bits for the pre token in BSQuantizer.
s2_bits (int): Number of bits for the post token in BSQuantizer.
beta (float): Beta parameter for BSQuantizer.
gamma0 (float): Gamma0 parameter for BSQuantizer.
gamma (float): Gamma parameter for BSQuantizer.
zeta (float): Zeta parameter for BSQuantizer.
group_size (int): Group size parameter for BSQuantizer.
"""
def __init__(self, d_in, d_model, n_heads, ff_dim, n_enc_layers, n_dec_layers, ffn_dropout_p, attn_dropout_p, resid_dropout_p, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size):
super().__init__()
self.d_in = d_in
self.d_model = d_model
self.n_heads = n_heads
self.ff_dim = ff_dim
self.enc_layers = n_enc_layers
self.dec_layers = n_dec_layers
self.ffn_dropout_p = ffn_dropout_p
self.attn_dropout_p = attn_dropout_p
self.resid_dropout_p = resid_dropout_p
self.s1_bits = s1_bits
self.s2_bits = s2_bits
self.codebook_dim = s1_bits + s2_bits # Total dimension of the codebook after quantization
self.embed = nn.Linear(self.d_in, self.d_model)
self.head = nn.Linear(self.d_model, self.d_in)
# Encoder Transformer Blocks
self.encoder = nn.ModuleList([
TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p)
for _ in range(self.enc_layers - 1)
])
# Decoder Transformer Blocks
self.decoder = nn.ModuleList([
TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p)
for _ in range(self.dec_layers - 1)
])
self.quant_embed = nn.Linear(in_features=self.d_model, out_features=self.codebook_dim) # Linear layer before quantization
self.post_quant_embed_pre = nn.Linear(in_features=self.s1_bits, out_features=self.d_model) # Linear layer after quantization (pre part - s1 bits)
self.post_quant_embed = nn.Linear(in_features=self.codebook_dim, out_features=self.d_model) # Linear layer after quantization (full codebook)
self.tokenizer = BSQuantizer(self.s1_bits, self.s2_bits, beta, gamma0, gamma, zeta, group_size) # BSQuantizer module
def forward(self, x):
"""
Forward pass of the KronosTokenizer.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_in).
Returns:
tuple: A tuple containing:
- tuple: (z_pre, z) - Reconstructed outputs from decoder with s1_bits and full codebook respectively,
both of shape (batch_size, seq_len, d_in).
- torch.Tensor: bsq_loss - Loss from the BSQuantizer.
- torch.Tensor: quantized - Quantized representation from BSQuantizer.
- torch.Tensor: z_indices - Indices from the BSQuantizer.
"""
z = self.embed(x)
for layer in self.encoder:
z = layer(z)
z = self.quant_embed(z) # (B, T, codebook)
bsq_loss, quantized, z_indices = self.tokenizer(z)
quantized_pre = quantized[:, :, :self.s1_bits] # Extract the first part of quantized representation (s1_bits)
z_pre = self.post_quant_embed_pre(quantized_pre)
z = self.post_quant_embed(quantized)
# Decoder layers (for pre part - s1 bits)
for layer in self.decoder:
z_pre = layer(z_pre)
z_pre = self.head(z_pre)
# Decoder layers (for full codebook)
for layer in self.decoder:
z = layer(z)
z = self.head(z)
return (z_pre, z), bsq_loss, quantized, z_indices
def indices_to_bits(self, x, half=False):
"""
Converts indices to bit representations and scales them.
Args:
x (torch.Tensor): Indices tensor.
half (bool, optional): Whether to process only half of the codebook dimension. Defaults to False.
Returns:
torch.Tensor: Bit representation tensor.
"""
if half:
x1 = x[0] # Assuming x is a tuple of indices if half is True
x2 = x[1]
mask = 2 ** torch.arange(self.codebook_dim//2, device=x1.device, dtype=torch.long) # Create a mask for bit extraction
x1 = (x1.unsqueeze(-1) & mask) != 0 # Extract bits for the first half
x2 = (x2.unsqueeze(-1) & mask) != 0 # Extract bits for the second half
x = torch.cat([x1, x2], dim=-1) # Concatenate the bit representations
else:
mask = 2 ** torch.arange(self.codebook_dim, device=x.device, dtype=torch.long) # Create a mask for bit extraction
x = (x.unsqueeze(-1) & mask) != 0 # Extract bits
x = x.float() * 2 - 1 # Convert boolean to bipolar (-1, 1)
q_scale = 1. / (self.codebook_dim ** 0.5) # Scaling factor
x = x * q_scale
return x
def encode(self, x, half=False):
"""
Encodes the input data into quantized indices.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_in).
half (bool, optional): Whether to use half quantization in BSQuantizer. Defaults to False.
Returns:
torch.Tensor: Quantized indices from BSQuantizer.
"""
z = self.embed(x)
for layer in self.encoder:
z = layer(z)
z = self.quant_embed(z)
bsq_loss, quantized, z_indices = self.tokenizer(z, half=half, collect_metrics=False)
return z_indices
def decode(self, x, half=False):
"""
Decodes quantized indices back to the input data space.
Args:
x (torch.Tensor): Quantized indices tensor.
half (bool, optional): Whether the indices were generated with half quantization. Defaults to False.
Returns:
torch.Tensor: Reconstructed output tensor of shape (batch_size, seq_len, d_in).
"""
quantized = self.indices_to_bits(x, half)
z = self.post_quant_embed(quantized)
for layer in self.decoder:
z = layer(z)
z = self.head(z)
return z
class Kronos(nn.Module, PyTorchModelHubMixin):
"""
Kronos Model.
Args:
s1_bits (int): Number of bits for pre tokens.
s2_bits (int): Number of bits for post tokens.
n_layers (int): Number of Transformer blocks.
d_model (int): Dimension of the model's embeddings and hidden states.
n_heads (int): Number of attention heads in the MultiheadAttention layers.
ff_dim (int): Dimension of the feedforward network in the Transformer blocks.
ffn_dropout_p (float): Dropout probability for the feedforward network.
attn_dropout_p (float): Dropout probability for the attention layers.
resid_dropout_p (float): Dropout probability for residual connections.
token_dropout_p (float): Dropout probability for token embeddings.
learn_te (bool): Whether to use learnable temporal embeddings.
"""
def __init__(self, s1_bits, s2_bits, n_layers, d_model, n_heads, ff_dim, ffn_dropout_p, attn_dropout_p, resid_dropout_p, token_dropout_p, learn_te, news_dim=None):
super().__init__()
self.s1_bits = s1_bits
self.s2_bits = s2_bits
self.n_layers = n_layers
self.d_model = d_model
self.n_heads = n_heads
self.learn_te = learn_te
self.ff_dim = ff_dim
self.ffn_dropout_p = ffn_dropout_p
self.attn_dropout_p = attn_dropout_p
self.resid_dropout_p = resid_dropout_p
self.token_dropout_p = token_dropout_p
self.news_dim = news_dim
self.s1_vocab_size = 2 ** self.s1_bits
self.token_drop = nn.Dropout(self.token_dropout_p)
self.embedding = HierarchicalEmbedding(self.s1_bits, self.s2_bits, self.d_model)
self.time_emb = TemporalEmbedding(self.d_model, self.learn_te)
self.transformer = nn.ModuleList([
TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p)
for _ in range(self.n_layers)
])
self.norm = RMSNorm(self.d_model)
self.dep_layer = DependencyAwareLayer(self.d_model)
self.head = DualHead(self.s1_bits, self.s2_bits, self.d_model)
if self.news_dim is not None:
self.news_proj = nn.Linear(self.news_dim, self.d_model)
else:
self.news_proj = None
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.xavier_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0, std=self.embedding.d_model ** -0.5)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, RMSNorm):
nn.init.ones_(module.weight)
def forward(self, s1_ids, s2_ids, stamp=None, padding_mask=None, use_teacher_forcing=False, s1_targets=None, news_emb=None):
"""
Args:
s1_ids (torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len]
s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len]
stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None.
padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None.
use_teacher_forcing (bool, optional): Whether to use teacher forcing for s1 decoding. Defaults to False.
s1_targets (torch.Tensor, optional): Target s1 token IDs for teacher forcing. Shape: [batch_size, seq_len]. Defaults to None.
news_emb (torch.Tensor, optional): News embedding tensor. Shape: [batch_size, news_dim]. Defaults to None.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
- s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size]
- s2_logits: Logits for s2 token predictions, conditioned on s1. Shape: [batch_size, seq_len, s2_vocab_size]
"""
x = self.embedding([s1_ids, s2_ids])
if stamp is not None:
time_embedding = self.time_emb(stamp)
x = x + time_embedding
x = self.token_drop(x)
for layer in self.transformer:
x = layer(x, key_padding_mask=padding_mask)
x = self.norm(x)
if news_emb is not None and self.news_proj is not None:
news_bias = self.news_proj(news_emb).unsqueeze(1) # [B, 1, d_model]
x = x + news_bias
s1_logits = self.head(x)
if use_teacher_forcing:
sibling_embed = self.embedding.emb_s1(s1_targets)
else:
s1_probs = F.softmax(s1_logits.detach(), dim=-1)
sample_s1_ids = torch.multinomial(s1_probs.view(-1, self.s1_vocab_size), 1).view(s1_ids.shape)
sibling_embed = self.embedding.emb_s1(sample_s1_ids)
x2 = self.dep_layer(x, sibling_embed, key_padding_mask=padding_mask) # Dependency Aware Layer: Condition on s1 embeddings
s2_logits = self.head.cond_forward(x2)
return s1_logits, s2_logits
def decode_s1(self, s1_ids, s2_ids, stamp=None, padding_mask=None, news_emb=None):
"""
Decodes only the s1 tokens.
This method performs a forward pass to predict only s1 tokens. It returns the s1 logits
and the context representation from the Transformer, which can be used for subsequent s2 decoding.
Args:
s1_ids (torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len]
s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len]
stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None.
padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None.
news_emb (torch.Tensor, optional): News embedding tensor. Shape: [batch_size, news_dim]. Defaults to None.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
- s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size]
- context: Context representation from the Transformer. Shape: [batch_size, seq_len, d_model]
"""
x = self.embedding([s1_ids, s2_ids])
if stamp is not None:
time_embedding = self.time_emb(stamp)
x = x + time_embedding
x = self.token_drop(x)
for layer in self.transformer:
x = layer(x, key_padding_mask=padding_mask)
x = self.norm(x)
if news_emb is not None and self.news_proj is not None:
news_bias = self.news_proj(news_emb).unsqueeze(1) # [B, 1, d_model]
x = x + news_bias
s1_logits = self.head(x)
return s1_logits, x
def decode_s2(self, context, s1_ids, padding_mask=None):
"""
Decodes the s2 tokens, conditioned on the context and s1 tokens.
This method decodes s2 tokens based on a pre-computed context representation (typically from `decode_s1`)
and the s1 token IDs. It uses the dependency-aware layer and the conditional s2 head to predict s2 tokens.
Args:
context (torch.Tensor): Context representation from the transformer (output of decode_s1).
Shape: [batch_size, seq_len, d_model]
s1_ids (torch.torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len]
padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None.
Returns:
torch.Tensor: s2 logits. Shape: [batch_size, seq_len, s2_vocab_size]
"""
sibling_embed = self.embedding.emb_s1(s1_ids)
x2 = self.dep_layer(context, sibling_embed, key_padding_mask=padding_mask)
return self.head.cond_forward(x2)
def top_k_top_p_filtering(
logits,
top_k: int = 0,
top_p: float = 1.0,
filter_value: float = -float("Inf"),
min_tokens_to_keep: int = 1,
):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens_to_keep per batch example in the output
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
return logits
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None, sample_logits=True):
logits = logits / temperature
if top_k is not None or top_p is not None:
if top_k > 0 or top_p < 1.0:
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
probs = F.softmax(logits, dim=-1)
if not sample_logits:
_, x = top_k(probs, k=1, dim=-1)
else:
x = torch.multinomial(probs, num_samples=1)
return x
def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context, pred_len, clip=5, T=1.0, top_k=0, top_p=0.99, sample_count=5, verbose=False, news_emb=None):
with torch.no_grad():
x = torch.clip(x, -clip, clip)
device = x.device
x = x.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, x.size(1), x.size(2)).to(device)
x_stamp = x_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, x_stamp.size(1), x_stamp.size(2)).to(device)
y_stamp = y_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, y_stamp.size(1), y_stamp.size(2)).to(device)
x_token = tokenizer.encode(x, half=True)
initial_seq_len = x.size(1)
batch_size = x_token[0].size(0)
total_seq_len = initial_seq_len + pred_len
full_stamp = torch.cat([x_stamp, y_stamp], dim=1)
generated_pre = x_token[0].new_empty(batch_size, pred_len)
generated_post = x_token[1].new_empty(batch_size, pred_len)
pre_buffer = x_token[0].new_zeros(batch_size, max_context)
post_buffer = x_token[1].new_zeros(batch_size, max_context)
buffer_len = min(initial_seq_len, max_context)
if buffer_len > 0:
start_idx = max(0, initial_seq_len - max_context)
pre_buffer[:, :buffer_len] = x_token[0][:, start_idx:start_idx + buffer_len]
post_buffer[:, :buffer_len] = x_token[1][:, start_idx:start_idx + buffer_len]
if verbose:
ran = trange
else:
ran = range
for i in ran(pred_len):
current_seq_len = initial_seq_len + i
window_len = min(current_seq_len, max_context)
if current_seq_len <= max_context:
input_tokens = [
pre_buffer[:, :window_len],
post_buffer[:, :window_len]
]
else:
input_tokens = [pre_buffer, post_buffer]
context_end = current_seq_len
context_start = max(0, context_end - max_context)
current_stamp = full_stamp[:, context_start:context_end, :].contiguous()
s1_logits, context = model.decode_s1(input_tokens[0], input_tokens[1], current_stamp, news_emb=news_emb)
s1_logits = s1_logits[:, -1, :]
sample_pre = sample_from_logits(s1_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True)
s2_logits = model.decode_s2(context, sample_pre)
s2_logits = s2_logits[:, -1, :]
sample_post = sample_from_logits(s2_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True)
generated_pre[:, i] = sample_pre.squeeze(-1)
generated_post[:, i] = sample_post.squeeze(-1)
if current_seq_len < max_context:
pre_buffer[:, current_seq_len] = sample_pre.squeeze(-1)
post_buffer[:, current_seq_len] = sample_post.squeeze(-1)
else:
pre_buffer.copy_(torch.roll(pre_buffer, shifts=-1, dims=1))
post_buffer.copy_(torch.roll(post_buffer, shifts=-1, dims=1))
pre_buffer[:, -1] = sample_pre.squeeze(-1)
post_buffer[:, -1] = sample_post.squeeze(-1)
full_pre = torch.cat([x_token[0], generated_pre], dim=1)
full_post = torch.cat([x_token[1], generated_post], dim=1)
context_start = max(0, total_seq_len - max_context)
input_tokens = [
full_pre[:, context_start:total_seq_len].contiguous(),
full_post[:, context_start:total_seq_len].contiguous()
]
z = tokenizer.decode(input_tokens, half=True)
z = z.reshape(-1, sample_count, z.size(1), z.size(2))
preds = z.cpu().numpy()
preds = np.mean(preds, axis=1)
return preds
def calc_time_stamps(x_timestamp):
time_df = pd.DataFrame()
time_df['minute'] = x_timestamp.dt.minute
time_df['hour'] = x_timestamp.dt.hour
time_df['weekday'] = x_timestamp.dt.weekday
time_df['day'] = x_timestamp.dt.day
time_df['month'] = x_timestamp.dt.month
return time_df
class KronosPredictor:
def __init__(self, model, tokenizer, device="cuda:0", max_context=512, clip=5):
self.tokenizer = tokenizer
self.model = model
self.max_context = max_context
self.clip = clip
self.price_cols = ['open', 'high', 'low', 'close']
self.vol_col = 'volume'
self.amt_vol = 'amount'
self.time_cols = ['minute', 'hour', 'weekday', 'day', 'month']
self.device = device
self.tokenizer = self.tokenizer.to(self.device)
self.model = self.model.to(self.device)
def generate(self, x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose, news_emb=None):
x_tensor = torch.from_numpy(np.array(x).astype(np.float32)).to(self.device)
x_stamp_tensor = torch.from_numpy(np.array(x_stamp).astype(np.float32)).to(self.device)
y_stamp_tensor = torch.from_numpy(np.array(y_stamp).astype(np.float32)).to(self.device)
preds = auto_regressive_inference(self.tokenizer, self.model, x_tensor, x_stamp_tensor, y_stamp_tensor, self.max_context, pred_len,
self.clip, T, top_k, top_p, sample_count, verbose, news_emb=news_emb)
preds = preds[:, -pred_len:, :]
return preds
def predict(self, df, x_timestamp, y_timestamp, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True, news_emb=None):
if not isinstance(df, pd.DataFrame):
raise ValueError("Input must be a pandas DataFrame.")
if not all(col in df.columns for col in self.price_cols):
raise ValueError(f"Price columns {self.price_cols} not found in DataFrame.")
df = df.copy()
if self.vol_col not in df.columns:
df[self.vol_col] = 0.0 # Fill missing volume with zeros
df[self.amt_vol] = 0.0 # Fill missing amount with zeros
if self.amt_vol not in df.columns and self.vol_col in df.columns:
df[self.amt_vol] = df[self.vol_col] * df[self.price_cols].mean(axis=1)
if df[self.price_cols + [self.vol_col, self.amt_vol]].isnull().values.any():
raise ValueError("Input DataFrame contains NaN values in price or volume columns.")
x_time_df = calc_time_stamps(x_timestamp)
y_time_df = calc_time_stamps(y_timestamp)
x = df[self.price_cols + [self.vol_col, self.amt_vol]].values.astype(np.float32)
x_stamp = x_time_df.values.astype(np.float32)
y_stamp = y_time_df.values.astype(np.float32)
x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
x = (x - x_mean) / (x_std + 1e-5)
x = np.clip(x, -self.clip, self.clip)
x = x[np.newaxis, :]
x_stamp = x_stamp[np.newaxis, :]
y_stamp = y_stamp[np.newaxis, :]
if news_emb is not None:
news_emb_tensor = torch.from_numpy(np.array(news_emb).astype(np.float32)).to(self.device)
# Ensure batch dimension for news_emb if only one sample
if news_emb_tensor.ndim == 1:
news_emb_tensor = news_emb_tensor.unsqueeze(0)
else:
news_emb_tensor = None
preds = self.generate(x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose, news_emb=news_emb_tensor)
preds = preds.squeeze(0)
preds = preds * (x_std + 1e-5) + x_mean
pred_df = pd.DataFrame(preds, columns=self.price_cols + [self.vol_col, self.amt_vol], index=y_timestamp)
return pred_df
def predict_batch(self, df_list, x_timestamp_list, y_timestamp_list, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True):
"""
Perform parallel (batch) prediction on multiple time series. All series must have the same historical length and prediction length (pred_len).
Args:
df_list (List[pd.DataFrame]): List of input DataFrames, each containing price columns and optional volume/amount columns.
x_timestamp_list (List[pd.DatetimeIndex or Series]): List of timestamps corresponding to historical data, length should match the number of rows in each DataFrame.
y_timestamp_list (List[pd.DatetimeIndex or Series]): List of future prediction timestamps, length should equal pred_len.
pred_len (int): Number of prediction steps.
T (float): Sampling temperature.
top_k (int): Top-k filtering threshold.
top_p (float): Top-p (nucleus sampling) threshold.
sample_count (int): Number of parallel samples per series, automatically averaged internally.
verbose (bool): Whether to display autoregressive progress.
Returns:
List[pd.DataFrame]: List of prediction results in the same order as input, each DataFrame contains
`open, high, low, close, volume, amount` columns, indexed by corresponding `y_timestamp`.
"""
# Basic validation
if not isinstance(df_list, (list, tuple)) or not isinstance(x_timestamp_list, (list, tuple)) or not isinstance(y_timestamp_list, (list, tuple)):
raise ValueError("df_list, x_timestamp_list, y_timestamp_list must be list or tuple types.")
if not (len(df_list) == len(x_timestamp_list) == len(y_timestamp_list)):
raise ValueError("df_list, x_timestamp_list, y_timestamp_list must have consistent lengths.")
num_series = len(df_list)
x_list = []
x_stamp_list = []
y_stamp_list = []
means = []
stds = []
seq_lens = []
y_lens = []
for i in range(num_series):
df = df_list[i]
if not isinstance(df, pd.DataFrame):
raise ValueError(f"Input at index {i} is not a pandas DataFrame.")
if not all(col in df.columns for col in self.price_cols):
raise ValueError(f"DataFrame at index {i} is missing price columns {self.price_cols}.")
df = df.copy()
if self.vol_col not in df.columns:
df[self.vol_col] = 0.0
df[self.amt_vol] = 0.0
if self.amt_vol not in df.columns and self.vol_col in df.columns:
df[self.amt_vol] = df[self.vol_col] * df[self.price_cols].mean(axis=1)
if df[self.price_cols + [self.vol_col, self.amt_vol]].isnull().values.any():
raise ValueError(f"DataFrame at index {i} contains NaN values in price or volume columns.")
x_timestamp = x_timestamp_list[i]
y_timestamp = y_timestamp_list[i]
x_time_df = calc_time_stamps(x_timestamp)
y_time_df = calc_time_stamps(y_timestamp)
x = df[self.price_cols + [self.vol_col, self.amt_vol]].values.astype(np.float32)
x_stamp = x_time_df.values.astype(np.float32)
y_stamp = y_time_df.values.astype(np.float32)
if x.shape[0] != x_stamp.shape[0]:
raise ValueError(f"Inconsistent lengths at index {i}: x has {x.shape[0]} vs x_stamp has {x_stamp.shape[0]}.")
if y_stamp.shape[0] != pred_len:
raise ValueError(f"y_timestamp length at index {i} should equal pred_len={pred_len}, got {y_stamp.shape[0]}.")
x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
x_norm = (x - x_mean) / (x_std + 1e-5)
x_norm = np.clip(x_norm, -self.clip, self.clip)
x_list.append(x_norm)
x_stamp_list.append(x_stamp)
y_stamp_list.append(y_stamp)
means.append(x_mean)
stds.append(x_std)
seq_lens.append(x_norm.shape[0])
y_lens.append(y_stamp.shape[0])
# Require all series to have consistent historical and prediction lengths for batch processing
if len(set(seq_lens)) != 1:
raise ValueError(f"Parallel prediction requires all series to have consistent historical lengths, got: {seq_lens}")
if len(set(y_lens)) != 1:
raise ValueError(f"Parallel prediction requires all series to have consistent prediction lengths, got: {y_lens}")
x_batch = np.stack(x_list, axis=0).astype(np.float32) # (B, seq_len, feat)
x_stamp_batch = np.stack(x_stamp_list, axis=0).astype(np.float32) # (B, seq_len, time_feat)
y_stamp_batch = np.stack(y_stamp_list, axis=0).astype(np.float32) # (B, pred_len, time_feat)
preds = self.generate(x_batch, x_stamp_batch, y_stamp_batch, pred_len, T, top_k, top_p, sample_count, verbose)
# preds: (B, pred_len, feat)
pred_dfs = []
for i in range(num_series):
preds_i = preds[i] * (stds[i] + 1e-5) + means[i]
pred_df = pd.DataFrame(preds_i, columns=self.price_cols + [self.vol_col, self.amt_vol], index=y_timestamp_list[i])
pred_dfs.append(pred_df)
return pred_dfs

View File

@@ -0,0 +1,562 @@
import math
from einops import rearrange, reduce
import torch
import torch.nn as nn
from torch.autograd import Function
import torch.nn.functional as F
class DifferentiableEntropyFunction(Function):
@staticmethod
def forward(ctx, zq, basis, K, eps):
zb = (zq + 1) / 2
zi = ((zb * basis).sum(-1)).to(torch.int64)
cnt = torch.scatter_reduce(torch.zeros(2 ** K, device=zq.device, dtype=zq.dtype),
0,
zi.flatten(),
torch.ones_like(zi.flatten()).to(zq.dtype),
'sum')
prob = (cnt + eps) / (cnt + eps).sum()
H = -(prob * torch.log(prob)).sum()
ctx.save_for_backward(zq, zi, prob)
ctx.K = K
return H
@staticmethod
def backward(ctx, grad_output):
zq, zi, prob = ctx.saved_tensors
grad_array = -grad_output * (torch.log(prob) + 1) / zi.numel() / ctx.K
reord_grad = grad_array[zi.flatten()].reshape(zi.shape)
grad_input = reord_grad.unsqueeze(-1) * zq
return grad_input, None, None, None, None
def codebook_entropy(zq, basis, K, eps=1e-4):
return DifferentiableEntropyFunction.apply(zq, basis, K, eps)
class BinarySphericalQuantizer(nn.Module):
def __init__(self, embed_dim, beta, gamma0, gamma, zeta,
input_format='bchw',
soft_entropy=True, group_size=9,
persample_entropy_compute='analytical',
cb_entropy_compute='group',
l2_norm=True,
inv_temperature=1):
"""
Paper link: https://arxiv.org/pdf/2406.07548.pdf
Here we use the official implementation of the BinarySphericalQuantizer.
"""
super().__init__()
self.embed_dim = embed_dim
self.beta = beta # loss weight for commit loss
self.gamma0 = gamma0 # loss weight for entropy penalty
self.gamma = gamma # loss weight for entropy penalty
self.zeta = zeta # loss weight for entire entropy penalty
self.input_format = input_format
assert self.embed_dim % group_size == 0, "embed_dim must be divisible by group_size"
self.num_groups = self.embed_dim // group_size
self.group_size = group_size
assert persample_entropy_compute in ['group', 'analytical'], "persample_entropy_compute must be either 'group' or 'analytical'"
assert cb_entropy_compute in ['group', 'nce'], "cb_entropy_compute must be either 'group' or 'nce'"
self.persample_entropy_compute = persample_entropy_compute
self.cb_entropy_compute = cb_entropy_compute
self.l2_norm = l2_norm
self.inv_temperature = inv_temperature
self.register_buffer('basis', 2 ** torch.arange(embed_dim - 1, -1, -1))
self.register_buffer('group_basis', 2 ** torch.arange(group_size - 1, -1, -1))
self.num_dimensions = 2 ** embed_dim
self.bits_per_index = embed_dim
# we only need to keep the codebook portion up to the group size
# because we approximate the H loss with this subcode
group_codes = torch.arange(2 ** self.group_size)
group_codebook = self.indexes_to_codes(group_codes).float()[:, -group_size:]
self.register_buffer('group_codebook', group_codebook, persistent=False)
self.soft_entropy = soft_entropy # soft_entropy: Sec 3.2 of https://arxiv.org/pdf/1911.05894.pdf
def quantize(self, z):
assert z.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {z.shape[-1]}"
zhat = torch.where(z > 0,
torch.tensor(1, dtype=z.dtype, device=z.device),
torch.tensor(-1, dtype=z.dtype, device=z.device))
return z + (zhat - z).detach()
def forward(self, z, collect_metrics=True):
# if self.input_format == 'bchw':
# z = rearrange(z, 'b c h w -> b h w c')
zq = self.quantize(z)
q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
zq = zq * q_scale
if not collect_metrics:
return zq, zq.new_zeros(()), {}
indices = self.codes_to_indexes(zq.detach())
group_indices = self.codes_to_group_indexes(zq.detach())
if not self.training:
used_codes = torch.unique(indices, return_counts=False)
else:
used_codes = None
if self.soft_entropy:
persample_entropy, cb_entropy, avg_prob = self.soft_entropy_loss(z)
entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy
else:
zb_by_sample = ((zq + 1) / 2).reshape(z.shape[0], -1, z.shape[-1]).to(torch.float32)
persample_entropy = self.get_hard_per_sample_entropy(zb_by_sample)
cb_entropy = codebook_entropy(zq, self.basis, self.embed_dim)
entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy
# commit loss
commit_loss = self.beta * torch.mean(((zq.detach() - z) ** 2).sum(dim=-1))
# if self.input_format == 'bchw':
# zq = rearrange(zq, 'b h w c -> b c h w')
return (
zq,
commit_loss + self.zeta * entropy_penalty / self.inv_temperature,
{"H": cb_entropy, "used_codes": used_codes, "indices": indices, "group_indices": group_indices,
"avg_prob": avg_prob}
)
def soft_entropy_loss(self, z):
# if we divide the code in subgroups of size group_size, the codebook will be of size 2 ** group_size
# the sub-code is the last group_size bits of the full code
group_code_book = self.group_codebook / (self.embed_dim ** 0.5 if self.l2_norm else 1)
divided_z = rearrange(z, '... (g c) -> ... g c', c=self.group_size)
# we calculate the distance between the divided_z and the codebook for each subgroup
distance = - 2 * torch.einsum('... g c, d c ->... g d', divided_z, group_code_book)
prob = (-distance * self.inv_temperature).softmax(dim=-1)
if self.persample_entropy_compute == 'analytical':
if self.l2_norm:
p = torch.sigmoid(-4 * z / (self.embed_dim ** 0.5) * self.inv_temperature)
else:
p = torch.sigmoid(-4 * z * self.inv_temperature)
prob = torch.stack([p, 1 - p], dim=-1)
per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()
else:
per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()
# macro average of the probability of each subgroup
avg_prob = reduce(prob, '... g d ->g d', 'mean')
codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False)
# the approximation of the entropy is the sum of the entropy of each subgroup
return per_sample_entropy, codebook_entropy.sum(), avg_prob
def get_hard_per_sample_entropy(self, zb_by_sample):
probs_per_dim = zb_by_sample.sum(1) / zb_by_sample.shape[1]
persample_entropy = - probs_per_dim * torch.log(probs_per_dim + 1e-8) - (1 - probs_per_dim) * torch.log(1 - probs_per_dim + 1e-8)
persample_entropy = persample_entropy.sum(-1)
return persample_entropy.mean()
def codes_to_indexes(self, zhat):
"""Converts a `code` to an index in the codebook.
Args:
zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1}
"""
assert zhat.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {zhat.shape[-1]}"
return ((zhat + 1) / 2 * self.basis).sum(axis=-1).to(torch.int64)
def codes_to_group_indexes(self, zhat):
"""Converts a `code` to a list of indexes (in groups) in the codebook.
Args:
zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1}
"""
zhat_in_group = rearrange(zhat, 'b ... (g c) -> b ... g c', c=self.group_size)
return ((zhat_in_group + 1) / 2 * self.group_basis).sum(axis=-1).to(torch.int64)
def indexes_to_codes(self, indices):
"""Inverse of `indexes_to_codes`."""
indices = indices.unsqueeze(-1)
codes_non_centered = torch.remainder(
torch.floor_divide(indices, self.basis), 2
)
return codes_non_centered * 2 - 1
def group_indexes_to_codes(self, group_indices):
"""Inverse of `group_indexes_to_codes`."""
group_indices = group_indices.unsqueeze(-1)
codes_non_centered = torch.remainder(
torch.floor_divide(group_indices, self.group_basis), 2
)
codes_non_centered = rearrange(codes_non_centered, 'b ... g c -> b ... (g c)')
return codes_non_centered * 2 - 1
def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True):
if normalize:
probs = (count + eps) / (count + eps).sum(dim=dim, keepdim=True)
else:
probs = count
H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim)
return H
def get_group_codebook_entry(self, group_indices):
z_q = self.group_indexes_to_codes(group_indices)
q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
z_q = z_q * q_scale
if self.input_format == 'bchw':
h, w = int(z_q.shape[1] ** 0.5)
assert h * w == z_q.shape[1], 'Invalid sequence length'
z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h)
return z_q
def get_codebook_entry(self, indices):
z_q = self.indexes_to_codes(indices)
q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
z_q = z_q * q_scale
if self.input_format == 'bchw':
h, w = int(z_q.shape[1] ** 0.5)
assert h * w == z_q.shape[1], 'Invalid sequence length'
z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h)
return z_q
class BSQuantizer(nn.Module):
def __init__(self, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size):
super().__init__()
self.codebook_dim = s1_bits + s2_bits
self.s1_bits = s1_bits
self.s2_bits = s2_bits
self.bsq = BinarySphericalQuantizer(self.codebook_dim, beta, gamma0, gamma, zeta, group_size=group_size)
def bits_to_indices(self, bits):
bits = (bits >= 0).to(torch.long)
indices = 2 ** torch.arange(
0,
bits.shape[-1],
1,
dtype=torch.long,
device=bits.device,
)
return (bits * indices).sum(-1)
def forward(self, z, half=False, collect_metrics=True):
z = F.normalize(z, dim=-1)
quantized, bsq_loss, metrics = self.bsq(z, collect_metrics=collect_metrics)
if half:
q_pre = quantized[:, :, :self.s1_bits]
q_post = quantized[:, :, self.s1_bits:]
z_indices = [self.bits_to_indices(q_pre), self.bits_to_indices(q_post)]
else:
z_indices = self.bits_to_indices(quantized)
return bsq_loss, quantized, z_indices
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
class FeedForward(nn.Module):
def __init__(self, d_model, ff_dim, ffn_dropout_p=0.0):
super().__init__()
self.w1 = nn.Linear(d_model, ff_dim, bias=False)
self.w3 = nn.Linear(d_model, ff_dim, bias=False)
self.w2 = nn.Linear(ff_dim, d_model, bias=False)
self.ffn_dropout = nn.Dropout(ffn_dropout_p)
def forward(self, x):
return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class RotaryPositionalEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
def _update_cos_sin_cache(self, x, seq_len):
if seq_len != self.seq_len_cached:
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]
return self.cos_cached, self.sin_cached
def forward(self, q, k):
cos, sin = self._update_cos_sin_cache(q, q.shape[-2])
return (
(q * cos) + (self._rotate_half(q) * sin),
(k * cos) + (self._rotate_half(k) * sin),
)
def _rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
class MultiHeadAttentionWithRoPE(nn.Module):
def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout_p=0.0):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
self.rotary = RotaryPositionalEmbedding(self.head_dim)
self.attn_dropout_p = attn_dropout_p
self.resid_dropout = nn.Dropout(resid_dropout_p)
def forward(self, x, key_padding_mask=None):
batch_size, seq_len, _ = x.shape
q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
q, k = self.rotary(q, k)
if key_padding_mask is not None:
attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) # [batch, 1, 1, seq_len]
attn_mask = attn_mask.expand(-1, self.n_heads, seq_len, -1) # [batch, n_heads, q_len, k_len]
else:
attn_mask = None
attn_output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_dropout_p if self.training else 0.0,
is_causal=True
)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
return self.resid_dropout(self.out_proj(attn_output))
class MultiHeadCrossAttentionWithRoPE(nn.Module):
def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout=0.0):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
self.rotary = RotaryPositionalEmbedding(self.head_dim)
self.attn_dropout_p = attn_dropout_p
self.resid_dropout = nn.Dropout(resid_dropout)
def forward(self, query, key, value, key_padding_mask=None):
batch_size, q_len, _ = query.shape
_, seq_len, _ = key.shape
q = self.q_proj(query).view(batch_size, q_len, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(key).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(value).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
q, k = self.rotary(q, k)
if key_padding_mask is not None:
attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2)
attn_mask = attn_mask.expand(-1, self.n_heads, q_len, -1)
else:
attn_mask = None
is_causal_flag = self.training
attn_output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_dropout_p if self.training else 0.0,
is_causal=is_causal_flag
)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, q_len, self.d_model)
return self.resid_dropout(self.out_proj(attn_output))
class HierarchicalEmbedding(nn.Module):
def __init__(self, s1_bits, s2_bits, d_model=256):
super().__init__()
self.s1_bits = s1_bits
self.s2_bits = s2_bits
vocab_s1 = 2 ** s1_bits
vocab_s2 = 2 ** s2_bits
self.emb_s1 = nn.Embedding(vocab_s1, d_model)
self.emb_s2 = nn.Embedding(vocab_s2, d_model)
self.d_model = d_model
self.fusion_proj = nn.Linear(d_model * 2, d_model)
nn.init.normal_(self.emb_s1.weight, mean=0, std=d_model ** -0.5)
nn.init.normal_(self.emb_s2.weight, mean=0, std=d_model ** -0.5)
def split_token(self, token_ids: torch.Tensor, s2_bits: int):
"""Inputs:
token_ids (torch.Tensor): Composite token IDs of shape [batch_size, seq_len] or [N], each in range [0, 2^(s1_bits + s2_bits) - 1].
s2_bits (int): Number of low bits used for the fine token (s2).
"""
assert isinstance(s2_bits, int) and s2_bits > 0, "s2_bits must be a positive integer"
t = token_ids.long()
mask = (1 << s2_bits) - 1
s2_ids = t & mask # extract low bits
s1_ids = t >> s2_bits # extract high bits
return s1_ids, s2_ids
def forward(self, token_ids):
"""Inputs:
token_ids:
- tuple or list: (s1_ids, s2_ids), each of shape [batch_size, seq_len], or
- torch.Tensor: composite token IDs of shape [batch_size, seq_len], which will be split into (s1_ids, s2_ids) internally.
Output: [batch_size, seq_len, d_model]
"""
if isinstance(token_ids, tuple) or isinstance(token_ids, list):
s1_ids, s2_ids = token_ids
else:
s1_ids, s2_ids = self.split_token(token_ids, self.s2_bits)
s1_emb = self.emb_s1(s1_ids) * math.sqrt(self.d_model)
s2_emb = self.emb_s2(s2_ids) * math.sqrt(self.d_model)
return self.fusion_proj(torch.cat([s1_emb, s2_emb], dim=-1))
class DependencyAwareLayer(nn.Module):
def __init__(self, d_model, n_heads=4, attn_dropout_p=0.0, resid_dropout=0.0):
super().__init__()
self.cross_attn = MultiHeadCrossAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout)
self.norm = RMSNorm(d_model)
def forward(self, hidden_states, sibling_embed, key_padding_mask=None):
"""hidden_states: [batch, seq_len, d_model]
sibling_embed: Embedding from another subtoken
"""
attn_out = self.cross_attn(
query=sibling_embed,
key=hidden_states,
value=hidden_states,
key_padding_mask=key_padding_mask
)
return self.norm(hidden_states + attn_out)
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, ff_dim=1024, ffn_dropout_p=0.0, attn_dropout_p=0.0, resid_dropout_p=0.0):
super().__init__()
self.norm1 = RMSNorm(d_model)
self.self_attn = MultiHeadAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout_p)
self.norm2 = RMSNorm(d_model)
self.ffn = FeedForward(d_model, ff_dim, ffn_dropout_p)
def forward(self, x, key_padding_mask=None):
residual = x
x = self.norm1(x)
attn_out = self.self_attn(x, key_padding_mask=key_padding_mask)
x = residual + attn_out
residual = x
x = self.norm2(x)
ffn_out = self.ffn(x)
x = residual + ffn_out
return x
class DualHead(nn.Module):
def __init__(self, s1_bits, s2_bits, d_model):
super().__init__()
self.vocab_s1 = 2 ** s1_bits
self.vocab_s2 = 2 ** s2_bits
self.proj_s1 = nn.Linear(d_model, self.vocab_s1)
self.proj_s2 = nn.Linear(d_model, self.vocab_s2)
def compute_loss(self, s1_logits, s2_logits, s1_targets, s2_targets, padding_mask=None):
if padding_mask is not None:
valid_mask = (padding_mask == 0)
s1_logits = s1_logits[valid_mask]
s2_logits = s2_logits[valid_mask]
s1_targets = s1_targets[valid_mask]
s2_targets = s2_targets[valid_mask]
ce_s1 = F.cross_entropy(s1_logits, s1_targets)
ce_s2 = F.cross_entropy(s2_logits, s2_targets)
else:
ce_s1 = F.cross_entropy(s1_logits.reshape(-1, self.vocab_s1), s1_targets.reshape(-1))
ce_s2 = F.cross_entropy(s2_logits.reshape(-1, self.vocab_s2), s2_targets.reshape(-1))
ce_loss = (ce_s1 + ce_s2) / 2
return ce_loss, ce_s1, ce_s2
def forward(self, x):
return self.proj_s1(x)
def cond_forward(self, x2):
return self.proj_s2(x2)
class FixedEmbedding(nn.Module):
def __init__(self, c_in, d_model):
super(FixedEmbedding, self).__init__()
w = torch.zeros(c_in, d_model).float()
w.require_grad = False
position = torch.arange(0, c_in).float().unsqueeze(1)
div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
w[:, 0::2] = torch.sin(position * div_term)
w[:, 1::2] = torch.cos(position * div_term)
self.emb = nn.Embedding(c_in, d_model)
self.emb.weight = nn.Parameter(w, requires_grad=False)
def forward(self, x):
return self.emb(x).detach()
class TemporalEmbedding(nn.Module):
def __init__(self, d_model, learn_pe):
super(TemporalEmbedding, self).__init__()
minute_size = 60
hour_size = 24
weekday_size = 7
day_size = 32
month_size = 13
Embed = FixedEmbedding if not learn_pe else nn.Embedding
self.minute_embed = Embed(minute_size, d_model)
self.hour_embed = Embed(hour_size, d_model)
self.weekday_embed = Embed(weekday_size, d_model)
self.day_embed = Embed(day_size, d_model)
self.month_embed = Embed(month_size, d_model)
def forward(self, x):
x = x.long()
minute_x = self.minute_embed(x[:, :, 0])
hour_x = self.hour_embed(x[:, :, 1])
weekday_x = self.weekday_embed(x[:, :, 2])
day_x = self.day_embed(x[:, :, 3])
month_x = self.month_embed(x[:, :, 4])
return hour_x + weekday_x + day_x + month_x + minute_x

View File

@@ -0,0 +1,539 @@
import os
import sys
import time
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import json
import random
from loguru import logger
from datetime import datetime, timedelta
from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv
load_dotenv(os.path.expanduser("~/.config/opencode/.env"))
# Setup paths
KRONOS_DIR = os.path.dirname(os.path.abspath(__file__))
SRC_DIR = os.path.dirname(os.path.dirname(KRONOS_DIR))
if SRC_DIR not in sys.path:
sys.path.insert(0, SRC_DIR)
from ..kronos.model import Kronos, KronosTokenizer, KronosPredictor
from ..database_manager import DatabaseManager
from ..stock_tools import StockTools
from ..search_tools import SearchTools
from ..llm.factory import get_model
from ..visualizer import VisualizerTools
from ..schema.models import ForecastResult, KLinePoint
from agno.agent import Agent
class AutoSynthesisTrainer:
def __init__(self, news_dim=384):
self.device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
self.db = DatabaseManager()
self.tools = StockTools(self.db)
self.searcher = SearchTools(self.db)
# Try loading from local cache first to avoid network timeouts
model_name = os.getenv(
"EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
)
try:
logger.info(f"🔄 Attempting to load {model_name} from local cache...")
self.embedder = SentenceTransformer(
model_name, device=self.device, local_files_only=True
)
logger.success("✅ Model loaded from local cache.")
except Exception:
logger.warning(
"⚠️ Local cache not found or incomplete. Attempting to download..."
)
self.embedder = SentenceTransformer(model_name, device=self.device)
self.news_dim = news_dim
# Try loading from local cache first to avoid network timeouts
try:
logger.info(
"🔄 Attempting to load Kronos and Tokenizer from local cache..."
)
self.tokenizer = KronosTokenizer.from_pretrained(
"NeoQuasar/Kronos-Tokenizer-base", local_files_only=True
).to(self.device)
base_model = Kronos.from_pretrained(
"NeoQuasar/Kronos-base", local_files_only=True
)
logger.success("✅ Kronos and Tokenizer loaded from local cache.")
except Exception:
logger.warning(
"⚠️ Local Kronos/Tokenizer not found or incomplete. Attempting to download..."
)
self.tokenizer = KronosTokenizer.from_pretrained(
"NeoQuasar/Kronos-Tokenizer-base"
).to(self.device)
base_model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
self.model = Kronos(
base_model.s1_bits,
base_model.s2_bits,
base_model.n_layers,
base_model.d_model,
base_model.n_heads,
base_model.ff_dim,
base_model.ffn_dropout_p,
base_model.attn_dropout_p,
base_model.resid_dropout_p,
base_model.token_dropout_p,
base_model.learn_te,
news_dim=self.news_dim,
).to(self.device)
self.model.load_state_dict(base_model.state_dict(), strict=False)
# LLM for causality verification
provider = os.getenv("LLM_PROVIDER", "ust")
model_id = os.getenv("LLM_MODEL", "Qwen")
self.llm_agent = Agent(model=get_model(provider, model_id))
def discover_shocks(
self, ticker_list, threshold=2.0, limit_per_stock=5, days=365, pred_len=5
):
"""1. Find days with significant price movements (Look back 1 year)"""
shocks = []
end_date = datetime.now().strftime("%Y-%m-%d")
start_date = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d")
for ticker in ticker_list:
df = self.tools.get_stock_price(
ticker, start_date=start_date, end_date=end_date
)
if df.empty or len(df) < 60:
continue
# Look for big moves
moves = df[df["change_pct"].abs() > threshold].copy()
if moves.empty:
continue
count = 0
for idx, row in moves.iterrows():
# Ensure we have history before this day AND enough future days for eval
date_idx = df.index.get_loc(idx)
if date_idx < 50 or date_idx + pred_len > len(df):
continue
shocks.append(
{
"ticker": ticker,
"date": row["date"],
"change": row["change_pct"],
"history": df.iloc[date_idx - 50 : date_idx],
"target": df.iloc[
date_idx : date_idx + pred_len
], # Now capturing pred_len days
}
)
count += 1
if count >= limit_per_stock:
break
logger.info(
f"✨ Discovered {len(shocks)} potential price shocks over the last {days} days."
)
return shocks
def find_reason_and_verify(self, shock):
"""2. Search for reasons and verify causality using LLM"""
ticker_info = self.db.get_stock_by_code(shock["ticker"])
name = ticker_info["name"] if ticker_info else shock["ticker"]
date_str = shock["date"]
# Try multiple query variations and engines
queries = [
f"{name} ({shock['ticker']}) {date_str} 为什么涨跌 原因",
f"{name} {date_str} 异动 原因",
f"{shock['ticker']} {date_str} 新闻",
]
search_results = []
for query in queries:
logger.info(f"🔍 Searching for reason: {query}")
# Try alternate engines
for engine in ["baidu"]:
try:
results = self.searcher.search_list(
query, engine=engine, max_results=3, enrich=False
)
if results:
search_results = results
break
except Exception as e:
logger.warning(f"Search failed for {query} on {engine}: {e}")
if search_results:
break
time.sleep(random.uniform(1.0, 2.0))
if not search_results:
logger.warning(
f"⚠️ No search results found for {name} on {date_str} after multiple attempts."
)
return None
context = "\n".join(
[f"- {r['title']}: {r.get('content', '')[:300]}" for r in search_results]
)
prompt = f"""
任务:判断以下新闻是否解释了该股票在 {date_str}{shock["change"]:.2f}% 价格变动。
股票:{name}
日期:{date_str}
变动:{shock["change"]:.2f}%
搜索结果:
{context}
要求:
1. 该新闻是否在该日期左右发生?
2. 该新闻是否能逻辑上解释这种大幅波动(如财报、利好政策、重组、大环境暴跌等)?
3. 如果是,请总结一段 100 字以内的“核心推动原因”。
4. 返回 JSON: {{"is_causal": true/false, "summary": "原因摘要"}}
"""
try:
res = self.llm_agent.run(prompt)
data = json.loads(
res.content.replace("```json", "").replace("```", "").strip()
)
if data.get("is_causal"):
logger.success(
f"✅ Verified cause for {name} on {date_str}: {data['summary']}"
)
return data["summary"]
else:
logger.warning(
f"❌ Verified cause for {name} on {date_str}: {data['summary']}"
)
return None
except Exception as e:
logger.warning(f"Verification failed: {e}")
return None
def save_model(self, path=None):
"""Save the news_proj weights"""
if path is None:
save_dir = os.path.join(SRC_DIR, "exports/models")
os.makedirs(save_dir, exist_ok=True)
path = os.path.join(
save_dir, f"kronos_news_v1_{datetime.now().strftime('%Y%m%d_%H%M')}.pt"
)
# We only really need to save the news_proj part as it's the only one we train
torch.save(
{
"news_proj_state_dict": self.model.news_proj.state_dict(),
"news_dim": self.news_dim,
"d_model": self.model.d_model,
},
path,
)
logger.success(f"💾 Model weights saved to {path}")
return path
def run_synthesis_and_train(self, tickers, pred_len=5):
# 1. Discovery
shocks = self.discover_shocks(tickers, pred_len=pred_len)
print(f"find {len(shocks)} shocks")
# 2. News Association & Verification
dataset = []
max_news_items = 200 # Limit to 200 news items per session to avoid search bans
logger.info(
f"🧬 Starting News Association for {len(shocks)} shocks (Max limit: {max_news_items})"
)
for i, shock in enumerate(shocks):
if len(dataset) >= max_news_items:
logger.info("Reached maximum news items limit for this session.")
break
summary = self.find_reason_and_verify(shock)
if summary:
# 3. Embedding news
emb = self.embedder.encode(summary)
dataset.append(
{
"history": shock["history"],
"target": shock["target"],
"news_emb": emb,
"summary": summary,
}
)
# Add delay after search with randomness to avoid being blocked
if i < len(shocks) - 1:
delay = random.uniform(2.0, 4.0)
time.sleep(delay)
if not dataset:
logger.error(
"❌ No verified news-price pairs found. Adjust threshold or check if news is available in that period."
)
return
# 4. Train/Val Split
random.seed(42)
random.shuffle(dataset)
if len(dataset) < 2:
train_set = dataset
val_set = []
logger.warning(
f"⚠️ Only {len(dataset)} sample(s) found. Training on all, skipping validation."
)
else:
split_idx = max(1, int(len(dataset) * 0.8))
if split_idx >= len(dataset):
split_idx = len(dataset) - 1
train_set = dataset[:split_idx]
val_set = dataset[split_idx:]
logger.info(
f"🏗️ Dataset Split: {len(train_set)} samples for training, {len(val_set)} for validation."
)
if not train_set:
logger.error("❌ No samples for training.")
return
# 5. Training (Few-shot)
optimizer = torch.optim.Adam(self.model.news_proj.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
self.model.train()
loss_history = []
logger.info(f"🚀 Training for 30 epochs...")
for epoch in range(30):
total_loss = 0
for item in train_set:
optimizer.zero_grad()
# Prep Data
hist_df = item["history"]
# For training, we still focus on the immediate next point (teacher forcing)
target_df = item["target"].iloc[:1]
hist_raw = hist_df[
["open", "high", "low", "close", "volume"]
].values.astype(np.float32)
hist_raw = np.column_stack([hist_raw, hist_raw[:, 3] * hist_raw[:, 4]])
mean, std = hist_raw.mean(axis=0), hist_raw.std(axis=0) + 1e-5
hist_norm = (
torch.from_numpy((hist_raw - mean) / std)
.unsqueeze(0)
.to(self.device)
)
target_raw = target_df[
["open", "high", "low", "close", "volume"]
].values.astype(np.float32)
target_raw = np.column_stack(
[target_raw, target_raw[:, 3] * target_raw[:, 4]]
)
target_norm = (
torch.from_numpy((target_raw - mean) / std)
.unsqueeze(0)
.to(self.device)
)
with torch.no_grad():
z_indices = self.tokenizer.encode(hist_norm, half=True)
t_indices = self.tokenizer.encode(target_norm, half=True)
s1_ids, s2_ids = z_indices[0], z_indices[1]
t_s1, t_s2 = t_indices[0], t_indices[1]
news_t = torch.from_numpy(item["news_emb"]).unsqueeze(0).to(self.device)
s1_logits, s2_logits = self.model(
s1_ids,
s2_ids,
news_emb=news_t,
use_teacher_forcing=True,
s1_targets=t_s1,
)
loss = (
criterion(s1_logits[:, -1, :], t_s1[:, 0])
+ criterion(s2_logits[:, -1, :], t_s2[:, 0])
) / 2
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_epoch_loss = total_loss / max(1, len(train_set))
loss_history.append(avg_epoch_loss)
if (epoch + 1) % 10 == 0:
logger.info(f"Epoch {epoch + 1} Loss: {avg_epoch_loss:.4f}")
# 5.1 Visualize Loss Curve
loss_chart = VisualizerTools.generate_loss_chart(loss_history)
VisualizerTools.render_chart_to_file(
loss_chart,
os.path.join(SRC_DIR, "exports/training_results/loss_curve.html"),
)
# 5.2 Save final model
self.save_model()
# 6. Final Evaluation on Validation Set
if not val_set:
logger.warning("⚠️ Validation set is empty. Skipping statistical analysis.")
return
logger.info(
f"🧪 Final Evaluation: Base vs News-Integrated ({pred_len}-day Window)"
)
self.model.eval()
predictor = KronosPredictor(self.model, self.tokenizer, device=self.device)
base_maes = []
news_maes = []
print("\n" + "=" * 90)
print(
f"{'Date':<12} | {'Ticker':<8} | {'Base MAE':<15} | {'News MAE':<15} | {'Improvement'}"
)
print("-" * 90)
for item in val_set:
h = item["history"]
t = item["target"]
actuals = t["close"].values[:pred_len]
x_ts = pd.to_datetime(h["date"])
# Future timestamps: handle business days if possible, or just simple offset
future_dates = pd.date_range(
start=x_ts.iloc[-1] + timedelta(days=1), periods=pred_len, freq="B"
)
y_ts = pd.Series(future_dates)
# A. Base Prediction
p_base = predictor.predict(
h, x_ts, y_ts, pred_len=pred_len, news_emb=None, verbose=False
)
b_preds = p_base["close"].values[: len(actuals)]
# B. News-Aware Prediction
p_news = predictor.predict(
h,
x_ts,
y_ts,
pred_len=pred_len,
news_emb=item["news_emb"],
verbose=False,
)
n_preds = p_news["close"].values[: len(actuals)]
# Calculate MAE over the window
b_mae = np.mean(np.abs(b_preds - actuals))
n_mae = np.mean(np.abs(n_preds - actuals))
base_maes.append(b_mae)
news_maes.append(n_mae)
improvement = (b_mae - n_mae) / (b_mae + 1e-6) * 100
date_str = str(t["date"].values[0])[:10]
ticker = h.iloc[-1]["ticker"] if "ticker" in h.columns else "Stock"
print(
f"{date_str:<12} | {ticker:<8} | {b_mae:<15.4f} | {n_mae:<15.4f} | {improvement:>+7.1f}%"
)
# C. Generate Visualization for this case
try:
# Helper to convert DF to KLinePoints
def to_kp_list(preds_df):
points = []
for idx, row in preds_df.iterrows():
points.append(
KLinePoint(
date=str(idx)[:10],
open=row["open"],
high=row["high"],
low=row["low"],
close=row["close"],
volume=row["volume"] if "volume" in row else 0,
)
)
return points
forecast_obj = ForecastResult(
ticker=ticker,
base_forecast=to_kp_list(p_base),
adjusted_forecast=to_kp_list(p_news),
rationale=item["summary"],
)
# Ground truth for visualizer expects a DataFrame with 'date' and 'close'
gt_df = t[["date", "open", "high", "low", "close", "volume"]]
chart = VisualizerTools.generate_stock_chart(
df=h,
ticker=ticker,
title=f"Training Eval: {ticker} ({date_str}) Improvement: {improvement:.1f}%",
forecast=forecast_obj,
ground_truth=gt_df,
)
safe_date = date_str.replace("-", "")
filename = f"eval_{ticker}_{safe_date}.html"
VisualizerTools.render_chart_to_file(
chart, os.path.join(SRC_DIR, f"exports/training_results/{filename}")
)
except Exception as e:
logger.error(f"Failed to generate eval chart for {ticker}: {e}")
# Summary Statistics
avg_base_err = sum(base_maes) / max(1, len(base_maes))
avg_news_err = sum(news_maes) / max(1, len(news_maes))
overall_imp = (avg_base_err - avg_news_err) / (avg_base_err + 1e-6) * 100
print("-" * 90)
print(
f"{'AVERAGE':<12} | {'-':<8} | {avg_base_err:<15.4f} | {avg_news_err:<15.4f} | {overall_imp:>+7.1f}%"
)
print("=" * 90 + "\n")
logger.success(
f"🏁 Statistical Analysis Complete. Avg Error Reduction ({pred_len}-day): {overall_imp:.2f}%"
)
logger.info(
f"📊 Visualization results saved to: {os.path.join(SRC_DIR, 'exports/training_results/')}"
)
if __name__ == "__main__":
trainer = AutoSynthesisTrainer()
logger.info("📂 Fetching all stock codes from database...")
res = trainer.db.execute_query("SELECT code FROM stock_list")
all_tickers = [row["code"] for row in res]
if not all_tickers:
logger.warning("⚠️ No tickers found in stock_list table. Trying to sync...")
trainer.tools._check_and_update_stock_list(force=True)
res = trainer.db.execute_query("SELECT code FROM stock_list")
all_tickers = [row["code"] for row in res]
logger.info(f"🚀 Starting training on potential stocks (1-year scan)...")
# 为了演示,我们扫描前 100 个股票,寻找最近一年的冲击点
trainer.run_synthesis_and_train(all_tickers[:100], pred_len=1)