Files
opencode-skill/skills/alphaear-predictor/scripts/kronos_predictor.py
Kunthawat Greethong e4d41e3ae5 Move .env into skills/ for easy install
- Added skills/_env_loader.py - shared env loader for all scripts
- Updated 17 scripts to use load_unified_env()
- Updated install-skills.sh to copy .env into skills/
- Updated README with simpler OpenClaw install instructions
- .env in skills/ is gitignored (credentials stay private)
2026-03-27 17:49:20 +07:00

219 lines
7.6 KiB
Python
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import pandas as pd
import numpy as np
from datetime import datetime
from typing import List, Optional
from loguru import logger
from pandas.tseries.offsets import BusinessDay
import os
import sys
KRONOS_DIR = os.path.join(os.path.dirname(__file__), "predictor")
if KRONOS_DIR not in sys.path:
sys.path.append(KRONOS_DIR)
from skills._env_loader import load_unified_env
load_unified_env()
import glob
from sentence_transformers import SentenceTransformer
from .predictor.model import Kronos, KronosTokenizer, KronosPredictor
from .schema.models import KLinePoint
class KronosPredictorUtility:
"""
Kronos 时序预测工具类
负责模型加载、推理以及数据结构转换
"""
_instance = None
_predictor = None
def __new__(cls, *args, **kwargs):
if not cls._instance:
cls._instance = super(KronosPredictorUtility, cls).__new__(cls)
return cls._instance
def __init__(self, device: Optional[str] = None):
if self._predictor is not None:
return
try:
if not device:
device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
logger.info(f"🔮 Loading Kronos Model on {device}...")
# 1. Load Embedder (SentenceTransformer)
model_name = os.getenv(
"EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
) # Match training
try:
self.embedder = SentenceTransformer(
model_name, device=device, local_files_only=True
)
except Exception:
logger.warning(
f"⚠️ Local embedder {model_name} not found. Downloading..."
)
self.embedder = SentenceTransformer(model_name, device=device)
# 2. Load Kronos Base
try:
tokenizer = KronosTokenizer.from_pretrained(
"NeoQuasar/Kronos-Tokenizer-base", local_files_only=True
)
model = Kronos.from_pretrained(
"NeoQuasar/Kronos-base", local_files_only=True
)
except Exception:
logger.warning(
"⚠️ Local Kronos cache not found. Attempting to download..."
)
tokenizer = KronosTokenizer.from_pretrained(
"NeoQuasar/Kronos-Tokenizer-base"
)
model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
# 3. Load Trained News Projector Weights
# Check predictor/exports/models directory
models_dir = os.path.join(KRONOS_DIR, "exports/models")
model_files = glob.glob(os.path.join(models_dir, "*.pt"))
if model_files:
latest_model = max(model_files, key=os.path.getctime)
logger.info(f"🔄 Loading trained news weights from {latest_model}...")
try:
checkpoint = torch.load(latest_model, map_location=device)
# The checkpoint contains 'news_proj_state_dict'
if "news_proj_state_dict" in checkpoint:
if not hasattr(model, "news_proj") or model.news_proj is None:
import torch.nn as nn
news_dim = checkpoint.get("news_dim", 384)
model.news_proj = nn.Linear(news_dim, model.d_model).to(
device
)
model.news_proj.load_state_dict(
checkpoint["news_proj_state_dict"]
)
logger.success("✅ News-Aware Projection Layer loaded!")
self.has_news_model = True
else:
logger.warning(
"⚠️ Checkpoint found but missing 'news_proj_state_dict'. Using base model."
)
self.has_news_model = False
except Exception as e:
logger.error(
f"❌ Failed to load trained weights: {e}. Using base model."
)
self.has_news_model = False
else:
logger.info(" No trained news models found. Using base model.")
self.has_news_model = False
tokenizer = tokenizer.to(device)
model = model.to(device)
self._predictor = KronosPredictor(
model, tokenizer, device=device, max_context=512
)
logger.info("✅ Kronos Model loaded successfully.")
except Exception as e:
logger.error(f"❌ Failed to load Kronos Model: {e}")
self._predictor = None
self.has_news_model = False
def get_base_forecast(
self,
df: pd.DataFrame,
lookback: int = 20,
pred_len: int = 5,
news_text: Optional[str] = None,
) -> List[KLinePoint]:
"""
生成原始模型预测
"""
if self._predictor is None:
logger.error("Predictor not initialized.")
return []
if len(df) < lookback:
logger.warning(
f"Insufficient historical data ({len(df)}) for lookback ({lookback})."
)
return []
# 获取最后 lookback 条数据
x_df = df.iloc[-lookback:].copy()
x_timestamp = pd.to_datetime(x_df["date"]) # Ensure datetime
last_date = x_timestamp.iloc[-1]
# 生成未来时间戳
future_dates = pd.date_range(
start=last_date + BusinessDay(1), periods=pred_len, freq="B"
)
y_timestamp = pd.Series(future_dates)
# Embedding News if available
news_emb = None
if (
news_text
and getattr(self, "has_news_model", False)
and hasattr(self, "embedder")
):
try:
# Truncate to avoid too long text
emb = self.embedder.encode(news_text[:1000])
news_emb = emb # KronosPredictor expects numpy array or tensor
except Exception as e:
logger.error(f"Failed to encode news: {e}")
try:
# 预测所需的列
cols = ["open", "high", "low", "close", "volume"]
pred_df = self._predictor.predict(
df=x_df[cols],
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=pred_len,
T=1.0,
top_p=0.9,
sample_count=1,
verbose=False,
news_emb=news_emb,
)
# 转换为 KLinePoint
results = []
for date, row in pred_df.iterrows():
results.append(
KLinePoint(
date=date.strftime("%Y-%m-%d"),
open=float(row["open"]),
high=float(row["high"]),
low=float(row["low"]),
close=float(row["close"]),
volume=float(row["volume"]),
)
)
return results
except Exception as e:
logger.error(f"Forecast generation failed: {e}")
return []
# Singleton instance for easy access
# Usage: predictor = KronosPredictorUtility()