trading-platform-ml-engine-v2/scripts/visualize_predictions.py
rckrdmrd 75c4d07690 feat: Initial commit - ML Engine codebase
Hierarchical ML Pipeline for trading predictions:
- Level 0: Attention Models (volatility/flow classification)
- Level 1: Base Models (XGBoost per symbol/timeframe)
- Level 2: Metamodels (XGBoost Stacking + Neural Gating)

Key components:
- src/pipelines/hierarchical_pipeline.py - Main prediction pipeline
- src/models/ - All ML model classes
- src/training/ - Training utilities
- src/api/ - FastAPI endpoints
- scripts/ - Training and evaluation scripts
- config/ - YAML configurations

Note: Trained models (*.joblib, *.pt) are gitignored.
      Regenerate with training scripts.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 04:27:40 -06:00

783 lines
27 KiB
Python

#!/usr/bin/env python3
"""
Multi-Model Prediction Visualizer
==================================
Visualizes predictions from multiple ML models with interactive charts.
Uses lightweight-charts for interactive trading charts with:
- Candlestick price data
- Range predictions (high/low) from multiple timeframes
- Movement magnitude predictions
- AMD phase indicators
- Technical indicators (RSI, MACD, SAR)
Author: ML-Specialist (NEXUS v4.0)
Date: 2026-01-05
"""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
import numpy as np
import pandas as pd
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Any
from datetime import datetime, timedelta
import joblib
from loguru import logger
import psycopg2
from psycopg2.extras import RealDictCursor
try:
from lightweight_charts import Chart
HAS_LIGHTWEIGHT_CHARTS = True
except ImportError:
HAS_LIGHTWEIGHT_CHARTS = False
logger.warning("lightweight-charts not installed")
try:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
HAS_PLOTLY = True
except ImportError:
HAS_PLOTLY = False
logger.warning("plotly not installed. Install with: pip install plotly")
# ML-Engine imports
from config.reduced_features import generate_reduced_features, get_feature_columns_without_ohlcv
# ============================================================
# Configuration
# ============================================================
@dataclass
class VisualizerConfig:
"""Configuration for the visualizer"""
# PostgreSQL connection
db_host: str = "localhost"
db_port: int = 5432
db_name: str = "orbiquant_trading"
db_user: str = "orbiquant_user"
db_password: str = "orbiquant_dev_2025"
# Visualization settings
chart_height: float = 0.6
indicator_height: float = 0.2
prediction_line_width: int = 1
# Colors
high_colors: List[str] = None
low_colors: List[str] = None
def __post_init__(self):
if self.high_colors is None:
self.high_colors = ["#006400", "#228B22", "#32CD32", "#7FFF00"] # Greens
if self.low_colors is None:
self.low_colors = ["#FF0000", "#B22222", "#8B0000", "#CD5C5C"] # Reds
# ============================================================
# PostgreSQL Data Loader
# ============================================================
class PostgreSQLDataLoader:
"""Loads market data from PostgreSQL"""
def __init__(self, config: VisualizerConfig = None):
self.config = config or VisualizerConfig()
self.connection = None
self._ticker_cache = {}
def connect(self):
"""Connect to PostgreSQL"""
if self.connection is None or self.connection.closed:
self.connection = psycopg2.connect(
host=self.config.db_host,
port=self.config.db_port,
dbname=self.config.db_name,
user=self.config.db_user,
password=self.config.db_password
)
logger.info(f"Connected to PostgreSQL at {self.config.db_host}:{self.config.db_port}")
def close(self):
"""Close connection"""
if self.connection and not self.connection.closed:
self.connection.close()
def get_ticker_id(self, symbol: str) -> Optional[int]:
"""Get ticker ID from symbol"""
if symbol in self._ticker_cache:
return self._ticker_cache[symbol]
self.connect()
with self.connection.cursor() as cur:
cur.execute(
"SELECT id FROM market_data.tickers WHERE symbol = %s",
(symbol,)
)
result = cur.fetchone()
if result:
self._ticker_cache[symbol] = result[0]
return result[0]
return None
def load_ohlcv(
self,
symbol: str,
timeframe: str,
start_date: str,
end_date: str
) -> pd.DataFrame:
"""Load OHLCV data from PostgreSQL"""
self.connect()
ticker_id = self.get_ticker_id(symbol)
if ticker_id is None:
logger.error(f"Symbol not found: {symbol}")
return pd.DataFrame()
# Determine table based on date range
start_year = int(start_date[:4])
table = f"market_data.ohlcv_{timeframe}_{start_year}"
# Always use 5m parent table (has all data)
# We'll resample to 15m if needed
table = "market_data.ohlcv_5m"
query = f"""
SELECT
timestamp as time,
open,
high,
low,
close,
volume
FROM {table}
WHERE ticker_id = %s
AND timestamp >= %s
AND timestamp <= %s
ORDER BY timestamp ASC
"""
try:
df = pd.read_sql_query(
query,
self.connection,
params=(ticker_id, start_date, end_date),
parse_dates=['time']
)
if not df.empty:
df.set_index('time', inplace=True)
# Resample to 15m if requested
if timeframe == '15m':
logger.info(f"Resampling {len(df)} 5m records to 15m...")
df = df.resample('15min').agg({
'open': 'first',
'high': 'max',
'low': 'min',
'close': 'last',
'volume': 'sum'
}).dropna()
logger.info(f"Loaded {len(df)} records for {symbol} {timeframe}")
return df
except Exception as e:
logger.error(f"Failed to load data: {e}")
return pd.DataFrame()
# ============================================================
# Multi-Model Prediction Generator
# ============================================================
class MultiModelPredictor:
"""Generates predictions from multiple models"""
def __init__(self, model_dir: str = 'models/reduced_features_models'):
self.model_dir = Path(model_dir)
self.models = {}
self.load_models()
def load_models(self):
"""Load all available models"""
if not self.model_dir.exists():
logger.warning(f"Model directory not found: {self.model_dir}")
return
for model_file in self.model_dir.glob("*.joblib"):
if model_file.name != 'metadata.joblib':
key = model_file.stem
self.models[key] = joblib.load(model_file)
logger.info(f"Loaded model: {key}")
def predict(
self,
features: pd.DataFrame,
symbol: str,
timeframe: str,
horizon: int = 3
) -> Dict[str, np.ndarray]:
"""Get predictions from models"""
predictions = {}
feature_cols = get_feature_columns_without_ohlcv()
available_cols = [c for c in feature_cols if c in features.columns]
if not available_cols:
return predictions
X = features[available_cols].values
# High prediction
key_high = f"{symbol}_{timeframe}_high_h{horizon}"
if key_high in self.models:
predictions[f'pred_high_{timeframe}'] = self.models[key_high].predict(X)
# Low prediction
key_low = f"{symbol}_{timeframe}_low_h{horizon}"
if key_low in self.models:
predictions[f'pred_low_{timeframe}'] = self.models[key_low].predict(X)
return predictions
# ============================================================
# Prediction Visualizer
# ============================================================
class MultiModelVisualizer:
"""
Interactive chart visualizer for multi-model predictions.
Features:
- Candlestick chart with price data
- Range predictions (high/low) from 5m and 15m models
- Technical indicators (RSI, MACD, SAR)
- AMD phase overlay
"""
def __init__(self, config: VisualizerConfig = None):
self.config = config or VisualizerConfig()
self.data_loader = PostgreSQLDataLoader(self.config)
self.predictor = MultiModelPredictor()
def prepare_data(
self,
symbol: str,
start_date: str,
end_date: str,
timeframe: str = '5m'
) -> pd.DataFrame:
"""Prepare data with predictions for visualization"""
# Load OHLCV data
df = self.data_loader.load_ohlcv(symbol, timeframe, start_date, end_date)
if df.empty:
return df
# Generate features
features = generate_reduced_features(df)
# Get predictions for this timeframe
predictions = self.predictor.predict(features, symbol, timeframe)
# Add predictions to dataframe
for key, values in predictions.items():
# Shift predictions forward (they predict future values)
df[key] = np.nan
df.iloc[:-3, df.columns.get_loc(key)] = values[3:] # Shift by horizon
# Convert relative predictions to absolute prices
if 'high' in key:
df[f'{key}_price'] = df['close'] + df[key]
elif 'low' in key:
df[f'{key}_price'] = df['close'] - df[key]
# Add 15m predictions if using 5m data
if timeframe == '5m':
df_15m = self.data_loader.load_ohlcv(symbol, '15m', start_date, end_date)
if not df_15m.empty:
features_15m = generate_reduced_features(df_15m)
predictions_15m = self.predictor.predict(features_15m, symbol, '15m')
# Align 15m predictions to 5m timeframe
for key, values in predictions_15m.items():
# Create 15m series aligned to 15m index
series_15m = pd.Series(values, index=features_15m.index)
# Reindex to 5m
df[key] = series_15m.reindex(df.index, method='ffill')
if 'high' in key:
df[f'{key}_price'] = df['close'] + df[key]
elif 'low' in key:
df[f'{key}_price'] = df['close'] - df[key]
# Add features to df
for col in features.columns:
if col not in df.columns and col not in ['open', 'high', 'low', 'close', 'volume']:
df[col] = features[col]
return df
def visualize(
self,
symbol: str,
start_date: str,
end_date: str,
timeframe: str = '5m',
show_predictions: bool = True,
show_indicators: bool = True
):
"""
Create interactive chart with predictions.
Args:
symbol: Trading symbol (e.g., 'XAUUSD')
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
timeframe: Base timeframe ('5m' or '15m')
show_predictions: Show prediction lines
show_indicators: Show technical indicators
"""
if not HAS_LIGHTWEIGHT_CHARTS:
logger.error("lightweight-charts not installed. Cannot visualize.")
return
logger.info(f"Preparing data for {symbol} {timeframe} from {start_date} to {end_date}")
# Prepare data
df = self.prepare_data(symbol, start_date, end_date, timeframe)
if df.empty:
logger.error("No data to visualize")
return
# Prepare for plotting
df_plot = df.reset_index()
df_plot['time'] = df_plot['time'].dt.strftime('%Y-%m-%d %H:%M:%S')
# Create chart
chart = Chart(
toolbox=True,
inner_height=0.05,
title=f"{symbol} - Multi-Model Predictions"
)
chart.legend(True, font_size=12)
chart.topbar.textbox(
name="SYMBOL",
initial_text=f"{symbol} {timeframe} | Predictions: 5m & 15m"
)
# Main price chart
price_chart = chart.create_subchart(
height=self.config.chart_height,
width=1,
sync=True
)
price_chart.precision(precision=2 if 'XAU' in symbol else 5)
price_chart.legend(True, font_size=12)
# Set candlestick data
ohlcv_cols = ['time', 'open', 'high', 'low', 'close', 'volume']
price_chart.set(df_plot[ohlcv_cols])
if show_predictions:
# Add prediction lines
pred_high_cols = [c for c in df_plot.columns if 'pred_high' in c and '_price' in c]
pred_low_cols = [c for c in df_plot.columns if 'pred_low' in c and '_price' in c]
# High predictions (greens)
for i, col in enumerate(pred_high_cols):
if col in df_plot.columns:
pred_line = price_chart.create_line(
col.replace('_price', ''),
color=self.config.high_colors[i % len(self.config.high_colors)],
width=self.config.prediction_line_width
)
pred_line.set(df_plot[['time', col]].rename(columns={col: col.replace('_price', '')}))
# Low predictions (reds)
for i, col in enumerate(pred_low_cols):
if col in df_plot.columns:
pred_line = price_chart.create_line(
col.replace('_price', ''),
color=self.config.low_colors[i % len(self.config.low_colors)],
width=self.config.prediction_line_width
)
pred_line.set(df_plot[['time', col]].rename(columns={col: col.replace('_price', '')}))
# SAR points
if 'SAR' in df_plot.columns:
sar_line = price_chart.create_line('SAR', color='#FF69B4', width=1)
sar_line.set(df_plot[['time', 'SAR']])
if show_indicators:
# RSI subchart
if 'RSI' in df_plot.columns:
rsi_chart = chart.create_subchart(height=0.15, width=1, sync=True)
rsi_chart.legend(True, font_size=10)
rsi_line = rsi_chart.create_line('RSI', color='#20B2AA', width=1)
rsi_line.set(df_plot[['time', 'RSI']])
# Overbought/oversold levels
ob_data = df_plot[['time']].copy()
ob_data['overbought'] = 70
os_data = df_plot[['time']].copy()
os_data['oversold'] = 30
ob_line = rsi_chart.create_line('overbought', color='#DC143C', width=1)
ob_line.set(ob_data)
os_line = rsi_chart.create_line('oversold', color='#32CD32', width=1)
os_line.set(os_data)
# CMF subchart
if 'CMF' in df_plot.columns:
cmf_chart = chart.create_subchart(height=0.1, width=1, sync=True)
cmf_chart.legend(True, font_size=10)
cmf_line = cmf_chart.create_line('CMF', color='#9370DB', width=1)
cmf_line.set(df_plot[['time', 'CMF']])
# Zero line
zero_data = df_plot[['time']].copy()
zero_data['zero'] = 0
zero_line = cmf_chart.create_line('zero', color='#808080', width=1)
zero_line.set(zero_data)
logger.info("Displaying chart...")
chart.show(block=True)
def visualize_plotly(
self,
symbol: str,
start_date: str,
end_date: str,
timeframe: str = '5m',
show_predictions: bool = True,
show_indicators: bool = True,
output_file: str = None
) -> str:
"""
Create interactive HTML chart with Plotly (fallback for environments without GTK/QT).
Args:
symbol: Trading symbol (e.g., 'XAUUSD')
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
timeframe: Base timeframe ('5m' or '15m')
show_predictions: Show prediction lines
show_indicators: Show technical indicators
output_file: Output HTML file path (auto-generated if None)
Returns:
Path to generated HTML file
"""
if not HAS_PLOTLY:
logger.error("plotly not installed. Cannot visualize.")
return None
logger.info(f"Preparing Plotly chart for {symbol} {timeframe} from {start_date} to {end_date}")
# Prepare data
df = self.prepare_data(symbol, start_date, end_date, timeframe)
if df.empty:
logger.error("No data to visualize")
return None
# Determine number of rows for subplots
n_rows = 1
row_heights = [0.6]
if show_indicators:
if 'RSI' in df.columns:
n_rows += 1
row_heights.append(0.2)
if 'CMF' in df.columns:
n_rows += 1
row_heights.append(0.2)
# Normalize heights
total = sum(row_heights)
row_heights = [h/total for h in row_heights]
# Create subplots
fig = make_subplots(
rows=n_rows, cols=1,
shared_xaxes=True,
vertical_spacing=0.03,
row_heights=row_heights,
subplot_titles=[f"{symbol} {timeframe} - Multi-Model Predictions"] +
(['RSI'] if 'RSI' in df.columns and show_indicators else []) +
(['CMF'] if 'CMF' in df.columns and show_indicators else [])
)
# Candlestick chart
df_plot = df.reset_index()
fig.add_trace(
go.Candlestick(
x=df_plot['time'],
open=df_plot['open'],
high=df_plot['high'],
low=df_plot['low'],
close=df_plot['close'],
name='Price'
),
row=1, col=1
)
if show_predictions:
# High predictions (green shades)
pred_high_cols = [c for c in df.columns if 'pred_high' in c and '_price' in c]
for i, col in enumerate(pred_high_cols):
color = self.config.high_colors[i % len(self.config.high_colors)]
label = col.replace('_price', '')
fig.add_trace(
go.Scatter(
x=df_plot['time'],
y=df_plot[col],
mode='lines',
name=label,
line=dict(color=color, width=1),
opacity=0.7
),
row=1, col=1
)
# Low predictions (red shades)
pred_low_cols = [c for c in df.columns if 'pred_low' in c and '_price' in c]
for i, col in enumerate(pred_low_cols):
color = self.config.low_colors[i % len(self.config.low_colors)]
label = col.replace('_price', '')
fig.add_trace(
go.Scatter(
x=df_plot['time'],
y=df_plot[col],
mode='lines',
name=label,
line=dict(color=color, width=1),
opacity=0.7
),
row=1, col=1
)
# SAR points
if 'SAR' in df.columns:
fig.add_trace(
go.Scatter(
x=df_plot['time'],
y=df_plot['SAR'],
mode='markers',
name='SAR',
marker=dict(color='#FF69B4', size=3)
),
row=1, col=1
)
current_row = 2
if show_indicators:
# RSI subplot
if 'RSI' in df.columns:
fig.add_trace(
go.Scatter(
x=df_plot['time'],
y=df_plot['RSI'],
mode='lines',
name='RSI',
line=dict(color='#20B2AA', width=1)
),
row=current_row, col=1
)
# Overbought/oversold lines
fig.add_hline(y=70, line_dash="dash", line_color="red", opacity=0.5, row=current_row, col=1)
fig.add_hline(y=30, line_dash="dash", line_color="green", opacity=0.5, row=current_row, col=1)
current_row += 1
# CMF subplot
if 'CMF' in df.columns:
fig.add_trace(
go.Scatter(
x=df_plot['time'],
y=df_plot['CMF'],
mode='lines',
name='CMF',
line=dict(color='#9370DB', width=1)
),
row=current_row, col=1
)
fig.add_hline(y=0, line_dash="dash", line_color="gray", opacity=0.5, row=current_row, col=1)
# Update layout
fig.update_layout(
title=f"{symbol} {timeframe} - Multi-Model Predictions ({start_date} to {end_date})",
xaxis_title="Time",
yaxis_title="Price",
template="plotly_dark",
height=800,
showlegend=True,
legend=dict(
yanchor="top",
y=0.99,
xanchor="left",
x=0.01
),
xaxis_rangeslider_visible=False
)
# Generate output filename
if output_file is None:
output_dir = Path(__file__).parent.parent / 'reports' / 'charts'
output_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
output_file = str(output_dir / f"predictions_{symbol}_{timeframe}_{timestamp}.html")
# Save HTML
fig.write_html(output_file)
logger.info(f"Chart saved to: {output_file}")
return output_file
def visualize_backtest_results(
self,
df: pd.DataFrame,
trades: List[Dict],
symbol: str
):
"""
Visualize backtest results with trade markers.
Args:
df: DataFrame with OHLCV and predictions
trades: List of trade dictionaries with entry/exit info
symbol: Trading symbol
"""
if not HAS_LIGHTWEIGHT_CHARTS:
logger.error("lightweight-charts not installed")
return
df_plot = df.reset_index()
df_plot['time'] = df_plot['time'].dt.strftime('%Y-%m-%d %H:%M:%S')
chart = Chart(toolbox=True, title=f"{symbol} - Backtest Results")
chart.legend(True)
# Main chart
price_chart = chart.create_subchart(height=0.7, width=1, sync=True)
price_chart.precision(precision=2 if 'XAU' in symbol else 5)
price_chart.set(df_plot[['time', 'open', 'high', 'low', 'close', 'volume']])
# Add trade markers
for trade in trades:
# Entry marker
entry_time = trade.get('entry_time')
entry_price = trade.get('entry_price')
direction = trade.get('direction', 'LONG')
if entry_time and entry_price:
color = '#00FF00' if direction == 'LONG' else '#FF0000'
marker_type = 'arrow_up' if direction == 'LONG' else 'arrow_down'
price_chart.marker(
time=entry_time.strftime('%Y-%m-%d %H:%M:%S'),
position='below' if direction == 'LONG' else 'above',
color=color,
shape=marker_type,
text=f"{direction} Entry"
)
# Exit marker
exit_time = trade.get('exit_time')
exit_price = trade.get('exit_price')
pnl = trade.get('pnl', 0)
if exit_time and exit_price:
color = '#00FF00' if pnl > 0 else '#FF0000'
price_chart.marker(
time=exit_time.strftime('%Y-%m-%d %H:%M:%S'),
position='above' if direction == 'LONG' else 'below',
color=color,
shape='circle',
text=f"Exit ${pnl:+.2f}"
)
# Equity curve subchart
equity_chart = chart.create_subchart(height=0.2, width=1, sync=True)
equity_chart.legend(True)
# Calculate cumulative equity from trades
equity = [1000.0] # Starting capital
times = [df_plot['time'].iloc[0]]
for trade in sorted(trades, key=lambda x: x.get('exit_time', datetime.now())):
if trade.get('exit_time'):
equity.append(equity[-1] + trade.get('pnl', 0))
times.append(trade['exit_time'].strftime('%Y-%m-%d %H:%M:%S'))
equity_df = pd.DataFrame({'time': times, 'equity': equity})
equity_line = equity_chart.create_line('Equity', color='#4169E1', width=2)
equity_line.set(equity_df)
chart.show(block=True)
# ============================================================
# Main Execution
# ============================================================
def main():
"""Main function to demonstrate visualization"""
import argparse
parser = argparse.ArgumentParser(description='Visualize multi-model predictions')
parser.add_argument('--symbol', type=str, default='XAUUSD', help='Trading symbol')
parser.add_argument('--timeframe', type=str, default='5m', help='Timeframe (5m or 15m)')
parser.add_argument('--start', type=str, default='2025-01-01', help='Start date')
parser.add_argument('--end', type=str, default='2025-01-31', help='End date')
parser.add_argument('--no-predictions', action='store_true', help='Hide predictions')
parser.add_argument('--no-indicators', action='store_true', help='Hide indicators')
parser.add_argument('--output', type=str, default=None, help='Output HTML file path')
parser.add_argument('--use-lightweight', action='store_true', help='Use lightweight-charts (requires GTK/QT)')
args = parser.parse_args()
config = VisualizerConfig()
visualizer = MultiModelVisualizer(config)
if args.use_lightweight and HAS_LIGHTWEIGHT_CHARTS:
visualizer.visualize(
symbol=args.symbol,
start_date=args.start,
end_date=args.end,
timeframe=args.timeframe,
show_predictions=not args.no_predictions,
show_indicators=not args.no_indicators
)
else:
# Use Plotly (default - works in WSL)
output_file = visualizer.visualize_plotly(
symbol=args.symbol,
start_date=args.start,
end_date=args.end,
timeframe=args.timeframe,
show_predictions=not args.no_predictions,
show_indicators=not args.no_indicators,
output_file=args.output
)
if output_file:
print(f"Chart saved: {output_file}")
if __name__ == "__main__":
main()