Files
opencode-skill/skills/alphaear-reporter/scripts/utils/predictor/training.py
Kunthawat Greethong 17930e9650 Update alphaear skills to use MiniMax as default LLM
Changes:
- Added 'minimax' provider to LLM factory
- Changed default LLM_PROVIDER from 'ust' to 'minimax'
- Changed default LLM_MODEL from 'Qwen' to 'MiniMax-Text-01'
- Updated REASONING_MODEL_PROVIDER and TOOL_MODEL_PROVIDER to minimax
- Sentiment tools now prefer MINIMAX_API_KEY over UST_KEY_API
- .env.example updated with MiniMax defaults
2026-03-27 10:27:00 +07:00

540 lines
20 KiB
Python

import os
import sys
import time
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import json
import random
from loguru import logger
from datetime import datetime, timedelta
from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv
load_dotenv(os.path.expanduser("~/.config/opencode/.env"))
# Setup paths
KRONOS_DIR = os.path.dirname(os.path.abspath(__file__))
SRC_DIR = os.path.dirname(os.path.dirname(KRONOS_DIR))
if SRC_DIR not in sys.path:
sys.path.insert(0, SRC_DIR)
from ..kronos.model import Kronos, KronosTokenizer, KronosPredictor
from ..database_manager import DatabaseManager
from ..stock_tools import StockTools
from ..search_tools import SearchTools
from ..llm.factory import get_model
from ..visualizer import VisualizerTools
from ..schema.models import ForecastResult, KLinePoint
from agno.agent import Agent
class AutoSynthesisTrainer:
def __init__(self, news_dim=384):
self.device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
self.db = DatabaseManager()
self.tools = StockTools(self.db)
self.searcher = SearchTools(self.db)
# Try loading from local cache first to avoid network timeouts
model_name = os.getenv(
"EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
)
try:
logger.info(f"🔄 Attempting to load {model_name} from local cache...")
self.embedder = SentenceTransformer(
model_name, device=self.device, local_files_only=True
)
logger.success("✅ Model loaded from local cache.")
except Exception:
logger.warning(
"⚠️ Local cache not found or incomplete. Attempting to download..."
)
self.embedder = SentenceTransformer(model_name, device=self.device)
self.news_dim = news_dim
# Try loading from local cache first to avoid network timeouts
try:
logger.info(
"🔄 Attempting to load Kronos and Tokenizer from local cache..."
)
self.tokenizer = KronosTokenizer.from_pretrained(
"NeoQuasar/Kronos-Tokenizer-base", local_files_only=True
).to(self.device)
base_model = Kronos.from_pretrained(
"NeoQuasar/Kronos-base", local_files_only=True
)
logger.success("✅ Kronos and Tokenizer loaded from local cache.")
except Exception:
logger.warning(
"⚠️ Local Kronos/Tokenizer not found or incomplete. Attempting to download..."
)
self.tokenizer = KronosTokenizer.from_pretrained(
"NeoQuasar/Kronos-Tokenizer-base"
).to(self.device)
base_model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
self.model = Kronos(
base_model.s1_bits,
base_model.s2_bits,
base_model.n_layers,
base_model.d_model,
base_model.n_heads,
base_model.ff_dim,
base_model.ffn_dropout_p,
base_model.attn_dropout_p,
base_model.resid_dropout_p,
base_model.token_dropout_p,
base_model.learn_te,
news_dim=self.news_dim,
).to(self.device)
self.model.load_state_dict(base_model.state_dict(), strict=False)
# LLM for causality verification
provider = os.getenv("LLM_PROVIDER", "minimax")
model_id = os.getenv("LLM_MODEL", "Qwen")
self.llm_agent = Agent(model=get_model(provider, model_id))
def discover_shocks(
self, ticker_list, threshold=2.0, limit_per_stock=5, days=365, pred_len=5
):
"""1. Find days with significant price movements (Look back 1 year)"""
shocks = []
end_date = datetime.now().strftime("%Y-%m-%d")
start_date = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d")
for ticker in ticker_list:
df = self.tools.get_stock_price(
ticker, start_date=start_date, end_date=end_date
)
if df.empty or len(df) < 60:
continue
# Look for big moves
moves = df[df["change_pct"].abs() > threshold].copy()
if moves.empty:
continue
count = 0
for idx, row in moves.iterrows():
# Ensure we have history before this day AND enough future days for eval
date_idx = df.index.get_loc(idx)
if date_idx < 50 or date_idx + pred_len > len(df):
continue
shocks.append(
{
"ticker": ticker,
"date": row["date"],
"change": row["change_pct"],
"history": df.iloc[date_idx - 50 : date_idx],
"target": df.iloc[
date_idx : date_idx + pred_len
], # Now capturing pred_len days
}
)
count += 1
if count >= limit_per_stock:
break
logger.info(
f"✨ Discovered {len(shocks)} potential price shocks over the last {days} days."
)
return shocks
def find_reason_and_verify(self, shock):
"""2. Search for reasons and verify causality using LLM"""
ticker_info = self.db.get_stock_by_code(shock["ticker"])
name = ticker_info["name"] if ticker_info else shock["ticker"]
date_str = shock["date"]
# Try multiple query variations and engines
queries = [
f"{name} ({shock['ticker']}) {date_str} 为什么涨跌 原因",
f"{name} {date_str} 异动 原因",
f"{shock['ticker']} {date_str} 新闻",
]
search_results = []
for query in queries:
logger.info(f"🔍 Searching for reason: {query}")
# Try alternate engines
for engine in ["baidu"]:
try:
results = self.searcher.search_list(
query, engine=engine, max_results=3, enrich=False
)
if results:
search_results = results
break
except Exception as e:
logger.warning(f"Search failed for {query} on {engine}: {e}")
if search_results:
break
time.sleep(random.uniform(1.0, 2.0))
if not search_results:
logger.warning(
f"⚠️ No search results found for {name} on {date_str} after multiple attempts."
)
return None
context = "\n".join(
[f"- {r['title']}: {r.get('content', '')[:300]}" for r in search_results]
)
prompt = f"""
任务:判断以下新闻是否解释了该股票在 {date_str}{shock["change"]:.2f}% 价格变动。
股票:{name}
日期:{date_str}
变动:{shock["change"]:.2f}%
搜索结果:
{context}
要求:
1. 该新闻是否在该日期左右发生?
2. 该新闻是否能逻辑上解释这种大幅波动(如财报、利好政策、重组、大环境暴跌等)?
3. 如果是,请总结一段 100 字以内的“核心推动原因”。
4. 返回 JSON: {{"is_causal": true/false, "summary": "原因摘要"}}
"""
try:
res = self.llm_agent.run(prompt)
data = json.loads(
res.content.replace("```json", "").replace("```", "").strip()
)
if data.get("is_causal"):
logger.success(
f"✅ Verified cause for {name} on {date_str}: {data['summary']}"
)
return data["summary"]
else:
logger.warning(
f"❌ Verified cause for {name} on {date_str}: {data['summary']}"
)
return None
except Exception as e:
logger.warning(f"Verification failed: {e}")
return None
def save_model(self, path=None):
"""Save the news_proj weights"""
if path is None:
save_dir = os.path.join(SRC_DIR, "exports/models")
os.makedirs(save_dir, exist_ok=True)
path = os.path.join(
save_dir, f"kronos_news_v1_{datetime.now().strftime('%Y%m%d_%H%M')}.pt"
)
# We only really need to save the news_proj part as it's the only one we train
torch.save(
{
"news_proj_state_dict": self.model.news_proj.state_dict(),
"news_dim": self.news_dim,
"d_model": self.model.d_model,
},
path,
)
logger.success(f"💾 Model weights saved to {path}")
return path
def run_synthesis_and_train(self, tickers, pred_len=5):
# 1. Discovery
shocks = self.discover_shocks(tickers, pred_len=pred_len)
print(f"find {len(shocks)} shocks")
# 2. News Association & Verification
dataset = []
max_news_items = 200 # Limit to 200 news items per session to avoid search bans
logger.info(
f"🧬 Starting News Association for {len(shocks)} shocks (Max limit: {max_news_items})"
)
for i, shock in enumerate(shocks):
if len(dataset) >= max_news_items:
logger.info("Reached maximum news items limit for this session.")
break
summary = self.find_reason_and_verify(shock)
if summary:
# 3. Embedding news
emb = self.embedder.encode(summary)
dataset.append(
{
"history": shock["history"],
"target": shock["target"],
"news_emb": emb,
"summary": summary,
}
)
# Add delay after search with randomness to avoid being blocked
if i < len(shocks) - 1:
delay = random.uniform(2.0, 4.0)
time.sleep(delay)
if not dataset:
logger.error(
"❌ No verified news-price pairs found. Adjust threshold or check if news is available in that period."
)
return
# 4. Train/Val Split
random.seed(42)
random.shuffle(dataset)
if len(dataset) < 2:
train_set = dataset
val_set = []
logger.warning(
f"⚠️ Only {len(dataset)} sample(s) found. Training on all, skipping validation."
)
else:
split_idx = max(1, int(len(dataset) * 0.8))
if split_idx >= len(dataset):
split_idx = len(dataset) - 1
train_set = dataset[:split_idx]
val_set = dataset[split_idx:]
logger.info(
f"🏗️ Dataset Split: {len(train_set)} samples for training, {len(val_set)} for validation."
)
if not train_set:
logger.error("❌ No samples for training.")
return
# 5. Training (Few-shot)
optimizer = torch.optim.Adam(self.model.news_proj.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
self.model.train()
loss_history = []
logger.info(f"🚀 Training for 30 epochs...")
for epoch in range(30):
total_loss = 0
for item in train_set:
optimizer.zero_grad()
# Prep Data
hist_df = item["history"]
# For training, we still focus on the immediate next point (teacher forcing)
target_df = item["target"].iloc[:1]
hist_raw = hist_df[
["open", "high", "low", "close", "volume"]
].values.astype(np.float32)
hist_raw = np.column_stack([hist_raw, hist_raw[:, 3] * hist_raw[:, 4]])
mean, std = hist_raw.mean(axis=0), hist_raw.std(axis=0) + 1e-5
hist_norm = (
torch.from_numpy((hist_raw - mean) / std)
.unsqueeze(0)
.to(self.device)
)
target_raw = target_df[
["open", "high", "low", "close", "volume"]
].values.astype(np.float32)
target_raw = np.column_stack(
[target_raw, target_raw[:, 3] * target_raw[:, 4]]
)
target_norm = (
torch.from_numpy((target_raw - mean) / std)
.unsqueeze(0)
.to(self.device)
)
with torch.no_grad():
z_indices = self.tokenizer.encode(hist_norm, half=True)
t_indices = self.tokenizer.encode(target_norm, half=True)
s1_ids, s2_ids = z_indices[0], z_indices[1]
t_s1, t_s2 = t_indices[0], t_indices[1]
news_t = torch.from_numpy(item["news_emb"]).unsqueeze(0).to(self.device)
s1_logits, s2_logits = self.model(
s1_ids,
s2_ids,
news_emb=news_t,
use_teacher_forcing=True,
s1_targets=t_s1,
)
loss = (
criterion(s1_logits[:, -1, :], t_s1[:, 0])
+ criterion(s2_logits[:, -1, :], t_s2[:, 0])
) / 2
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_epoch_loss = total_loss / max(1, len(train_set))
loss_history.append(avg_epoch_loss)
if (epoch + 1) % 10 == 0:
logger.info(f"Epoch {epoch + 1} Loss: {avg_epoch_loss:.4f}")
# 5.1 Visualize Loss Curve
loss_chart = VisualizerTools.generate_loss_chart(loss_history)
VisualizerTools.render_chart_to_file(
loss_chart,
os.path.join(SRC_DIR, "exports/training_results/loss_curve.html"),
)
# 5.2 Save final model
self.save_model()
# 6. Final Evaluation on Validation Set
if not val_set:
logger.warning("⚠️ Validation set is empty. Skipping statistical analysis.")
return
logger.info(
f"🧪 Final Evaluation: Base vs News-Integrated ({pred_len}-day Window)"
)
self.model.eval()
predictor = KronosPredictor(self.model, self.tokenizer, device=self.device)
base_maes = []
news_maes = []
print("\n" + "=" * 90)
print(
f"{'Date':<12} | {'Ticker':<8} | {'Base MAE':<15} | {'News MAE':<15} | {'Improvement'}"
)
print("-" * 90)
for item in val_set:
h = item["history"]
t = item["target"]
actuals = t["close"].values[:pred_len]
x_ts = pd.to_datetime(h["date"])
# Future timestamps: handle business days if possible, or just simple offset
future_dates = pd.date_range(
start=x_ts.iloc[-1] + timedelta(days=1), periods=pred_len, freq="B"
)
y_ts = pd.Series(future_dates)
# A. Base Prediction
p_base = predictor.predict(
h, x_ts, y_ts, pred_len=pred_len, news_emb=None, verbose=False
)
b_preds = p_base["close"].values[: len(actuals)]
# B. News-Aware Prediction
p_news = predictor.predict(
h,
x_ts,
y_ts,
pred_len=pred_len,
news_emb=item["news_emb"],
verbose=False,
)
n_preds = p_news["close"].values[: len(actuals)]
# Calculate MAE over the window
b_mae = np.mean(np.abs(b_preds - actuals))
n_mae = np.mean(np.abs(n_preds - actuals))
base_maes.append(b_mae)
news_maes.append(n_mae)
improvement = (b_mae - n_mae) / (b_mae + 1e-6) * 100
date_str = str(t["date"].values[0])[:10]
ticker = h.iloc[-1]["ticker"] if "ticker" in h.columns else "Stock"
print(
f"{date_str:<12} | {ticker:<8} | {b_mae:<15.4f} | {n_mae:<15.4f} | {improvement:>+7.1f}%"
)
# C. Generate Visualization for this case
try:
# Helper to convert DF to KLinePoints
def to_kp_list(preds_df):
points = []
for idx, row in preds_df.iterrows():
points.append(
KLinePoint(
date=str(idx)[:10],
open=row["open"],
high=row["high"],
low=row["low"],
close=row["close"],
volume=row["volume"] if "volume" in row else 0,
)
)
return points
forecast_obj = ForecastResult(
ticker=ticker,
base_forecast=to_kp_list(p_base),
adjusted_forecast=to_kp_list(p_news),
rationale=item["summary"],
)
# Ground truth for visualizer expects a DataFrame with 'date' and 'close'
gt_df = t[["date", "open", "high", "low", "close", "volume"]]
chart = VisualizerTools.generate_stock_chart(
df=h,
ticker=ticker,
title=f"Training Eval: {ticker} ({date_str}) Improvement: {improvement:.1f}%",
forecast=forecast_obj,
ground_truth=gt_df,
)
safe_date = date_str.replace("-", "")
filename = f"eval_{ticker}_{safe_date}.html"
VisualizerTools.render_chart_to_file(
chart, os.path.join(SRC_DIR, f"exports/training_results/{filename}")
)
except Exception as e:
logger.error(f"Failed to generate eval chart for {ticker}: {e}")
# Summary Statistics
avg_base_err = sum(base_maes) / max(1, len(base_maes))
avg_news_err = sum(news_maes) / max(1, len(news_maes))
overall_imp = (avg_base_err - avg_news_err) / (avg_base_err + 1e-6) * 100
print("-" * 90)
print(
f"{'AVERAGE':<12} | {'-':<8} | {avg_base_err:<15.4f} | {avg_news_err:<15.4f} | {overall_imp:>+7.1f}%"
)
print("=" * 90 + "\n")
logger.success(
f"🏁 Statistical Analysis Complete. Avg Error Reduction ({pred_len}-day): {overall_imp:.2f}%"
)
logger.info(
f"📊 Visualization results saved to: {os.path.join(SRC_DIR, 'exports/training_results/')}"
)
if __name__ == "__main__":
trainer = AutoSynthesisTrainer()
logger.info("📂 Fetching all stock codes from database...")
res = trainer.db.execute_query("SELECT code FROM stock_list")
all_tickers = [row["code"] for row in res]
if not all_tickers:
logger.warning("⚠️ No tickers found in stock_list table. Trying to sync...")
trainer.tools._check_and_update_stock_list(force=True)
res = trainer.db.execute_query("SELECT code FROM stock_list")
all_tickers = [row["code"] for row in res]
logger.info(f"🚀 Starting training on potential stocks (1-year scan)...")
# 为了演示,我们扫描前 100 个股票,寻找最近一年的冲击点
trainer.run_synthesis_and_train(all_tickers[:100], pred_len=1)