Hierarchical ML Pipeline for trading predictions:
- Level 0: Attention Models (volatility/flow classification)
- Level 1: Base Models (XGBoost per symbol/timeframe)
- Level 2: Metamodels (XGBoost Stacking + Neural Gating)
Key components:
- src/pipelines/hierarchical_pipeline.py - Main prediction pipeline
- src/models/ - All ML model classes
- src/training/ - Training utilities
- src/api/ - FastAPI endpoints
- scripts/ - Training and evaluation scripts
- config/ - YAML configurations
Note: Trained models (*.joblib, *.pt) are gitignored.
Regenerate with training scripts.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
478 lines
17 KiB
Python
478 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Training Script for Enhanced Range Predictor
|
|
=============================================
|
|
Complete training pipeline for the volatility-factor based model.
|
|
|
|
Features:
|
|
- Loads OHLCV data from Parquet/CSV
|
|
- Generates features using the existing feature pipeline
|
|
- Trains dual-horizon ensemble with sample weighting
|
|
- Validates with walk-forward approach
|
|
- Saves model and generates report
|
|
|
|
Usage:
|
|
python train_enhanced_model.py --symbol XAUUSD --timeframe 15m --data-path data/
|
|
python train_enhanced_model.py --config config/training_config.yaml
|
|
|
|
Author: Trading Strategist + ML Specialist
|
|
Version: 1.0.0
|
|
"""
|
|
|
|
import argparse
|
|
import sys
|
|
from pathlib import Path
|
|
from datetime import datetime, timedelta
|
|
import json
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
from loguru import logger
|
|
|
|
# Add parent directory to path for imports
|
|
sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
|
|
|
|
from models.enhanced_range_predictor import (
|
|
EnhancedRangePredictor,
|
|
EnhancedRangePredictorConfig
|
|
)
|
|
from data.corrected_targets import CorrectedTargetConfig
|
|
from training.sample_weighting import SampleWeightConfig
|
|
from training.session_volatility_weighting import SessionWeightConfig
|
|
from models.dual_horizon_ensemble import DualHorizonConfig
|
|
|
|
|
|
def setup_logging(log_dir: Path, experiment_name: str):
|
|
"""Configure logging to file and console."""
|
|
log_file = log_dir / f"{experiment_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
|
|
|
|
logger.remove()
|
|
logger.add(sys.stderr, level="INFO")
|
|
logger.add(log_file, level="DEBUG", rotation="10 MB")
|
|
|
|
logger.info(f"Logging to {log_file}")
|
|
|
|
|
|
def load_data(data_path: Path, symbol: str, timeframe: str) -> pd.DataFrame:
|
|
"""Load OHLCV data from file."""
|
|
# Try different file formats
|
|
possible_files = [
|
|
data_path / f"{symbol}_{timeframe}.parquet",
|
|
data_path / f"{symbol}_{timeframe}.csv",
|
|
data_path / f"{symbol.lower()}_{timeframe}.parquet",
|
|
data_path / f"{symbol.lower()}_{timeframe}.csv",
|
|
data_path / f"{symbol}_{timeframe}_ohlcv.parquet",
|
|
]
|
|
|
|
for file_path in possible_files:
|
|
if file_path.exists():
|
|
logger.info(f"Loading data from {file_path}")
|
|
if file_path.suffix == '.parquet':
|
|
df = pd.read_parquet(file_path)
|
|
else:
|
|
df = pd.read_csv(file_path)
|
|
|
|
# Ensure datetime index
|
|
if 'timestamp' in df.columns:
|
|
df['timestamp'] = pd.to_datetime(df['timestamp'])
|
|
df = df.set_index('timestamp')
|
|
elif 'date' in df.columns:
|
|
df['date'] = pd.to_datetime(df['date'])
|
|
df = df.set_index('date')
|
|
elif not isinstance(df.index, pd.DatetimeIndex):
|
|
df.index = pd.to_datetime(df.index)
|
|
|
|
# Normalize column names
|
|
df.columns = df.columns.str.lower()
|
|
|
|
logger.info(f"Loaded {len(df)} samples from {df.index.min()} to {df.index.max()}")
|
|
return df
|
|
|
|
raise FileNotFoundError(f"No data file found for {symbol}_{timeframe} in {data_path}")
|
|
|
|
|
|
def generate_features(df: pd.DataFrame) -> pd.DataFrame:
|
|
"""Generate features for the model."""
|
|
logger.info("Generating features...")
|
|
|
|
features = pd.DataFrame(index=df.index)
|
|
|
|
close = df['close']
|
|
high = df['high']
|
|
low = df['low']
|
|
volume = df['volume'] if 'volume' in df.columns else pd.Series(1, index=df.index)
|
|
|
|
# Price-based features
|
|
features['returns_1'] = close.pct_change(1)
|
|
features['returns_5'] = close.pct_change(5)
|
|
features['returns_15'] = close.pct_change(15)
|
|
|
|
# Volatility features
|
|
features['volatility_5'] = close.pct_change().rolling(5).std()
|
|
features['volatility_20'] = close.pct_change().rolling(20).std()
|
|
|
|
# Range features
|
|
features['range'] = high - low
|
|
features['range_pct'] = (high - low) / close
|
|
features['range_ma_5'] = features['range'].rolling(5).mean()
|
|
features['range_ma_20'] = features['range'].rolling(20).mean()
|
|
features['range_ratio'] = features['range'] / features['range_ma_20']
|
|
|
|
# ATR
|
|
tr1 = high - low
|
|
tr2 = abs(high - close.shift(1))
|
|
tr3 = abs(low - close.shift(1))
|
|
true_range = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
|
|
features['atr_14'] = true_range.rolling(14).mean()
|
|
features['atr_ratio'] = true_range / features['atr_14']
|
|
|
|
# Moving averages
|
|
features['sma_5'] = close.rolling(5).mean()
|
|
features['sma_20'] = close.rolling(20).mean()
|
|
features['sma_50'] = close.rolling(50).mean()
|
|
features['price_vs_sma5'] = (close - features['sma_5']) / features['atr_14']
|
|
features['price_vs_sma20'] = (close - features['sma_20']) / features['atr_14']
|
|
features['sma5_vs_sma20'] = (features['sma_5'] - features['sma_20']) / features['atr_14']
|
|
|
|
# RSI
|
|
delta = close.diff()
|
|
gain = delta.where(delta > 0, 0).rolling(14).mean()
|
|
loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
|
|
rs = gain / (loss + 1e-10)
|
|
features['rsi_14'] = 100 - (100 / (1 + rs))
|
|
|
|
# Bollinger Bands
|
|
bb_middle = close.rolling(20).mean()
|
|
bb_std = close.rolling(20).std()
|
|
features['bb_upper'] = bb_middle + 2 * bb_std
|
|
features['bb_lower'] = bb_middle - 2 * bb_std
|
|
features['bb_width'] = (features['bb_upper'] - features['bb_lower']) / bb_middle
|
|
features['bb_position'] = (close - features['bb_lower']) / (features['bb_upper'] - features['bb_lower'])
|
|
|
|
# MACD
|
|
ema_12 = close.ewm(span=12, adjust=False).mean()
|
|
ema_26 = close.ewm(span=26, adjust=False).mean()
|
|
features['macd'] = ema_12 - ema_26
|
|
features['macd_signal'] = features['macd'].ewm(span=9, adjust=False).mean()
|
|
features['macd_hist'] = features['macd'] - features['macd_signal']
|
|
|
|
# Momentum
|
|
features['momentum_5'] = close - close.shift(5)
|
|
features['momentum_10'] = close - close.shift(10)
|
|
features['momentum_20'] = close - close.shift(20)
|
|
|
|
# Volume features (if available)
|
|
if 'volume' in df.columns:
|
|
features['volume_ma_5'] = volume.rolling(5).mean()
|
|
features['volume_ma_20'] = volume.rolling(20).mean()
|
|
features['volume_ratio'] = volume / (features['volume_ma_20'] + 1)
|
|
|
|
# High/Low position
|
|
features['high_5'] = high.rolling(5).max()
|
|
features['low_5'] = low.rolling(5).min()
|
|
features['close_vs_high5'] = (close - features['low_5']) / (features['high_5'] - features['low_5'] + 1e-10)
|
|
|
|
# Candle patterns
|
|
features['body'] = close - df['open']
|
|
features['body_pct'] = features['body'] / (high - low + 1e-10)
|
|
features['upper_shadow'] = high - np.maximum(close, df['open'])
|
|
features['lower_shadow'] = np.minimum(close, df['open']) - low
|
|
|
|
# Trend strength
|
|
features['adx_proxy'] = abs(features['price_vs_sma20']) * features['range_ratio']
|
|
|
|
# Clean up
|
|
features = features.replace([np.inf, -np.inf], np.nan)
|
|
|
|
# Drop columns that are not features (intermediate calculations)
|
|
drop_cols = ['sma_5', 'sma_20', 'sma_50', 'bb_upper', 'bb_lower',
|
|
'high_5', 'low_5', 'volume_ma_5', 'volume_ma_20']
|
|
features = features.drop(columns=[c for c in drop_cols if c in features.columns], errors='ignore')
|
|
|
|
logger.info(f"Generated {len(features.columns)} features")
|
|
|
|
return features
|
|
|
|
|
|
def walk_forward_validation(
|
|
df_ohlcv: pd.DataFrame,
|
|
df_features: pd.DataFrame,
|
|
config: EnhancedRangePredictorConfig,
|
|
n_splits: int = 5,
|
|
test_size_months: int = 2
|
|
) -> dict:
|
|
"""
|
|
Perform walk-forward validation.
|
|
|
|
Returns dict with validation metrics.
|
|
"""
|
|
logger.info(f"Starting walk-forward validation with {n_splits} splits...")
|
|
|
|
results = []
|
|
timestamps = df_ohlcv.index
|
|
|
|
# Calculate split points
|
|
total_days = (timestamps.max() - timestamps.min()).days
|
|
test_days = test_size_months * 30
|
|
train_days = (total_days - test_days * n_splits) // n_splits
|
|
|
|
for i in range(n_splits):
|
|
logger.info(f"\n=== Split {i+1}/{n_splits} ===")
|
|
|
|
# Calculate dates for this split
|
|
test_end = timestamps.max() - timedelta(days=test_days * (n_splits - i - 1))
|
|
test_start = test_end - timedelta(days=test_days)
|
|
train_end = test_start - timedelta(days=1)
|
|
|
|
# Filter data
|
|
train_mask = timestamps <= train_end
|
|
test_mask = (timestamps > test_start) & (timestamps <= test_end)
|
|
|
|
df_train = df_ohlcv[train_mask]
|
|
df_test = df_ohlcv[test_mask]
|
|
feat_train = df_features[train_mask]
|
|
feat_test = df_features[test_mask]
|
|
|
|
if len(df_train) < 1000 or len(df_test) < 100:
|
|
logger.warning(f"Insufficient data for split {i+1}, skipping")
|
|
continue
|
|
|
|
logger.info(f"Train: {len(df_train)} samples, Test: {len(df_test)} samples")
|
|
|
|
# Train predictor
|
|
predictor = EnhancedRangePredictor(config)
|
|
predictor.fit(df_train, feat_train)
|
|
|
|
# Evaluate on test set
|
|
test_predictions = predictor.predict_batch(
|
|
feat_test.dropna().values,
|
|
feat_test.dropna().index
|
|
)
|
|
|
|
# Calculate metrics
|
|
# (In real implementation, compare predictions to actual outcomes)
|
|
split_results = {
|
|
'split': i + 1,
|
|
'train_samples': len(df_train),
|
|
'test_samples': len(df_test),
|
|
'predictions': len(test_predictions),
|
|
'long_signals': (test_predictions['direction'] == 'LONG').sum(),
|
|
'short_signals': (test_predictions['direction'] == 'SHORT').sum(),
|
|
'mean_confidence': test_predictions['confidence'].mean(),
|
|
'mean_rr': test_predictions['rr_best'].mean()
|
|
}
|
|
|
|
results.append(split_results)
|
|
logger.info(f"Split {i+1} results: {json.dumps(split_results, indent=2)}")
|
|
|
|
# Aggregate results
|
|
if results:
|
|
summary = {
|
|
'n_splits': len(results),
|
|
'avg_confidence': np.mean([r['mean_confidence'] for r in results]),
|
|
'avg_rr': np.mean([r['mean_rr'] for r in results]),
|
|
'total_long': sum(r['long_signals'] for r in results),
|
|
'total_short': sum(r['short_signals'] for r in results),
|
|
'splits': results
|
|
}
|
|
else:
|
|
summary = {'error': 'No valid splits'}
|
|
|
|
return summary
|
|
|
|
|
|
def generate_report(
|
|
predictor: EnhancedRangePredictor,
|
|
validation_results: dict,
|
|
output_dir: Path
|
|
) -> Path:
|
|
"""Generate training report."""
|
|
report_path = output_dir / f"training_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md"
|
|
|
|
summary = predictor.get_model_summary()
|
|
|
|
report = f"""# Enhanced Range Predictor Training Report
|
|
|
|
**Generated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
|
|
|
## Configuration
|
|
|
|
- **Symbol:** {summary['config']['symbol']}
|
|
- **Base Factor:** {summary['config']['base_factor']} USD
|
|
- **Input Timeframe:** {summary['config']['input_timeframe']}
|
|
- **Prediction Horizon:** {summary['config']['prediction_horizon_bars']} bars
|
|
|
|
## Training Statistics
|
|
|
|
- **Total Samples:** {summary['training_stats'].get('total_samples', 'N/A')}
|
|
- **Valid Samples:** {summary['training_stats'].get('valid_samples', 'N/A')}
|
|
- **LONG Opportunities:** {summary['training_stats'].get('long_opportunities', 'N/A')}
|
|
- **SHORT Opportunities:** {summary['training_stats'].get('short_opportunities', 'N/A')}
|
|
- **Feature Count:** {summary['feature_count']}
|
|
|
|
## Volatility Metrics
|
|
|
|
- **Normal Variation (median):** {summary['volatility_metrics'].get('normal_variation', 'N/A'):.2f} USD
|
|
- **Strong Movement (P85):** {summary['volatility_metrics'].get('strong_movement', 'N/A'):.2f} USD
|
|
- **Noise Floor (P25):** {summary['volatility_metrics'].get('noise_floor', 'N/A'):.2f} USD
|
|
- **ATR(14):** {summary['volatility_metrics'].get('atr_14', 'N/A'):.2f} USD
|
|
|
|
## Dual Horizon Ensemble
|
|
|
|
- **Long-term Years:** {summary['ensemble_summary'].get('long_term_years', 'N/A')}
|
|
- **Short-term Months:** {summary['ensemble_summary'].get('short_term_months', 'N/A')}
|
|
- **Long-term Weight:** {summary['ensemble_summary'].get('weight_long', 'N/A'):.2f}
|
|
- **Short-term Weight:** {summary['ensemble_summary'].get('weight_short', 'N/A'):.2f}
|
|
|
|
## Walk-Forward Validation
|
|
|
|
- **Number of Splits:** {validation_results.get('n_splits', 'N/A')}
|
|
- **Average Confidence:** {validation_results.get('avg_confidence', 'N/A'):.3f}
|
|
- **Average R:R Ratio:** {validation_results.get('avg_rr', 'N/A'):.2f}
|
|
- **Total LONG Signals:** {validation_results.get('total_long', 'N/A')}
|
|
- **Total SHORT Signals:** {validation_results.get('total_short', 'N/A')}
|
|
|
|
## Feature Importance (Top 20)
|
|
|
|
"""
|
|
|
|
# Add feature importance table
|
|
try:
|
|
importance = predictor.get_feature_importance()
|
|
report += "| Feature | Importance |\n|---------|------------|\n"
|
|
for feat, row in importance.head(20).iterrows():
|
|
report += f"| {feat} | {row.iloc[0]:.4f} |\n"
|
|
except Exception as e:
|
|
report += f"*Error getting feature importance: {e}*\n"
|
|
|
|
report += """
|
|
## Next Steps
|
|
|
|
1. Monitor model performance in paper trading
|
|
2. Retrain short-term model weekly
|
|
3. Adjust weights based on performance
|
|
4. Consider adding more features if needed
|
|
|
|
---
|
|
*Report generated by Enhanced Range Predictor training pipeline*
|
|
"""
|
|
|
|
with open(report_path, 'w') as f:
|
|
f.write(report)
|
|
|
|
logger.info(f"Report saved to {report_path}")
|
|
return report_path
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Train Enhanced Range Predictor')
|
|
parser.add_argument('--symbol', type=str, default='XAUUSD', help='Trading symbol')
|
|
parser.add_argument('--timeframe', type=str, default='15m', help='Input timeframe')
|
|
parser.add_argument('--data-path', type=str, default='data/', help='Path to data directory')
|
|
parser.add_argument('--output-path', type=str, default='models/', help='Path to save model')
|
|
parser.add_argument('--base-factor', type=float, default=5.0, help='Base volatility factor in USD')
|
|
parser.add_argument('--horizon-bars', type=int, default=3, help='Prediction horizon in bars')
|
|
parser.add_argument('--min-rr', type=float, default=2.0, help='Minimum R:R ratio')
|
|
parser.add_argument('--validate', action='store_true', help='Run walk-forward validation')
|
|
parser.add_argument('--n-splits', type=int, default=5, help='Number of validation splits')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Setup paths
|
|
data_path = Path(args.data_path)
|
|
output_path = Path(args.output_path)
|
|
output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Setup logging
|
|
setup_logging(output_path / 'logs', f"train_{args.symbol}_{args.timeframe}")
|
|
|
|
logger.info("=" * 60)
|
|
logger.info("Enhanced Range Predictor Training")
|
|
logger.info("=" * 60)
|
|
logger.info(f"Symbol: {args.symbol}")
|
|
logger.info(f"Timeframe: {args.timeframe}")
|
|
logger.info(f"Base Factor: {args.base_factor} USD")
|
|
logger.info(f"Horizon: {args.horizon_bars} bars")
|
|
|
|
try:
|
|
# Load data
|
|
df_ohlcv = load_data(data_path, args.symbol, args.timeframe)
|
|
|
|
# Generate features
|
|
df_features = generate_features(df_ohlcv)
|
|
|
|
# Drop rows with NaN features
|
|
valid_idx = df_features.dropna().index
|
|
df_ohlcv = df_ohlcv.loc[valid_idx]
|
|
df_features = df_features.loc[valid_idx]
|
|
|
|
logger.info(f"Data after cleaning: {len(df_ohlcv)} samples")
|
|
|
|
# Configure predictor
|
|
config = EnhancedRangePredictorConfig(
|
|
symbol=args.symbol,
|
|
base_factor=args.base_factor,
|
|
input_timeframe=args.timeframe,
|
|
prediction_horizon_bars=args.horizon_bars,
|
|
target_config=CorrectedTargetConfig(
|
|
horizon_bars=args.horizon_bars,
|
|
min_movement_usd=args.base_factor,
|
|
min_rr_ratio=args.min_rr,
|
|
base_factor=args.base_factor
|
|
),
|
|
sample_weight_config=SampleWeightConfig(
|
|
min_movement_threshold=args.base_factor,
|
|
min_rr_ratio=args.min_rr
|
|
),
|
|
dual_horizon_config=DualHorizonConfig(
|
|
long_term_years=5.0,
|
|
short_term_months=3.0
|
|
)
|
|
)
|
|
|
|
# Run validation if requested
|
|
validation_results = {}
|
|
if args.validate:
|
|
validation_results = walk_forward_validation(
|
|
df_ohlcv, df_features, config,
|
|
n_splits=args.n_splits
|
|
)
|
|
|
|
# Train final model on all data
|
|
logger.info("\n" + "=" * 60)
|
|
logger.info("Training final model on all data...")
|
|
logger.info("=" * 60)
|
|
|
|
predictor = EnhancedRangePredictor(config)
|
|
predictor.fit(df_ohlcv, df_features)
|
|
|
|
# Save model
|
|
model_path = output_path / f"{args.symbol}_{args.timeframe}_enhanced"
|
|
predictor.save(str(model_path))
|
|
logger.info(f"Model saved to {model_path}")
|
|
|
|
# Generate report
|
|
report_path = generate_report(predictor, validation_results, output_path)
|
|
|
|
logger.info("\n" + "=" * 60)
|
|
logger.info("Training Complete!")
|
|
logger.info("=" * 60)
|
|
logger.info(f"Model: {model_path}")
|
|
logger.info(f"Report: {report_path}")
|
|
|
|
# Print summary
|
|
summary = predictor.get_model_summary()
|
|
logger.info(f"\nModel Summary:")
|
|
logger.info(f" Valid samples: {summary['training_stats']['valid_samples']}")
|
|
logger.info(f" LONG opportunities: {summary['training_stats']['long_opportunities']}")
|
|
logger.info(f" SHORT opportunities: {summary['training_stats']['short_opportunities']}")
|
|
logger.info(f" Features: {summary['feature_count']}")
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Training failed: {e}")
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|