- alphaear-deepear-lite: DeepEar Lite API integration - alphaear-logic-visualizer: Draw.io XML finance diagrams - alphaear-news: Real-time finance news (10+ sources) - alphaear-predictor: Kronos time-series forecasting - alphaear-reporter: Professional financial reports - alphaear-search: Web search + local RAG - alphaear-sentiment: FinBERT/LLM sentiment analysis - alphaear-signal-tracker: Signal evolution tracking - alphaear-stock: A-Share/HK/US stock data Updates: - All scripts updated to use universal .env path - Added JINA_API_KEY, LLM_*, DEEPSEEK_API_KEY to .env.example - Updated load_dotenv() to use ~/.config/opencode/.env
220 lines
7.6 KiB
Python
220 lines
7.6 KiB
Python
import torch
|
||
import pandas as pd
|
||
import numpy as np
|
||
from datetime import datetime
|
||
from typing import List, Optional
|
||
from loguru import logger
|
||
from pandas.tseries.offsets import BusinessDay
|
||
import os
|
||
from dotenv import load_dotenv
|
||
|
||
load_dotenv(os.path.expanduser("~/.config/opencode/.env"))
|
||
|
||
# Fix for Kronos internal imports
|
||
import sys
|
||
|
||
KRONOS_DIR = os.path.join(os.path.dirname(__file__), "predictor")
|
||
if KRONOS_DIR not in sys.path:
|
||
sys.path.append(KRONOS_DIR)
|
||
|
||
import glob
|
||
from sentence_transformers import SentenceTransformer
|
||
|
||
from .predictor.model import Kronos, KronosTokenizer, KronosPredictor
|
||
from .schema.models import KLinePoint
|
||
|
||
|
||
class KronosPredictorUtility:
|
||
"""
|
||
Kronos 时序预测工具类
|
||
负责模型加载、推理以及数据结构转换
|
||
"""
|
||
|
||
_instance = None
|
||
_predictor = None
|
||
|
||
def __new__(cls, *args, **kwargs):
|
||
if not cls._instance:
|
||
cls._instance = super(KronosPredictorUtility, cls).__new__(cls)
|
||
return cls._instance
|
||
|
||
def __init__(self, device: Optional[str] = None):
|
||
if self._predictor is not None:
|
||
return
|
||
|
||
try:
|
||
if not device:
|
||
device = (
|
||
"cuda"
|
||
if torch.cuda.is_available()
|
||
else "mps"
|
||
if torch.backends.mps.is_available()
|
||
else "cpu"
|
||
)
|
||
|
||
logger.info(f"🔮 Loading Kronos Model on {device}...")
|
||
|
||
# 1. Load Embedder (SentenceTransformer)
|
||
model_name = os.getenv(
|
||
"EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
|
||
) # Match training
|
||
try:
|
||
self.embedder = SentenceTransformer(
|
||
model_name, device=device, local_files_only=True
|
||
)
|
||
except Exception:
|
||
logger.warning(
|
||
f"⚠️ Local embedder {model_name} not found. Downloading..."
|
||
)
|
||
self.embedder = SentenceTransformer(model_name, device=device)
|
||
|
||
# 2. Load Kronos Base
|
||
try:
|
||
tokenizer = KronosTokenizer.from_pretrained(
|
||
"NeoQuasar/Kronos-Tokenizer-base", local_files_only=True
|
||
)
|
||
model = Kronos.from_pretrained(
|
||
"NeoQuasar/Kronos-base", local_files_only=True
|
||
)
|
||
except Exception:
|
||
logger.warning(
|
||
"⚠️ Local Kronos cache not found. Attempting to download..."
|
||
)
|
||
tokenizer = KronosTokenizer.from_pretrained(
|
||
"NeoQuasar/Kronos-Tokenizer-base"
|
||
)
|
||
model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
|
||
|
||
# 3. Load Trained News Projector Weights
|
||
# Check predictor/exports/models directory
|
||
models_dir = os.path.join(KRONOS_DIR, "exports/models")
|
||
model_files = glob.glob(os.path.join(models_dir, "*.pt"))
|
||
|
||
if model_files:
|
||
latest_model = max(model_files, key=os.path.getctime)
|
||
logger.info(f"🔄 Loading trained news weights from {latest_model}...")
|
||
try:
|
||
checkpoint = torch.load(latest_model, map_location=device)
|
||
# The checkpoint contains 'news_proj_state_dict'
|
||
if "news_proj_state_dict" in checkpoint:
|
||
if not hasattr(model, "news_proj") or model.news_proj is None:
|
||
import torch.nn as nn
|
||
|
||
news_dim = checkpoint.get("news_dim", 384)
|
||
model.news_proj = nn.Linear(news_dim, model.d_model).to(
|
||
device
|
||
)
|
||
|
||
model.news_proj.load_state_dict(
|
||
checkpoint["news_proj_state_dict"]
|
||
)
|
||
logger.success("✅ News-Aware Projection Layer loaded!")
|
||
self.has_news_model = True
|
||
else:
|
||
logger.warning(
|
||
"⚠️ Checkpoint found but missing 'news_proj_state_dict'. Using base model."
|
||
)
|
||
self.has_news_model = False
|
||
except Exception as e:
|
||
logger.error(
|
||
f"❌ Failed to load trained weights: {e}. Using base model."
|
||
)
|
||
self.has_news_model = False
|
||
else:
|
||
logger.info("ℹ️ No trained news models found. Using base model.")
|
||
self.has_news_model = False
|
||
|
||
tokenizer = tokenizer.to(device)
|
||
model = model.to(device)
|
||
|
||
self._predictor = KronosPredictor(
|
||
model, tokenizer, device=device, max_context=512
|
||
)
|
||
logger.info("✅ Kronos Model loaded successfully.")
|
||
except Exception as e:
|
||
logger.error(f"❌ Failed to load Kronos Model: {e}")
|
||
self._predictor = None
|
||
self.has_news_model = False
|
||
|
||
def get_base_forecast(
|
||
self,
|
||
df: pd.DataFrame,
|
||
lookback: int = 20,
|
||
pred_len: int = 5,
|
||
news_text: Optional[str] = None,
|
||
) -> List[KLinePoint]:
|
||
"""
|
||
生成原始模型预测
|
||
"""
|
||
if self._predictor is None:
|
||
logger.error("Predictor not initialized.")
|
||
return []
|
||
|
||
if len(df) < lookback:
|
||
logger.warning(
|
||
f"Insufficient historical data ({len(df)}) for lookback ({lookback})."
|
||
)
|
||
return []
|
||
|
||
# 获取最后 lookback 条数据
|
||
x_df = df.iloc[-lookback:].copy()
|
||
x_timestamp = pd.to_datetime(x_df["date"]) # Ensure datetime
|
||
last_date = x_timestamp.iloc[-1]
|
||
|
||
# 生成未来时间戳
|
||
future_dates = pd.date_range(
|
||
start=last_date + BusinessDay(1), periods=pred_len, freq="B"
|
||
)
|
||
y_timestamp = pd.Series(future_dates)
|
||
|
||
# Embedding News if available
|
||
news_emb = None
|
||
if (
|
||
news_text
|
||
and getattr(self, "has_news_model", False)
|
||
and hasattr(self, "embedder")
|
||
):
|
||
try:
|
||
# Truncate to avoid too long text
|
||
emb = self.embedder.encode(news_text[:1000])
|
||
news_emb = emb # KronosPredictor expects numpy array or tensor
|
||
except Exception as e:
|
||
logger.error(f"Failed to encode news: {e}")
|
||
|
||
try:
|
||
# 预测所需的列
|
||
cols = ["open", "high", "low", "close", "volume"]
|
||
pred_df = self._predictor.predict(
|
||
df=x_df[cols],
|
||
x_timestamp=x_timestamp,
|
||
y_timestamp=y_timestamp,
|
||
pred_len=pred_len,
|
||
T=1.0,
|
||
top_p=0.9,
|
||
sample_count=1,
|
||
verbose=False,
|
||
news_emb=news_emb,
|
||
)
|
||
|
||
# 转换为 KLinePoint
|
||
results = []
|
||
for date, row in pred_df.iterrows():
|
||
results.append(
|
||
KLinePoint(
|
||
date=date.strftime("%Y-%m-%d"),
|
||
open=float(row["open"]),
|
||
high=float(row["high"]),
|
||
low=float(row["low"]),
|
||
close=float(row["close"]),
|
||
volume=float(row["volume"]),
|
||
)
|
||
)
|
||
return results
|
||
except Exception as e:
|
||
logger.error(f"Forecast generation failed: {e}")
|
||
return []
|
||
|
||
|
||
# Singleton instance for easy access
|
||
# Usage: predictor = KronosPredictorUtility()
|