trading-platform-ml-engine-v2/scripts/train_attention_model.py
rckrdmrd 75c4d07690 feat: Initial commit - ML Engine codebase
Hierarchical ML Pipeline for trading predictions:
- Level 0: Attention Models (volatility/flow classification)
- Level 1: Base Models (XGBoost per symbol/timeframe)
- Level 2: Metamodels (XGBoost Stacking + Neural Gating)

Key components:
- src/pipelines/hierarchical_pipeline.py - Main prediction pipeline
- src/models/ - All ML model classes
- src/training/ - Training utilities
- src/api/ - FastAPI endpoints
- scripts/ - Training and evaluation scripts
- config/ - YAML configurations

Note: Trained models (*.joblib, *.pt) are gitignored.
      Regenerate with training scripts.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 04:27:40 -06:00

617 lines
18 KiB
Python

#!/usr/bin/env python3
"""
Attention Model Training Script
================================
Trains attention score models for identifying high-flow market moments.
This script:
1. Loads OHLCV data from MySQL database
2. Trains attention models for all symbols and timeframes
3. Generates both regression (0-3+) and classification (low/medium/high) outputs
4. Saves models to models/attention/
5. Generates comprehensive training report
Features learned:
- volume_ratio, volume_z (volume activity)
- ATR, ATR_ratio (volatility)
- CMF, MFI, OBV_delta (money flow)
- BB_width, displacement (price structure)
Usage:
python scripts/train_attention_model.py
python scripts/train_attention_model.py --symbols XAUUSD EURUSD
python scripts/train_attention_model.py --cutoff-date 2024-03-01
Author: ML Pipeline
Version: 1.0.0
Created: 2026-01-06
"""
import argparse
import sys
from pathlib import Path
from datetime import datetime, timedelta
import json
import os
# Setup path BEFORE any other imports
_SCRIPT_DIR = Path(__file__).parent.parent.absolute()
os.chdir(_SCRIPT_DIR)
sys.path.insert(0, str(_SCRIPT_DIR / 'src'))
import numpy as np
import pandas as pd
from loguru import logger
import importlib.util
# Load modules directly to avoid circular imports in models/__init__.py
def _load_module_direct(module_name: str, file_path: Path):
"""Load a module directly from file without going through __init__.py"""
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
# Load attention modules with CONSISTENT names (important for joblib pickle)
_src_dir = _SCRIPT_DIR / 'src'
# First load the attention_score_model with a stable name
_attention_model_module = _load_module_direct(
"models.attention_score_model",
_src_dir / 'models' / 'attention_score_model.py'
)
# Now load the trainer
_attention_trainer_module = _load_module_direct(
"training.attention_trainer",
_src_dir / 'training' / 'attention_trainer.py'
)
AttentionModelTrainer = _attention_trainer_module.AttentionModelTrainer
AttentionTrainerConfig = _attention_trainer_module.AttentionTrainerConfig
generate_attention_training_report = _attention_trainer_module.generate_attention_training_report
AttentionModelConfig = _attention_model_module.AttentionModelConfig
# Load database module normally (it doesn't have circular imports)
sys.path.insert(0, str(_src_dir))
from data.database import MySQLConnection
def setup_logging(log_dir: Path, experiment_name: str) -> Path:
"""Configure logging to file and console."""
log_dir.mkdir(parents=True, exist_ok=True)
log_file = log_dir / f"{experiment_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
logger.remove()
logger.add(sys.stderr, level="INFO", format="{time:HH:mm:ss} | {level} | {message}")
logger.add(log_file, level="DEBUG", rotation="10 MB")
logger.info(f"Logging to {log_file}")
return log_file
def load_data_from_db(
db: MySQLConnection,
symbol: str,
start_date: str = None,
end_date: str = None,
limit: int = None
) -> pd.DataFrame:
"""
Load OHLCV data from MySQL database.
Args:
db: MySQL connection
symbol: Trading symbol (e.g., 'XAUUSD')
start_date: Start date filter (YYYY-MM-DD)
end_date: End date filter (YYYY-MM-DD)
limit: Maximum records to fetch
Returns:
DataFrame with OHLCV data
"""
# Normalize symbol name for database
db_symbol = symbol
if not symbol.startswith('C:') and not symbol.startswith('X:'):
if symbol == 'BTCUSD':
db_symbol = f'X:{symbol}'
else:
db_symbol = f'C:{symbol}'
logger.info(f"Loading data for {db_symbol}...")
query = """
SELECT
date_agg as time,
open,
high,
low,
close,
volume,
vwap
FROM tickers_agg_data
WHERE ticker = :symbol
"""
params = {'symbol': db_symbol}
if start_date:
query += " AND date_agg >= :start_date"
params['start_date'] = start_date
if end_date:
query += " AND date_agg <= :end_date"
params['end_date'] = end_date
query += " ORDER BY date_agg ASC"
if limit:
query += f" LIMIT {limit}"
df = db.execute_query(query, params)
if df.empty:
logger.warning(f"No data found for {symbol}")
return df
# Set datetime index
df['time'] = pd.to_datetime(df['time'])
df.set_index('time', inplace=True)
df = df.sort_index()
# Rename columns to standard format
df.columns = ['open', 'high', 'low', 'close', 'volume', 'vwap']
logger.info(f"Loaded {len(df)} records for {symbol}")
logger.info(f" Date range: {df.index.min()} to {df.index.max()}")
return df
def resample_to_timeframe(df: pd.DataFrame, timeframe: str) -> pd.DataFrame:
"""
Resample 5-minute data to different timeframe.
Args:
df: DataFrame with 5m data
timeframe: Target timeframe ('5m', '15m', '1H', etc.)
Returns:
Resampled DataFrame
"""
if timeframe == '5m':
return df # Already in 5m
# Map timeframe to pandas offset
tf_map = {
'15m': '15min',
'30m': '30min',
'1H': '1H',
'4H': '4H',
'1D': '1D'
}
offset = tf_map.get(timeframe, timeframe)
resampled = df.resample(offset).agg({
'open': 'first',
'high': 'max',
'low': 'min',
'close': 'last',
'volume': 'sum',
'vwap': 'mean'
}).dropna()
logger.info(f"Resampled to {timeframe}: {len(resampled)} bars")
return resampled
def train_attention_models(
symbols: list,
timeframes: list,
output_dir: Path,
cutoff_date: str = '2024-12-31',
train_years: float = 5.0,
holdout_years: float = 1.0,
db_config_path: str = 'config/database.yaml'
) -> dict:
"""
Train attention models for all symbol/timeframe combinations.
Args:
symbols: List of symbols to train
timeframes: List of timeframes
output_dir: Directory to save models
cutoff_date: Training data cutoff date
train_years: Years of training data
holdout_years: Years reserved for holdout validation
db_config_path: Path to database config
Returns:
Dictionary with training results
"""
logger.info("="*60)
logger.info("Attention Model Training")
logger.info("="*60)
logger.info(f"Symbols: {symbols}")
logger.info(f"Timeframes: {timeframes}")
logger.info(f"Cutoff date: {cutoff_date}")
logger.info(f"Training years: {train_years}")
logger.info(f"Holdout years: {holdout_years}")
# Connect to database
db = MySQLConnection(db_config_path)
# Configure attention model
model_config = AttentionModelConfig(
factor_window=200,
horizon_bars=3,
low_flow_threshold=1.0,
high_flow_threshold=2.0,
reg_params={
'n_estimators': 200,
'max_depth': 5,
'learning_rate': 0.05,
'subsample': 0.8,
'colsample_bytree': 0.8,
'min_child_weight': 10,
'gamma': 0.1,
'reg_alpha': 0.1,
'reg_lambda': 1.0,
'tree_method': 'hist',
'random_state': 42
},
clf_params={
'n_estimators': 150,
'max_depth': 4,
'learning_rate': 0.05,
'subsample': 0.8,
'colsample_bytree': 0.8,
'min_child_weight': 15,
'gamma': 0.2,
'reg_alpha': 0.1,
'reg_lambda': 1.0,
'tree_method': 'hist',
'random_state': 42,
'objective': 'multi:softmax',
'num_class': 3
},
min_train_samples=5000
)
# Configure trainer
trainer_config = AttentionTrainerConfig(
symbols=symbols,
timeframes=timeframes,
train_years=train_years,
holdout_years=holdout_years,
model_config=model_config,
output_dir=str(output_dir / 'attention')
)
trainer = AttentionModelTrainer(trainer_config)
# Prepare data dictionary
data_dict = {}
for symbol in symbols:
logger.info(f"\n{'='*60}")
logger.info(f"Loading data for {symbol}")
logger.info(f"{'='*60}")
# Load raw data (5m)
df_5m = load_data_from_db(db, symbol, end_date=cutoff_date)
if df_5m.empty:
logger.warning(f"No data for {symbol}, skipping...")
continue
# Verify we have enough data
if len(df_5m) < 50000:
logger.warning(f"Insufficient data for {symbol}: {len(df_5m)} rows (need 50000+)")
continue
data_dict[symbol] = {}
for timeframe in timeframes:
logger.info(f"\n--- Preparing {symbol} {timeframe} ---")
# Resample if needed
if timeframe == '5m':
df_tf = df_5m.copy()
else:
df_tf = resample_to_timeframe(df_5m.copy(), timeframe)
if len(df_tf) < 10000:
logger.warning(f"Insufficient {timeframe} data: {len(df_tf)} rows (need 10000+)")
continue
logger.info(f"Data shape: {df_tf.shape}")
logger.info(f"Date range: {df_tf.index.min()} to {df_tf.index.max()}")
data_dict[symbol][timeframe] = df_tf
# Train all models
logger.info("\n" + "="*60)
logger.info("Starting model training")
logger.info("="*60)
all_results = trainer.train_all(data_dict)
# Save models
model_dir = output_dir / 'attention'
trainer.save(str(model_dir))
logger.info(f"\nModels saved to {model_dir}")
# Generate training summary
summary_df = trainer.get_training_summary()
if not summary_df.empty:
summary_path = output_dir / 'attention_training_summary.csv'
summary_df.to_csv(summary_path, index=False)
logger.info(f"Summary saved to {summary_path}")
logger.info("\n" + "="*60)
logger.info("TRAINING SUMMARY")
logger.info("="*60)
print(summary_df.to_string(index=False))
return {
'results': all_results,
'summary': summary_df.to_dict() if not summary_df.empty else {},
'model_dir': str(model_dir),
'trainer': trainer
}
def generate_markdown_report(
trainer: AttentionModelTrainer,
output_dir: Path,
symbols: list,
timeframes: list,
cutoff_date: str
) -> Path:
"""Generate detailed Markdown training report."""
report_path = output_dir / f"ATTENTION_TRAINING_REPORT_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md"
report = f"""# Attention Score Model Training Report
**Generated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
## Overview
The attention model learns to identify high-flow market moments using volume, volatility, and money flow indicators - WITHOUT hardcoding specific trading hours or sessions.
## Configuration
- **Symbols:** {', '.join(symbols)}
- **Timeframes:** {', '.join(timeframes)}
- **Training Data Cutoff:** {cutoff_date}
- **Training Years:** {trainer.config.train_years}
- **Holdout Years:** {trainer.config.holdout_years}
### Model Parameters
| Parameter | Value |
|-----------|-------|
| Factor Window | {trainer.config.model_config.factor_window} |
| Horizon Bars | {trainer.config.model_config.horizon_bars} |
| Low Flow Threshold | {trainer.config.model_config.low_flow_threshold} |
| High Flow Threshold | {trainer.config.model_config.high_flow_threshold} |
### Features Used (9 total)
| Feature | Description |
|---------|-------------|
| volume_ratio | volume / rolling_median(volume, 20) |
| volume_z | z-score of volume over 20 periods |
| ATR | Average True Range (14 periods) |
| ATR_ratio | ATR / rolling_median(ATR, 50) |
| CMF | Chaikin Money Flow (20 periods) |
| MFI | Money Flow Index (14 periods) |
| OBV_delta | diff(OBV) / rolling_std(OBV, 20) |
| BB_width | (BB_upper - BB_lower) / close |
| displacement | (close - open) / ATR |
## Training Results
| Model | Symbol | TF | Reg MAE | Reg R2 | Clf Acc | Clf F1 | N Train | High Flow % |
|-------|--------|-----|---------|--------|---------|--------|---------|-------------|
"""
for key, result in trainer.results.items():
total_samples = sum(result.class_distribution.values())
high_pct = result.class_distribution.get('high_flow', 0) / max(total_samples, 1) * 100
report += f"| {key} | {result.symbol} | {result.timeframe} | "
report += f"{result.reg_mae:.4f} | {result.reg_r2:.4f} | "
report += f"{result.clf_accuracy:.2%} | {result.clf_f1:.2%} | "
report += f"{result.n_train} | {high_pct:.1f}% |\n"
report += """
## Class Distribution (Holdout Set)
| Model | Low Flow | Medium Flow | High Flow |
|-------|----------|-------------|-----------|
"""
for key, result in trainer.results.items():
low = result.class_distribution.get('low_flow', 0)
med = result.class_distribution.get('medium_flow', 0)
high = result.class_distribution.get('high_flow', 0)
total = max(low + med + high, 1)
report += f"| {key} | {low} ({low/total*100:.1f}%) | {med} ({med/total*100:.1f}%) | {high} ({high/total*100:.1f}%) |\n"
report += """
## Feature Importance
"""
for key, result in trainer.results.items():
report += f"### {key}\n\n"
report += "| Rank | Feature | Combined Importance |\n|------|---------|--------------------|\n"
sorted_features = sorted(result.feature_importance.items(), key=lambda x: -x[1])
for rank, (feat, imp) in enumerate(sorted_features, 1):
report += f"| {rank} | {feat} | {imp:.4f} |\n"
report += "\n"
report += f"""
## Interpretation
### Attention Score (Regression)
- **< 1.0**: Low flow period - below average market movement expected
- **1.0 - 2.0**: Medium flow period - average market conditions
- **> 2.0**: High flow period - above average movement expected (best trading opportunities)
### Flow Class (Classification)
- **0 (low_flow)**: move_multiplier < 1.0
- **1 (medium_flow)**: 1.0 <= move_multiplier < 2.0
- **2 (high_flow)**: move_multiplier >= 2.0
## Trading Recommendations
1. **Filter by attention_score**: Only trade when attention_score > 1.0
2. **Adjust position sizing**: Increase size when attention_score > 2.0
3. **Combine with base models**: Use attention_score as feature #51 in prediction models
4. **Time-agnostic**: The model identifies flow without hardcoded sessions
## Usage Example
```python
from training.attention_trainer import AttentionModelTrainer
# Load trained models
trainer = AttentionModelTrainer.load('models/attention/')
# Get attention score for new OHLCV data
attention = trainer.get_attention_score(df_ohlcv, 'XAUUSD', '5m')
# Filter trades
mask_trade = attention > 1.0 # Only trade in medium/high flow
# Or use as feature in base models
df['attention_score'] = attention
```
## Files Generated
- `models/attention/{{symbol}}_{{timeframe}}_attention/` - Model directories
- `models/attention/trainer_metadata.joblib` - Trainer configuration
- `models/attention/training_summary.csv` - Summary metrics
---
*Report generated by Attention Model Training Pipeline*
"""
with open(report_path, 'w') as f:
f.write(report)
logger.info(f"Report saved to {report_path}")
return report_path
def main():
parser = argparse.ArgumentParser(description='Train Attention Score Models')
parser.add_argument(
'--symbols',
nargs='+',
default=['XAUUSD', 'EURUSD', 'BTCUSD', 'GBPUSD', 'USDJPY'],
help='Symbols to train (default: XAUUSD EURUSD BTCUSD GBPUSD USDJPY)'
)
parser.add_argument(
'--timeframes',
nargs='+',
default=['5m', '15m'],
help='Timeframes to train (default: 5m 15m)'
)
parser.add_argument(
'--output-dir',
type=str,
default='models/',
help='Output directory for models (default: models/)'
)
parser.add_argument(
'--cutoff-date',
type=str,
default='2024-12-31',
help='Training data cutoff date (default: 2024-12-31)'
)
parser.add_argument(
'--train-years',
type=float,
default=5.0,
help='Years of training data (default: 5.0)'
)
parser.add_argument(
'--holdout-years',
type=float,
default=1.0,
help='Years for holdout validation (default: 1.0)'
)
parser.add_argument(
'--db-config',
type=str,
default='config/database.yaml',
help='Database configuration file'
)
args = parser.parse_args()
# Setup paths
script_dir = Path(__file__).parent.parent
output_dir = script_dir / args.output_dir
output_dir.mkdir(parents=True, exist_ok=True)
logs_dir = output_dir / 'logs'
setup_logging(logs_dir, 'attention_model_training')
# Run training
try:
results = train_attention_models(
symbols=args.symbols,
timeframes=args.timeframes,
output_dir=output_dir,
cutoff_date=args.cutoff_date,
train_years=args.train_years,
holdout_years=args.holdout_years,
db_config_path=str(script_dir / args.db_config)
)
# Generate detailed report
if results.get('trainer'):
generate_markdown_report(
results['trainer'],
output_dir,
args.symbols,
args.timeframes,
args.cutoff_date
)
logger.info("\n" + "="*60)
logger.info("ATTENTION MODEL TRAINING COMPLETE!")
logger.info("="*60)
logger.info(f"Models saved to: {results.get('model_dir', 'N/A')}")
logger.info(f"Total models trained: {len(results.get('results', {}))}")
# Print quick summary
if results.get('results'):
logger.info("\nQuick Summary:")
for key, result in results['results'].items():
high_pct = result.class_distribution.get('high_flow', 0) / max(sum(result.class_distribution.values()), 1) * 100
logger.info(f" {key}: R2={result.reg_r2:.3f}, Clf Acc={result.clf_accuracy:.1%}, High Flow={high_pct:.1f}%")
except Exception as e:
logger.exception(f"Training failed: {e}")
sys.exit(1)
if __name__ == "__main__":
main()