Import 9 alphaear finance skills
- alphaear-deepear-lite: DeepEar Lite API integration - alphaear-logic-visualizer: Draw.io XML finance diagrams - alphaear-news: Real-time finance news (10+ sources) - alphaear-predictor: Kronos time-series forecasting - alphaear-reporter: Professional financial reports - alphaear-search: Web search + local RAG - alphaear-sentiment: FinBERT/LLM sentiment analysis - alphaear-signal-tracker: Signal evolution tracking - alphaear-stock: A-Share/HK/US stock data Updates: - All scripts updated to use universal .env path - Added JINA_API_KEY, LLM_*, DEEPSEEK_API_KEY to .env.example - Updated load_dotenv() to use ~/.config/opencode/.env
This commit is contained in:
219
skills/alphaear-predictor/scripts/kronos_predictor.py
Normal file
219
skills/alphaear-predictor/scripts/kronos_predictor.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import torch
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from loguru import logger
|
||||
from pandas.tseries.offsets import BusinessDay
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(os.path.expanduser("~/.config/opencode/.env"))
|
||||
|
||||
# Fix for Kronos internal imports
|
||||
import sys
|
||||
|
||||
KRONOS_DIR = os.path.join(os.path.dirname(__file__), "predictor")
|
||||
if KRONOS_DIR not in sys.path:
|
||||
sys.path.append(KRONOS_DIR)
|
||||
|
||||
import glob
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from .predictor.model import Kronos, KronosTokenizer, KronosPredictor
|
||||
from .schema.models import KLinePoint
|
||||
|
||||
|
||||
class KronosPredictorUtility:
|
||||
"""
|
||||
Kronos 时序预测工具类
|
||||
负责模型加载、推理以及数据结构转换
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_predictor = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if not cls._instance:
|
||||
cls._instance = super(KronosPredictorUtility, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, device: Optional[str] = None):
|
||||
if self._predictor is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
if not device:
|
||||
device = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
logger.info(f"🔮 Loading Kronos Model on {device}...")
|
||||
|
||||
# 1. Load Embedder (SentenceTransformer)
|
||||
model_name = os.getenv(
|
||||
"EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
|
||||
) # Match training
|
||||
try:
|
||||
self.embedder = SentenceTransformer(
|
||||
model_name, device=device, local_files_only=True
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"⚠️ Local embedder {model_name} not found. Downloading..."
|
||||
)
|
||||
self.embedder = SentenceTransformer(model_name, device=device)
|
||||
|
||||
# 2. Load Kronos Base
|
||||
try:
|
||||
tokenizer = KronosTokenizer.from_pretrained(
|
||||
"NeoQuasar/Kronos-Tokenizer-base", local_files_only=True
|
||||
)
|
||||
model = Kronos.from_pretrained(
|
||||
"NeoQuasar/Kronos-base", local_files_only=True
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"⚠️ Local Kronos cache not found. Attempting to download..."
|
||||
)
|
||||
tokenizer = KronosTokenizer.from_pretrained(
|
||||
"NeoQuasar/Kronos-Tokenizer-base"
|
||||
)
|
||||
model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
|
||||
|
||||
# 3. Load Trained News Projector Weights
|
||||
# Check predictor/exports/models directory
|
||||
models_dir = os.path.join(KRONOS_DIR, "exports/models")
|
||||
model_files = glob.glob(os.path.join(models_dir, "*.pt"))
|
||||
|
||||
if model_files:
|
||||
latest_model = max(model_files, key=os.path.getctime)
|
||||
logger.info(f"🔄 Loading trained news weights from {latest_model}...")
|
||||
try:
|
||||
checkpoint = torch.load(latest_model, map_location=device)
|
||||
# The checkpoint contains 'news_proj_state_dict'
|
||||
if "news_proj_state_dict" in checkpoint:
|
||||
if not hasattr(model, "news_proj") or model.news_proj is None:
|
||||
import torch.nn as nn
|
||||
|
||||
news_dim = checkpoint.get("news_dim", 384)
|
||||
model.news_proj = nn.Linear(news_dim, model.d_model).to(
|
||||
device
|
||||
)
|
||||
|
||||
model.news_proj.load_state_dict(
|
||||
checkpoint["news_proj_state_dict"]
|
||||
)
|
||||
logger.success("✅ News-Aware Projection Layer loaded!")
|
||||
self.has_news_model = True
|
||||
else:
|
||||
logger.warning(
|
||||
"⚠️ Checkpoint found but missing 'news_proj_state_dict'. Using base model."
|
||||
)
|
||||
self.has_news_model = False
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"❌ Failed to load trained weights: {e}. Using base model."
|
||||
)
|
||||
self.has_news_model = False
|
||||
else:
|
||||
logger.info("ℹ️ No trained news models found. Using base model.")
|
||||
self.has_news_model = False
|
||||
|
||||
tokenizer = tokenizer.to(device)
|
||||
model = model.to(device)
|
||||
|
||||
self._predictor = KronosPredictor(
|
||||
model, tokenizer, device=device, max_context=512
|
||||
)
|
||||
logger.info("✅ Kronos Model loaded successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to load Kronos Model: {e}")
|
||||
self._predictor = None
|
||||
self.has_news_model = False
|
||||
|
||||
def get_base_forecast(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
lookback: int = 20,
|
||||
pred_len: int = 5,
|
||||
news_text: Optional[str] = None,
|
||||
) -> List[KLinePoint]:
|
||||
"""
|
||||
生成原始模型预测
|
||||
"""
|
||||
if self._predictor is None:
|
||||
logger.error("Predictor not initialized.")
|
||||
return []
|
||||
|
||||
if len(df) < lookback:
|
||||
logger.warning(
|
||||
f"Insufficient historical data ({len(df)}) for lookback ({lookback})."
|
||||
)
|
||||
return []
|
||||
|
||||
# 获取最后 lookback 条数据
|
||||
x_df = df.iloc[-lookback:].copy()
|
||||
x_timestamp = pd.to_datetime(x_df["date"]) # Ensure datetime
|
||||
last_date = x_timestamp.iloc[-1]
|
||||
|
||||
# 生成未来时间戳
|
||||
future_dates = pd.date_range(
|
||||
start=last_date + BusinessDay(1), periods=pred_len, freq="B"
|
||||
)
|
||||
y_timestamp = pd.Series(future_dates)
|
||||
|
||||
# Embedding News if available
|
||||
news_emb = None
|
||||
if (
|
||||
news_text
|
||||
and getattr(self, "has_news_model", False)
|
||||
and hasattr(self, "embedder")
|
||||
):
|
||||
try:
|
||||
# Truncate to avoid too long text
|
||||
emb = self.embedder.encode(news_text[:1000])
|
||||
news_emb = emb # KronosPredictor expects numpy array or tensor
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to encode news: {e}")
|
||||
|
||||
try:
|
||||
# 预测所需的列
|
||||
cols = ["open", "high", "low", "close", "volume"]
|
||||
pred_df = self._predictor.predict(
|
||||
df=x_df[cols],
|
||||
x_timestamp=x_timestamp,
|
||||
y_timestamp=y_timestamp,
|
||||
pred_len=pred_len,
|
||||
T=1.0,
|
||||
top_p=0.9,
|
||||
sample_count=1,
|
||||
verbose=False,
|
||||
news_emb=news_emb,
|
||||
)
|
||||
|
||||
# 转换为 KLinePoint
|
||||
results = []
|
||||
for date, row in pred_df.iterrows():
|
||||
results.append(
|
||||
KLinePoint(
|
||||
date=date.strftime("%Y-%m-%d"),
|
||||
open=float(row["open"]),
|
||||
high=float(row["high"]),
|
||||
low=float(row["low"]),
|
||||
close=float(row["close"]),
|
||||
volume=float(row["volume"]),
|
||||
)
|
||||
)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"Forecast generation failed: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# Singleton instance for easy access
|
||||
# Usage: predictor = KronosPredictorUtility()
|
||||
Reference in New Issue
Block a user