Import 9 alphaear finance skills
- alphaear-deepear-lite: DeepEar Lite API integration - alphaear-logic-visualizer: Draw.io XML finance diagrams - alphaear-news: Real-time finance news (10+ sources) - alphaear-predictor: Kronos time-series forecasting - alphaear-reporter: Professional financial reports - alphaear-search: Web search + local RAG - alphaear-sentiment: FinBERT/LLM sentiment analysis - alphaear-signal-tracker: Signal evolution tracking - alphaear-stock: A-Share/HK/US stock data Updates: - All scripts updated to use universal .env path - Added JINA_API_KEY, LLM_*, DEEPSEEK_API_KEY to .env.example - Updated load_dotenv() to use ~/.config/opencode/.env
This commit is contained in:
137
skills/alphaear-predictor/scripts/utils/predictor/evaluation.py
Normal file
137
skills/alphaear-predictor/scripts/utils/predictor/evaluation.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import glob
|
||||
from loguru import logger
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Setup paths
|
||||
KRONOS_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
SRC_DIR = os.path.dirname(os.path.dirname(KRONOS_DIR))
|
||||
if SRC_DIR not in sys.path:
|
||||
sys.path.insert(0, SRC_DIR)
|
||||
|
||||
from ..kronos.auto_synthesis_training import AutoSynthesisTrainer
|
||||
from ..kronos.model import KronosPredictor
|
||||
from ..visualizer import VisualizerTools
|
||||
from ..schema.models import ForecastResult, KLinePoint
|
||||
|
||||
class NewsModelEvaluator:
|
||||
def __init__(self, model_path=None):
|
||||
self.trainer = AutoSynthesisTrainer()
|
||||
self.device = self.trainer.device
|
||||
|
||||
if model_path is None:
|
||||
# Try to find the latest model in exports/models
|
||||
model_files = glob.glob(os.path.join(SRC_DIR, "exports/models/*.pt"))
|
||||
if not model_files:
|
||||
logger.warning("⚠️ No trained models found in exports/models/. Using base model (zero-init proj).")
|
||||
else:
|
||||
model_path = max(model_files, key=os.path.getctime)
|
||||
|
||||
if model_path:
|
||||
self.load_weights(model_path)
|
||||
|
||||
def load_weights(self, path):
|
||||
logger.info(f"🔄 Loading model weights from {path}...")
|
||||
checkpoint = torch.load(path, map_location=self.device)
|
||||
self.trainer.model.news_proj.load_state_dict(checkpoint['news_proj_state_dict'])
|
||||
logger.success("✅ News projection layer loaded.")
|
||||
|
||||
def evaluate_range(self, start_idx=100, end_idx=200, pred_len=5):
|
||||
# 1. Fetch Tickers
|
||||
res = self.trainer.db.execute_query("SELECT code FROM stock_list")
|
||||
all_tickers = [row['code'] for row in res]
|
||||
test_tickers = all_tickers[start_idx:end_idx]
|
||||
|
||||
if not test_tickers:
|
||||
logger.error(f"No tickers found in range {start_idx}-{end_idx}")
|
||||
return
|
||||
|
||||
logger.info(f"🚀 Evaluating News Model on stocks {start_idx} to {end_idx}...")
|
||||
|
||||
# 2. Discover Shocks
|
||||
shocks = self.trainer.discover_shocks(test_tickers, pred_len=pred_len)
|
||||
|
||||
# 3. Associate News & Predict
|
||||
self.trainer.model.eval()
|
||||
predictor = KronosPredictor(self.trainer.model, self.trainer.tokenizer, device=self.device)
|
||||
|
||||
save_dir = os.path.join(SRC_DIR, "exports/evaluation_results")
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
count = 0
|
||||
for shock in shocks:
|
||||
summary = self.trainer.find_reason_and_verify(shock)
|
||||
if not summary:
|
||||
continue
|
||||
|
||||
logger.info(f"📈 Testing shock: {shock['ticker']} on {shock['date']}")
|
||||
|
||||
# Embedding news
|
||||
news_emb = self.trainer.embedder.encode(summary)
|
||||
|
||||
# Prediction
|
||||
h = shock['history']
|
||||
t = shock['target']
|
||||
actuals = t['close'].values[:pred_len]
|
||||
|
||||
x_ts = pd.to_datetime(h['date'])
|
||||
future_dates = pd.date_range(start=x_ts.iloc[-1] + timedelta(days=1), periods=pred_len, freq='B')
|
||||
y_ts = pd.Series(future_dates)
|
||||
|
||||
# A. Base Prediction (No news)
|
||||
p_base = predictor.predict(h, x_ts, y_ts, pred_len=pred_len, news_emb=None, verbose=False)
|
||||
|
||||
# B. News-Aware Prediction
|
||||
p_news = predictor.predict(h, x_ts, y_ts, pred_len=pred_len, news_emb=news_emb, verbose=False)
|
||||
|
||||
# Calculate Improvement
|
||||
b_preds = p_base['close'].values[:len(actuals)]
|
||||
n_preds = p_news['close'].values[:len(actuals)]
|
||||
b_mae = np.mean(np.abs(b_preds - actuals))
|
||||
n_mae = np.mean(np.abs(n_preds - actuals))
|
||||
improvement = (b_mae - n_mae) / (b_mae + 1e-6) * 100
|
||||
|
||||
# C. Visualize
|
||||
try:
|
||||
def to_kp_list(preds_df):
|
||||
points = []
|
||||
for idx, row in preds_df.iterrows():
|
||||
points.append(KLinePoint(
|
||||
date=str(idx)[:10], open=row['open'], high=row['high'],
|
||||
low=row['low'], close=row['close'], volume=row.get('volume', 0)
|
||||
))
|
||||
return points
|
||||
|
||||
forecast_obj = ForecastResult(
|
||||
ticker=shock['ticker'],
|
||||
base_forecast=to_kp_list(p_base),
|
||||
adjusted_forecast=to_kp_list(p_news),
|
||||
rationale=summary
|
||||
)
|
||||
|
||||
chart = VisualizerTools.generate_stock_chart(
|
||||
df=h, ticker=shock['ticker'],
|
||||
title=f"Test Eval: {shock['ticker']} ({shock['date']}) Imp: {improvement:.1f}%",
|
||||
forecast=forecast_obj,
|
||||
ground_truth=t[['date', 'open', 'high', 'low', 'close', 'volume']]
|
||||
)
|
||||
|
||||
safe_date = shock['date'].replace("-", "")
|
||||
filename = f"test_{shock['ticker']}_{safe_date}.html"
|
||||
VisualizerTools.render_chart_to_file(chart, os.path.join(save_dir, filename))
|
||||
|
||||
logger.success(f"📊 Result for {shock['ticker']} saved. Base MAE: {b_mae:.4f}, News MAE: {n_mae:.4f}")
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Visualization failed: {e}")
|
||||
|
||||
logger.info(f"🏁 Finished evaluation. {count} cases visualized in {save_dir}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# If you have a specific model, pass the path here. Otherwise it picks the latest.
|
||||
evaluator = NewsModelEvaluator()
|
||||
evaluator.evaluate_range(start_idx=100, end_idx=200, pred_len=1)
|
||||
Reference in New Issue
Block a user