Vortex_QF / app.py
GoshawkVortexAI's picture
Update app.py
e973ba4 verified
İşte tek dosya `app.py`. Gradio (blank) arayüzü, OKX REST'ten BTC/USDT (spot) candle verisi çekme, önişleme, birkaç basit modelden (LightGBM, XGBoost, küçük PyTorch LSTM ve basit RandomForest) oluşan ensemble ile inference yapacak şekilde hazırlanmıştır. Eksik modeller varsa demo (dummy) modeller üretecek; gerçek eğitim için ek adımlar gerekir. Dosya, Spaces/Gradio üzerinde çalışacak şekilde tasarlandı.
python
# app.py
"""
Gradio (blank) tabanlı Hugging Face Space uygulaması.
- OKX REST API'den BTC/USDT (spot) candle verisi çeker
- Teknik göstergeler üretir
- Ensemble: LightGBM, XGBoost, RandomForest (sklearn) + küçük PyTorch LSTM
- Eğer pretrained model dosyaları yoksa küçük demo modeller oluşturur
- Outputs: tahmin (regresyon: next-close), model katkıları, grafikler
Not:
- requirements.txt'de aşağıdakiler olmalı:
gradio, pandas, numpy, requests, ta, scikit-learn, lightgbm, xgboost, torch, matplotlib
- Kullanıcı OKX API anahtarı gerekli değildir (public candles endpoint kullanılıyor).
- Bu dosya tek başına çalışır; ancak ağır paketler (lightgbm, xgboost, torch) Spaces ortamında kurulmadıysa hata verebilir.
"""
import os
import io
import time
import math
import json
import threading
from typing import Tuple, Dict, Any, List
import numpy as np
import pandas as pd
import requests
from datetime import datetime, timedelta, timezone
# Visualization
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
# Technical indicators
try:
import ta
except Exception:
# Minimal fallback implementations if ta isn't installed
ta = None
# ML libs
from sklearn.ensemble import RandomForestRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.base import BaseEstimator, RegressorMixin
# Try import optional libs
HAS_LGB = True
HAS_XGB = True
HAS_TORCH = True
try:
import lightgbm as lgb
except Exception:
HAS_LGB = False
try:
import xgboost as xgb
except Exception:
HAS_XGB = False
try:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
except Exception:
HAS_TORCH = False
# Gradio
import gradio as gr
# -------------------------
# Configuration/Constants
# -------------------------
OKX_BASE = "https://www.okx.com"
# Public candles: GET /api/v5/market/history-candles?instId=BTC-USDT-SWAP&bar=1m&limit=100
# We'll use spot: BTC-USDT
DEFAULT_INSTRUMENT = "BTC-USDT"
DEFAULT_BAR = "1m" # options: 1m, 3m, 5m, 15m, 1H etc.
DEFAULT_LIMIT = 500 # up to 1000 depending on endpoint
# Model filenames (in repo or persisted by training)
MODEL_DIR = "models"
os.makedirs(MODEL_DIR, exist_ok=True)
LGB_MODEL_FILE = os.path.join(MODEL_DIR, "lgb_model.txt")
XGB_MODEL_FILE = os.path.join(MODEL_DIR, "xgb_model.json")
RF_MODEL_FILE = os.path.join(MODEL_DIR, "rf_model.pkl")
LSTM_MODEL_FILE = os.path.join(MODEL_DIR, "lstm_model.pt")
SCALER_FILE = os.path.join(MODEL_DIR, "scaler.npy") # save scaler mean/scale
# Thread-safe model cache
_MODEL_LOCK = threading.Lock()
_MODELS = {}
# -------------------------
# Utilities
# -------------------------
def now_iso():
return datetime.now(timezone.utc).isoformat()
def okx_candles(inst_id: str = DEFAULT_INSTRUMENT, bar: str = DEFAULT_BAR, limit: int = DEFAULT_LIMIT) -> pd.DataFrame:
"""
Fetch recent candle data from OKX public REST API.
Returns DataFrame with columns: ts, open, high, low, close, volume
ts in UTC datetime
"""
url = f"{OKX_BASE}/api/v5/market/history-candles"
params = {"instId": inst_id, "bar": bar, "limit": str(limit)}
resp = requests.get(url, params=params, timeout=15)
resp.raise_for_status()
data = resp.json()
if not data or data.get("code") not in (None, "0", 0):
# OKX returns "code": "0" on success sometimes; be permissive
# If structure unexpected, raise
# Try to parse anyway
pass
cand = data.get("data", [])
if not cand:
# Possibly different field
raise RuntimeError("No candle data returned from OKX")
# OKX returns list of lists: [ts, open, high, low, close, volume, ...]
# timestamp in millis
rows = []
for c in cand:
# According to OKX docs: [ts, open, high, low, close, volume]
ts = int(c[0]) // 1000 if len(str(c[0])) > 10 else int(c[0])
dt = datetime.fromtimestamp(ts, tz=timezone.utc)
rows.append({
"ts": dt,
"open": float(c[1]),
"high": float(c[2]),
"low": float(c[3]),
"close": float(c[4]),
"volume": float(c[5])
})
df = pd.DataFrame(rows)
df = df.sort_values("ts").reset_index(drop=True)
return df
# Minimal TA indicators if `ta` package is not available
def add_technical_indicators(df: pd.DataFrame) -> pd.DataFrame:
df = df.copy()
if ta is not None:
# Use ta to add common indicators
df["rsi"] = ta.momentum.RSIIndicator(df["close"], window=14, fillna=True).rsi()
df["ema12"] = ta.trend.EMAIndicator(df["close"], window=12, fillna=True).ema_indicator()
df["ema26"] = ta.trend.EMAIndicator(df["close"], window=26, fillna=True).ema_indicator()
macd = ta.trend.MACD(df["close"], window_slow=26, window_fast=12, window_sign=9, fillna=True)
df["macd"] = macd.macd()
df["macd_signal"] = macd.macd_signal()
df["bb_high"] = ta.volatility.BollingerBands(df["close"], window=20, fillna=True).bollinger_hband()
df["bb_low"] = ta.volatility.BollingerBands(df["close"], window=20, fillna=True).bollinger_lband()
df["atr"] = ta.volatility.AverageTrueRange(df["high"], df["low"], df["close"], window=14, fillna=True).average_true_range()
else:
# Fallback simple computations
df["rsi"] = simple_rsi(df["close"], window=14)
df["ema12"] = df["close"].ewm(span=12, adjust=False).mean()
df["ema26"] = df["close"].ewm(span=26, adjust=False).mean()
df["macd"] = df["ema12"] - df["ema26"]
df["macd_signal"] = df["macd"].ewm(span=9, adjust=False).mean()
df["bb_mid"] = df["close"].rolling(20).mean()
df["bb_std"] = df["close"].rolling(20).std()
df["bb_high"] = df["bb_mid"] + 2 * df["bb_std"]
df["bb_low"] = df["bb_mid"] - 2 * df["bb_std"]
df["atr"] = simple_atr(df, window=14)
# Fill na
df = df.fillna(method="bfill").fillna(method="ffill").fillna(0.0)
return df
def simple_rsi(series: pd.Series, window: int = 14) -> pd.Series:
delta = series.diff()
up = delta.clip(lower=0)
down = -1 * delta.clip(upper=0)
ma_up = up.ewm(alpha=1/window, adjust=False).mean()
ma_down = down.ewm(alpha=1/window, adjust=False).mean()
rs = ma_up / (ma_down + 1e-8)
rsi = 100 - (100 / (1 + rs))
return rsi.fillna(50.0)
def simple_atr(df: pd.DataFrame, window: int = 14) -> pd.Series:
high_low = df["high"] - df["low"]
high_close = (df["high"] - df["close"].shift()).abs()
low_close = (df["low"] - df["close"].shift()).abs()
tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
atr = tr.ewm(span=window, adjust=False).mean()
return atr.fillna(0.0)
def create_features(df: pd.DataFrame) -> pd.DataFrame:
df = df.copy()
df = add_technical_indicators(df)
# Returns features aligned to each row predicting next row's close
# Feature engineering: returns, log returns, vol, moving averages, ratios
df["return_1"] = df["close"].pct_change().fillna(0.0)
df["log_return_1"] = np.log1p(df["return_1"])
df["vol_5"] = df["close"].rolling(5).std().fillna(0.0)
df["vol_20"] = df["close"].rolling(20).std().fillna(0.0)
df["ma_5"] = df["close"].rolling(5).mean().fillna(method="bfill")
df["ma_20"] = df["close"].rolling(20).mean().fillna(method="bfill")
df["ma_50"] = df["close"].rolling(50).mean().fillna(method="bfill")
# ratio features
df["ma5_div_ma20"] = df["ma_5"] / (df["ma_20"] + 1e-9)
df["ema_diff"] = df["ema12"] - df["ema26"]
# time features
df["ts_unix"] = df["ts"].astype(np.int64) // 10**9
df["hour"] = df["ts"].dt.hour
df["minute"] = df["ts"].dt.minute
# fill remaining na
df = df.fillna(method="bfill").fillna(0.0)
return df
# -------------------------
# Model wrappers and helpers
# -------------------------
class DummyRegressor(BaseEstimator, RegressorMixin):
"""Simple mean predictor used as fallback."""
def fit(self, X, y):
self._mean = np.mean(y) if len(y) else 0.0
return self
def predict(self, X):
return np.full((X.shape[0],), getattr(self, "_mean", 0.0))
def save_numpy(obj: np.ndarray, path: str):
np.save(path, obj)
def load_numpy(path: str) -> np.ndarray:
return np.load(path)
def get_feature_columns() -> List[str]:
cols = [
"open","high","low","close","volume",
"rsi","ema12","ema26","macd","macd_signal","bb_high","bb_low","atr",
"return_1","log_return_1","vol_5","vol_20","ma_5","ma_20","ma_50",
"ma5_div_ma20","ema_diff","ts_unix","hour","minute"
]
return cols
# Model persistence helpers (light, simple)
def load_models() -> Dict[str, Any]:
"""
Try to load pretrained models from MODEL_DIR. If missing, create small demo models.
Returns dict of models and scaler.
"""
with _MODEL_LOCK:
if _MODELS:
return _MODELS
models = {}
scaler = None
# Try load scaler if exists
if os.path.exists(SCALER_FILE):
try:
sc = np.load(SCALER_FILE, allow_pickle=True).item()
scaler = StandardScaler()
scaler.mean_ = sc["mean"]
scaler.scale_ = sc["scale"]
scaler.n_features_in_ = sc["n_in"]
except Exception:
scaler = None
# RandomForest (sklearn)
try:
import joblib
if os.path.exists(RF_MODEL_FILE):
models["rf"] = joblib.load(RF_MODEL_FILE)
else:
raise FileNotFoundError
except Exception:
# create small RF demo
models["rf"] = RandomForestRegressor(n_estimators=10, random_state=42)
# LightGBM
if HAS_LGB and os.path.exists(LGB_MODEL_FILE):
try:
models["lgb"] = lgb.Booster(model_file=LGB_MODEL_FILE)
except Exception:
models["lgb"] = None
else:
models["lgb"] = None if not HAS_LGB else None
# XGBoost
if HAS_XGB and os.path.exists(XGB_MODEL_FILE):
try:
models["xgb"] = xgb.Booster()
models["xgb"].load_model(XGB_MODEL_FILE)
except Exception:
models["xgb"] = None
else:
models["xgb"] = None
# LSTM / PyTorch
if HAS_TORCH and os.path.exists(LSTM_MODEL_FILE):
try:
lstm = torch.load(LSTM_MODEL_FILE, map_location=torch.device("cpu"))
models["lstm"] = lstm
except Exception:
models["lstm"] = None
else:
models["lstm"] = None
# If scaler missing, create a dummy one later in pipeline when training; for inference create StandardScaler default
if scaler is None:
scaler = StandardScaler()
# Create an ensemble wrapper
models["scaler"] = scaler
_MODELS.update(models)
return _MODELS
def save_scaler(scaler: StandardScaler, path: str = SCALER_FILE):
obj = {"mean": scaler.mean_, "scale": scaler.scale_, "n_in": scaler.n_features_in_}
np.save(path, obj)
# -------------------------
# Inference logic
# -------------------------
def prepare_inference_features(df: pd.DataFrame) -> Tuple[np.ndarray, List[str], pd.DataFrame]:
"""
Takes raw candles df, returns (X, feature_cols, df_ready)
X is 2D array for model input, aligned so that each row predicts next close.
"""
df2 = create_features(df)
feat_cols = get_feature_columns()
# Ensure columns present
for c in feat_cols:
if c not in df2.columns:
df2[c] = 0.0
X = df2[feat_cols].values
return X, feat_cols, df2
def predict_ensemble(X: np.ndarray, models: Dict[str, Any]) -> Dict[str, Any]:
"""
Predict next-step close using ensemble of models.
Return dict:
- per_model_preds: {name: scalar_pred}
- ensemble_mean: float
- weighted: float (weights fallback equal)
"""
scaler = models.get("scaler", None)
if scaler is None:
scaler = StandardScaler()
# Use last row features to predict next
if X.ndim == 1:
X_row = X.reshape(1, -1)
else:
X_row = X[-1:, :]
# scale
try:
Xs = scaler.transform(X_row)
except Exception:
# If scaler not fitted, fit on X (fallback)
try:
scaler.fit(X)
save_scaler(scaler)
Xs = scaler.transform(X_row)
except Exception:
Xs = X_row
preds = {}
# RandomForest
rf = models.get("rf", None)
if rf is not None:
try:
p = rf.predict(Xs)[0]
except Exception:
p = float(np.nan)
else:
p = float(np.nan)
preds["rf"] = float(p)
# LightGBM
if HAS_LGB and models.get("lgb", None) is not None:
try:
dmat = lgb.Dataset(Xs, free_raw_data=False)
p = models["lgb"].predict(Xs)[0]
except Exception:
p = float(np.nan)
else:
p = float(np.nan)
preds["lgb"] = float(p)
# XGBoost
if HAS_XGB and models.get("xgb", None) is not None:
try:
dm = xgb.DMatrix(Xs)
p = models["xgb"].predict(dm)[0]
except Exception:
p = float(np.nan)
else:
p = float(np.nan)
preds["xgb"] = float(p)
# LSTM (PyTorch)
if HAS_TORCH and models.get("lstm", None) is not None:
try:
model = models["lstm"]
model.eval()
with torch.no_grad():
t = torch.tensor(X_row, dtype=torch.float32).unsqueeze(0) # shape (1,1,features) if expected
# try both (1,features) or (1,seq,features)
if t.dim() == 3:
out = model(t)
else:
# reshape to (1,1,features)
t2 = t.unsqueeze(1)
out = model(t2)
p = float(out.squeeze().cpu().numpy())
except Exception:
p = float(np.nan)
else:
p = float(np.nan)
preds["lstm"] = float(p)
# If models missing, fallback: use RF or mean of last price as naive
valid_preds = [v for v in preds.values() if not (math.isnan(v) or v is None)]
if not valid_preds:
# fallback naive next-close = last close
naive = float(X_row[0, get_feature_columns().index("close")])
ensemble_mean = naive
weighted = naive
else:
ensemble_mean = float(np.nanmean(valid_preds))
# Simple weighting: prefer models that exist; equal weight
weighted = ensemble_mean
return {
"per_model": preds,
"ensemble_mean": ensemble_mean,
"weighted": weighted
}
# -------------------------
# LSTM simple architecture (for demo)
# -------------------------
if HAS_TORCH:
class SimpleLSTM(nn.Module):
def __init__(self, input_size: int, hidden_size: int = 32, num_layers: int = 1):
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, 1)
def forward(self, x):
# x: (batch, seq_len, input_size)
out, _ = self.lstm(x)
# take last time step
last = out[:, -1, :]
return self.fc(last)
# -------------------------
# Visualization helpers
# -------------------------
def plot_price_and_preds(df: pd.DataFrame, preds: Dict[str, Any]) -> bytes:
fig, ax = plt.subplots(figsize=(9,4))
ax.plot(df["ts"], df["close"], label="close", color="black", lw=1)
# mark last price and ensemble prediction
last_ts = df["ts"].iloc[-1]
last_close = df["close"].iloc[-1]
pred = preds.get("weighted", preds.get("ensemble_mean", last_close))
ax.scatter([last_ts + pd.Timedelta(seconds=1)], [pred], color="red", label="ensemble_pred")
ax.axhline(last_close, linestyle="--", color="gray", alpha=0.6)
ax.set_title("BTC/USDT close and ensemble prediction")
ax.set_xlabel("Time (UTC)")
ax.set_ylabel("Price")
ax.legend()
fig.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format="png")
plt.close(fig)
buf.seek(0)
return buf.read()
def plot_model_contributions(per_model: Dict[str, float]) -> bytes:
names = list(per_model.keys())
vals = [per_model[n] if (not math.isnan(per_model[n])) else 0.0 for n in names]
fig, ax = plt.subplots(figsize=(6,3))
ax.bar(names, vals, color=["#1f77b4","#ff7f0e","#2ca02c","#d62728"])
ax.set_title("Per-model predictions (abs values)")
ax.set_ylabel("Predicted price")
fig.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format="png")
plt.close(fig)
buf.seek(0)
return buf.read()
# -------------------------
# Gradio app components
# -------------------------
def inference_pipeline(inst_id: str = DEFAULT_INSTRUMENT,
bar: str = DEFAULT_BAR,
limit: int = DEFAULT_LIMIT,
show_plot: bool = True):
"""
High-level function called by Gradio. Returns JSON/dicts + image bytes for display.
"""
# Step 1: fetch candles
try:
df = okx_candles(inst_id=inst_id, bar=bar, limit=int(limit))
except Exception as e:
return {"error": f"Failed to fetch candles: {e}"}
# Step 2: prepare features
X, feat_cols, df_ready = prepare_inference_features(df)
# Step 3: load models
models = load_models()
# Step 4: predict
preds = predict_ensemble(X, models)
# Step 5: build result
last_close = float(df_ready["close"].iloc[-1])
ensemble = preds.get("weighted", preds.get("ensemble_mean", last_close))
out = {
"instrument": inst_id,
"bar": bar,
"fetched_candles": int(limit),
"last_ts": df_ready["ts"].iloc[-1].isoformat(),
"last_close": float(last_close),
"ensemble_prediction": float(ensemble),
"per_model": preds.get("per_model", {})
}
# Prepare images
img_price = plot_price_and_preds(df_ready, {"weighted": ensemble})
img_contrib = plot_model_contributions(out["per_model"])
return {
"result": out,
"img_price": img_price,
"img_contrib": img_contrib
}
# Helper to convert bytes to gradio displayable
def bytes_to_pil(b: bytes):
from PIL import Image
buf = io.BytesIO(b)
return Image.open(buf)
# -------------------------
# Gradio layout (blank template)
# -------------------------
def build_gradio_app():
title = "BTC/USDT Price Prediction (OKX REST) — Ensemble Demo"
description = "Fetch recent candles from OKX and predict next close using an ensemble (demo)."
with gr.Blocks(title=title) as demo:
gr.Markdown(f"## {title}")
gr.Markdown(description)
with gr.Row():
with gr.Column(scale=1):
inst_in = gr.Textbox(label="Instrument", value=DEFAULT_INSTRUMENT)
bar_in = gr.Dropdown(label="Candle bar", choices=["1m","3m","5m","15m","1H","4H","1D"], value=DEFAULT_BAR)
limit_in = gr.Slider(label="Limit (number of candles)", minimum=50, maximum=1000, step=50, value=DEFAULT_LIMIT)
run_btn = gr.Button("Run Inference")
refresh_btn = gr.Button("Refresh Models (clear cache)")
info_out = gr.Textbox(label="Info / JSON result", interactive=False)
with gr.Column(scale=2):
price_img = gr.Image(label="Price & Prediction", type="pil")
contrib_img = gr.Image(label="Per-model predictions", type="pil")
# Callbacks
def on_run(inst, bar, limit):
res = inference_pipeline(inst, bar, limit)
if "error" in res:
return "", gr.update(value=None), gr.update(value=None), json.dumps({"error": res["error"]}, indent=2)
out = res["result"]
price_pil = bytes_to_pil(res["img_price"])
contrib_pil = bytes_to_pil(res["img_contrib"])
info_json = json.dumps(out, indent=2, default=str)
return price_pil, contrib_pil, info_json
def on_refresh():
# clear model cache and reload
with _MODEL_LOCK:
_MODELS.clear()
return "Model cache cleared."
run_btn.click(on_run, inputs=[inst_in, bar_in, limit_in], outputs=[price_img, contrib_img, info_out])
refresh_btn.click(on_refresh, inputs=None, outputs=info_out)
gr.Markdown("Notes: This demo uses public OKX market endpoints. For production, validate rate limits and handle API keys for private data. Ensemble models here are demo-friendly; train and persist stronger models for real use.")
return demo
# -------------------------
# If run as app
# -------------------------
if __name__ == "__main__":
app = build_gradio_app()
app.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", ave)