Hierarchical ML Pipeline for trading predictions:
- Level 0: Attention Models (volatility/flow classification)
- Level 1: Base Models (XGBoost per symbol/timeframe)
- Level 2: Metamodels (XGBoost Stacking + Neural Gating)
Key components:
- src/pipelines/hierarchical_pipeline.py - Main prediction pipeline
- src/models/ - All ML model classes
- src/training/ - Training utilities
- src/api/ - FastAPI endpoints
- scripts/ - Training and evaluation scripts
- config/ - YAML configurations
Note: Trained models (*.joblib, *.pt) are gitignored.
Regenerate with training scripts.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
783 lines
27 KiB
Python
783 lines
27 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Multi-Model Prediction Visualizer
|
|
==================================
|
|
Visualizes predictions from multiple ML models with interactive charts.
|
|
|
|
Uses lightweight-charts for interactive trading charts with:
|
|
- Candlestick price data
|
|
- Range predictions (high/low) from multiple timeframes
|
|
- Movement magnitude predictions
|
|
- AMD phase indicators
|
|
- Technical indicators (RSI, MACD, SAR)
|
|
|
|
Author: ML-Specialist (NEXUS v4.0)
|
|
Date: 2026-01-05
|
|
"""
|
|
|
|
import sys
|
|
from pathlib import Path
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
from dataclasses import dataclass
|
|
from typing import Dict, List, Optional, Tuple, Any
|
|
from datetime import datetime, timedelta
|
|
import joblib
|
|
from loguru import logger
|
|
import psycopg2
|
|
from psycopg2.extras import RealDictCursor
|
|
|
|
try:
|
|
from lightweight_charts import Chart
|
|
HAS_LIGHTWEIGHT_CHARTS = True
|
|
except ImportError:
|
|
HAS_LIGHTWEIGHT_CHARTS = False
|
|
logger.warning("lightweight-charts not installed")
|
|
|
|
try:
|
|
import plotly.graph_objects as go
|
|
from plotly.subplots import make_subplots
|
|
HAS_PLOTLY = True
|
|
except ImportError:
|
|
HAS_PLOTLY = False
|
|
logger.warning("plotly not installed. Install with: pip install plotly")
|
|
|
|
# ML-Engine imports
|
|
from config.reduced_features import generate_reduced_features, get_feature_columns_without_ohlcv
|
|
|
|
|
|
# ============================================================
|
|
# Configuration
|
|
# ============================================================
|
|
|
|
@dataclass
|
|
class VisualizerConfig:
|
|
"""Configuration for the visualizer"""
|
|
# PostgreSQL connection
|
|
db_host: str = "localhost"
|
|
db_port: int = 5432
|
|
db_name: str = "orbiquant_trading"
|
|
db_user: str = "orbiquant_user"
|
|
db_password: str = "orbiquant_dev_2025"
|
|
|
|
# Visualization settings
|
|
chart_height: float = 0.6
|
|
indicator_height: float = 0.2
|
|
prediction_line_width: int = 1
|
|
|
|
# Colors
|
|
high_colors: List[str] = None
|
|
low_colors: List[str] = None
|
|
|
|
def __post_init__(self):
|
|
if self.high_colors is None:
|
|
self.high_colors = ["#006400", "#228B22", "#32CD32", "#7FFF00"] # Greens
|
|
if self.low_colors is None:
|
|
self.low_colors = ["#FF0000", "#B22222", "#8B0000", "#CD5C5C"] # Reds
|
|
|
|
|
|
# ============================================================
|
|
# PostgreSQL Data Loader
|
|
# ============================================================
|
|
|
|
class PostgreSQLDataLoader:
|
|
"""Loads market data from PostgreSQL"""
|
|
|
|
def __init__(self, config: VisualizerConfig = None):
|
|
self.config = config or VisualizerConfig()
|
|
self.connection = None
|
|
self._ticker_cache = {}
|
|
|
|
def connect(self):
|
|
"""Connect to PostgreSQL"""
|
|
if self.connection is None or self.connection.closed:
|
|
self.connection = psycopg2.connect(
|
|
host=self.config.db_host,
|
|
port=self.config.db_port,
|
|
dbname=self.config.db_name,
|
|
user=self.config.db_user,
|
|
password=self.config.db_password
|
|
)
|
|
logger.info(f"Connected to PostgreSQL at {self.config.db_host}:{self.config.db_port}")
|
|
|
|
def close(self):
|
|
"""Close connection"""
|
|
if self.connection and not self.connection.closed:
|
|
self.connection.close()
|
|
|
|
def get_ticker_id(self, symbol: str) -> Optional[int]:
|
|
"""Get ticker ID from symbol"""
|
|
if symbol in self._ticker_cache:
|
|
return self._ticker_cache[symbol]
|
|
|
|
self.connect()
|
|
with self.connection.cursor() as cur:
|
|
cur.execute(
|
|
"SELECT id FROM market_data.tickers WHERE symbol = %s",
|
|
(symbol,)
|
|
)
|
|
result = cur.fetchone()
|
|
if result:
|
|
self._ticker_cache[symbol] = result[0]
|
|
return result[0]
|
|
return None
|
|
|
|
def load_ohlcv(
|
|
self,
|
|
symbol: str,
|
|
timeframe: str,
|
|
start_date: str,
|
|
end_date: str
|
|
) -> pd.DataFrame:
|
|
"""Load OHLCV data from PostgreSQL"""
|
|
|
|
self.connect()
|
|
ticker_id = self.get_ticker_id(symbol)
|
|
|
|
if ticker_id is None:
|
|
logger.error(f"Symbol not found: {symbol}")
|
|
return pd.DataFrame()
|
|
|
|
# Determine table based on date range
|
|
start_year = int(start_date[:4])
|
|
table = f"market_data.ohlcv_{timeframe}_{start_year}"
|
|
|
|
# Always use 5m parent table (has all data)
|
|
# We'll resample to 15m if needed
|
|
table = "market_data.ohlcv_5m"
|
|
|
|
query = f"""
|
|
SELECT
|
|
timestamp as time,
|
|
open,
|
|
high,
|
|
low,
|
|
close,
|
|
volume
|
|
FROM {table}
|
|
WHERE ticker_id = %s
|
|
AND timestamp >= %s
|
|
AND timestamp <= %s
|
|
ORDER BY timestamp ASC
|
|
"""
|
|
|
|
try:
|
|
df = pd.read_sql_query(
|
|
query,
|
|
self.connection,
|
|
params=(ticker_id, start_date, end_date),
|
|
parse_dates=['time']
|
|
)
|
|
|
|
if not df.empty:
|
|
df.set_index('time', inplace=True)
|
|
|
|
# Resample to 15m if requested
|
|
if timeframe == '15m':
|
|
logger.info(f"Resampling {len(df)} 5m records to 15m...")
|
|
df = df.resample('15min').agg({
|
|
'open': 'first',
|
|
'high': 'max',
|
|
'low': 'min',
|
|
'close': 'last',
|
|
'volume': 'sum'
|
|
}).dropna()
|
|
|
|
logger.info(f"Loaded {len(df)} records for {symbol} {timeframe}")
|
|
return df
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load data: {e}")
|
|
return pd.DataFrame()
|
|
|
|
|
|
# ============================================================
|
|
# Multi-Model Prediction Generator
|
|
# ============================================================
|
|
|
|
class MultiModelPredictor:
|
|
"""Generates predictions from multiple models"""
|
|
|
|
def __init__(self, model_dir: str = 'models/reduced_features_models'):
|
|
self.model_dir = Path(model_dir)
|
|
self.models = {}
|
|
self.load_models()
|
|
|
|
def load_models(self):
|
|
"""Load all available models"""
|
|
if not self.model_dir.exists():
|
|
logger.warning(f"Model directory not found: {self.model_dir}")
|
|
return
|
|
|
|
for model_file in self.model_dir.glob("*.joblib"):
|
|
if model_file.name != 'metadata.joblib':
|
|
key = model_file.stem
|
|
self.models[key] = joblib.load(model_file)
|
|
logger.info(f"Loaded model: {key}")
|
|
|
|
def predict(
|
|
self,
|
|
features: pd.DataFrame,
|
|
symbol: str,
|
|
timeframe: str,
|
|
horizon: int = 3
|
|
) -> Dict[str, np.ndarray]:
|
|
"""Get predictions from models"""
|
|
|
|
predictions = {}
|
|
feature_cols = get_feature_columns_without_ohlcv()
|
|
available_cols = [c for c in feature_cols if c in features.columns]
|
|
|
|
if not available_cols:
|
|
return predictions
|
|
|
|
X = features[available_cols].values
|
|
|
|
# High prediction
|
|
key_high = f"{symbol}_{timeframe}_high_h{horizon}"
|
|
if key_high in self.models:
|
|
predictions[f'pred_high_{timeframe}'] = self.models[key_high].predict(X)
|
|
|
|
# Low prediction
|
|
key_low = f"{symbol}_{timeframe}_low_h{horizon}"
|
|
if key_low in self.models:
|
|
predictions[f'pred_low_{timeframe}'] = self.models[key_low].predict(X)
|
|
|
|
return predictions
|
|
|
|
|
|
# ============================================================
|
|
# Prediction Visualizer
|
|
# ============================================================
|
|
|
|
class MultiModelVisualizer:
|
|
"""
|
|
Interactive chart visualizer for multi-model predictions.
|
|
|
|
Features:
|
|
- Candlestick chart with price data
|
|
- Range predictions (high/low) from 5m and 15m models
|
|
- Technical indicators (RSI, MACD, SAR)
|
|
- AMD phase overlay
|
|
"""
|
|
|
|
def __init__(self, config: VisualizerConfig = None):
|
|
self.config = config or VisualizerConfig()
|
|
self.data_loader = PostgreSQLDataLoader(self.config)
|
|
self.predictor = MultiModelPredictor()
|
|
|
|
def prepare_data(
|
|
self,
|
|
symbol: str,
|
|
start_date: str,
|
|
end_date: str,
|
|
timeframe: str = '5m'
|
|
) -> pd.DataFrame:
|
|
"""Prepare data with predictions for visualization"""
|
|
|
|
# Load OHLCV data
|
|
df = self.data_loader.load_ohlcv(symbol, timeframe, start_date, end_date)
|
|
|
|
if df.empty:
|
|
return df
|
|
|
|
# Generate features
|
|
features = generate_reduced_features(df)
|
|
|
|
# Get predictions for this timeframe
|
|
predictions = self.predictor.predict(features, symbol, timeframe)
|
|
|
|
# Add predictions to dataframe
|
|
for key, values in predictions.items():
|
|
# Shift predictions forward (they predict future values)
|
|
df[key] = np.nan
|
|
df.iloc[:-3, df.columns.get_loc(key)] = values[3:] # Shift by horizon
|
|
|
|
# Convert relative predictions to absolute prices
|
|
if 'high' in key:
|
|
df[f'{key}_price'] = df['close'] + df[key]
|
|
elif 'low' in key:
|
|
df[f'{key}_price'] = df['close'] - df[key]
|
|
|
|
# Add 15m predictions if using 5m data
|
|
if timeframe == '5m':
|
|
df_15m = self.data_loader.load_ohlcv(symbol, '15m', start_date, end_date)
|
|
if not df_15m.empty:
|
|
features_15m = generate_reduced_features(df_15m)
|
|
predictions_15m = self.predictor.predict(features_15m, symbol, '15m')
|
|
|
|
# Align 15m predictions to 5m timeframe
|
|
for key, values in predictions_15m.items():
|
|
# Create 15m series aligned to 15m index
|
|
series_15m = pd.Series(values, index=features_15m.index)
|
|
# Reindex to 5m
|
|
df[key] = series_15m.reindex(df.index, method='ffill')
|
|
|
|
if 'high' in key:
|
|
df[f'{key}_price'] = df['close'] + df[key]
|
|
elif 'low' in key:
|
|
df[f'{key}_price'] = df['close'] - df[key]
|
|
|
|
# Add features to df
|
|
for col in features.columns:
|
|
if col not in df.columns and col not in ['open', 'high', 'low', 'close', 'volume']:
|
|
df[col] = features[col]
|
|
|
|
return df
|
|
|
|
def visualize(
|
|
self,
|
|
symbol: str,
|
|
start_date: str,
|
|
end_date: str,
|
|
timeframe: str = '5m',
|
|
show_predictions: bool = True,
|
|
show_indicators: bool = True
|
|
):
|
|
"""
|
|
Create interactive chart with predictions.
|
|
|
|
Args:
|
|
symbol: Trading symbol (e.g., 'XAUUSD')
|
|
start_date: Start date (YYYY-MM-DD)
|
|
end_date: End date (YYYY-MM-DD)
|
|
timeframe: Base timeframe ('5m' or '15m')
|
|
show_predictions: Show prediction lines
|
|
show_indicators: Show technical indicators
|
|
"""
|
|
|
|
if not HAS_LIGHTWEIGHT_CHARTS:
|
|
logger.error("lightweight-charts not installed. Cannot visualize.")
|
|
return
|
|
|
|
logger.info(f"Preparing data for {symbol} {timeframe} from {start_date} to {end_date}")
|
|
|
|
# Prepare data
|
|
df = self.prepare_data(symbol, start_date, end_date, timeframe)
|
|
|
|
if df.empty:
|
|
logger.error("No data to visualize")
|
|
return
|
|
|
|
# Prepare for plotting
|
|
df_plot = df.reset_index()
|
|
df_plot['time'] = df_plot['time'].dt.strftime('%Y-%m-%d %H:%M:%S')
|
|
|
|
# Create chart
|
|
chart = Chart(
|
|
toolbox=True,
|
|
inner_height=0.05,
|
|
title=f"{symbol} - Multi-Model Predictions"
|
|
)
|
|
chart.legend(True, font_size=12)
|
|
chart.topbar.textbox(
|
|
name="SYMBOL",
|
|
initial_text=f"{symbol} {timeframe} | Predictions: 5m & 15m"
|
|
)
|
|
|
|
# Main price chart
|
|
price_chart = chart.create_subchart(
|
|
height=self.config.chart_height,
|
|
width=1,
|
|
sync=True
|
|
)
|
|
price_chart.precision(precision=2 if 'XAU' in symbol else 5)
|
|
price_chart.legend(True, font_size=12)
|
|
|
|
# Set candlestick data
|
|
ohlcv_cols = ['time', 'open', 'high', 'low', 'close', 'volume']
|
|
price_chart.set(df_plot[ohlcv_cols])
|
|
|
|
if show_predictions:
|
|
# Add prediction lines
|
|
pred_high_cols = [c for c in df_plot.columns if 'pred_high' in c and '_price' in c]
|
|
pred_low_cols = [c for c in df_plot.columns if 'pred_low' in c and '_price' in c]
|
|
|
|
# High predictions (greens)
|
|
for i, col in enumerate(pred_high_cols):
|
|
if col in df_plot.columns:
|
|
pred_line = price_chart.create_line(
|
|
col.replace('_price', ''),
|
|
color=self.config.high_colors[i % len(self.config.high_colors)],
|
|
width=self.config.prediction_line_width
|
|
)
|
|
pred_line.set(df_plot[['time', col]].rename(columns={col: col.replace('_price', '')}))
|
|
|
|
# Low predictions (reds)
|
|
for i, col in enumerate(pred_low_cols):
|
|
if col in df_plot.columns:
|
|
pred_line = price_chart.create_line(
|
|
col.replace('_price', ''),
|
|
color=self.config.low_colors[i % len(self.config.low_colors)],
|
|
width=self.config.prediction_line_width
|
|
)
|
|
pred_line.set(df_plot[['time', col]].rename(columns={col: col.replace('_price', '')}))
|
|
|
|
# SAR points
|
|
if 'SAR' in df_plot.columns:
|
|
sar_line = price_chart.create_line('SAR', color='#FF69B4', width=1)
|
|
sar_line.set(df_plot[['time', 'SAR']])
|
|
|
|
if show_indicators:
|
|
# RSI subchart
|
|
if 'RSI' in df_plot.columns:
|
|
rsi_chart = chart.create_subchart(height=0.15, width=1, sync=True)
|
|
rsi_chart.legend(True, font_size=10)
|
|
|
|
rsi_line = rsi_chart.create_line('RSI', color='#20B2AA', width=1)
|
|
rsi_line.set(df_plot[['time', 'RSI']])
|
|
|
|
# Overbought/oversold levels
|
|
ob_data = df_plot[['time']].copy()
|
|
ob_data['overbought'] = 70
|
|
os_data = df_plot[['time']].copy()
|
|
os_data['oversold'] = 30
|
|
|
|
ob_line = rsi_chart.create_line('overbought', color='#DC143C', width=1)
|
|
ob_line.set(ob_data)
|
|
os_line = rsi_chart.create_line('oversold', color='#32CD32', width=1)
|
|
os_line.set(os_data)
|
|
|
|
# CMF subchart
|
|
if 'CMF' in df_plot.columns:
|
|
cmf_chart = chart.create_subchart(height=0.1, width=1, sync=True)
|
|
cmf_chart.legend(True, font_size=10)
|
|
|
|
cmf_line = cmf_chart.create_line('CMF', color='#9370DB', width=1)
|
|
cmf_line.set(df_plot[['time', 'CMF']])
|
|
|
|
# Zero line
|
|
zero_data = df_plot[['time']].copy()
|
|
zero_data['zero'] = 0
|
|
zero_line = cmf_chart.create_line('zero', color='#808080', width=1)
|
|
zero_line.set(zero_data)
|
|
|
|
logger.info("Displaying chart...")
|
|
chart.show(block=True)
|
|
|
|
def visualize_plotly(
|
|
self,
|
|
symbol: str,
|
|
start_date: str,
|
|
end_date: str,
|
|
timeframe: str = '5m',
|
|
show_predictions: bool = True,
|
|
show_indicators: bool = True,
|
|
output_file: str = None
|
|
) -> str:
|
|
"""
|
|
Create interactive HTML chart with Plotly (fallback for environments without GTK/QT).
|
|
|
|
Args:
|
|
symbol: Trading symbol (e.g., 'XAUUSD')
|
|
start_date: Start date (YYYY-MM-DD)
|
|
end_date: End date (YYYY-MM-DD)
|
|
timeframe: Base timeframe ('5m' or '15m')
|
|
show_predictions: Show prediction lines
|
|
show_indicators: Show technical indicators
|
|
output_file: Output HTML file path (auto-generated if None)
|
|
|
|
Returns:
|
|
Path to generated HTML file
|
|
"""
|
|
|
|
if not HAS_PLOTLY:
|
|
logger.error("plotly not installed. Cannot visualize.")
|
|
return None
|
|
|
|
logger.info(f"Preparing Plotly chart for {symbol} {timeframe} from {start_date} to {end_date}")
|
|
|
|
# Prepare data
|
|
df = self.prepare_data(symbol, start_date, end_date, timeframe)
|
|
|
|
if df.empty:
|
|
logger.error("No data to visualize")
|
|
return None
|
|
|
|
# Determine number of rows for subplots
|
|
n_rows = 1
|
|
row_heights = [0.6]
|
|
if show_indicators:
|
|
if 'RSI' in df.columns:
|
|
n_rows += 1
|
|
row_heights.append(0.2)
|
|
if 'CMF' in df.columns:
|
|
n_rows += 1
|
|
row_heights.append(0.2)
|
|
|
|
# Normalize heights
|
|
total = sum(row_heights)
|
|
row_heights = [h/total for h in row_heights]
|
|
|
|
# Create subplots
|
|
fig = make_subplots(
|
|
rows=n_rows, cols=1,
|
|
shared_xaxes=True,
|
|
vertical_spacing=0.03,
|
|
row_heights=row_heights,
|
|
subplot_titles=[f"{symbol} {timeframe} - Multi-Model Predictions"] +
|
|
(['RSI'] if 'RSI' in df.columns and show_indicators else []) +
|
|
(['CMF'] if 'CMF' in df.columns and show_indicators else [])
|
|
)
|
|
|
|
# Candlestick chart
|
|
df_plot = df.reset_index()
|
|
fig.add_trace(
|
|
go.Candlestick(
|
|
x=df_plot['time'],
|
|
open=df_plot['open'],
|
|
high=df_plot['high'],
|
|
low=df_plot['low'],
|
|
close=df_plot['close'],
|
|
name='Price'
|
|
),
|
|
row=1, col=1
|
|
)
|
|
|
|
if show_predictions:
|
|
# High predictions (green shades)
|
|
pred_high_cols = [c for c in df.columns if 'pred_high' in c and '_price' in c]
|
|
for i, col in enumerate(pred_high_cols):
|
|
color = self.config.high_colors[i % len(self.config.high_colors)]
|
|
label = col.replace('_price', '')
|
|
fig.add_trace(
|
|
go.Scatter(
|
|
x=df_plot['time'],
|
|
y=df_plot[col],
|
|
mode='lines',
|
|
name=label,
|
|
line=dict(color=color, width=1),
|
|
opacity=0.7
|
|
),
|
|
row=1, col=1
|
|
)
|
|
|
|
# Low predictions (red shades)
|
|
pred_low_cols = [c for c in df.columns if 'pred_low' in c and '_price' in c]
|
|
for i, col in enumerate(pred_low_cols):
|
|
color = self.config.low_colors[i % len(self.config.low_colors)]
|
|
label = col.replace('_price', '')
|
|
fig.add_trace(
|
|
go.Scatter(
|
|
x=df_plot['time'],
|
|
y=df_plot[col],
|
|
mode='lines',
|
|
name=label,
|
|
line=dict(color=color, width=1),
|
|
opacity=0.7
|
|
),
|
|
row=1, col=1
|
|
)
|
|
|
|
# SAR points
|
|
if 'SAR' in df.columns:
|
|
fig.add_trace(
|
|
go.Scatter(
|
|
x=df_plot['time'],
|
|
y=df_plot['SAR'],
|
|
mode='markers',
|
|
name='SAR',
|
|
marker=dict(color='#FF69B4', size=3)
|
|
),
|
|
row=1, col=1
|
|
)
|
|
|
|
current_row = 2
|
|
if show_indicators:
|
|
# RSI subplot
|
|
if 'RSI' in df.columns:
|
|
fig.add_trace(
|
|
go.Scatter(
|
|
x=df_plot['time'],
|
|
y=df_plot['RSI'],
|
|
mode='lines',
|
|
name='RSI',
|
|
line=dict(color='#20B2AA', width=1)
|
|
),
|
|
row=current_row, col=1
|
|
)
|
|
# Overbought/oversold lines
|
|
fig.add_hline(y=70, line_dash="dash", line_color="red", opacity=0.5, row=current_row, col=1)
|
|
fig.add_hline(y=30, line_dash="dash", line_color="green", opacity=0.5, row=current_row, col=1)
|
|
current_row += 1
|
|
|
|
# CMF subplot
|
|
if 'CMF' in df.columns:
|
|
fig.add_trace(
|
|
go.Scatter(
|
|
x=df_plot['time'],
|
|
y=df_plot['CMF'],
|
|
mode='lines',
|
|
name='CMF',
|
|
line=dict(color='#9370DB', width=1)
|
|
),
|
|
row=current_row, col=1
|
|
)
|
|
fig.add_hline(y=0, line_dash="dash", line_color="gray", opacity=0.5, row=current_row, col=1)
|
|
|
|
# Update layout
|
|
fig.update_layout(
|
|
title=f"{symbol} {timeframe} - Multi-Model Predictions ({start_date} to {end_date})",
|
|
xaxis_title="Time",
|
|
yaxis_title="Price",
|
|
template="plotly_dark",
|
|
height=800,
|
|
showlegend=True,
|
|
legend=dict(
|
|
yanchor="top",
|
|
y=0.99,
|
|
xanchor="left",
|
|
x=0.01
|
|
),
|
|
xaxis_rangeslider_visible=False
|
|
)
|
|
|
|
# Generate output filename
|
|
if output_file is None:
|
|
output_dir = Path(__file__).parent.parent / 'reports' / 'charts'
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
output_file = str(output_dir / f"predictions_{symbol}_{timeframe}_{timestamp}.html")
|
|
|
|
# Save HTML
|
|
fig.write_html(output_file)
|
|
logger.info(f"Chart saved to: {output_file}")
|
|
|
|
return output_file
|
|
|
|
def visualize_backtest_results(
|
|
self,
|
|
df: pd.DataFrame,
|
|
trades: List[Dict],
|
|
symbol: str
|
|
):
|
|
"""
|
|
Visualize backtest results with trade markers.
|
|
|
|
Args:
|
|
df: DataFrame with OHLCV and predictions
|
|
trades: List of trade dictionaries with entry/exit info
|
|
symbol: Trading symbol
|
|
"""
|
|
|
|
if not HAS_LIGHTWEIGHT_CHARTS:
|
|
logger.error("lightweight-charts not installed")
|
|
return
|
|
|
|
df_plot = df.reset_index()
|
|
df_plot['time'] = df_plot['time'].dt.strftime('%Y-%m-%d %H:%M:%S')
|
|
|
|
chart = Chart(toolbox=True, title=f"{symbol} - Backtest Results")
|
|
chart.legend(True)
|
|
|
|
# Main chart
|
|
price_chart = chart.create_subchart(height=0.7, width=1, sync=True)
|
|
price_chart.precision(precision=2 if 'XAU' in symbol else 5)
|
|
price_chart.set(df_plot[['time', 'open', 'high', 'low', 'close', 'volume']])
|
|
|
|
# Add trade markers
|
|
for trade in trades:
|
|
# Entry marker
|
|
entry_time = trade.get('entry_time')
|
|
entry_price = trade.get('entry_price')
|
|
direction = trade.get('direction', 'LONG')
|
|
|
|
if entry_time and entry_price:
|
|
color = '#00FF00' if direction == 'LONG' else '#FF0000'
|
|
marker_type = 'arrow_up' if direction == 'LONG' else 'arrow_down'
|
|
price_chart.marker(
|
|
time=entry_time.strftime('%Y-%m-%d %H:%M:%S'),
|
|
position='below' if direction == 'LONG' else 'above',
|
|
color=color,
|
|
shape=marker_type,
|
|
text=f"{direction} Entry"
|
|
)
|
|
|
|
# Exit marker
|
|
exit_time = trade.get('exit_time')
|
|
exit_price = trade.get('exit_price')
|
|
pnl = trade.get('pnl', 0)
|
|
|
|
if exit_time and exit_price:
|
|
color = '#00FF00' if pnl > 0 else '#FF0000'
|
|
price_chart.marker(
|
|
time=exit_time.strftime('%Y-%m-%d %H:%M:%S'),
|
|
position='above' if direction == 'LONG' else 'below',
|
|
color=color,
|
|
shape='circle',
|
|
text=f"Exit ${pnl:+.2f}"
|
|
)
|
|
|
|
# Equity curve subchart
|
|
equity_chart = chart.create_subchart(height=0.2, width=1, sync=True)
|
|
equity_chart.legend(True)
|
|
|
|
# Calculate cumulative equity from trades
|
|
equity = [1000.0] # Starting capital
|
|
times = [df_plot['time'].iloc[0]]
|
|
|
|
for trade in sorted(trades, key=lambda x: x.get('exit_time', datetime.now())):
|
|
if trade.get('exit_time'):
|
|
equity.append(equity[-1] + trade.get('pnl', 0))
|
|
times.append(trade['exit_time'].strftime('%Y-%m-%d %H:%M:%S'))
|
|
|
|
equity_df = pd.DataFrame({'time': times, 'equity': equity})
|
|
equity_line = equity_chart.create_line('Equity', color='#4169E1', width=2)
|
|
equity_line.set(equity_df)
|
|
|
|
chart.show(block=True)
|
|
|
|
|
|
# ============================================================
|
|
# Main Execution
|
|
# ============================================================
|
|
|
|
def main():
|
|
"""Main function to demonstrate visualization"""
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description='Visualize multi-model predictions')
|
|
parser.add_argument('--symbol', type=str, default='XAUUSD', help='Trading symbol')
|
|
parser.add_argument('--timeframe', type=str, default='5m', help='Timeframe (5m or 15m)')
|
|
parser.add_argument('--start', type=str, default='2025-01-01', help='Start date')
|
|
parser.add_argument('--end', type=str, default='2025-01-31', help='End date')
|
|
parser.add_argument('--no-predictions', action='store_true', help='Hide predictions')
|
|
parser.add_argument('--no-indicators', action='store_true', help='Hide indicators')
|
|
parser.add_argument('--output', type=str, default=None, help='Output HTML file path')
|
|
parser.add_argument('--use-lightweight', action='store_true', help='Use lightweight-charts (requires GTK/QT)')
|
|
|
|
args = parser.parse_args()
|
|
|
|
config = VisualizerConfig()
|
|
visualizer = MultiModelVisualizer(config)
|
|
|
|
if args.use_lightweight and HAS_LIGHTWEIGHT_CHARTS:
|
|
visualizer.visualize(
|
|
symbol=args.symbol,
|
|
start_date=args.start,
|
|
end_date=args.end,
|
|
timeframe=args.timeframe,
|
|
show_predictions=not args.no_predictions,
|
|
show_indicators=not args.no_indicators
|
|
)
|
|
else:
|
|
# Use Plotly (default - works in WSL)
|
|
output_file = visualizer.visualize_plotly(
|
|
symbol=args.symbol,
|
|
start_date=args.start,
|
|
end_date=args.end,
|
|
timeframe=args.timeframe,
|
|
show_predictions=not args.no_predictions,
|
|
show_indicators=not args.no_indicators,
|
|
output_file=args.output
|
|
)
|
|
if output_file:
|
|
print(f"Chart saved: {output_file}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|