trading-platform-ml-engine-v2/scripts/train_metamodels.py
rckrdmrd 75c4d07690 feat: Initial commit - ML Engine codebase
Hierarchical ML Pipeline for trading predictions:
- Level 0: Attention Models (volatility/flow classification)
- Level 1: Base Models (XGBoost per symbol/timeframe)
- Level 2: Metamodels (XGBoost Stacking + Neural Gating)

Key components:
- src/pipelines/hierarchical_pipeline.py - Main prediction pipeline
- src/models/ - All ML model classes
- src/training/ - Training utilities
- src/api/ - FastAPI endpoints
- scripts/ - Training and evaluation scripts
- config/ - YAML configurations

Note: Trained models (*.joblib, *.pt) are gitignored.
      Regenerate with training scripts.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 04:27:40 -06:00

583 lines
20 KiB
Python

#!/usr/bin/env python3
"""
Train Metamodels Script
=======================
CLI script to train Asset Metamodels (Nivel 2 of hierarchical architecture).
This script orchestrates:
1. Loading pre-trained Attention Models (Nivel 0)
2. Loading pre-trained Base Models (Nivel 1)
3. Generating OOS predictions
4. Training Metamodels per asset
Usage:
# Train metamodels for all assets
python train_metamodels.py
# Train for specific symbols
python train_metamodels.py --symbols XAUUSD EURUSD
# Specify model paths
python train_metamodels.py \
--attention-path models/attention \
--base-path models/base \
--output-path models/metamodels
# Custom OOS period
python train_metamodels.py \
--oos-start 2024-01-01 \
--oos-end 2024-08-31
Author: ML Pipeline
Version: 1.0.0
Created: 2026-01-07
"""
import sys
import argparse
from pathlib import Path
from datetime import datetime
import pandas as pd
import numpy as np
from loguru import logger
# Add parent directories to path
script_dir = Path(__file__).parent
project_dir = script_dir.parent
sys.path.insert(0, str(project_dir / 'src'))
def setup_logging(log_dir: Path, symbol: str = 'all'):
"""Configure logging to file and console."""
log_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = log_dir / f'metamodel_training_{symbol}_{timestamp}.log'
logger.remove()
logger.add(sys.stderr, level="INFO",
format="<green>{time:HH:mm:ss}</green> | <level>{level}</level> | {message}")
logger.add(log_file, level="DEBUG",
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}")
return log_file
def load_ohlcv_from_mysql(symbol: str, timeframe: str, config: dict) -> pd.DataFrame:
"""
Load OHLCV data from MySQL database.
Loads raw OHLCV data from tickers_agg_data table so we can generate
features fresh using the same logic as base model training.
"""
try:
# Try to use the project's database module
from data.database import MySQLConnection
db = MySQLConnection()
# Normalize symbol name for database (same as train_attention_model.py)
db_symbol = symbol
if not symbol.startswith('C:') and not symbol.startswith('X:'):
if symbol == 'BTCUSD':
db_symbol = f'X:{symbol}'
else:
db_symbol = f'C:{symbol}'
logger.info(f"Loading OHLCV data for {db_symbol} ({timeframe})...")
# Load raw OHLCV data from tickers_agg_data
query = f"""
SELECT
date_agg as timestamp,
open,
high,
low,
close,
volume
FROM tickers_agg_data
WHERE ticker = '{db_symbol}'
ORDER BY date_agg ASC
"""
df = pd.read_sql(query, db.engine)
if len(df) == 0:
logger.warning(f"No data found for {symbol} {timeframe}")
return pd.DataFrame()
logger.info(f"Loaded {len(df)} rows for {symbol}")
# Set timestamp as index
df['timestamp'] = pd.to_datetime(df['timestamp'])
df.set_index('timestamp', inplace=True)
# Resample to requested timeframe
agg_dict = {
'open': 'first',
'high': 'max',
'low': 'min',
'close': 'last',
'volume': 'sum'
}
if timeframe == '5m':
df = df.resample('5min').agg(agg_dict).dropna()
elif timeframe == '15m':
df = df.resample('15min').agg(agg_dict).dropna()
logger.info(f"Resampled to {timeframe}: {len(df)} bars")
return df
except Exception as e:
logger.error(f"Failed to load data from MySQL: {e}")
import traceback
traceback.print_exc()
return pd.DataFrame()
def load_ohlcv_from_parquet(data_dir: Path, symbol: str, timeframe: str) -> pd.DataFrame:
"""Load OHLCV data from Parquet files."""
# Try different file patterns
patterns = [
f"{symbol}_{timeframe}.parquet",
f"{symbol.lower()}_{timeframe}.parquet",
f"{symbol}/{timeframe}.parquet",
f"{symbol.lower()}/{timeframe}.parquet"
]
for pattern in patterns:
file_path = data_dir / pattern
if file_path.exists():
df = pd.read_parquet(file_path)
# Ensure datetime index
if 'timestamp' in df.columns:
df['timestamp'] = pd.to_datetime(df['timestamp'])
df.set_index('timestamp', inplace=True)
elif not isinstance(df.index, pd.DatetimeIndex):
logger.warning(f"No datetime index in {file_path}")
# Standardize column names
col_map = {
'open': 'Open', 'high': 'High', 'low': 'Low',
'close': 'Close', 'volume': 'Volume'
}
df.rename(columns={k: v for k, v in col_map.items() if k in df.columns}, inplace=True)
logger.info(f"Loaded {len(df)} rows from {file_path}")
return df
logger.warning(f"No parquet file found for {symbol} {timeframe}")
return pd.DataFrame()
def generate_features(df: pd.DataFrame, symbol: str = '') -> pd.DataFrame:
"""
Generate comprehensive feature set for training.
This function generates the EXACT same 50 features used by symbol_timeframe_trainer,
ensuring compatibility with the base models.
Args:
df: OHLCV DataFrame with columns: Open, High, Low, Close, Volume (or lowercase)
symbol: Symbol for context-specific features (unused but kept for compatibility)
Returns:
DataFrame with all features (OHLCV + 50 generated features)
"""
if len(df) == 0:
return df
df = df.copy()
# Normalize column names to lowercase
col_map = {'Open': 'open', 'High': 'high', 'Low': 'low', 'Close': 'close', 'Volume': 'volume'}
df.rename(columns={k: v for k, v in col_map.items() if k in df.columns}, inplace=True)
features = pd.DataFrame(index=df.index)
close = df['close']
high = df['high']
low = df['low']
open_price = df['open']
volume = df['volume'] if 'volume' in df.columns else pd.Series(1, index=df.index)
# ===== Price Returns (5 features) =====
features['returns_1'] = close.pct_change(1)
features['returns_3'] = close.pct_change(3)
features['returns_5'] = close.pct_change(5)
features['returns_10'] = close.pct_change(10)
features['returns_20'] = close.pct_change(20)
# ===== Volatility Features (3 features) =====
features['volatility_5'] = close.pct_change().rolling(5).std()
features['volatility_10'] = close.pct_change().rolling(10).std()
features['volatility_20'] = close.pct_change().rolling(20).std()
# ===== Range Features (7 features) =====
candle_range = high - low
features['range'] = candle_range
features['range_pct'] = candle_range / close
features['range_ma_5'] = candle_range.rolling(5).mean()
features['range_ma_10'] = candle_range.rolling(10).mean()
features['range_ma_20'] = candle_range.rolling(20).mean()
features['range_ratio_5'] = candle_range / features['range_ma_5']
features['range_ratio_20'] = candle_range / features['range_ma_20']
# ===== ATR Features (4 features) =====
tr1 = high - low
tr2 = abs(high - close.shift(1))
tr3 = abs(low - close.shift(1))
true_range = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
features['atr_5'] = true_range.rolling(5).mean()
features['atr_14'] = true_range.rolling(14).mean()
features['atr_20'] = true_range.rolling(20).mean()
features['atr_ratio'] = true_range / features['atr_14']
# ===== Moving Averages (6 features) =====
sma_5 = close.rolling(5).mean()
sma_10 = close.rolling(10).mean()
sma_20 = close.rolling(20).mean()
sma_50 = close.rolling(50).mean()
ema_5 = close.ewm(span=5, adjust=False).mean()
ema_20 = close.ewm(span=20, adjust=False).mean()
features['price_vs_sma5'] = (close - sma_5) / features['atr_14']
features['price_vs_sma10'] = (close - sma_10) / features['atr_14']
features['price_vs_sma20'] = (close - sma_20) / features['atr_14']
features['price_vs_sma50'] = (close - sma_50) / features['atr_14']
features['sma5_vs_sma20'] = (sma_5 - sma_20) / features['atr_14']
features['ema5_vs_ema20'] = (ema_5 - ema_20) / features['atr_14']
# ===== RSI (3 features) =====
delta = close.diff()
gain = delta.where(delta > 0, 0).rolling(14).mean()
loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
rs = gain / (loss + 1e-10)
features['rsi_14'] = 100 - (100 / (1 + rs))
features['rsi_oversold'] = (features['rsi_14'] < 30).astype(float)
features['rsi_overbought'] = (features['rsi_14'] > 70).astype(float)
# ===== Bollinger Bands (2 features) =====
bb_middle = close.rolling(20).mean()
bb_std = close.rolling(20).std()
bb_upper = bb_middle + 2 * bb_std
bb_lower = bb_middle - 2 * bb_std
features['bb_width'] = (bb_upper - bb_lower) / bb_middle
features['bb_position'] = (close - bb_lower) / (bb_upper - bb_lower + 1e-10)
# ===== MACD (3 features) =====
ema_12 = close.ewm(span=12, adjust=False).mean()
ema_26 = close.ewm(span=26, adjust=False).mean()
macd = ema_12 - ema_26
macd_signal = macd.ewm(span=9, adjust=False).mean()
features['macd'] = macd / features['atr_14']
features['macd_signal'] = macd_signal / features['atr_14']
features['macd_hist'] = (macd - macd_signal) / features['atr_14']
# ===== Momentum (3 features) =====
features['momentum_5'] = (close - close.shift(5)) / features['atr_14']
features['momentum_10'] = (close - close.shift(10)) / features['atr_14']
features['momentum_20'] = (close - close.shift(20)) / features['atr_14']
# ===== Stochastic (2 features) =====
low_14 = low.rolling(14).min()
high_14 = high.rolling(14).max()
features['stoch_k'] = 100 * (close - low_14) / (high_14 - low_14 + 1e-10)
features['stoch_d'] = features['stoch_k'].rolling(3).mean()
# ===== Williams %R (1 feature) =====
features['williams_r'] = -100 * (high_14 - close) / (high_14 - low_14 + 1e-10)
# ===== Volume Features (2 features) =====
if volume.sum() > 0:
vol_ma_5 = volume.rolling(5).mean()
vol_ma_20 = volume.rolling(20).mean()
features['volume_ratio'] = volume / (vol_ma_20 + 1)
features['volume_trend'] = (vol_ma_5 - vol_ma_20) / (vol_ma_20 + 1)
else:
features['volume_ratio'] = 1.0
features['volume_trend'] = 0.0
# ===== Candle Patterns (3 features) =====
body = close - open_price
features['body_pct'] = body / (candle_range + 1e-10)
features['upper_shadow'] = (high - np.maximum(close, open_price)) / (candle_range + 1e-10)
features['lower_shadow'] = (np.minimum(close, open_price) - low) / (candle_range + 1e-10)
# ===== Price Position (3 features) =====
features['close_position'] = (close - low) / (candle_range + 1e-10)
high_5 = high.rolling(5).max()
low_5 = low.rolling(5).min()
features['price_position_5'] = (close - low_5) / (high_5 - low_5 + 1e-10)
high_20 = high.rolling(20).max()
low_20 = low.rolling(20).min()
features['price_position_20'] = (close - low_20) / (high_20 - low_20 + 1e-10)
# ===== Time Features (7 features) =====
hour = df.index.hour
day_of_week = df.index.dayofweek
features['hour_sin'] = np.sin(2 * np.pi * hour / 24)
features['hour_cos'] = np.cos(2 * np.pi * hour / 24)
features['dow_sin'] = np.sin(2 * np.pi * day_of_week / 7)
features['dow_cos'] = np.cos(2 * np.pi * day_of_week / 7)
# Trading sessions
features['is_london'] = ((hour >= 8) & (hour < 16)).astype(float)
features['is_newyork'] = ((hour >= 13) & (hour < 21)).astype(float)
features['is_overlap'] = ((hour >= 13) & (hour < 16)).astype(float)
# Clean up
features = features.replace([np.inf, -np.inf], np.nan)
# Combine with OHLCV
result = pd.concat([df[['open', 'high', 'low', 'close', 'volume']], features], axis=1)
logger.info(f"Generated {len(features.columns)} features (total columns: {len(result.columns)})")
return result
def main():
parser = argparse.ArgumentParser(
description='Train Asset Metamodels (Nivel 2)',
formatter_class=argparse.RawDescriptionHelpFormatter
)
# Symbol configuration
parser.add_argument('--symbols', nargs='+',
default=['XAUUSD', 'EURUSD'],
help='Symbols to train (default: XAUUSD EURUSD)')
# Path configuration
parser.add_argument('--attention-path', type=str,
default='models/attention',
help='Path to attention models (default: models/attention)')
parser.add_argument('--base-path', type=str,
default='models/base',
help='Path to base models (default: models/base)')
parser.add_argument('--output-path', type=str,
default='models/metamodels',
help='Output path for metamodels (default: models/metamodels)')
# Data source
parser.add_argument('--data-source', type=str,
choices=['mysql', 'parquet'],
default='mysql',
help='Data source type (default: mysql)')
parser.add_argument('--data-dir', type=str,
default='data/',
help='Data directory for parquet files')
# MySQL configuration
parser.add_argument('--db-host', type=str, default='localhost')
parser.add_argument('--db-user', type=str, default='root')
parser.add_argument('--db-password', type=str, default='')
parser.add_argument('--db-name', type=str, default='trading')
# OOS period
parser.add_argument('--oos-start', type=str,
default='2024-01-01',
help='OOS period start date (default: 2024-01-01)')
parser.add_argument('--oos-end', type=str,
default='2024-08-31',
help='OOS period end date (default: 2024-08-31)')
# Training parameters
parser.add_argument('--min-samples', type=int, default=2000,
help='Minimum OOS samples required (default: 2000)')
parser.add_argument('--val-split', type=float, default=0.15,
help='Validation split ratio (default: 0.15)')
# Output options
parser.add_argument('--log-dir', type=str,
default='models/logs',
help='Log directory')
parser.add_argument('--generate-report', action='store_true',
help='Generate markdown training report')
args = parser.parse_args()
# Setup logging
log_file = setup_logging(Path(args.log_dir), 'metamodels')
logger.info(f"Log file: {log_file}")
logger.info("="*60)
logger.info("METAMODEL TRAINING SCRIPT")
logger.info("="*60)
logger.info(f"Symbols: {args.symbols}")
logger.info(f"OOS Period: {args.oos_start} to {args.oos_end}")
logger.info(f"Attention models: {args.attention_path}")
logger.info(f"Base models: {args.base_path}")
logger.info(f"Output: {args.output_path}")
# Import trainer
from training.metamodel_trainer import MetamodelTrainer, MetamodelTrainerConfig
# Create config
config = MetamodelTrainerConfig(
symbols=args.symbols,
timeframes=['5m', '15m'],
attention_model_path=args.attention_path,
base_model_path=args.base_path,
output_path=args.output_path,
oos_start_date=args.oos_start,
oos_end_date=args.oos_end,
min_oos_samples=args.min_samples,
val_split=args.val_split
)
# Create trainer
trainer = MetamodelTrainer(config)
# Load pre-trained models
logger.info("\n" + "="*60)
logger.info("Loading pre-trained models...")
logger.info("="*60)
models_loaded = trainer.load_models()
if not models_loaded:
logger.warning("Some models failed to load, training may be incomplete")
# Load data
logger.info("\n" + "="*60)
logger.info("Loading OHLCV data...")
logger.info("="*60)
data_dict = {}
db_config = {
'host': args.db_host,
'user': args.db_user,
'password': args.db_password,
'database': args.db_name
}
for symbol in args.symbols:
data_dict[symbol] = {}
for timeframe in ['5m', '15m']:
if args.data_source == 'mysql':
df = load_ohlcv_from_mysql(symbol, timeframe, db_config)
else:
df = load_ohlcv_from_parquet(Path(args.data_dir), symbol, timeframe)
if len(df) > 0:
# Generate features (same as base model training)
df = generate_features(df, symbol)
data_dict[symbol][timeframe] = df
logger.info(f" {symbol} {timeframe}: {len(df)} rows, {df.shape[1]} columns")
else:
logger.warning(f" {symbol} {timeframe}: No data loaded")
# Train metamodels
logger.info("\n" + "="*60)
logger.info("Training Metamodels...")
logger.info("="*60)
results = trainer.train_all(data_dict)
# Print summary
logger.info("\n" + "="*60)
logger.info("TRAINING SUMMARY")
logger.info("="*60)
summary = trainer.get_training_summary()
if len(summary) > 0:
print("\n" + summary.to_string(index=False))
# Save models
logger.info("\n" + "="*60)
logger.info("Saving models...")
logger.info("="*60)
trainer.save()
# Generate report
if args.generate_report:
report_path = Path(args.output_path) / f'training_report_{datetime.now().strftime("%Y%m%d_%H%M%S")}.md'
report_content = f"""# Metamodel Training Report
**Generated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
## Configuration
- **Symbols:** {', '.join(args.symbols)}
- **OOS Period:** {args.oos_start} to {args.oos_end}
- **Attention Models:** {args.attention_path}
- **Base Models:** {args.base_path}
- **Output:** {args.output_path}
## Training Results
| Symbol | Status | Samples | MAE High | MAE Low | R2 High | R2 Low | Confidence Acc | Improvement |
|--------|--------|---------|----------|---------|---------|--------|----------------|-------------|
"""
for _, row in summary.iterrows():
if row['status'] == 'success':
report_content += f"| {row['symbol']} | {row['status']} | {row.get('n_samples', 'N/A')} | "
report_content += f"{row.get('mae_high', 'N/A'):.4f} | {row.get('mae_low', 'N/A'):.4f} | "
report_content += f"{row.get('r2_high', 'N/A'):.4f} | {row.get('r2_low', 'N/A'):.4f} | "
report_content += f"{row.get('confidence_accuracy', 'N/A'):.2%} | "
report_content += f"{row.get('improvement_over_avg', 'N/A'):.1f}% |\n"
else:
report_content += f"| {row['symbol']} | {row['status']} | - | - | - | - | - | - | - |\n"
report_content += f"""
## Architecture
```
Nivel 2: Metamodel (per asset)
├── Input Features (10):
│ ├── pred_high_5m, pred_low_5m
│ ├── pred_high_15m, pred_low_15m
│ ├── attention_5m, attention_15m
│ ├── attention_class_5m, attention_class_15m
│ └── ATR_ratio, volume_z
├── Models:
│ ├── XGBoost Regressor (HIGH)
│ ├── XGBoost Regressor (LOW)
│ └── XGBoost Classifier (Confidence)
└── Outputs:
├── delta_high_final
├── delta_low_final
└── confidence (binary + probability)
```
## Log File
`{log_file}`
"""
with open(report_path, 'w') as f:
f.write(report_content)
logger.info(f"Report saved to: {report_path}")
logger.info("\n" + "="*60)
logger.info("TRAINING COMPLETE")
logger.info("="*60)
# Return exit code based on results
success_count = sum(1 for r in results.values() if r.get('status') == 'success')
total_count = len(results)
if success_count == total_count:
logger.info(f"All {total_count} metamodels trained successfully")
return 0
elif success_count > 0:
logger.warning(f"{success_count}/{total_count} metamodels trained successfully")
return 0
else:
logger.error("No metamodels were trained successfully")
return 1
if __name__ == "__main__":
sys.exit(main())