#!/usr/bin/env python3 """ Multi-Model Prediction Visualizer ================================== Visualizes predictions from multiple ML models with interactive charts. Uses lightweight-charts for interactive trading charts with: - Candlestick price data - Range predictions (high/low) from multiple timeframes - Movement magnitude predictions - AMD phase indicators - Technical indicators (RSI, MACD, SAR) Author: ML-Specialist (NEXUS v4.0) Date: 2026-01-05 """ import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) sys.path.insert(0, str(Path(__file__).parent.parent / 'src')) import numpy as np import pandas as pd from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Any from datetime import datetime, timedelta import joblib from loguru import logger import psycopg2 from psycopg2.extras import RealDictCursor try: from lightweight_charts import Chart HAS_LIGHTWEIGHT_CHARTS = True except ImportError: HAS_LIGHTWEIGHT_CHARTS = False logger.warning("lightweight-charts not installed") try: import plotly.graph_objects as go from plotly.subplots import make_subplots HAS_PLOTLY = True except ImportError: HAS_PLOTLY = False logger.warning("plotly not installed. Install with: pip install plotly") # ML-Engine imports from config.reduced_features import generate_reduced_features, get_feature_columns_without_ohlcv # ============================================================ # Configuration # ============================================================ @dataclass class VisualizerConfig: """Configuration for the visualizer""" # PostgreSQL connection db_host: str = "localhost" db_port: int = 5432 db_name: str = "orbiquant_trading" db_user: str = "orbiquant_user" db_password: str = "orbiquant_dev_2025" # Visualization settings chart_height: float = 0.6 indicator_height: float = 0.2 prediction_line_width: int = 1 # Colors high_colors: List[str] = None low_colors: List[str] = None def __post_init__(self): if self.high_colors is None: self.high_colors = ["#006400", "#228B22", "#32CD32", "#7FFF00"] # Greens if self.low_colors is None: self.low_colors = ["#FF0000", "#B22222", "#8B0000", "#CD5C5C"] # Reds # ============================================================ # PostgreSQL Data Loader # ============================================================ class PostgreSQLDataLoader: """Loads market data from PostgreSQL""" def __init__(self, config: VisualizerConfig = None): self.config = config or VisualizerConfig() self.connection = None self._ticker_cache = {} def connect(self): """Connect to PostgreSQL""" if self.connection is None or self.connection.closed: self.connection = psycopg2.connect( host=self.config.db_host, port=self.config.db_port, dbname=self.config.db_name, user=self.config.db_user, password=self.config.db_password ) logger.info(f"Connected to PostgreSQL at {self.config.db_host}:{self.config.db_port}") def close(self): """Close connection""" if self.connection and not self.connection.closed: self.connection.close() def get_ticker_id(self, symbol: str) -> Optional[int]: """Get ticker ID from symbol""" if symbol in self._ticker_cache: return self._ticker_cache[symbol] self.connect() with self.connection.cursor() as cur: cur.execute( "SELECT id FROM market_data.tickers WHERE symbol = %s", (symbol,) ) result = cur.fetchone() if result: self._ticker_cache[symbol] = result[0] return result[0] return None def load_ohlcv( self, symbol: str, timeframe: str, start_date: str, end_date: str ) -> pd.DataFrame: """Load OHLCV data from PostgreSQL""" self.connect() ticker_id = self.get_ticker_id(symbol) if ticker_id is None: logger.error(f"Symbol not found: {symbol}") return pd.DataFrame() # Determine table based on date range start_year = int(start_date[:4]) table = f"market_data.ohlcv_{timeframe}_{start_year}" # Always use 5m parent table (has all data) # We'll resample to 15m if needed table = "market_data.ohlcv_5m" query = f""" SELECT timestamp as time, open, high, low, close, volume FROM {table} WHERE ticker_id = %s AND timestamp >= %s AND timestamp <= %s ORDER BY timestamp ASC """ try: df = pd.read_sql_query( query, self.connection, params=(ticker_id, start_date, end_date), parse_dates=['time'] ) if not df.empty: df.set_index('time', inplace=True) # Resample to 15m if requested if timeframe == '15m': logger.info(f"Resampling {len(df)} 5m records to 15m...") df = df.resample('15min').agg({ 'open': 'first', 'high': 'max', 'low': 'min', 'close': 'last', 'volume': 'sum' }).dropna() logger.info(f"Loaded {len(df)} records for {symbol} {timeframe}") return df except Exception as e: logger.error(f"Failed to load data: {e}") return pd.DataFrame() # ============================================================ # Multi-Model Prediction Generator # ============================================================ class MultiModelPredictor: """Generates predictions from multiple models""" def __init__(self, model_dir: str = 'models/reduced_features_models'): self.model_dir = Path(model_dir) self.models = {} self.load_models() def load_models(self): """Load all available models""" if not self.model_dir.exists(): logger.warning(f"Model directory not found: {self.model_dir}") return for model_file in self.model_dir.glob("*.joblib"): if model_file.name != 'metadata.joblib': key = model_file.stem self.models[key] = joblib.load(model_file) logger.info(f"Loaded model: {key}") def predict( self, features: pd.DataFrame, symbol: str, timeframe: str, horizon: int = 3 ) -> Dict[str, np.ndarray]: """Get predictions from models""" predictions = {} feature_cols = get_feature_columns_without_ohlcv() available_cols = [c for c in feature_cols if c in features.columns] if not available_cols: return predictions X = features[available_cols].values # High prediction key_high = f"{symbol}_{timeframe}_high_h{horizon}" if key_high in self.models: predictions[f'pred_high_{timeframe}'] = self.models[key_high].predict(X) # Low prediction key_low = f"{symbol}_{timeframe}_low_h{horizon}" if key_low in self.models: predictions[f'pred_low_{timeframe}'] = self.models[key_low].predict(X) return predictions # ============================================================ # Prediction Visualizer # ============================================================ class MultiModelVisualizer: """ Interactive chart visualizer for multi-model predictions. Features: - Candlestick chart with price data - Range predictions (high/low) from 5m and 15m models - Technical indicators (RSI, MACD, SAR) - AMD phase overlay """ def __init__(self, config: VisualizerConfig = None): self.config = config or VisualizerConfig() self.data_loader = PostgreSQLDataLoader(self.config) self.predictor = MultiModelPredictor() def prepare_data( self, symbol: str, start_date: str, end_date: str, timeframe: str = '5m' ) -> pd.DataFrame: """Prepare data with predictions for visualization""" # Load OHLCV data df = self.data_loader.load_ohlcv(symbol, timeframe, start_date, end_date) if df.empty: return df # Generate features features = generate_reduced_features(df) # Get predictions for this timeframe predictions = self.predictor.predict(features, symbol, timeframe) # Add predictions to dataframe for key, values in predictions.items(): # Shift predictions forward (they predict future values) df[key] = np.nan df.iloc[:-3, df.columns.get_loc(key)] = values[3:] # Shift by horizon # Convert relative predictions to absolute prices if 'high' in key: df[f'{key}_price'] = df['close'] + df[key] elif 'low' in key: df[f'{key}_price'] = df['close'] - df[key] # Add 15m predictions if using 5m data if timeframe == '5m': df_15m = self.data_loader.load_ohlcv(symbol, '15m', start_date, end_date) if not df_15m.empty: features_15m = generate_reduced_features(df_15m) predictions_15m = self.predictor.predict(features_15m, symbol, '15m') # Align 15m predictions to 5m timeframe for key, values in predictions_15m.items(): # Create 15m series aligned to 15m index series_15m = pd.Series(values, index=features_15m.index) # Reindex to 5m df[key] = series_15m.reindex(df.index, method='ffill') if 'high' in key: df[f'{key}_price'] = df['close'] + df[key] elif 'low' in key: df[f'{key}_price'] = df['close'] - df[key] # Add features to df for col in features.columns: if col not in df.columns and col not in ['open', 'high', 'low', 'close', 'volume']: df[col] = features[col] return df def visualize( self, symbol: str, start_date: str, end_date: str, timeframe: str = '5m', show_predictions: bool = True, show_indicators: bool = True ): """ Create interactive chart with predictions. Args: symbol: Trading symbol (e.g., 'XAUUSD') start_date: Start date (YYYY-MM-DD) end_date: End date (YYYY-MM-DD) timeframe: Base timeframe ('5m' or '15m') show_predictions: Show prediction lines show_indicators: Show technical indicators """ if not HAS_LIGHTWEIGHT_CHARTS: logger.error("lightweight-charts not installed. Cannot visualize.") return logger.info(f"Preparing data for {symbol} {timeframe} from {start_date} to {end_date}") # Prepare data df = self.prepare_data(symbol, start_date, end_date, timeframe) if df.empty: logger.error("No data to visualize") return # Prepare for plotting df_plot = df.reset_index() df_plot['time'] = df_plot['time'].dt.strftime('%Y-%m-%d %H:%M:%S') # Create chart chart = Chart( toolbox=True, inner_height=0.05, title=f"{symbol} - Multi-Model Predictions" ) chart.legend(True, font_size=12) chart.topbar.textbox( name="SYMBOL", initial_text=f"{symbol} {timeframe} | Predictions: 5m & 15m" ) # Main price chart price_chart = chart.create_subchart( height=self.config.chart_height, width=1, sync=True ) price_chart.precision(precision=2 if 'XAU' in symbol else 5) price_chart.legend(True, font_size=12) # Set candlestick data ohlcv_cols = ['time', 'open', 'high', 'low', 'close', 'volume'] price_chart.set(df_plot[ohlcv_cols]) if show_predictions: # Add prediction lines pred_high_cols = [c for c in df_plot.columns if 'pred_high' in c and '_price' in c] pred_low_cols = [c for c in df_plot.columns if 'pred_low' in c and '_price' in c] # High predictions (greens) for i, col in enumerate(pred_high_cols): if col in df_plot.columns: pred_line = price_chart.create_line( col.replace('_price', ''), color=self.config.high_colors[i % len(self.config.high_colors)], width=self.config.prediction_line_width ) pred_line.set(df_plot[['time', col]].rename(columns={col: col.replace('_price', '')})) # Low predictions (reds) for i, col in enumerate(pred_low_cols): if col in df_plot.columns: pred_line = price_chart.create_line( col.replace('_price', ''), color=self.config.low_colors[i % len(self.config.low_colors)], width=self.config.prediction_line_width ) pred_line.set(df_plot[['time', col]].rename(columns={col: col.replace('_price', '')})) # SAR points if 'SAR' in df_plot.columns: sar_line = price_chart.create_line('SAR', color='#FF69B4', width=1) sar_line.set(df_plot[['time', 'SAR']]) if show_indicators: # RSI subchart if 'RSI' in df_plot.columns: rsi_chart = chart.create_subchart(height=0.15, width=1, sync=True) rsi_chart.legend(True, font_size=10) rsi_line = rsi_chart.create_line('RSI', color='#20B2AA', width=1) rsi_line.set(df_plot[['time', 'RSI']]) # Overbought/oversold levels ob_data = df_plot[['time']].copy() ob_data['overbought'] = 70 os_data = df_plot[['time']].copy() os_data['oversold'] = 30 ob_line = rsi_chart.create_line('overbought', color='#DC143C', width=1) ob_line.set(ob_data) os_line = rsi_chart.create_line('oversold', color='#32CD32', width=1) os_line.set(os_data) # CMF subchart if 'CMF' in df_plot.columns: cmf_chart = chart.create_subchart(height=0.1, width=1, sync=True) cmf_chart.legend(True, font_size=10) cmf_line = cmf_chart.create_line('CMF', color='#9370DB', width=1) cmf_line.set(df_plot[['time', 'CMF']]) # Zero line zero_data = df_plot[['time']].copy() zero_data['zero'] = 0 zero_line = cmf_chart.create_line('zero', color='#808080', width=1) zero_line.set(zero_data) logger.info("Displaying chart...") chart.show(block=True) def visualize_plotly( self, symbol: str, start_date: str, end_date: str, timeframe: str = '5m', show_predictions: bool = True, show_indicators: bool = True, output_file: str = None ) -> str: """ Create interactive HTML chart with Plotly (fallback for environments without GTK/QT). Args: symbol: Trading symbol (e.g., 'XAUUSD') start_date: Start date (YYYY-MM-DD) end_date: End date (YYYY-MM-DD) timeframe: Base timeframe ('5m' or '15m') show_predictions: Show prediction lines show_indicators: Show technical indicators output_file: Output HTML file path (auto-generated if None) Returns: Path to generated HTML file """ if not HAS_PLOTLY: logger.error("plotly not installed. Cannot visualize.") return None logger.info(f"Preparing Plotly chart for {symbol} {timeframe} from {start_date} to {end_date}") # Prepare data df = self.prepare_data(symbol, start_date, end_date, timeframe) if df.empty: logger.error("No data to visualize") return None # Determine number of rows for subplots n_rows = 1 row_heights = [0.6] if show_indicators: if 'RSI' in df.columns: n_rows += 1 row_heights.append(0.2) if 'CMF' in df.columns: n_rows += 1 row_heights.append(0.2) # Normalize heights total = sum(row_heights) row_heights = [h/total for h in row_heights] # Create subplots fig = make_subplots( rows=n_rows, cols=1, shared_xaxes=True, vertical_spacing=0.03, row_heights=row_heights, subplot_titles=[f"{symbol} {timeframe} - Multi-Model Predictions"] + (['RSI'] if 'RSI' in df.columns and show_indicators else []) + (['CMF'] if 'CMF' in df.columns and show_indicators else []) ) # Candlestick chart df_plot = df.reset_index() fig.add_trace( go.Candlestick( x=df_plot['time'], open=df_plot['open'], high=df_plot['high'], low=df_plot['low'], close=df_plot['close'], name='Price' ), row=1, col=1 ) if show_predictions: # High predictions (green shades) pred_high_cols = [c for c in df.columns if 'pred_high' in c and '_price' in c] for i, col in enumerate(pred_high_cols): color = self.config.high_colors[i % len(self.config.high_colors)] label = col.replace('_price', '') fig.add_trace( go.Scatter( x=df_plot['time'], y=df_plot[col], mode='lines', name=label, line=dict(color=color, width=1), opacity=0.7 ), row=1, col=1 ) # Low predictions (red shades) pred_low_cols = [c for c in df.columns if 'pred_low' in c and '_price' in c] for i, col in enumerate(pred_low_cols): color = self.config.low_colors[i % len(self.config.low_colors)] label = col.replace('_price', '') fig.add_trace( go.Scatter( x=df_plot['time'], y=df_plot[col], mode='lines', name=label, line=dict(color=color, width=1), opacity=0.7 ), row=1, col=1 ) # SAR points if 'SAR' in df.columns: fig.add_trace( go.Scatter( x=df_plot['time'], y=df_plot['SAR'], mode='markers', name='SAR', marker=dict(color='#FF69B4', size=3) ), row=1, col=1 ) current_row = 2 if show_indicators: # RSI subplot if 'RSI' in df.columns: fig.add_trace( go.Scatter( x=df_plot['time'], y=df_plot['RSI'], mode='lines', name='RSI', line=dict(color='#20B2AA', width=1) ), row=current_row, col=1 ) # Overbought/oversold lines fig.add_hline(y=70, line_dash="dash", line_color="red", opacity=0.5, row=current_row, col=1) fig.add_hline(y=30, line_dash="dash", line_color="green", opacity=0.5, row=current_row, col=1) current_row += 1 # CMF subplot if 'CMF' in df.columns: fig.add_trace( go.Scatter( x=df_plot['time'], y=df_plot['CMF'], mode='lines', name='CMF', line=dict(color='#9370DB', width=1) ), row=current_row, col=1 ) fig.add_hline(y=0, line_dash="dash", line_color="gray", opacity=0.5, row=current_row, col=1) # Update layout fig.update_layout( title=f"{symbol} {timeframe} - Multi-Model Predictions ({start_date} to {end_date})", xaxis_title="Time", yaxis_title="Price", template="plotly_dark", height=800, showlegend=True, legend=dict( yanchor="top", y=0.99, xanchor="left", x=0.01 ), xaxis_rangeslider_visible=False ) # Generate output filename if output_file is None: output_dir = Path(__file__).parent.parent / 'reports' / 'charts' output_dir.mkdir(parents=True, exist_ok=True) timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') output_file = str(output_dir / f"predictions_{symbol}_{timeframe}_{timestamp}.html") # Save HTML fig.write_html(output_file) logger.info(f"Chart saved to: {output_file}") return output_file def visualize_backtest_results( self, df: pd.DataFrame, trades: List[Dict], symbol: str ): """ Visualize backtest results with trade markers. Args: df: DataFrame with OHLCV and predictions trades: List of trade dictionaries with entry/exit info symbol: Trading symbol """ if not HAS_LIGHTWEIGHT_CHARTS: logger.error("lightweight-charts not installed") return df_plot = df.reset_index() df_plot['time'] = df_plot['time'].dt.strftime('%Y-%m-%d %H:%M:%S') chart = Chart(toolbox=True, title=f"{symbol} - Backtest Results") chart.legend(True) # Main chart price_chart = chart.create_subchart(height=0.7, width=1, sync=True) price_chart.precision(precision=2 if 'XAU' in symbol else 5) price_chart.set(df_plot[['time', 'open', 'high', 'low', 'close', 'volume']]) # Add trade markers for trade in trades: # Entry marker entry_time = trade.get('entry_time') entry_price = trade.get('entry_price') direction = trade.get('direction', 'LONG') if entry_time and entry_price: color = '#00FF00' if direction == 'LONG' else '#FF0000' marker_type = 'arrow_up' if direction == 'LONG' else 'arrow_down' price_chart.marker( time=entry_time.strftime('%Y-%m-%d %H:%M:%S'), position='below' if direction == 'LONG' else 'above', color=color, shape=marker_type, text=f"{direction} Entry" ) # Exit marker exit_time = trade.get('exit_time') exit_price = trade.get('exit_price') pnl = trade.get('pnl', 0) if exit_time and exit_price: color = '#00FF00' if pnl > 0 else '#FF0000' price_chart.marker( time=exit_time.strftime('%Y-%m-%d %H:%M:%S'), position='above' if direction == 'LONG' else 'below', color=color, shape='circle', text=f"Exit ${pnl:+.2f}" ) # Equity curve subchart equity_chart = chart.create_subchart(height=0.2, width=1, sync=True) equity_chart.legend(True) # Calculate cumulative equity from trades equity = [1000.0] # Starting capital times = [df_plot['time'].iloc[0]] for trade in sorted(trades, key=lambda x: x.get('exit_time', datetime.now())): if trade.get('exit_time'): equity.append(equity[-1] + trade.get('pnl', 0)) times.append(trade['exit_time'].strftime('%Y-%m-%d %H:%M:%S')) equity_df = pd.DataFrame({'time': times, 'equity': equity}) equity_line = equity_chart.create_line('Equity', color='#4169E1', width=2) equity_line.set(equity_df) chart.show(block=True) # ============================================================ # Main Execution # ============================================================ def main(): """Main function to demonstrate visualization""" import argparse parser = argparse.ArgumentParser(description='Visualize multi-model predictions') parser.add_argument('--symbol', type=str, default='XAUUSD', help='Trading symbol') parser.add_argument('--timeframe', type=str, default='5m', help='Timeframe (5m or 15m)') parser.add_argument('--start', type=str, default='2025-01-01', help='Start date') parser.add_argument('--end', type=str, default='2025-01-31', help='End date') parser.add_argument('--no-predictions', action='store_true', help='Hide predictions') parser.add_argument('--no-indicators', action='store_true', help='Hide indicators') parser.add_argument('--output', type=str, default=None, help='Output HTML file path') parser.add_argument('--use-lightweight', action='store_true', help='Use lightweight-charts (requires GTK/QT)') args = parser.parse_args() config = VisualizerConfig() visualizer = MultiModelVisualizer(config) if args.use_lightweight and HAS_LIGHTWEIGHT_CHARTS: visualizer.visualize( symbol=args.symbol, start_date=args.start, end_date=args.end, timeframe=args.timeframe, show_predictions=not args.no_predictions, show_indicators=not args.no_indicators ) else: # Use Plotly (default - works in WSL) output_file = visualizer.visualize_plotly( symbol=args.symbol, start_date=args.start, end_date=args.end, timeframe=args.timeframe, show_predictions=not args.no_predictions, show_indicators=not args.no_indicators, output_file=args.output ) if output_file: print(f"Chart saved: {output_file}") if __name__ == "__main__": main()