Hierarchical ML Pipeline for trading predictions:
- Level 0: Attention Models (volatility/flow classification)
- Level 1: Base Models (XGBoost per symbol/timeframe)
- Level 2: Metamodels (XGBoost Stacking + Neural Gating)
Key components:
- src/pipelines/hierarchical_pipeline.py - Main prediction pipeline
- src/models/ - All ML model classes
- src/training/ - Training utilities
- src/api/ - FastAPI endpoints
- scripts/ - Training and evaluation scripts
- config/ - YAML configurations
Note: Trained models (*.joblib, *.pt) are gitignored.
Regenerate with training scripts.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
801 lines
26 KiB
Python
801 lines
26 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
ML Models Visualization Script V2
|
|
=================================
|
|
Visualizes predictions from reduced-features models for all symbols and timeframes.
|
|
|
|
Models Visualized:
|
|
1. RangePredictor (high/low) - Reduced features models
|
|
2. Volatility Attention Weights - Shows where model focuses
|
|
|
|
Supports:
|
|
- Multiple symbols: XAUUSD, EURUSD, BTCUSD
|
|
- Multiple timeframes: 5m, 15m
|
|
- Date range filtering
|
|
- Out-of-sample visualization (2025 data)
|
|
|
|
Usage:
|
|
python scripts/run_visualization_v2.py --symbol XAUUSD --timeframe 15m --start 2025-01-01 --end 2025-01-31
|
|
python scripts/run_visualization_v2.py --symbol XAUUSD --timeframe 5m --start 2025-01-01 --end 2025-01-31
|
|
python scripts/run_visualization_v2.py --all-symbols --all-timeframes
|
|
|
|
Author: ML-Specialist (NEXUS v4.0)
|
|
Version: 2.0.0
|
|
Created: 2026-01-05
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
|
|
# Add src to path
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
from pathlib import Path
|
|
from datetime import datetime, timedelta
|
|
import argparse
|
|
from typing import Dict, List, Optional, Tuple, Any
|
|
import json
|
|
from loguru import logger
|
|
import joblib
|
|
|
|
# Local imports
|
|
from config.reduced_features import (
|
|
COLUMNS_TO_TRAIN,
|
|
generate_reduced_features,
|
|
get_feature_columns_without_ohlcv
|
|
)
|
|
from models.volatility_attention import (
|
|
compute_factor_median_range,
|
|
compute_move_multiplier,
|
|
weight_smooth,
|
|
compute_attention_weights,
|
|
VolatilityAttentionConfig
|
|
)
|
|
|
|
# Visualization libraries
|
|
try:
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.dates as mdates
|
|
from matplotlib.patches import Rectangle, Patch
|
|
from matplotlib.lines import Line2D
|
|
HAS_MATPLOTLIB = True
|
|
except ImportError:
|
|
HAS_MATPLOTLIB = False
|
|
logger.warning("matplotlib not available - install with: pip install matplotlib")
|
|
|
|
try:
|
|
import plotly.graph_objects as go
|
|
from plotly.subplots import make_subplots
|
|
import plotly.express as px
|
|
HAS_PLOTLY = True
|
|
except ImportError:
|
|
HAS_PLOTLY = False
|
|
logger.warning("plotly not available - install with: pip install plotly kaleido")
|
|
|
|
|
|
# ==============================================================================
|
|
# Constants
|
|
# ==============================================================================
|
|
|
|
SUPPORTED_SYMBOLS = ['XAUUSD', 'EURUSD', 'BTCUSD']
|
|
SUPPORTED_TIMEFRAMES = ['5m', '15m']
|
|
|
|
SYMBOL_CONFIGS = {
|
|
'XAUUSD': {'db_prefix': 'C:', 'base_price': 2650, 'pip_value': 0.01},
|
|
'EURUSD': {'db_prefix': 'C:', 'base_price': 1.10, 'pip_value': 0.0001},
|
|
'BTCUSD': {'db_prefix': 'X:', 'base_price': 95000, 'pip_value': 0.01}
|
|
}
|
|
|
|
HORIZONS = {'5m': 3, '15m': 3}
|
|
|
|
|
|
# ==============================================================================
|
|
# Data Loading
|
|
# ==============================================================================
|
|
|
|
def load_data_for_visualization(
|
|
symbol: str,
|
|
timeframe: str,
|
|
start_date: str,
|
|
end_date: str,
|
|
db_config_path: str = 'config/database.yaml'
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Load data for visualization from database or sample.
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
timeframe: Timeframe
|
|
start_date: Start date
|
|
end_date: End date
|
|
db_config_path: Database config path
|
|
|
|
Returns:
|
|
DataFrame with OHLCV data
|
|
"""
|
|
try:
|
|
from data.database import MySQLConnection
|
|
db = MySQLConnection(db_config_path)
|
|
|
|
config = SYMBOL_CONFIGS.get(symbol, {'db_prefix': 'C:'})
|
|
db_symbol = f"{config['db_prefix']}{symbol}"
|
|
|
|
query = """
|
|
SELECT
|
|
date_agg as time,
|
|
open, high, low, close, volume
|
|
FROM tickers_agg_data
|
|
WHERE ticker = :symbol
|
|
AND date_agg >= :start_date
|
|
AND date_agg <= :end_date
|
|
ORDER BY date_agg ASC
|
|
"""
|
|
|
|
params = {
|
|
'symbol': db_symbol,
|
|
'start_date': start_date,
|
|
'end_date': end_date
|
|
}
|
|
|
|
df = db.execute_query(query, params)
|
|
|
|
if df.empty:
|
|
logger.warning(f"No data found for {symbol} in {start_date} to {end_date}")
|
|
return create_sample_visualization_data(symbol, timeframe, start_date, end_date)
|
|
|
|
df['time'] = pd.to_datetime(df['time'])
|
|
df.set_index('time', inplace=True)
|
|
df = df.sort_index()
|
|
|
|
# Resample if needed
|
|
if timeframe != '5m':
|
|
tf_map = {'15m': '15min', '30m': '30min', '1H': '1H'}
|
|
offset = tf_map.get(timeframe, timeframe)
|
|
|
|
df = df.resample(offset).agg({
|
|
'open': 'first',
|
|
'high': 'max',
|
|
'low': 'min',
|
|
'close': 'last',
|
|
'volume': 'sum'
|
|
}).dropna()
|
|
|
|
logger.info(f"Loaded {len(df)} records for {symbol} {timeframe}")
|
|
return df
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Database load failed: {e}")
|
|
return create_sample_visualization_data(symbol, timeframe, start_date, end_date)
|
|
|
|
|
|
def create_sample_visualization_data(
|
|
symbol: str,
|
|
timeframe: str,
|
|
start_date: str,
|
|
end_date: str
|
|
) -> pd.DataFrame:
|
|
"""Create sample data for demonstration."""
|
|
logger.info(f"Creating sample visualization data for {symbol} {timeframe}...")
|
|
|
|
np.random.seed(42)
|
|
|
|
tf_map = {'5m': '5min', '15m': '15min', '30m': '30min', '1H': '1H'}
|
|
freq = tf_map.get(timeframe, '15min')
|
|
|
|
dates = pd.date_range(start=start_date, end=end_date, freq=freq)
|
|
n = len(dates)
|
|
|
|
config = SYMBOL_CONFIGS.get(symbol, {'base_price': 100})
|
|
base_price = config.get('base_price', 100)
|
|
|
|
# Generate realistic price movement
|
|
returns = np.random.randn(n) * 0.001
|
|
price = base_price * np.exp(np.cumsum(returns))
|
|
|
|
# Vary volatility by session
|
|
volatility = np.where(
|
|
(dates.hour >= 13) & (dates.hour < 16),
|
|
0.003, # High volatility during overlap
|
|
0.001 # Normal volatility
|
|
)
|
|
|
|
df = pd.DataFrame({
|
|
'open': price * (1 + np.random.randn(n) * volatility),
|
|
'high': price * (1 + np.abs(np.random.randn(n)) * volatility * 2),
|
|
'low': price * (1 - np.abs(np.random.randn(n)) * volatility * 2),
|
|
'close': price * (1 + np.random.randn(n) * volatility),
|
|
'volume': np.random.randint(1000, 50000, n)
|
|
}, index=dates)
|
|
|
|
# Ensure OHLC consistency
|
|
df['high'] = df[['open', 'high', 'close']].max(axis=1)
|
|
df['low'] = df[['open', 'low', 'close']].min(axis=1)
|
|
|
|
return df
|
|
|
|
|
|
# ==============================================================================
|
|
# Model Loading and Prediction
|
|
# ==============================================================================
|
|
|
|
def load_reduced_features_models(
|
|
symbol: str,
|
|
timeframe: str,
|
|
model_dir: str = 'models/reduced_features_models'
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Load reduced features models for a symbol/timeframe.
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
timeframe: Timeframe
|
|
model_dir: Model directory
|
|
|
|
Returns:
|
|
Dictionary with models and metadata
|
|
"""
|
|
model_path = Path(model_dir)
|
|
|
|
if not model_path.exists():
|
|
logger.warning(f"Model directory not found: {model_path}")
|
|
return {}
|
|
|
|
horizon = HORIZONS.get(timeframe, 3)
|
|
key_high = f"{symbol}_{timeframe}_high_h{horizon}"
|
|
key_low = f"{symbol}_{timeframe}_low_h{horizon}"
|
|
|
|
models = {}
|
|
|
|
# Try to load models
|
|
for key in [key_high, key_low]:
|
|
model_file = model_path / f"{key}.joblib"
|
|
if model_file.exists():
|
|
models[key] = joblib.load(model_file)
|
|
logger.info(f"Loaded model: {key}")
|
|
else:
|
|
logger.warning(f"Model not found: {model_file}")
|
|
|
|
# Load metadata
|
|
metadata_file = model_path / 'metadata.joblib'
|
|
if metadata_file.exists():
|
|
models['metadata'] = joblib.load(metadata_file)
|
|
|
|
return models
|
|
|
|
|
|
def predict_with_models(
|
|
df: pd.DataFrame,
|
|
models: Dict[str, Any],
|
|
symbol: str,
|
|
timeframe: str
|
|
) -> Dict[str, np.ndarray]:
|
|
"""
|
|
Generate predictions using loaded models.
|
|
|
|
Args:
|
|
df: DataFrame with OHLCV data
|
|
models: Loaded models dictionary
|
|
symbol: Trading symbol
|
|
timeframe: Timeframe
|
|
|
|
Returns:
|
|
Dictionary with predictions
|
|
"""
|
|
predictions = {}
|
|
|
|
# Generate features
|
|
features = generate_reduced_features(df)
|
|
feature_cols = get_feature_columns_without_ohlcv()
|
|
available_cols = [c for c in feature_cols if c in features.columns]
|
|
|
|
if not available_cols:
|
|
logger.warning("No feature columns available for prediction")
|
|
return predictions
|
|
|
|
X = features[available_cols].values
|
|
|
|
horizon = HORIZONS.get(timeframe, 3)
|
|
key_high = f"{symbol}_{timeframe}_high_h{horizon}"
|
|
key_low = f"{symbol}_{timeframe}_low_h{horizon}"
|
|
|
|
if key_high in models:
|
|
predictions['delta_high'] = models[key_high].predict(X)
|
|
logger.info(f"Generated {len(predictions['delta_high'])} high predictions")
|
|
|
|
if key_low in models:
|
|
predictions['delta_low'] = models[key_low].predict(X)
|
|
logger.info(f"Generated {len(predictions['delta_low'])} low predictions")
|
|
|
|
# Compute volatility attention weights
|
|
try:
|
|
config = VolatilityAttentionConfig(factor_window=100, w_max=3.0, beta=4.0)
|
|
predictions['attention_weights'] = compute_attention_weights(df, config)
|
|
logger.info("Computed attention weights")
|
|
except Exception as e:
|
|
logger.warning(f"Could not compute attention weights: {e}")
|
|
|
|
return predictions
|
|
|
|
|
|
# ==============================================================================
|
|
# Visualization Functions
|
|
# ==============================================================================
|
|
|
|
def create_visualization(
|
|
df: pd.DataFrame,
|
|
predictions: Dict[str, np.ndarray],
|
|
symbol: str,
|
|
timeframe: str,
|
|
output_path: Path,
|
|
start_date: str,
|
|
end_date: str,
|
|
output_format: str = 'both'
|
|
):
|
|
"""
|
|
Create visualizations for model predictions.
|
|
|
|
Args:
|
|
df: OHLCV DataFrame
|
|
predictions: Model predictions
|
|
symbol: Trading symbol
|
|
timeframe: Timeframe
|
|
output_path: Output directory
|
|
start_date: Start date
|
|
end_date: End date
|
|
output_format: 'matplotlib', 'plotly', or 'both'
|
|
"""
|
|
date_str = f"{start_date.replace('-', '')}_{end_date.replace('-', '')}"
|
|
|
|
if output_format in ['matplotlib', 'both'] and HAS_MATPLOTLIB:
|
|
create_matplotlib_visualization(
|
|
df, predictions, symbol, timeframe, output_path, date_str
|
|
)
|
|
|
|
if output_format in ['plotly', 'both'] and HAS_PLOTLY:
|
|
create_plotly_visualization(
|
|
df, predictions, symbol, timeframe, output_path, date_str
|
|
)
|
|
|
|
|
|
def create_matplotlib_visualization(
|
|
df: pd.DataFrame,
|
|
predictions: Dict,
|
|
symbol: str,
|
|
timeframe: str,
|
|
output_path: Path,
|
|
date_str: str
|
|
):
|
|
"""Create matplotlib visualization."""
|
|
fig, axes = plt.subplots(3, 1, figsize=(16, 12), sharex=True,
|
|
gridspec_kw={'height_ratios': [3, 1, 1]})
|
|
|
|
fig.suptitle(f'{symbol} - {timeframe} Reduced Features Model Predictions\n'
|
|
f'Period: {date_str.replace("_", " to ")}', fontsize=14)
|
|
|
|
# ---- Subplot 1: Price with Predictions ----
|
|
ax1 = axes[0]
|
|
|
|
# Plot candlesticks
|
|
for idx, (time, row) in enumerate(df.iterrows()):
|
|
color = 'green' if row['close'] >= row['open'] else 'red'
|
|
ax1.add_patch(Rectangle(
|
|
(mdates.date2num(time) - 0.0002, min(row['open'], row['close'])),
|
|
0.0004, abs(row['close'] - row['open']) or 0.1,
|
|
facecolor=color, edgecolor=color, alpha=0.8
|
|
))
|
|
ax1.plot([mdates.date2num(time), mdates.date2num(time)],
|
|
[row['low'], row['high']], color=color, linewidth=0.5)
|
|
|
|
# Plot predictions
|
|
if 'delta_high' in predictions and 'delta_low' in predictions:
|
|
close_prices = df['close'].values
|
|
n_preds = min(len(predictions['delta_high']), len(df))
|
|
|
|
upper_band = close_prices[:n_preds] + predictions['delta_high'][:n_preds]
|
|
lower_band = close_prices[:n_preds] - predictions['delta_low'][:n_preds]
|
|
|
|
ax1.fill_between(df.index[:n_preds], lower_band, upper_band,
|
|
alpha=0.2, color='blue', label='Predicted Range')
|
|
ax1.plot(df.index[:n_preds], upper_band, 'b--', linewidth=0.8, alpha=0.7)
|
|
ax1.plot(df.index[:n_preds], lower_band, 'b--', linewidth=0.8, alpha=0.7)
|
|
|
|
# Plot actual high/low
|
|
ax1.plot(df.index, df['high'], 'g-', linewidth=0.5, alpha=0.5, label='Actual High')
|
|
ax1.plot(df.index, df['low'], 'r-', linewidth=0.5, alpha=0.5, label='Actual Low')
|
|
|
|
ax1.set_ylabel('Price')
|
|
ax1.legend(loc='upper left')
|
|
ax1.grid(True, alpha=0.3)
|
|
|
|
# ---- Subplot 2: Attention Weights ----
|
|
ax2 = axes[1]
|
|
|
|
if 'attention_weights' in predictions:
|
|
n = min(len(predictions['attention_weights']), len(df))
|
|
weights = predictions['attention_weights'][:n]
|
|
|
|
colors = ['green' if w > 1.5 else 'orange' if w > 1 else 'gray' for w in weights]
|
|
ax2.bar(df.index[:n], weights, width=0.0005, color=colors, alpha=0.7)
|
|
ax2.axhline(y=1.5, color='green', linestyle='--', linewidth=1, label='High Attention')
|
|
ax2.axhline(y=1.0, color='black', linestyle='-', linewidth=0.5)
|
|
|
|
legend_elements = [
|
|
Patch(facecolor='green', alpha=0.7, label='High Attention (>1.5)'),
|
|
Patch(facecolor='orange', alpha=0.7, label='Moderate (1-1.5)'),
|
|
Patch(facecolor='gray', alpha=0.7, label='Low (<1)')
|
|
]
|
|
ax2.legend(handles=legend_elements, loc='upper right', fontsize=8)
|
|
else:
|
|
ax2.text(0.5, 0.5, 'Attention weights not available',
|
|
transform=ax2.transAxes, ha='center', va='center')
|
|
|
|
ax2.set_ylabel('Attention Weight')
|
|
ax2.set_ylim(0, 4)
|
|
ax2.grid(True, alpha=0.3)
|
|
|
|
# ---- Subplot 3: Prediction Errors ----
|
|
ax3 = axes[2]
|
|
|
|
if 'delta_high' in predictions and 'delta_low' in predictions:
|
|
close_prices = df['close'].values
|
|
n_preds = min(len(predictions['delta_high']), len(df) - 3)
|
|
|
|
# Compute actual deltas (shifted by horizon)
|
|
horizon = HORIZONS.get(timeframe, 3)
|
|
actual_high = np.zeros(n_preds)
|
|
actual_low = np.zeros(n_preds)
|
|
|
|
for i in range(n_preds - horizon):
|
|
future_slice = slice(i+1, i+1+horizon)
|
|
actual_high[i] = df['high'].iloc[future_slice].max() - close_prices[i]
|
|
actual_low[i] = close_prices[i] - df['low'].iloc[future_slice].min()
|
|
|
|
# Plot errors
|
|
error_high = predictions['delta_high'][:n_preds] - actual_high
|
|
error_low = predictions['delta_low'][:n_preds] - actual_low
|
|
|
|
ax3.fill_between(df.index[:n_preds], error_high, alpha=0.5, color='blue', label='High Error')
|
|
ax3.fill_between(df.index[:n_preds], -error_low, alpha=0.5, color='red', label='Low Error')
|
|
ax3.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
|
|
|
|
ax3.set_ylabel('Prediction Error')
|
|
ax3.set_xlabel('Time')
|
|
ax3.legend(loc='upper right')
|
|
ax3.grid(True, alpha=0.3)
|
|
|
|
# Format x-axis
|
|
ax3.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d %H:%M'))
|
|
plt.xticks(rotation=45)
|
|
|
|
plt.tight_layout()
|
|
|
|
# Save
|
|
output_file = output_path / f"{symbol}_{timeframe}_predictions_{date_str}.png"
|
|
plt.savefig(output_file, dpi=150, bbox_inches='tight')
|
|
logger.info(f"Saved matplotlib chart to {output_file}")
|
|
|
|
plt.close(fig)
|
|
|
|
|
|
def create_plotly_visualization(
|
|
df: pd.DataFrame,
|
|
predictions: Dict,
|
|
symbol: str,
|
|
timeframe: str,
|
|
output_path: Path,
|
|
date_str: str
|
|
):
|
|
"""Create plotly interactive visualization."""
|
|
fig = make_subplots(
|
|
rows=3, cols=1,
|
|
shared_xaxes=True,
|
|
vertical_spacing=0.05,
|
|
row_heights=[0.5, 0.25, 0.25],
|
|
subplot_titles=(
|
|
f'{symbol} - {timeframe} Price & Predictions',
|
|
'Volatility Attention Weights',
|
|
'Prediction Analysis'
|
|
)
|
|
)
|
|
|
|
# ---- Row 1: Candlestick with Predictions ----
|
|
fig.add_trace(
|
|
go.Candlestick(
|
|
x=df.index,
|
|
open=df['open'],
|
|
high=df['high'],
|
|
low=df['low'],
|
|
close=df['close'],
|
|
name='OHLC',
|
|
increasing_line_color='green',
|
|
decreasing_line_color='red'
|
|
),
|
|
row=1, col=1
|
|
)
|
|
|
|
if 'delta_high' in predictions and 'delta_low' in predictions:
|
|
close_prices = df['close'].values
|
|
n_preds = min(len(predictions['delta_high']), len(df))
|
|
|
|
upper_band = close_prices[:n_preds] + predictions['delta_high'][:n_preds]
|
|
lower_band = close_prices[:n_preds] - predictions['delta_low'][:n_preds]
|
|
|
|
fig.add_trace(
|
|
go.Scatter(
|
|
x=df.index[:n_preds], y=upper_band,
|
|
mode='lines', name='Predicted High',
|
|
line=dict(color='blue', dash='dash', width=1),
|
|
opacity=0.7
|
|
),
|
|
row=1, col=1
|
|
)
|
|
|
|
fig.add_trace(
|
|
go.Scatter(
|
|
x=df.index[:n_preds], y=lower_band,
|
|
mode='lines', name='Predicted Low',
|
|
line=dict(color='blue', dash='dash', width=1),
|
|
fill='tonexty', fillcolor='rgba(0,0,255,0.1)',
|
|
opacity=0.7
|
|
),
|
|
row=1, col=1
|
|
)
|
|
|
|
# ---- Row 2: Attention Weights ----
|
|
if 'attention_weights' in predictions:
|
|
n = min(len(predictions['attention_weights']), len(df))
|
|
weights = predictions['attention_weights'][:n]
|
|
|
|
colors = ['green' if w > 1.5 else 'orange' if w > 1 else 'gray' for w in weights]
|
|
|
|
fig.add_trace(
|
|
go.Bar(
|
|
x=df.index[:n],
|
|
y=weights,
|
|
marker_color=colors,
|
|
name='Attention Weight',
|
|
opacity=0.7
|
|
),
|
|
row=2, col=1
|
|
)
|
|
|
|
fig.add_hline(y=1.5, line_dash="dash", line_color="green", row=2, col=1)
|
|
fig.add_hline(y=1.0, line_color="black", line_width=0.5, row=2, col=1)
|
|
|
|
# ---- Row 3: Prediction Statistics ----
|
|
if 'delta_high' in predictions and 'delta_low' in predictions:
|
|
n_preds = min(len(predictions['delta_high']), len(df))
|
|
|
|
# Asymmetry ratio
|
|
asymmetry = predictions['delta_high'][:n_preds] / (predictions['delta_low'][:n_preds] + 1e-10)
|
|
colors = ['green' if a > 1.2 else 'red' if a < 0.8 else 'gray' for a in asymmetry]
|
|
|
|
fig.add_trace(
|
|
go.Bar(
|
|
x=df.index[:n_preds],
|
|
y=asymmetry,
|
|
marker_color=colors,
|
|
name='High/Low Asymmetry',
|
|
opacity=0.7
|
|
),
|
|
row=3, col=1
|
|
)
|
|
|
|
fig.add_hline(y=1.2, line_dash="dash", line_color="green", row=3, col=1)
|
|
fig.add_hline(y=0.8, line_dash="dash", line_color="red", row=3, col=1)
|
|
fig.add_hline(y=1.0, line_color="black", line_width=0.5, row=3, col=1)
|
|
|
|
# Update layout
|
|
fig.update_layout(
|
|
title=f'{symbol} - {timeframe} Reduced Features Model Analysis',
|
|
height=1000,
|
|
showlegend=True,
|
|
xaxis_rangeslider_visible=False,
|
|
template='plotly_white'
|
|
)
|
|
|
|
fig.update_yaxes(title_text="Price", row=1, col=1)
|
|
fig.update_yaxes(title_text="Attention", range=[0, 4], row=2, col=1)
|
|
fig.update_yaxes(title_text="Asymmetry", row=3, col=1)
|
|
fig.update_xaxes(title_text="Time", row=3, col=1)
|
|
|
|
# Save HTML
|
|
output_file = output_path / f"{symbol}_{timeframe}_predictions_{date_str}.html"
|
|
fig.write_html(str(output_file))
|
|
logger.info(f"Saved plotly chart to {output_file}")
|
|
|
|
# Try to save PNG
|
|
try:
|
|
png_file = output_path / f"{symbol}_{timeframe}_predictions_{date_str}_plotly.png"
|
|
fig.write_image(str(png_file), width=1600, height=1000)
|
|
logger.info(f"Saved PNG chart to {png_file}")
|
|
except Exception as e:
|
|
logger.warning(f"Could not save PNG: {e}")
|
|
|
|
|
|
# ==============================================================================
|
|
# Main Function
|
|
# ==============================================================================
|
|
|
|
def run_visualization(
|
|
symbol: str,
|
|
timeframe: str,
|
|
start_date: str,
|
|
end_date: str,
|
|
model_dir: str = 'models/reduced_features_models',
|
|
output_dir: str = 'charts',
|
|
output_format: str = 'both',
|
|
db_config_path: str = 'config/database.yaml'
|
|
) -> Dict:
|
|
"""
|
|
Run visualization for a symbol/timeframe.
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
timeframe: Timeframe
|
|
start_date: Start date
|
|
end_date: End date
|
|
model_dir: Model directory
|
|
output_dir: Output directory
|
|
output_format: Output format
|
|
db_config_path: Database config path
|
|
|
|
Returns:
|
|
Summary dictionary
|
|
"""
|
|
logger.info("=" * 60)
|
|
logger.info("REDUCED FEATURES MODEL VISUALIZATION V2")
|
|
logger.info(f"Symbol: {symbol}")
|
|
logger.info(f"Timeframe: {timeframe}")
|
|
logger.info(f"Period: {start_date} to {end_date}")
|
|
logger.info("=" * 60)
|
|
|
|
# Create output directory
|
|
output_path = Path(output_dir) / symbol / timeframe
|
|
output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Load data
|
|
df = load_data_for_visualization(
|
|
symbol, timeframe, start_date, end_date, db_config_path
|
|
)
|
|
|
|
if df.empty:
|
|
logger.error("No data available for visualization")
|
|
return {'error': 'No data'}
|
|
|
|
logger.info(f"Loaded {len(df)} records")
|
|
|
|
# Load models
|
|
models = load_reduced_features_models(symbol, timeframe, model_dir)
|
|
|
|
if not models:
|
|
logger.warning("No models loaded - using sample predictions")
|
|
# Generate sample predictions for demo
|
|
predictions = {
|
|
'delta_high': np.random.uniform(0, 10, len(df)),
|
|
'delta_low': np.random.uniform(0, 10, len(df))
|
|
}
|
|
else:
|
|
# Generate predictions
|
|
predictions = predict_with_models(df, models, symbol, timeframe)
|
|
|
|
# Create visualizations
|
|
create_visualization(
|
|
df, predictions, symbol, timeframe,
|
|
output_path, start_date, end_date, output_format
|
|
)
|
|
|
|
# Summary
|
|
summary = {
|
|
'symbol': symbol,
|
|
'timeframe': timeframe,
|
|
'period': {'start': start_date, 'end': end_date},
|
|
'data_points': len(df),
|
|
'models_loaded': list(models.keys()) if models else [],
|
|
'predictions_generated': list(predictions.keys()),
|
|
'output_path': str(output_path)
|
|
}
|
|
|
|
# Save summary
|
|
summary_file = output_path / f"summary_{start_date}_{end_date}.json"
|
|
with open(summary_file, 'w') as f:
|
|
json.dump(summary, f, indent=2, default=str)
|
|
|
|
logger.info(f"\nVisualization complete! Charts saved to {output_path}")
|
|
|
|
return summary
|
|
|
|
|
|
# ==============================================================================
|
|
# CLI Entry Point
|
|
# ==============================================================================
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description='Visualize reduced-features model predictions',
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
# Visualize XAUUSD 5m for January 2025
|
|
python scripts/run_visualization_v2.py --symbol XAUUSD --timeframe 5m --start 2025-01-01 --end 2025-01-31
|
|
|
|
# Visualize XAUUSD 15m for January 2025
|
|
python scripts/run_visualization_v2.py --symbol XAUUSD --timeframe 15m --start 2025-01-01 --end 2025-01-31
|
|
|
|
# All symbols and timeframes
|
|
python scripts/run_visualization_v2.py --all-symbols --all-timeframes
|
|
|
|
# Only matplotlib output
|
|
python scripts/run_visualization_v2.py --symbol XAUUSD --format matplotlib
|
|
"""
|
|
)
|
|
|
|
parser.add_argument('--symbol', default='XAUUSD',
|
|
help='Trading symbol (default: XAUUSD)')
|
|
parser.add_argument('--timeframe', default='15m',
|
|
help='Timeframe: 5m or 15m (default: 15m)')
|
|
parser.add_argument('--start', default='2025-01-01',
|
|
help='Start date YYYY-MM-DD (default: 2025-01-01)')
|
|
parser.add_argument('--end', default='2025-01-31',
|
|
help='End date YYYY-MM-DD (default: 2025-01-31)')
|
|
parser.add_argument('--format', default='both',
|
|
choices=['matplotlib', 'plotly', 'both'],
|
|
help='Output format (default: both)')
|
|
parser.add_argument('--model-dir', default='models/reduced_features_models',
|
|
help='Model directory')
|
|
parser.add_argument('--output-dir', default='charts',
|
|
help='Output directory for charts')
|
|
parser.add_argument('--all-symbols', action='store_true',
|
|
help='Run for all symbols')
|
|
parser.add_argument('--all-timeframes', action='store_true',
|
|
help='Run for all timeframes')
|
|
parser.add_argument('--db-config', default='config/database.yaml',
|
|
help='Database config file')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Determine symbols and timeframes
|
|
symbols = SUPPORTED_SYMBOLS if args.all_symbols else [args.symbol]
|
|
timeframes = SUPPORTED_TIMEFRAMES if args.all_timeframes else [args.timeframe]
|
|
|
|
# Change to script directory
|
|
script_dir = Path(__file__).parent.parent
|
|
os.chdir(script_dir)
|
|
|
|
# Run visualizations
|
|
results = []
|
|
for symbol in symbols:
|
|
for timeframe in timeframes:
|
|
logger.info(f"\nProcessing {symbol} {timeframe}...")
|
|
try:
|
|
result = run_visualization(
|
|
symbol=symbol,
|
|
timeframe=timeframe,
|
|
start_date=args.start,
|
|
end_date=args.end,
|
|
model_dir=args.model_dir,
|
|
output_dir=args.output_dir,
|
|
output_format=args.format,
|
|
db_config_path=args.db_config
|
|
)
|
|
results.append(result)
|
|
except Exception as e:
|
|
logger.error(f"Failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
# Final summary
|
|
print("\n" + "=" * 60)
|
|
print("VISUALIZATION SUMMARY")
|
|
print("=" * 60)
|
|
for r in results:
|
|
if 'error' not in r:
|
|
print(f" {r['symbol']} {r['timeframe']}: {r['data_points']} points -> {r['output_path']}")
|
|
print("=" * 60)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|