import os import sys import time import torch import torch.nn as nn import pandas as pd import numpy as np import json import random from loguru import logger from datetime import datetime, timedelta from sentence_transformers import SentenceTransformer from dotenv import load_dotenv load_dotenv(os.path.expanduser("~/.config/opencode/.env")) # Setup paths KRONOS_DIR = os.path.dirname(os.path.abspath(__file__)) SRC_DIR = os.path.dirname(os.path.dirname(KRONOS_DIR)) if SRC_DIR not in sys.path: sys.path.insert(0, SRC_DIR) from ..kronos.model import Kronos, KronosTokenizer, KronosPredictor from ..database_manager import DatabaseManager from ..stock_tools import StockTools from ..search_tools import SearchTools from ..llm.factory import get_model from ..visualizer import VisualizerTools from ..schema.models import ForecastResult, KLinePoint from agno.agent import Agent class AutoSynthesisTrainer: def __init__(self, news_dim=384): self.device = ( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) self.db = DatabaseManager() self.tools = StockTools(self.db) self.searcher = SearchTools(self.db) # Try loading from local cache first to avoid network timeouts model_name = os.getenv( "EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2" ) try: logger.info(f"🔄 Attempting to load {model_name} from local cache...") self.embedder = SentenceTransformer( model_name, device=self.device, local_files_only=True ) logger.success("✅ Model loaded from local cache.") except Exception: logger.warning( "⚠️ Local cache not found or incomplete. Attempting to download..." ) self.embedder = SentenceTransformer(model_name, device=self.device) self.news_dim = news_dim # Try loading from local cache first to avoid network timeouts try: logger.info( "🔄 Attempting to load Kronos and Tokenizer from local cache..." ) self.tokenizer = KronosTokenizer.from_pretrained( "NeoQuasar/Kronos-Tokenizer-base", local_files_only=True ).to(self.device) base_model = Kronos.from_pretrained( "NeoQuasar/Kronos-base", local_files_only=True ) logger.success("✅ Kronos and Tokenizer loaded from local cache.") except Exception: logger.warning( "⚠️ Local Kronos/Tokenizer not found or incomplete. Attempting to download..." ) self.tokenizer = KronosTokenizer.from_pretrained( "NeoQuasar/Kronos-Tokenizer-base" ).to(self.device) base_model = Kronos.from_pretrained("NeoQuasar/Kronos-base") self.model = Kronos( base_model.s1_bits, base_model.s2_bits, base_model.n_layers, base_model.d_model, base_model.n_heads, base_model.ff_dim, base_model.ffn_dropout_p, base_model.attn_dropout_p, base_model.resid_dropout_p, base_model.token_dropout_p, base_model.learn_te, news_dim=self.news_dim, ).to(self.device) self.model.load_state_dict(base_model.state_dict(), strict=False) # LLM for causality verification provider = os.getenv("LLM_PROVIDER", "minimax") model_id = os.getenv("LLM_MODEL", "Qwen") self.llm_agent = Agent(model=get_model(provider, model_id)) def discover_shocks( self, ticker_list, threshold=2.0, limit_per_stock=5, days=365, pred_len=5 ): """1. Find days with significant price movements (Look back 1 year)""" shocks = [] end_date = datetime.now().strftime("%Y-%m-%d") start_date = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d") for ticker in ticker_list: df = self.tools.get_stock_price( ticker, start_date=start_date, end_date=end_date ) if df.empty or len(df) < 60: continue # Look for big moves moves = df[df["change_pct"].abs() > threshold].copy() if moves.empty: continue count = 0 for idx, row in moves.iterrows(): # Ensure we have history before this day AND enough future days for eval date_idx = df.index.get_loc(idx) if date_idx < 50 or date_idx + pred_len > len(df): continue shocks.append( { "ticker": ticker, "date": row["date"], "change": row["change_pct"], "history": df.iloc[date_idx - 50 : date_idx], "target": df.iloc[ date_idx : date_idx + pred_len ], # Now capturing pred_len days } ) count += 1 if count >= limit_per_stock: break logger.info( f"✨ Discovered {len(shocks)} potential price shocks over the last {days} days." ) return shocks def find_reason_and_verify(self, shock): """2. Search for reasons and verify causality using LLM""" ticker_info = self.db.get_stock_by_code(shock["ticker"]) name = ticker_info["name"] if ticker_info else shock["ticker"] date_str = shock["date"] # Try multiple query variations and engines queries = [ f"{name} ({shock['ticker']}) {date_str} 为什么涨跌 原因", f"{name} {date_str} 异动 原因", f"{shock['ticker']} {date_str} 新闻", ] search_results = [] for query in queries: logger.info(f"🔍 Searching for reason: {query}") # Try alternate engines for engine in ["baidu"]: try: results = self.searcher.search_list( query, engine=engine, max_results=3, enrich=False ) if results: search_results = results break except Exception as e: logger.warning(f"Search failed for {query} on {engine}: {e}") if search_results: break time.sleep(random.uniform(1.0, 2.0)) if not search_results: logger.warning( f"⚠️ No search results found for {name} on {date_str} after multiple attempts." ) return None context = "\n".join( [f"- {r['title']}: {r.get('content', '')[:300]}" for r in search_results] ) prompt = f""" 任务:判断以下新闻是否解释了该股票在 {date_str} 的 {shock["change"]:.2f}% 价格变动。 股票:{name} 日期:{date_str} 变动:{shock["change"]:.2f}% 搜索结果: {context} 要求: 1. 该新闻是否在该日期左右发生? 2. 该新闻是否能逻辑上解释这种大幅波动(如财报、利好政策、重组、大环境暴跌等)? 3. 如果是,请总结一段 100 字以内的“核心推动原因”。 4. 返回 JSON: {{"is_causal": true/false, "summary": "原因摘要"}} """ try: res = self.llm_agent.run(prompt) data = json.loads( res.content.replace("```json", "").replace("```", "").strip() ) if data.get("is_causal"): logger.success( f"✅ Verified cause for {name} on {date_str}: {data['summary']}" ) return data["summary"] else: logger.warning( f"❌ Verified cause for {name} on {date_str}: {data['summary']}" ) return None except Exception as e: logger.warning(f"Verification failed: {e}") return None def save_model(self, path=None): """Save the news_proj weights""" if path is None: save_dir = os.path.join(SRC_DIR, "exports/models") os.makedirs(save_dir, exist_ok=True) path = os.path.join( save_dir, f"kronos_news_v1_{datetime.now().strftime('%Y%m%d_%H%M')}.pt" ) # We only really need to save the news_proj part as it's the only one we train torch.save( { "news_proj_state_dict": self.model.news_proj.state_dict(), "news_dim": self.news_dim, "d_model": self.model.d_model, }, path, ) logger.success(f"💾 Model weights saved to {path}") return path def run_synthesis_and_train(self, tickers, pred_len=5): # 1. Discovery shocks = self.discover_shocks(tickers, pred_len=pred_len) print(f"find {len(shocks)} shocks") # 2. News Association & Verification dataset = [] max_news_items = 200 # Limit to 200 news items per session to avoid search bans logger.info( f"🧬 Starting News Association for {len(shocks)} shocks (Max limit: {max_news_items})" ) for i, shock in enumerate(shocks): if len(dataset) >= max_news_items: logger.info("Reached maximum news items limit for this session.") break summary = self.find_reason_and_verify(shock) if summary: # 3. Embedding news emb = self.embedder.encode(summary) dataset.append( { "history": shock["history"], "target": shock["target"], "news_emb": emb, "summary": summary, } ) # Add delay after search with randomness to avoid being blocked if i < len(shocks) - 1: delay = random.uniform(2.0, 4.0) time.sleep(delay) if not dataset: logger.error( "❌ No verified news-price pairs found. Adjust threshold or check if news is available in that period." ) return # 4. Train/Val Split random.seed(42) random.shuffle(dataset) if len(dataset) < 2: train_set = dataset val_set = [] logger.warning( f"⚠️ Only {len(dataset)} sample(s) found. Training on all, skipping validation." ) else: split_idx = max(1, int(len(dataset) * 0.8)) if split_idx >= len(dataset): split_idx = len(dataset) - 1 train_set = dataset[:split_idx] val_set = dataset[split_idx:] logger.info( f"🏗️ Dataset Split: {len(train_set)} samples for training, {len(val_set)} for validation." ) if not train_set: logger.error("❌ No samples for training.") return # 5. Training (Few-shot) optimizer = torch.optim.Adam(self.model.news_proj.parameters(), lr=1e-3) criterion = nn.CrossEntropyLoss() self.model.train() loss_history = [] logger.info(f"🚀 Training for 30 epochs...") for epoch in range(30): total_loss = 0 for item in train_set: optimizer.zero_grad() # Prep Data hist_df = item["history"] # For training, we still focus on the immediate next point (teacher forcing) target_df = item["target"].iloc[:1] hist_raw = hist_df[ ["open", "high", "low", "close", "volume"] ].values.astype(np.float32) hist_raw = np.column_stack([hist_raw, hist_raw[:, 3] * hist_raw[:, 4]]) mean, std = hist_raw.mean(axis=0), hist_raw.std(axis=0) + 1e-5 hist_norm = ( torch.from_numpy((hist_raw - mean) / std) .unsqueeze(0) .to(self.device) ) target_raw = target_df[ ["open", "high", "low", "close", "volume"] ].values.astype(np.float32) target_raw = np.column_stack( [target_raw, target_raw[:, 3] * target_raw[:, 4]] ) target_norm = ( torch.from_numpy((target_raw - mean) / std) .unsqueeze(0) .to(self.device) ) with torch.no_grad(): z_indices = self.tokenizer.encode(hist_norm, half=True) t_indices = self.tokenizer.encode(target_norm, half=True) s1_ids, s2_ids = z_indices[0], z_indices[1] t_s1, t_s2 = t_indices[0], t_indices[1] news_t = torch.from_numpy(item["news_emb"]).unsqueeze(0).to(self.device) s1_logits, s2_logits = self.model( s1_ids, s2_ids, news_emb=news_t, use_teacher_forcing=True, s1_targets=t_s1, ) loss = ( criterion(s1_logits[:, -1, :], t_s1[:, 0]) + criterion(s2_logits[:, -1, :], t_s2[:, 0]) ) / 2 loss.backward() optimizer.step() total_loss += loss.item() avg_epoch_loss = total_loss / max(1, len(train_set)) loss_history.append(avg_epoch_loss) if (epoch + 1) % 10 == 0: logger.info(f"Epoch {epoch + 1} Loss: {avg_epoch_loss:.4f}") # 5.1 Visualize Loss Curve loss_chart = VisualizerTools.generate_loss_chart(loss_history) VisualizerTools.render_chart_to_file( loss_chart, os.path.join(SRC_DIR, "exports/training_results/loss_curve.html"), ) # 5.2 Save final model self.save_model() # 6. Final Evaluation on Validation Set if not val_set: logger.warning("⚠️ Validation set is empty. Skipping statistical analysis.") return logger.info( f"🧪 Final Evaluation: Base vs News-Integrated ({pred_len}-day Window)" ) self.model.eval() predictor = KronosPredictor(self.model, self.tokenizer, device=self.device) base_maes = [] news_maes = [] print("\n" + "=" * 90) print( f"{'Date':<12} | {'Ticker':<8} | {'Base MAE':<15} | {'News MAE':<15} | {'Improvement'}" ) print("-" * 90) for item in val_set: h = item["history"] t = item["target"] actuals = t["close"].values[:pred_len] x_ts = pd.to_datetime(h["date"]) # Future timestamps: handle business days if possible, or just simple offset future_dates = pd.date_range( start=x_ts.iloc[-1] + timedelta(days=1), periods=pred_len, freq="B" ) y_ts = pd.Series(future_dates) # A. Base Prediction p_base = predictor.predict( h, x_ts, y_ts, pred_len=pred_len, news_emb=None, verbose=False ) b_preds = p_base["close"].values[: len(actuals)] # B. News-Aware Prediction p_news = predictor.predict( h, x_ts, y_ts, pred_len=pred_len, news_emb=item["news_emb"], verbose=False, ) n_preds = p_news["close"].values[: len(actuals)] # Calculate MAE over the window b_mae = np.mean(np.abs(b_preds - actuals)) n_mae = np.mean(np.abs(n_preds - actuals)) base_maes.append(b_mae) news_maes.append(n_mae) improvement = (b_mae - n_mae) / (b_mae + 1e-6) * 100 date_str = str(t["date"].values[0])[:10] ticker = h.iloc[-1]["ticker"] if "ticker" in h.columns else "Stock" print( f"{date_str:<12} | {ticker:<8} | {b_mae:<15.4f} | {n_mae:<15.4f} | {improvement:>+7.1f}%" ) # C. Generate Visualization for this case try: # Helper to convert DF to KLinePoints def to_kp_list(preds_df): points = [] for idx, row in preds_df.iterrows(): points.append( KLinePoint( date=str(idx)[:10], open=row["open"], high=row["high"], low=row["low"], close=row["close"], volume=row["volume"] if "volume" in row else 0, ) ) return points forecast_obj = ForecastResult( ticker=ticker, base_forecast=to_kp_list(p_base), adjusted_forecast=to_kp_list(p_news), rationale=item["summary"], ) # Ground truth for visualizer expects a DataFrame with 'date' and 'close' gt_df = t[["date", "open", "high", "low", "close", "volume"]] chart = VisualizerTools.generate_stock_chart( df=h, ticker=ticker, title=f"Training Eval: {ticker} ({date_str}) Improvement: {improvement:.1f}%", forecast=forecast_obj, ground_truth=gt_df, ) safe_date = date_str.replace("-", "") filename = f"eval_{ticker}_{safe_date}.html" VisualizerTools.render_chart_to_file( chart, os.path.join(SRC_DIR, f"exports/training_results/{filename}") ) except Exception as e: logger.error(f"Failed to generate eval chart for {ticker}: {e}") # Summary Statistics avg_base_err = sum(base_maes) / max(1, len(base_maes)) avg_news_err = sum(news_maes) / max(1, len(news_maes)) overall_imp = (avg_base_err - avg_news_err) / (avg_base_err + 1e-6) * 100 print("-" * 90) print( f"{'AVERAGE':<12} | {'-':<8} | {avg_base_err:<15.4f} | {avg_news_err:<15.4f} | {overall_imp:>+7.1f}%" ) print("=" * 90 + "\n") logger.success( f"🏁 Statistical Analysis Complete. Avg Error Reduction ({pred_len}-day): {overall_imp:.2f}%" ) logger.info( f"📊 Visualization results saved to: {os.path.join(SRC_DIR, 'exports/training_results/')}" ) if __name__ == "__main__": trainer = AutoSynthesisTrainer() logger.info("📂 Fetching all stock codes from database...") res = trainer.db.execute_query("SELECT code FROM stock_list") all_tickers = [row["code"] for row in res] if not all_tickers: logger.warning("⚠️ No tickers found in stock_list table. Trying to sync...") trainer.tools._check_and_update_stock_list(force=True) res = trainer.db.execute_query("SELECT code FROM stock_list") all_tickers = [row["code"] for row in res] logger.info(f"🚀 Starting training on potential stocks (1-year scan)...") # 为了演示,我们扫描前 100 个股票,寻找最近一年的冲击点 trainer.run_synthesis_and_train(all_tickers[:100], pred_len=1)