trading-platform-ml-engine-v2/scripts/train_symbol_timeframe_models.py
rckrdmrd 75c4d07690 feat: Initial commit - ML Engine codebase
Hierarchical ML Pipeline for trading predictions:
- Level 0: Attention Models (volatility/flow classification)
- Level 1: Base Models (XGBoost per symbol/timeframe)
- Level 2: Metamodels (XGBoost Stacking + Neural Gating)

Key components:
- src/pipelines/hierarchical_pipeline.py - Main prediction pipeline
- src/models/ - All ML model classes
- src/training/ - Training utilities
- src/api/ - FastAPI endpoints
- scripts/ - Training and evaluation scripts
- config/ - YAML configurations

Note: Trained models (*.joblib, *.pt) are gitignored.
      Regenerate with training scripts.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 04:27:40 -06:00

626 lines
21 KiB
Python

#!/usr/bin/env python3
"""
Symbol-Timeframe Training Script
================================
Trains separate ML models for each symbol and timeframe combination.
This script uses the SymbolTimeframeTrainer to train models for:
- XAUUSD (Gold)
- EURUSD (Euro/USD)
- BTCUSD (Bitcoin) - if data available
Each symbol is trained for both 5m and 15m timeframes.
Features:
- Loads data from MySQL database
- Excludes last year (2025) for backtesting
- Uses dynamic factor-based sample weighting
- Generates comprehensive feature set
- Saves models and training reports
Usage:
python scripts/train_symbol_timeframe_models.py
python scripts/train_symbol_timeframe_models.py --symbols XAUUSD EURUSD --timeframes 5m 15m
Author: ML Training Pipeline
Version: 1.0.0
Created: 2026-01-05
"""
import argparse
import sys
from pathlib import Path
from datetime import datetime, timedelta
import json
import numpy as np
import pandas as pd
from loguru import logger
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
from training.symbol_timeframe_trainer import (
SymbolTimeframeTrainer,
TrainerConfig,
SYMBOL_CONFIGS
)
from data.database import MySQLConnection
def setup_logging(log_dir: Path, experiment_name: str):
"""Configure logging to file and console."""
log_dir.mkdir(parents=True, exist_ok=True)
log_file = log_dir / f"{experiment_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
logger.remove()
logger.add(sys.stderr, level="INFO", format="{time:HH:mm:ss} | {level} | {message}")
logger.add(log_file, level="DEBUG", rotation="10 MB")
logger.info(f"Logging to {log_file}")
return log_file
def load_data_from_db(
db: MySQLConnection,
symbol: str,
start_date: str = None,
end_date: str = None,
limit: int = None
) -> pd.DataFrame:
"""
Load OHLCV data from MySQL database.
Args:
db: MySQL connection
symbol: Trading symbol (e.g., 'XAUUSD')
start_date: Start date filter (YYYY-MM-DD)
end_date: End date filter (YYYY-MM-DD)
limit: Maximum records to fetch
Returns:
DataFrame with OHLCV data
"""
# Normalize symbol name
db_symbol = symbol
if not symbol.startswith('C:') and not symbol.startswith('X:'):
if symbol == 'BTCUSD':
db_symbol = f'X:{symbol}'
else:
db_symbol = f'C:{symbol}'
logger.info(f"Loading data for {db_symbol}...")
query = """
SELECT
date_agg as time,
open,
high,
low,
close,
volume,
vwap
FROM tickers_agg_data
WHERE ticker = :symbol
"""
params = {'symbol': db_symbol}
if start_date:
query += " AND date_agg >= :start_date"
params['start_date'] = start_date
if end_date:
query += " AND date_agg <= :end_date"
params['end_date'] = end_date
query += " ORDER BY date_agg ASC"
if limit:
query += f" LIMIT {limit}"
df = db.execute_query(query, params)
if df.empty:
logger.warning(f"No data found for {symbol}")
return df
# Set datetime index
df['time'] = pd.to_datetime(df['time'])
df.set_index('time', inplace=True)
df = df.sort_index()
# Rename columns to match expected format
df.columns = ['open', 'high', 'low', 'close', 'volume', 'vwap']
logger.info(f"Loaded {len(df)} records for {symbol}")
logger.info(f" Date range: {df.index.min()} to {df.index.max()}")
return df
def resample_to_timeframe(df: pd.DataFrame, timeframe: str) -> pd.DataFrame:
"""
Resample 5-minute data to different timeframe.
Args:
df: DataFrame with 5m data
timeframe: Target timeframe ('5m', '15m', '1H', etc.)
Returns:
Resampled DataFrame
"""
if timeframe == '5m':
return df # Already in 5m
# Map timeframe to pandas offset
tf_map = {
'15m': '15min',
'30m': '30min',
'1H': '1H',
'4H': '4H',
'1D': '1D'
}
offset = tf_map.get(timeframe, timeframe)
resampled = df.resample(offset).agg({
'open': 'first',
'high': 'max',
'low': 'min',
'close': 'last',
'volume': 'sum',
'vwap': 'mean'
}).dropna()
logger.info(f"Resampled to {timeframe}: {len(resampled)} bars")
return resampled
def generate_features(df: pd.DataFrame, symbol: str) -> pd.DataFrame:
"""
Generate comprehensive feature set for training.
Args:
df: OHLCV DataFrame
symbol: Symbol for context-specific features
Returns:
DataFrame with features
"""
logger.info(f"Generating features for {symbol}...")
features = pd.DataFrame(index=df.index)
close = df['close']
high = df['high']
low = df['low']
open_price = df['open']
volume = df['volume'] if 'volume' in df.columns else pd.Series(1, index=df.index)
# ===== Price Returns =====
features['returns_1'] = close.pct_change(1)
features['returns_3'] = close.pct_change(3)
features['returns_5'] = close.pct_change(5)
features['returns_10'] = close.pct_change(10)
features['returns_20'] = close.pct_change(20)
# ===== Volatility Features =====
features['volatility_5'] = close.pct_change().rolling(5).std()
features['volatility_10'] = close.pct_change().rolling(10).std()
features['volatility_20'] = close.pct_change().rolling(20).std()
# ===== Range Features =====
candle_range = high - low
features['range'] = candle_range
features['range_pct'] = candle_range / close
features['range_ma_5'] = candle_range.rolling(5).mean()
features['range_ma_10'] = candle_range.rolling(10).mean()
features['range_ma_20'] = candle_range.rolling(20).mean()
features['range_ratio_5'] = candle_range / features['range_ma_5']
features['range_ratio_20'] = candle_range / features['range_ma_20']
# ===== ATR Features =====
tr1 = high - low
tr2 = abs(high - close.shift(1))
tr3 = abs(low - close.shift(1))
true_range = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
features['atr_5'] = true_range.rolling(5).mean()
features['atr_14'] = true_range.rolling(14).mean()
features['atr_20'] = true_range.rolling(20).mean()
features['atr_ratio'] = true_range / features['atr_14']
# ===== Moving Averages =====
sma_5 = close.rolling(5).mean()
sma_10 = close.rolling(10).mean()
sma_20 = close.rolling(20).mean()
sma_50 = close.rolling(50).mean()
ema_5 = close.ewm(span=5, adjust=False).mean()
ema_10 = close.ewm(span=10, adjust=False).mean()
ema_20 = close.ewm(span=20, adjust=False).mean()
features['price_vs_sma5'] = (close - sma_5) / features['atr_14']
features['price_vs_sma10'] = (close - sma_10) / features['atr_14']
features['price_vs_sma20'] = (close - sma_20) / features['atr_14']
features['price_vs_sma50'] = (close - sma_50) / features['atr_14']
features['sma5_vs_sma20'] = (sma_5 - sma_20) / features['atr_14']
features['ema5_vs_ema20'] = (ema_5 - ema_20) / features['atr_14']
# ===== RSI =====
delta = close.diff()
gain = delta.where(delta > 0, 0).rolling(14).mean()
loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
rs = gain / (loss + 1e-10)
features['rsi_14'] = 100 - (100 / (1 + rs))
# RSI extremes
features['rsi_oversold'] = (features['rsi_14'] < 30).astype(float)
features['rsi_overbought'] = (features['rsi_14'] > 70).astype(float)
# ===== Bollinger Bands =====
bb_middle = close.rolling(20).mean()
bb_std = close.rolling(20).std()
bb_upper = bb_middle + 2 * bb_std
bb_lower = bb_middle - 2 * bb_std
features['bb_width'] = (bb_upper - bb_lower) / bb_middle
features['bb_position'] = (close - bb_lower) / (bb_upper - bb_lower + 1e-10)
# ===== MACD =====
ema_12 = close.ewm(span=12, adjust=False).mean()
ema_26 = close.ewm(span=26, adjust=False).mean()
macd = ema_12 - ema_26
macd_signal = macd.ewm(span=9, adjust=False).mean()
features['macd'] = macd / features['atr_14']
features['macd_signal'] = macd_signal / features['atr_14']
features['macd_hist'] = (macd - macd_signal) / features['atr_14']
# ===== Momentum =====
features['momentum_5'] = (close - close.shift(5)) / features['atr_14']
features['momentum_10'] = (close - close.shift(10)) / features['atr_14']
features['momentum_20'] = (close - close.shift(20)) / features['atr_14']
# ===== Stochastic =====
low_14 = low.rolling(14).min()
high_14 = high.rolling(14).max()
features['stoch_k'] = 100 * (close - low_14) / (high_14 - low_14 + 1e-10)
features['stoch_d'] = features['stoch_k'].rolling(3).mean()
# ===== Williams %R =====
features['williams_r'] = -100 * (high_14 - close) / (high_14 - low_14 + 1e-10)
# ===== Volume Features =====
if volume.sum() > 0:
vol_ma_5 = volume.rolling(5).mean()
vol_ma_20 = volume.rolling(20).mean()
features['volume_ratio'] = volume / (vol_ma_20 + 1)
features['volume_trend'] = (vol_ma_5 - vol_ma_20) / (vol_ma_20 + 1)
# ===== Candle Patterns =====
body = close - open_price
features['body_pct'] = body / (candle_range + 1e-10)
features['upper_shadow'] = (high - np.maximum(close, open_price)) / (candle_range + 1e-10)
features['lower_shadow'] = (np.minimum(close, open_price) - low) / (candle_range + 1e-10)
# ===== Price Position =====
features['close_position'] = (close - low) / (candle_range + 1e-10)
high_5 = high.rolling(5).max()
low_5 = low.rolling(5).min()
features['price_position_5'] = (close - low_5) / (high_5 - low_5 + 1e-10)
high_20 = high.rolling(20).max()
low_20 = low.rolling(20).min()
features['price_position_20'] = (close - low_20) / (high_20 - low_20 + 1e-10)
# ===== Time Features =====
features['hour'] = df.index.hour
features['hour_sin'] = np.sin(2 * np.pi * features['hour'] / 24)
features['hour_cos'] = np.cos(2 * np.pi * features['hour'] / 24)
features['day_of_week'] = df.index.dayofweek
features['dow_sin'] = np.sin(2 * np.pi * features['day_of_week'] / 7)
features['dow_cos'] = np.cos(2 * np.pi * features['day_of_week'] / 7)
# Trading sessions
features['is_london'] = ((features['hour'] >= 8) & (features['hour'] < 16)).astype(float)
features['is_newyork'] = ((features['hour'] >= 13) & (features['hour'] < 21)).astype(float)
features['is_overlap'] = ((features['hour'] >= 13) & (features['hour'] < 16)).astype(float)
# Clean up
features = features.replace([np.inf, -np.inf], np.nan)
# Drop non-feature columns used for intermediate calculations
drop_cols = ['hour', 'day_of_week']
features = features.drop(columns=[c for c in drop_cols if c in features.columns], errors='ignore')
logger.info(f"Generated {len(features.columns)} features")
return features
def train_models(
symbols: list,
timeframes: list,
output_dir: Path,
cutoff_date: str = '2024-12-31',
db_config_path: str = 'config/database.yaml',
use_attention: bool = False,
attention_model_path: str = 'models/attention'
) -> dict:
"""
Train models for all symbol/timeframe combinations.
Args:
symbols: List of symbols to train
timeframes: List of timeframes
output_dir: Directory to save models
cutoff_date: Training data cutoff date
db_config_path: Path to database config
use_attention: Whether to include attention features from pre-trained model
attention_model_path: Path to trained attention models
Returns:
Dictionary with training results
"""
logger.info("="*60)
logger.info("Symbol-Timeframe Model Training")
logger.info("="*60)
logger.info(f"Symbols: {symbols}")
logger.info(f"Timeframes: {timeframes}")
logger.info(f"Cutoff date: {cutoff_date}")
logger.info(f"Use attention features: {use_attention}")
# Connect to database
db = MySQLConnection(db_config_path)
# Configure trainer with improved parameters for better R^2
# Key improvements:
# 1. Targets are now normalized by ATR (handled in SymbolTimeframeTrainer)
# 2. Reduced sample weighting aggressiveness
# 3. More regularization in XGBoost
config = TrainerConfig(
symbols=symbols,
timeframes=timeframes,
horizons={
'5m': 3, # 15 minutes ahead
'15m': 3, # 45 minutes ahead
},
train_years=5.0,
holdout_years=1.0, # Exclude 2025 for backtesting
use_dynamic_factor_weighting=True,
factor_window=200,
softplus_beta=2.0, # Reduced from 4.0 - less aggressive weighting
softplus_w_max=2.0, # Reduced from 3.0 - cap weights lower
xgb_params={
'n_estimators': 150, # Reduced from 300
'max_depth': 4, # Reduced from 6 - shallower trees
'learning_rate': 0.02, # Reduced from 0.03
'subsample': 0.7, # Reduced from 0.8
'colsample_bytree': 0.7, # Reduced from 0.8
'min_child_weight': 20, # Increased from 10 - more regularization
'gamma': 0.3, # Increased from 0.1
'reg_alpha': 0.5, # Increased from 0.1 - L1 regularization
'reg_lambda': 5.0, # Increased from 1.0 - L2 regularization
'tree_method': 'hist',
'random_state': 42
},
min_train_samples=5000,
use_attention_features=use_attention,
attention_model_path=attention_model_path
)
trainer = SymbolTimeframeTrainer(config)
# Prepare data dictionary
data_dict = {}
all_results = {}
for symbol in symbols:
logger.info(f"\n{'='*60}")
logger.info(f"Processing {symbol}")
logger.info(f"{'='*60}")
# Load raw data (5m)
df_5m = load_data_from_db(db, symbol, end_date=cutoff_date)
if df_5m.empty:
logger.warning(f"No data for {symbol}, skipping...")
continue
# Verify we have enough data
if len(df_5m) < 50000:
logger.warning(f"Insufficient data for {symbol}: {len(df_5m)} rows")
continue
data_dict[symbol] = {}
for timeframe in timeframes:
logger.info(f"\n--- {symbol} {timeframe} ---")
# Resample if needed
if timeframe == '5m':
df_tf = df_5m.copy()
else:
df_tf = resample_to_timeframe(df_5m.copy(), timeframe)
if len(df_tf) < 10000:
logger.warning(f"Insufficient {timeframe} data: {len(df_tf)} rows")
continue
# Generate features
features = generate_features(df_tf, symbol)
# Combine OHLCV with features
df_combined = pd.concat([df_tf[['open', 'high', 'low', 'close', 'volume']], features], axis=1)
# Drop NaN rows
df_combined = df_combined.dropna()
logger.info(f"Final data shape: {df_combined.shape}")
data_dict[symbol][timeframe] = df_combined
# Train for this symbol/timeframe
try:
results = trainer.train_single(df_combined, symbol, timeframe)
all_results.update(results)
for key, result in results.items():
logger.info(f"\n{key}:")
logger.info(f" MAE: {result.mae:.6f}")
logger.info(f" RMSE: {result.rmse:.6f}")
logger.info(f" R2: {result.r2:.4f}")
logger.info(f" Dir Accuracy: {result.directional_accuracy:.2%}")
logger.info(f" Train samples: {result.n_train}")
logger.info(f" Val samples: {result.n_val}")
except Exception as e:
logger.error(f"Training failed for {symbol} {timeframe}: {e}")
import traceback
traceback.print_exc()
# Save models
model_dir = output_dir / 'symbol_timeframe_models'
trainer.save(str(model_dir))
logger.info(f"\nModels saved to {model_dir}")
# Generate summary report
summary_df = trainer.get_training_summary()
if not summary_df.empty:
report_path = output_dir / 'training_summary.csv'
summary_df.to_csv(report_path, index=False)
logger.info(f"Summary saved to {report_path}")
logger.info("\n" + "="*60)
logger.info("TRAINING SUMMARY")
logger.info("="*60)
print(summary_df.to_string(index=False))
return {
'results': all_results,
'summary': summary_df.to_dict() if not summary_df.empty else {},
'model_dir': str(model_dir)
}
def generate_markdown_report(results: dict, output_dir: Path) -> Path:
"""Generate a Markdown training report."""
report_path = output_dir / f"TRAINING_REPORT_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md"
report = f"""# Symbol-Timeframe Model Training Report
**Generated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
## Configuration
- **Training Data Cutoff:** 2024-12-31 (excluding 2025 for backtesting)
- **Dynamic Factor Weighting:** Enabled
- **Sample Weight Method:** Softplus with beta=4.0, w_max=3.0
## Training Results Summary
| Model | Symbol | Timeframe | Target | MAE | RMSE | R2 | Dir Accuracy | Train | Val |
|-------|--------|-----------|--------|-----|------|----|--------------| ----- | --- |
"""
for key, result in results.get('results', {}).items():
report += f"| {key} | {result.model_key.symbol} | {result.model_key.timeframe} | "
report += f"{result.model_key.target_type} | {result.mae:.6f} | {result.rmse:.6f} | "
report += f"{result.r2:.4f} | {result.directional_accuracy:.2%} | "
report += f"{result.n_train} | {result.n_val} |\n"
model_dir_str = results.get('model_dir', 'N/A')
report += f"""
## Model Files
Models saved to: `{model_dir_str}`
### Model Naming Convention
- `{{symbol}}_{{timeframe}}_high_h{{horizon}}.joblib` - High range predictor
- `{{symbol}}_{{timeframe}}_low_h{{horizon}}.joblib` - Low range predictor
## Usage Example
```python
from training.symbol_timeframe_trainer import SymbolTimeframeTrainer
# Load trained models
trainer = SymbolTimeframeTrainer()
trainer.load('models/symbol_timeframe_models/')
# Predict for XAUUSD 15m
predictions = trainer.predict(features, 'XAUUSD', '15m')
print(f"Predicted High: {{predictions['high']}}")
print(f"Predicted Low: {{predictions['low']}}")
```
## Notes
1. Models exclude 2025 data for out-of-sample backtesting
2. Dynamic factor weighting emphasizes high-movement samples
3. Separate models for HIGH and LOW predictions per symbol/timeframe
---
*Report generated by Symbol-Timeframe Training Pipeline*
"""
with open(report_path, 'w') as f:
f.write(report)
logger.info(f"Report saved to {report_path}")
return report_path
def main():
parser = argparse.ArgumentParser(description='Train Symbol-Timeframe ML Models')
parser.add_argument('--symbols', nargs='+', default=['XAUUSD', 'EURUSD'],
help='Symbols to train (default: XAUUSD EURUSD)')
parser.add_argument('--timeframes', nargs='+', default=['5m', '15m'],
help='Timeframes to train (default: 5m 15m)')
parser.add_argument('--output-dir', type=str, default='models/',
help='Output directory for models')
parser.add_argument('--cutoff-date', type=str, default='2024-12-31',
help='Training data cutoff date')
parser.add_argument('--db-config', type=str, default='config/database.yaml',
help='Database configuration file')
parser.add_argument('--use-attention', action='store_true',
help='Use attention features from pre-trained attention model')
parser.add_argument('--attention-model-path', type=str, default='models/attention',
help='Path to trained attention models')
args = parser.parse_args()
# Setup paths
script_dir = Path(__file__).parent.parent
output_dir = script_dir / args.output_dir
output_dir.mkdir(parents=True, exist_ok=True)
logs_dir = output_dir / 'logs'
setup_logging(logs_dir, 'symbol_timeframe_training')
# Run training
try:
results = train_models(
symbols=args.symbols,
timeframes=args.timeframes,
output_dir=output_dir,
cutoff_date=args.cutoff_date,
db_config_path=str(script_dir / args.db_config),
use_attention=args.use_attention,
attention_model_path=str(script_dir / args.attention_model_path)
)
# Generate report
generate_markdown_report(results, output_dir)
logger.info("\n" + "="*60)
logger.info("TRAINING COMPLETE!")
logger.info("="*60)
except Exception as e:
logger.exception(f"Training failed: {e}")
sys.exit(1)
if __name__ == "__main__":
main()